From 4440a735d7116637e8d81f4898a047608073bc29 Mon Sep 17 00:00:00 2001 From: ydajiang Date: Sat, 17 May 2025 22:56:50 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8sqlite=E6=9B=BF?= =?UTF-8?q?=E6=8D=A2redis=E4=BD=9C=E4=B8=BA=E6=8C=81=E4=B9=85=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + api.go | 102 +++--- broadcast.go | 25 +- broadcast_dialogs.go | 39 --- broadcast_manager.go | 77 ----- client.go | 2 +- client_benchmark_test.go | 2 +- config.json | 5 - dao_channel.go | 77 +++++ dao_device.go | 129 +++++++ dao_platform.go | 150 ++++++++ dao_sink.go | 157 +++++++++ dao_stream.go | 135 ++++++++ db.go | 68 ---- db_redis.go | 722 --------------------------------------- db_sqlite.go | 140 ++++++++ device.go | 65 ++-- device_manager.go | 54 --- dialogs.go | 77 ++--- go.mod | 15 +- live.go | 27 +- live_benchmark_test.go | 2 +- main.go | 32 +- online_devices.go | 60 ++++ platform.go | 23 +- platform_manager.go | 28 +- position.go | 4 +- recover.go | 155 +++------ sink.go | 42 ++- sink_manager.go | 221 +----------- sip_client.go | 24 +- sip_handler.go | 120 ++----- sip_server.go | 60 ++-- stream.go | 104 +++--- stream_manager.go | 125 ------- xml.go | 60 ++-- 36 files changed, 1304 insertions(+), 1827 deletions(-) delete mode 100644 broadcast_dialogs.go delete mode 100644 broadcast_manager.go create mode 100644 dao_channel.go create mode 100644 dao_device.go create mode 100644 dao_platform.go create mode 100644 dao_sink.go create mode 100644 dao_stream.go delete mode 100644 db.go delete mode 100644 db_redis.go create mode 100644 db_sqlite.go delete mode 100644 device_manager.go create mode 100644 online_devices.go delete mode 100644 stream_manager.go diff --git a/.gitignore b/.gitignore index 8afcb3c..6f18fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ vendor/ logs/ .idea/ +*.db +*.db-shm +*.db-wal diff --git a/api.go b/api.go index 97ac3da..0e1b261 100644 --- a/api.go +++ b/api.go @@ -208,9 +208,8 @@ func (api *ApiServer) OnPlay(params *StreamParams, w http.ResponseWriter, r *htt } // 已经存在,累加计数 - if stream := StreamManager.Find(params.Stream); stream != nil { + if stream, _ := StreamDao.QueryStream(params.Stream); stream != nil { stream.IncreaseSinkCount() - return } deviceId := sourceStream[0] @@ -256,14 +255,14 @@ func (api *ApiServer) OnPlay(params *StreamParams, w http.ResponseWriter, r *htt func (api *ApiServer) OnPlayDone(params *PlayDoneParams, w http.ResponseWriter, r *http.Request) { Sugar.Infof("播放结束事件. protocol: %s stream: %s", params.Protocol, params.Stream) - //stream := StreamManager.Find(params.Stream) + //stream := StreamManager.Find(params.StreamID) //if stream == nil { - // Sugar.Errorf("处理播放结束事件失败, stream不存在. id: %s", params.Stream) + // Sugar.Errorf("处理播放结束事件失败, stream不存在. id: %s", params.StreamID) // return //} //if 0 == stream.DecreaseSinkCount() && Config.AutoCloseOnIdle { - // CloseStream(params.Stream, true) + // CloseStream(params.StreamID, true) //} if !strings.HasPrefix(params.Protocol, "gb") { @@ -280,7 +279,7 @@ func (api *ApiServer) OnPlayDone(params *PlayDoneParams, w http.ResponseWriter, if params.Protocol == "gb_cascaded_forward" { if platform := PlatformManager.Find(sink.ServerAddr); platform != nil { callID, _ := sink.Dialog.CallID() - platform.CloseStream(callID.String(), true, false) + platform.CloseStream(callID.Value(), true, false) } } else if params.Protocol == "gb_talk_forward" { // 对讲设备断开连接 @@ -292,9 +291,9 @@ func (api *ApiServer) OnPlayDone(params *PlayDoneParams, w http.ResponseWriter, func (api *ApiServer) OnPublish(params *StreamParams, w http.ResponseWriter, r *http.Request) { Sugar.Infof("推流事件. protocol: %s stream: %s", params.Protocol, params.Stream) - stream := StreamManager.Find(params.Stream) + stream := Dialogs.Find(string(params.Stream)) if stream != nil { - stream.onPublishCb <- 200 + stream.Put(200) } // 对讲websocket已连接 @@ -303,21 +302,16 @@ func (api *ApiServer) OnPublish(params *StreamParams, w http.ResponseWriter, r * Sugar.Infof("对讲websocket已连接, stream: %s", params.Stream) s := &Stream{ - ID: params.Stream, - Protocol: params.Protocol, - CreateTime: time.Now().Unix(), + StreamID: params.Stream, + Protocol: params.Protocol, } - _, ok := StreamManager.Add(s) + _, ok := StreamDao.SaveStream(s) if !ok { Sugar.Errorf("处理推流事件失败, stream已存在. id: %s", params.Stream) w.WriteHeader(http.StatusBadRequest) return } - - if DB != nil { - go DB.SaveStream(s) - } } } @@ -382,7 +376,7 @@ func (api *ApiServer) OnInvite(v *InviteParams, w http.ResponseWriter, r *http.R Stream string `json:"stream_id"` Urls []string `json:"urls"` }{ - string(stream.ID), + string(stream.StreamID), stream.urls, } httpResponseOK(w, response) @@ -392,8 +386,8 @@ func (api *ApiServer) OnInvite(v *InviteParams, w http.ResponseWriter, r *http.R // DoInvite 发起Invite请求 // @params sync 是否异步等待流媒体的publish事件(确认收到流), 目前请求流分两种方式,流媒体hook和http接口, hook方式同步等待确认收到流再应答, http接口直接应答成功。 func (api *ApiServer) DoInvite(inviteType InviteType, params *InviteParams, sync bool, w http.ResponseWriter, r *http.Request) (int, *Stream, error) { - device := DeviceManager.Find(params.DeviceID) - if device == nil { + device, _ := DeviceDao.QueryDevice(params.DeviceID) + if device == nil || !device.Online() { return http.StatusNotFound, nil, fmt.Errorf("设备离线 id: %s", params.DeviceID) } @@ -422,7 +416,7 @@ func (api *ApiServer) DoInvite(inviteType InviteType, params *InviteParams, sync // 解析回放或下载速度参数 speed, _ := strconv.Atoi(params.Speed) speed = int(math.Min(4, float64(speed))) - stream, err := device.(*Device).StartStream(inviteType, params.streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync) + stream, err := device.StartStream(inviteType, params.streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync) if err != nil { return http.StatusInternalServerError, nil, err } @@ -431,12 +425,12 @@ func (api *ApiServer) DoInvite(inviteType InviteType, params *InviteParams, sync } func (api *ApiServer) OnCloseStream(v *StreamIDParams, w http.ResponseWriter, r *http.Request) { - stream := StreamManager.Find(v.StreamID) - - // 等空闲或收流超时会自动关闭 - if stream != nil && stream.GetSinkCount() < 1 { - CloseStream(v.StreamID, true) - } + //stream := StreamManager.Find(v.StreamID) + // + //// 等空闲或收流超时会自动关闭 + //if stream != nil && stream.GetSinkCount() < 1 { + // CloseStream(v.StreamID, true) + //} httpResponseOK(w, nil) } @@ -459,7 +453,7 @@ func (api *ApiServer) OnDeviceList(v *PageQuery, w http.ResponseWriter, r *http. v.PageSize = &defaultPageSize } - devices, total, err := DB.QueryDevices(*v.PageNumber, *v.PageSize) + devices, total, err := DeviceDao.QueryDevices(*v.PageNumber, *v.PageSize) if err != nil { Sugar.Errorf("查询设备列表失败 err: %s", err.Error()) return nil, err @@ -489,7 +483,7 @@ func (api *ApiServer) OnChannelList(v *PageQueryChannel, w http.ResponseWriter, v.PageSize = &defaultPageSize } - channels, total, err := DB.QueryChannels(v.DeviceID, *v.PageNumber, *v.PageSize) + channels, total, err := ChannelDao.QueryChannels(v.DeviceID, *v.PageNumber, *v.PageSize) if err != nil { Sugar.Errorf("查询通道列表失败 err: %s", err.Error()) return nil, err @@ -509,8 +503,8 @@ func (api *ApiServer) OnChannelList(v *PageQueryChannel, w http.ResponseWriter, func (api *ApiServer) OnRecordList(v *QueryRecordParams, w http.ResponseWriter, r *http.Request) (interface{}, error) { Sugar.Infof("查询录像列表 %v", *v) - device := DeviceManager.Find(v.DeviceID) - if device == nil { + device, _ := DeviceDao.QueryDevice(v.DeviceID) + if device == nil || !device.Online() { Sugar.Errorf("查询录像列表失败, 设备离线 device: %s", v.DeviceID) return nil, fmt.Errorf("设备离线") } @@ -551,8 +545,8 @@ func (api *ApiServer) OnRecordList(v *QueryRecordParams, w http.ResponseWriter, func (api *ApiServer) OnSubscribePosition(v *DeviceChannelID, w http.ResponseWriter, r *http.Request) (interface{}, error) { Sugar.Infof("订阅位置 %v", *v) - device := DeviceManager.Find(v.DeviceID) - if device == nil { + device, _ := DeviceDao.QueryDevice(v.DeviceID) + if device == nil || !device.Online() { Sugar.Errorf("订阅位置失败, 设备离线 device: %s", v.DeviceID) return nil, fmt.Errorf("设备离线") } @@ -568,7 +562,7 @@ func (api *ApiServer) OnSubscribePosition(v *DeviceChannelID, w http.ResponseWri func (api *ApiServer) OnSeekPlayback(v *SeekParams, w http.ResponseWriter, r *http.Request) (interface{}, error) { Sugar.Infof("快进回放 %v", *v) - stream := StreamManager.Find(v.StreamId) + stream, _ := StreamDao.QueryStream(v.StreamId) if stream == nil || stream.Dialog == nil { Sugar.Infof("快进回放失败 stream不存在 %s", v.StreamId) return nil, fmt.Errorf("stream不存在") @@ -593,7 +587,7 @@ func (api *ApiServer) OnHangup(v *BroadcastParams, w http.ResponseWriter, r *htt Sugar.Infof("广播挂断 %v", *v) id := GenerateStreamID(InviteTypeBroadcast, v.DeviceID, v.ChannelID, "", "") - if sink := RemoveForwardSinkWithSinkStreamId(id); sink != nil { + if sink := RemoveForwardSinkWithSinkStreamID(id); sink != nil { sink.Close(true, true) } @@ -610,23 +604,23 @@ func (api *ApiServer) OnBroadcast(v *BroadcastParams, w http.ResponseWriter, r * defer func() { if !ok { if InviteSourceId != "" { - BroadcastDialogs.Remove(InviteSourceId) + Dialogs.Remove(InviteSourceId) } if sinkStreamId != "" { - SinkManager.RemoveWithSinkStreamId(sinkStreamId) + _, _ = SinkDao.DeleteForwardSinkBySinkStreamID(sinkStreamId) } } }() - device := DeviceManager.Find(v.DeviceID) - if device == nil { + device, _ := DeviceDao.QueryDevice(v.DeviceID) + if device == nil || !device.Online() { Sugar.Errorf("广播失败, 设备离线, DeviceID: %s", v.DeviceID) return nil, fmt.Errorf("设备离线") } // 主讲人id - source := StreamManager.Find(v.StreamId) + source, _ := StreamDao.QueryStream(v.StreamId) if source == nil { Sugar.Errorf("广播失败, 找不到主讲人, stream: %s", v.StreamId) return nil, fmt.Errorf("找不到主讲人") @@ -645,17 +639,18 @@ func (api *ApiServer) OnBroadcast(v *BroadcastParams, w http.ResponseWriter, r * } sink := &Sink{ - Stream: v.StreamId, - SinkStream: sinkStreamId, - Protocol: "gb_talk_forward", - CreateTime: time.Now().Unix(), - SetupType: setupType, + StreamID: v.StreamId, + SinkStreamID: sinkStreamId, + Protocol: "gb_talk_forward", + CreateTime: time.Now().Unix(), + SetupType: setupType, } - if ok = SinkManager.AddWithSinkStreamId(sink); !ok { + streamWaiting := &StreamWaiting{data: sink} + if err := SinkDao.SaveForwardSink(v.StreamId, sink); err != nil { Sugar.Errorf("广播失败, 设备正在广播中. stream: %s", sinkStreamId) return nil, fmt.Errorf("设备正在广播中") - } else if _, ok = BroadcastDialogs.Add(InviteSourceId, sink); !ok { + } else if _, ok = Dialogs.Add(InviteSourceId, streamWaiting); !ok { Sugar.Errorf("广播失败, id冲突. id: %s", InviteSourceId) return nil, fmt.Errorf("id冲突") } @@ -678,7 +673,7 @@ func (api *ApiServer) OnBroadcast(v *BroadcastParams, w http.ResponseWriter, r * } // 等待下级设备的Invite请求 - code := sink.WaitForPublishEvent(10) + code := streamWaiting.Receive(10) if code == -1 { Sugar.Errorf("广播失败, 等待invite超时. stream: %s", sinkStreamId) return nil, fmt.Errorf("等待invite超时") @@ -699,15 +694,21 @@ func (api *ApiServer) OnBroadcast(v *BroadcastParams, w http.ResponseWriter, r * } func (api *ApiServer) OnTalk(w http.ResponseWriter, r *http.Request) { + } func (api *ApiServer) OnStarted(w http.ResponseWriter, req *http.Request) { Sugar.Infof("lkm启动") - streams := StreamManager.PopAll() + streams, _ := StreamDao.DeleteStreams() for _, stream := range streams { stream.Close(true, false) } + + sinks, _ := SinkDao.DeleteForwardSinks() + for _, sink := range sinks { + sink.Close(true, false) + } } func (api *ApiServer) OnPlatformAdd(v *SIPUAParams, w http.ResponseWriter, r *http.Request) (interface{}, error) { @@ -728,7 +729,6 @@ func (api *ApiServer) OnPlatformAdd(v *SIPUAParams, w http.ResponseWriter, r *ht return nil, err } - v.CreateTime = strconv.FormatInt(time.Now().UnixMilli(), 10) v.Status = "OFF" platform, err := NewGBPlatform(v, SipUA) @@ -772,7 +772,7 @@ func (api *ApiServer) OnPlatformChannelBind(v *PlatformChannel, w http.ResponseW } // 级联功能,通道号必须唯一 - channels, err := DB.BindChannels(v.ServerAddr, v.Channels) + channels, err := PlatformDao.BindChannels(v.ServerAddr, v.Channels) if err != nil { Sugar.Errorf("绑定通道失败 err: %s", err.Error()) return nil, err @@ -790,7 +790,7 @@ func (api *ApiServer) OnPlatformChannelUnbind(v *PlatformChannel, w http.Respons return nil, fmt.Errorf("not found platform") } - channels, err := DB.UnbindChannels(v.ServerAddr, v.Channels) + channels, err := PlatformDao.UnbindChannels(v.ServerAddr, v.Channels) if err != nil { Sugar.Errorf("解绑通道失败 err: %s", err.Error()) return nil, err diff --git a/broadcast.go b/broadcast.go index 9bf85a8..b9c0114 100644 --- a/broadcast.go +++ b/broadcast.go @@ -65,20 +65,21 @@ func (d *Device) DoBroadcast(sourceId, channelId string) error { // OnInvite 语音广播 func (d *Device) OnInvite(request sip.Request, user string) sip.Response { - sink := BroadcastDialogs.Find(user) - if sink == nil { + streamWaiting := Dialogs.Find(user) + if streamWaiting == nil { return CreateResponseWithStatusCode(request, http.StatusBadRequest) } + sink := streamWaiting.data.(*Sink) body := request.Body() offer, err := sdp.Parse(body) if err != nil { - Sugar.Infof("广播失败, 解析sdp发生err: %s sink: %s sdp: %s", err.Error(), sink.ID, body) - sink.onPublishCb <- http.StatusBadRequest + Sugar.Infof("广播失败, 解析sdp发生err: %s sink: %s sdp: %s", err.Error(), sink.SinkID, body) + streamWaiting.Put(http.StatusBadRequest) return CreateResponseWithStatusCode(request, http.StatusBadRequest) } else if offer.Audio == nil { - Sugar.Infof("广播失败, offer中缺少audio字段. sink: %s sdp: %s", sink.ID, body) - sink.onPublishCb <- http.StatusBadRequest + Sugar.Infof("广播失败, offer中缺少audio字段. sink: %s sdp: %s", sink.SinkID, body) + streamWaiting.Put(http.StatusBadRequest) return CreateResponseWithStatusCode(request, http.StatusBadRequest) } @@ -91,10 +92,10 @@ func (d *Device) OnInvite(request sip.Request, user string) sip.Response { } addr := net.JoinHostPort(offer.Addr, strconv.Itoa(int(offer.Audio.Port))) - host, port, sinkId, err := CreateAnswer(string(sink.Stream), addr, offerSetup.String(), answerSetup.String(), "", string(InviteTypeBroadcast)) + host, port, sinkId, err := CreateAnswer(string(sink.StreamID), addr, offerSetup.String(), answerSetup.String(), "", string(InviteTypeBroadcast)) if err != nil { - Sugar.Errorf("广播失败, 流媒体创建answer发生err: %s sink: %s ", err.Error(), sink.ID) - sink.onPublishCb <- http.StatusInternalServerError + Sugar.Errorf("广播失败, 流媒体创建answer发生err: %s sink: %s ", err.Error(), sink.SinkID) + streamWaiting.Put(http.StatusInternalServerError) return CreateResponseWithStatusCode(request, http.StatusInternalServerError) } @@ -111,13 +112,13 @@ func (d *Device) OnInvite(request sip.Request, user string) sip.Response { response := CreateResponseWithStatusCode(request, http.StatusOK) setToTag(response) - sink.ID = sinkId - sink.Dialog = d.CreateDialogRequestFromAnswer(response, true) + sink.SinkID = sinkId + sink.SetDialog(d.CreateDialogRequestFromAnswer(response, true)) response.SetBody(answerSDP, true) response.AppendHeader(&SDPMessageType) response.AppendHeader(GlobalContactAddress.AsContactHeader()) - sink.onPublishCb <- http.StatusOK + streamWaiting.Put(http.StatusOK) return response } diff --git a/broadcast_dialogs.go b/broadcast_dialogs.go deleted file mode 100644 index 7ea7e2b..0000000 --- a/broadcast_dialogs.go +++ /dev/null @@ -1,39 +0,0 @@ -package main - -import "sync" - -// BroadcastDialogs 临时保存广播会话 -var BroadcastDialogs = &broadcastDialogs{ - dialogs: make(map[string]*Sink), -} - -type broadcastDialogs struct { - lock sync.RWMutex - dialogs map[string]*Sink -} - -func (b *broadcastDialogs) Add(id string, dialog *Sink) (old *Sink, ok bool) { - b.lock.Lock() - defer b.lock.Unlock() - - if old, ok = b.dialogs[id]; ok { - return old, false - } - - b.dialogs[id] = dialog - return nil, true -} - -func (b *broadcastDialogs) Find(id string) *Sink { - b.lock.RLock() - defer b.lock.RUnlock() - return b.dialogs[id] -} - -func (b *broadcastDialogs) Remove(id string) *Sink { - b.lock.Lock() - defer b.lock.Unlock() - dialog := b.dialogs[id] - delete(b.dialogs, id) - return dialog -} diff --git a/broadcast_manager.go b/broadcast_manager.go deleted file mode 100644 index 2f26b82..0000000 --- a/broadcast_manager.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import "sync" - -//var BroadcastManager = &broadcastManager{ -// streams: make(map[StreamID]*Sink), -// callIds: make(map[string]*Sink), -//} - -type broadcastManager struct { - streams map[StreamID]*Sink // device stream id ->sink - callIds map[string]*Sink // invite call id->sink - lock sync.RWMutex -} - -func (b *broadcastManager) Add(id StreamID, sink *Sink) (old *Sink, ok bool) { - b.lock.Lock() - defer b.lock.Unlock() - old, ok = b.streams[id] - if ok { - return old, false - } - b.streams[id] = sink - return nil, true -} - -func (b *broadcastManager) AddWithCallId(id string, sink *Sink) bool { - b.lock.Lock() - defer b.lock.Unlock() - if _, ok := b.callIds[id]; ok { - return false - } - b.callIds[id] = sink - return true -} - -func (b *broadcastManager) Find(id StreamID) *Sink { - b.lock.RLock() - defer b.lock.RUnlock() - return b.streams[id] -} - -func (b *broadcastManager) FindWithCallId(id string) *Sink { - b.lock.RLock() - defer b.lock.RUnlock() - return b.callIds[id] -} - -func (b *broadcastManager) Remove(id StreamID) *Sink { - b.lock.Lock() - defer b.lock.Unlock() - sink, ok := b.streams[id] - if !ok { - return nil - } - - if sink.Dialog != nil { - callID, _ := sink.Dialog.CallID() - delete(b.callIds, callID.String()) - } - - delete(b.streams, id) - return sink -} - -func (b *broadcastManager) RemoveWithCallId(id string) *Sink { - b.lock.Lock() - defer b.lock.Unlock() - sink, ok := b.callIds[id] - if !ok { - return nil - } - - delete(b.callIds, id) - delete(b.streams, sink.Stream) - return sink -} diff --git a/client.go b/client.go index 855c561..5646fb5 100644 --- a/client.go +++ b/client.go @@ -139,6 +139,6 @@ func NewGBClient(params *SIPUAParams, ua SipServer) GBClient { ua: ua, } - client := &Client{sip, Device{ID: params.Username}, &DeviceInfoResponse{BaseResponse: BaseResponse{BaseMessage: BaseMessage{DeviceID: params.Username, CmdType: CmdDeviceInfo}, Result: "OK"}}} + client := &Client{sip, Device{DeviceID: params.Username}, &DeviceInfoResponse{BaseResponse: BaseResponse{BaseMessage: BaseMessage{DeviceID: params.Username, CmdType: CmdDeviceInfo}, Result: "OK"}}} return client } diff --git a/client_benchmark_test.go b/client_benchmark_test.go index 157f978..d0abe5f 100644 --- a/client_benchmark_test.go +++ b/client_benchmark_test.go @@ -220,7 +220,7 @@ package main // // // 绑定到StreamManager, bye请求才会找到设备回调 // streamId := GenerateStreamID(InviteTypePlay, v.sipClient.Username, user, "", "") -// s := Stream{ID: streamId, Dialog: stream.dialog} +// s := StreamID{StreamID: streamId, Dialog: stream.dialog} // StreamManager.Add(&s) // // callID, _ := request.CallID() diff --git a/config.json b/config.json index 3ddb812..fa4c2f0 100644 --- a/config.json +++ b/config.json @@ -14,11 +14,6 @@ "?auto_close_on_idle": "拉流空闲时, 立即关闭流", "auto_close_on_idle": true, - "redis": { - "addr": "0.0.0.0:6379", - "password": "" - }, - "hooks": { "?online": "设备上线通知", "online": "", diff --git a/dao_channel.go b/dao_channel.go new file mode 100644 index 0000000..3a9f9d3 --- /dev/null +++ b/dao_channel.go @@ -0,0 +1,77 @@ +package main + +import "gorm.io/gorm" + +type DaoChannel interface { + SaveChannel(deviceId string, channel *Channel) error + + UpdateChannelStatus(deviceId, channelId, status string) error + + QueryChannel(deviceId string, channelId string) (*Channel, error) + + QueryChannels(deviceId string, page, size int) ([]*Channel, int, error) + + QueryChanelCount(deviceId string) (int, error) + + QueryOnlineChanelCount(deviceId string) (int, error) +} + +type daoChannel struct { +} + +func (d *daoChannel) SaveChannel(deviceId string, channel *Channel) error { + return DBTransaction(func(tx *gorm.DB) error { + var old Channel + if db.Select("id").Where("parent_id =? and device_id =?", deviceId, channel.DeviceID).Take(&old).Error == nil { + channel.ID = old.ID + } + return tx.Save(channel).Error + }) +} + +func (d *daoChannel) UpdateChannelStatus(deviceId, channelId, status string) error { + return db.Model(&Channel{}).Where("parent_id =? and device_id =?", deviceId, channelId).Update("status", status).Error +} + +func (d *daoChannel) QueryChannel(deviceId string, channelId string) (*Channel, error) { + var channel Channel + tx := db.Where("parent_id =? and device_id =?", deviceId, channelId).Take(&channel) + if tx.Error != nil { + return nil, tx.Error + } + return &channel, nil +} + +func (d *daoChannel) QueryChannels(deviceId string, page, size int) ([]*Channel, int, error) { + var channels []*Channel + tx := db.Limit(size).Offset((page-1)*size).Where("parent_id =?", deviceId).Find(&channels) + if tx.Error != nil { + return nil, 0, tx.Error + } + + var total int64 + tx = db.Model(&Channel{}).Where("parent_id =?", deviceId).Count(&total) + if tx.Error != nil { + return nil, 0, tx.Error + } + + return channels, int(total), nil +} + +func (d *daoChannel) QueryChanelCount(deviceId string) (int, error) { + var total int64 + tx := db.Model(&Channel{}).Where("parent_id =?", deviceId).Count(&total) + if tx.Error != nil { + return 0, tx.Error + } + return int(total), nil +} + +func (d *daoChannel) QueryOnlineChanelCount(deviceId string) (int, error) { + var total int64 + tx := db.Model(&Channel{}).Where("parent_id =? and status =?", deviceId, "ON").Count(&total) + if tx.Error != nil { + return 0, tx.Error + } + return int(total), nil +} diff --git a/dao_device.go b/dao_device.go new file mode 100644 index 0000000..376b093 --- /dev/null +++ b/dao_device.go @@ -0,0 +1,129 @@ +package main + +import ( + "gorm.io/gorm" + "time" +) + +type DaoDevice interface { + LoadOnlineDevices() (map[string]*Device, error) + + LoadDevices() (map[string]*Device, error) + + SaveDevice(device *Device) error + + RefreshHeartbeat(deviceId string, now time.Time, addr string) error + + QueryDevice(id string) (*Device, error) + + QueryDevices(page int, size int) ([]*Device, int, error) + + UpdateDeviceStatus(deviceId string, status OnlineStatus) error + + UpdateDeviceInfo(deviceId string, device *Device) error + + UpdateOfflineDevices(deviceIds []string) error +} + +type daoDevice struct { +} + +func (d *daoDevice) LoadOnlineDevices() (map[string]*Device, error) { + //TODO implement me + panic("implement me") +} + +func (d *daoDevice) LoadDevices() (map[string]*Device, error) { + var devices []*Device + tx := db.Find(&devices) + if tx.Error != nil { + return nil, tx.Error + } + + deviceMap := make(map[string]*Device) + for _, device := range devices { + deviceMap[device.DeviceID] = device + } + + return deviceMap, nil +} + +func (d *daoDevice) SaveDevice(device *Device) error { + return DBTransaction(func(tx *gorm.DB) error { + old := Device{} + if db.Select("id").Where("device_id =?", device.DeviceID).Take(&old).Error == nil { + device.ID = old.ID + } + + if device.ID == 0 { + //return tx.Create(&old).Error + return tx.Save(device).Error + } else { + return tx.Model(device).Select("Transport", "RemoteAddr", "Status", "RegisterTime", "LastHeartbeat").Updates(*device).Error + } + }) +} + +func (d *daoDevice) UpdateDeviceInfo(deviceId string, device *Device) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Model(&Device{}).Select("Manufacturer", "Model", "Firmware", "Name").Where("device_id =?", deviceId).Updates(*device).Error + }) +} + +func (d *daoDevice) UpdateDeviceStatus(deviceId string, status OnlineStatus) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Model(&Device{}).Where("device_id =?", deviceId).Update("status", status).Error + }) +} + +func (d *daoDevice) RefreshHeartbeat(deviceId string, now time.Time, addr string) error { + if tx := db.Select("id").Take(&Device{}, "device_id =?", deviceId); tx.Error != nil { + return tx.Error + } + return DBTransaction(func(tx *gorm.DB) error { + return tx.Model(&Device{}).Select("LastHeartbeat", "Status", "RemoteAddr").Where("device_id =?", deviceId).Updates(&Device{ + LastHeartbeat: now, + Status: ON, + RemoteAddr: addr, + }).Error + }) +} + +func (d *daoDevice) QueryDevice(id string) (*Device, error) { + var device Device + tx := db.Where("device_id =?", id).Take(&device) + if tx.Error != nil { + return nil, tx.Error + } + + return &device, nil +} + +func (d *daoDevice) QueryDevices(page int, size int) ([]*Device, int, error) { + var devices []*Device + tx := db.Limit(size).Offset((page - 1) * size).Find(&devices) + if tx.Error != nil { + return nil, 0, tx.Error + } + + var total int64 + tx = db.Model(&Device{}).Count(&total) + if tx.Error != nil { + return nil, 0, tx.Error + } + + for _, device := range devices { + count, _ := ChannelDao.QueryChanelCount(device.DeviceID) + online, _ := ChannelDao.QueryOnlineChanelCount(device.DeviceID) + device.ChannelsOnline = online + device.ChannelsTotal = count + } + + return devices, int(total), nil +} + +func (d *daoDevice) UpdateOfflineDevices(deviceIds []string) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Model(&Device{}).Where("device_id in ?", deviceIds).Update("status", OFF).Error + }) +} diff --git a/dao_platform.go b/dao_platform.go new file mode 100644 index 0000000..adef939 --- /dev/null +++ b/dao_platform.go @@ -0,0 +1,150 @@ +package main + +type DaoPlatform interface { + LoadPlatforms() ([]*SIPUAParams, error) + + QueryPlatform(addr string) (*SIPUAParams, error) + + SavePlatform(platform *SIPUAParams) error + + DeletePlatform(addr string) error + + UpdatePlatform(platform *SIPUAParams) error + + UpdatePlatformStatus(addr string, status OnlineStatus) error + + BindChannels(addr string, channels [][2]string) ([][2]string, error) + + UnbindChannels(addr string, channels [][2]string) ([][2]string, error) + + // QueryPlatformChannel 查询级联设备的某个通道, 返回通道所属设备ID、通道. + QueryPlatformChannel(addr string, channelId string) (string, *Channel, error) + + QueryPlatformChannels(addr string) ([]*Channel, error) +} + +type daoPlatform struct { +} + +func (d *daoPlatform) LoadPlatforms() ([]*SIPUAParams, error) { + var platforms []*SIPUAParams + tx := db.Find(&platforms) + if tx.Error != nil { + return nil, tx.Error + } + + return platforms, nil +} + +func (d *daoPlatform) QueryPlatform(addr string) (*SIPUAParams, error) { + var platform SIPUAParams + tx := db.Where("server_addr =?", addr).First(&platform) + if tx.Error != nil { + return nil, tx.Error + } + + return &platform, nil +} + +func (d *daoPlatform) SavePlatform(platform *SIPUAParams) error { + var old SIPUAParams + tx := db.Where("server_addr =?", platform.ServerAddr).First(&old) + if tx.Error == nil { + platform.ID = old.ID + } + return db.Save(platform).Error +} + +func (d *daoPlatform) DeletePlatform(addr string) error { + return db.Where("server_addr =?", addr).Unscoped().Delete(&SIPUAParams{}).Error +} + +func (d *daoPlatform) UpdatePlatform(platform *SIPUAParams) error { + //TODO implement me + panic("implement me") +} + +func (d *daoPlatform) UpdatePlatformStatus(addr string, status OnlineStatus) error { + return db.Model(&SIPUAParams{}).Where("server_addr =?", addr).Update("status", status).Error +} + +type DBPlatformChannel struct { + GBModel + DeviceID string `json:"device_id"` + Channel string `json:"channel_id"` + ServerAddr string `json:"server_addr"` +} + +func (d *DBPlatformChannel) TableName() string { + return "lkm_platform_channel" +} + +func (d *daoPlatform) BindChannels(addr string, channels [][2]string) ([][2]string, error) { + var res [][2]string + for _, channel := range channels { + + var old DBPlatformChannel + _ = db.Where("device_id =? and channel_id =? and server_addr =?", channel[0], channel[1], addr).First(&old) + if old.ID == 0 { + _ = db.Create(&DBPlatformChannel{ + DeviceID: channel[0], + Channel: channel[1], + }) + } + res = append(res, channel) + } + + return res, nil +} + +func (d *daoPlatform) UnbindChannels(addr string, channels [][2]string) ([][2]string, error) { + var res [][2]string + for _, channel := range channels { + tx := db.Unscoped().Delete(&DBPlatformChannel{}, "device_id =? and channel_id =? and server_addr =?", channel[0], channel[1], addr) + if tx.Error == nil { + res = append(res, channel) + } else { + Sugar.Errorf("解绑级联设备通道失败. device_id: %s, channel_id: %s err: %s", channel[0], channel[1], tx.Error) + } + } + + return res, nil +} + +func (d *daoPlatform) QueryPlatformChannel(addr string, channelId string) (string, *Channel, error) { + var platformChannel DBPlatformChannel + tx := db.Model(&DBPlatformChannel{}).Where("channel_id =? and server_addr =?", channelId, addr).First(&platformChannel) + if tx.Error != nil { + return "", nil, tx.Error + } + + var channel Channel + tx = db.Where("device_id =? and channel_id =?", platformChannel.DeviceID, platformChannel.Channel).First(&channel) + if tx.Error != nil { + return "", nil, tx.Error + } + + return platformChannel.DeviceID, &channel, nil +} + +func (d *daoPlatform) QueryPlatformChannels(addr string) ([]*Channel, error) { + var platformChannels []*DBPlatformChannel + tx := db.Where("server_addr =?", addr).Find(&platformChannels) + if tx.Error != nil { + return nil, tx.Error + } + + var channels []*Channel + for _, platformChannel := range platformChannels { + var channel Channel + tx = db.Where("device_id =? and channel_id =?", platformChannel.DeviceID, platformChannel.Channel).First(&channel) + if tx.Error == nil { + channels = append(channels, &channel) + } else { + Sugar.Errorf("查询级联设备通道失败. device_id: %s, channel_id: %s err: %s", platformChannel.DeviceID, platformChannel.Channel, tx.Error) + } + + } + + return channels, nil +} diff --git a/dao_sink.go b/dao_sink.go new file mode 100644 index 0000000..9f66d88 --- /dev/null +++ b/dao_sink.go @@ -0,0 +1,157 @@ +package main + +import ( + "fmt" + "gorm.io/gorm" +) + +type DaoSink interface { + LoadForwardSinks() (map[string]*Sink, error) + + // QueryForwardSink 查询转发流Sink + QueryForwardSink(stream StreamID, sink string) (*Sink, error) + + QueryForwardSinks(stream StreamID) (map[string]*Sink, error) + + // SaveForwardSink 保存转发流Sink + SaveForwardSink(stream StreamID, sink *Sink) error + + DeleteForwardSink(stream StreamID, sink string) (*Sink, error) + + DeleteForwardSinksByStreamID(stream StreamID) ([]*Sink, error) + + DeleteForwardSinks() ([]*Sink, error) + + DeleteForwardSinksByIds(ids []uint) error + + QueryForwardSinkByCallID(callID string) (*Sink, error) + + DeleteForwardSinkByCallID(callID string) (*Sink, error) + + DeleteForwardSinkBySinkStreamID(sinkStreamID StreamID) (*Sink, error) +} + +type daoSink struct { +} + +func (d *daoSink) LoadForwardSinks() (map[string]*Sink, error) { + var sinks []*Sink + tx := db.Find(&sinks) + if tx.Error != nil { + return nil, tx.Error + } + + sinkMap := make(map[string]*Sink) + for _, sink := range sinks { + sinkMap[sink.SinkID] = sink + } + return sinkMap, nil +} + +func (d *daoSink) QueryForwardSink(stream StreamID, sinkId string) (*Sink, error) { + var sink Sink + db.Where("stream_id =? and sink_id =?", stream, sinkId).Take(&sink) + return &sink, db.Error +} + +func (d *daoSink) QueryForwardSinks(stream StreamID) (map[string]*Sink, error) { + var sinks []*Sink + tx := db.Where("stream_id =?", stream).Find(&sinks) + if tx.Error != nil { + return nil, tx.Error + } + + sinkMap := make(map[string]*Sink) + for _, sink := range sinks { + sinkMap[sink.SinkID] = sink + } + return sinkMap, nil +} + +func (d *daoSink) SaveForwardSink(stream StreamID, sink *Sink) error { + var old Sink + tx := db.Select("id").Where("sink_id =?", sink.SinkID).Take(&old) + if tx.Error == nil { + return fmt.Errorf("sink already exists") + } + + return DBTransaction(func(tx *gorm.DB) error { + return tx.Save(sink).Error + }) +} + +func (d *daoSink) DeleteForwardSink(stream StreamID, sinkId string) (*Sink, error) { + var sink Sink + tx := db.Where("sink_id =?", sinkId).Take(&sink) + if tx.Error != nil { + return nil, tx.Error + } + + return &sink, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("sink_id =?", sinkId).Unscoped().Delete(&Sink{}).Error + }) +} + +func (d *daoSink) DeleteForwardSinksByStreamID(stream StreamID) ([]*Sink, error) { + var sinks []*Sink + tx := db.Where("stream_id =?", stream).Find(&sinks) + if tx.Error != nil { + return nil, tx.Error + } + + return sinks, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("stream_id =?", stream).Unscoped().Delete(&Sink{}).Error + }) +} + +func (d *daoSink) QueryForwardSinkByCallID(callID string) (*Sink, error) { + var sinks Sink + tx := db.Where("call_id =?", callID).Find(&sinks) + if tx.Error != nil { + return nil, tx.Error + } + + return &sinks, nil +} + +func (d *daoSink) DeleteForwardSinkByCallID(callID string) (*Sink, error) { + var sink Sink + tx := db.Where("call_id =?", callID).First(&sink) + if tx.Error != nil { + return nil, tx.Error + } + + return &sink, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("call_id =?", callID).Unscoped().Delete(&Sink{}).Error + }) +} + +func (d *daoSink) DeleteForwardSinkBySinkStreamID(sinkStreamId StreamID) (*Sink, error) { + var sink Sink + tx := db.Where("sink_stream_id =?", sinkStreamId).First(&sink) + if tx.Error != nil { + return nil, tx.Error + } + + return &sink, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("sink_stream_id =?", sinkStreamId).Unscoped().Delete(&Sink{}).Error + }) +} + +func (d *daoSink) DeleteForwardSinks() ([]*Sink, error) { + var sinks []*Sink + tx := db.Find(&sinks) + if tx.Error != nil { + return nil, tx.Error + } + + return sinks, DBTransaction(func(tx *gorm.DB) error { + return tx.Unscoped().Delete(&Sink{}).Error + }) +} + +func (d *daoSink) DeleteForwardSinksByIds(ids []uint) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Where("id in?", ids).Unscoped().Delete(&Sink{}).Error + }) +} diff --git a/dao_stream.go b/dao_stream.go new file mode 100644 index 0000000..5ec8e51 --- /dev/null +++ b/dao_stream.go @@ -0,0 +1,135 @@ +package main + +import ( + "github.com/lkmio/avformat/utils" + "gorm.io/gorm" +) + +type DaoStream interface { + LoadStreams() (map[string]*Stream, error) + + SaveStream(stream *Stream) (*Stream, bool) + + UpdateStream(stream *Stream) error + + DeleteStream(streamId StreamID) (*Stream, error) + + DeleteStreams() ([]*Stream, error) + + DeleteStreamsByIds(ids []uint) error + + QueryStream(streamId StreamID) (*Stream, error) + + QueryStreamByCallID(callID string) (*Stream, error) + + DeleteStreamByCallID(callID string) (*Stream, error) +} + +type daoStream struct { +} + +func (d *daoStream) LoadStreams() (map[string]*Stream, error) { + var streams []*Stream + tx := db.Find(&streams) + if tx.Error != nil { + return nil, tx.Error + } + + streamMap := make(map[string]*Stream) + for _, stream := range streams { + streamMap[string(stream.StreamID)] = stream + } + + return streamMap, nil +} + +func (d *daoStream) SaveStream(stream *Stream) (*Stream, bool) { + var old Stream + tx := db.Select("id").Where("stream_id =?", stream.StreamID).Take(&old) + if old.ID != 0 { + return &old, false + } + // stream唯一必须不存在 + utils.Assert(tx.Error != nil) + return nil, DBTransaction(func(tx *gorm.DB) error { + return tx.Save(stream).Error + }) == nil +} + +func (d *daoStream) UpdateStream(stream *Stream) error { + var old Stream + tx := db.Where("stream_id =?", stream.StreamID).Take(&old) + if tx.Error != nil { + return tx.Error + } + + stream.ID = old.ID + return DBTransaction(func(tx *gorm.DB) error { + return tx.Save(stream).Error + }) +} + +func (d *daoStream) DeleteStream(streamId StreamID) (*Stream, error) { + var stream Stream + tx := db.Where("stream_id =?", streamId).Take(&stream) + if tx.Error != nil { + return nil, tx.Error + } + + return &stream, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("stream_id =?", streamId).Unscoped().Delete(&Stream{}).Error + }) +} + +func (d *daoStream) DeleteStreamsByIds(ids []uint) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Where("id in ?", ids).Unscoped().Delete(&Stream{}).Error + }) +} + +func (d *daoStream) DeleteStreams() ([]*Stream, error) { + var streams []*Stream + tx := db.Find(&streams) + if tx.Error != nil { + return nil, tx.Error + } + + DBTransaction(func(tx *gorm.DB) error { + for _, stream := range streams { + _ = tx.Where("stream_id =?", stream.StreamID).Unscoped().Delete(&Stream{}) + } + return nil + }) + + return streams, nil +} + +func (d *daoStream) QueryStream(streamId StreamID) (*Stream, error) { + var stream Stream + tx := db.Where("stream_id =?", streamId).Take(&stream) + if tx.Error != nil { + return nil, tx.Error + } + return &stream, nil +} + +func (d *daoStream) QueryStreamByCallID(callID string) (*Stream, error) { + var stream Stream + tx := db.Where("call_id =?", callID).Take(&stream) + if tx.Error != nil { + return nil, tx.Error + } + return &stream, nil +} + +func (d *daoStream) DeleteStreamByCallID(callID string) (*Stream, error) { + var stream Stream + tx := db.Where("call_id =?", callID).Take(&stream) + if tx.Error != nil { + return nil, tx.Error + } + + return &stream, DBTransaction(func(tx *gorm.DB) error { + return tx.Where("call_id =?", callID).Unscoped().Delete(&Stream{}).Error + }) +} diff --git a/db.go b/db.go deleted file mode 100644 index 36cbc4e..0000000 --- a/db.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -type GB28181DB interface { - LoadOnlineDevices() (map[string]*Device, error) - - LoadDevices() (map[string]*Device, error) - - SaveDevice(device *Device) error - - SaveChannel(deviceId string, channel *Channel) error - - UpdateDeviceStatus(deviceId string, status OnlineStatus) error - - UpdateChannelStatus(channelId, status string) error - - RefreshHeartbeat(deviceId string) error - - QueryDevice(id string) (*Device, error) - - QueryDevices(page int, size int) ([]*Device, int, error) - - QueryChannel(deviceId string, channelId string) (*Channel, error) - - QueryChannels(deviceId string, page, size int) ([]*Channel, int, error) - - LoadPlatforms() ([]*SIPUAParams, error) - - QueryPlatform(addr string) (*SIPUAParams, error) - - SavePlatform(platform *SIPUAParams) error - - DeletePlatform(addr string) error - - UpdatePlatform(platform *SIPUAParams) error - - UpdatePlatformStatus(addr string, status OnlineStatus) error - - BindChannels(addr string, channels [][2]string) ([][2]string, error) - - UnbindChannels(addr string, channels [][2]string) ([][2]string, error) - - // QueryPlatformChannel 查询级联设备的某个通道, 返回通道所属设备ID、通道. - QueryPlatformChannel(addr string, channelId string) (string, *Channel, error) - - QueryPlatformChannels(addr string) ([]*Channel, error) - - LoadStreams() (map[string]*Stream, error) - - SaveStream(stream *Stream) error - - DeleteStream(time int64) error - - //QueryStream(pate int, size int) - - // QueryForwardSink 查询转发流Sink - QueryForwardSink(stream StreamID, sink string) (*Sink, error) - - QueryForwardSinks(stream StreamID) (map[string]*Sink, error) - - // SaveForwardSink 保存转发流Sink - SaveForwardSink(stream StreamID, sink *Sink) error - - DeleteForwardSink(stream StreamID, sink string) error - - DeleteForwardSinks(stream StreamID) error - - Del(key string) error -} diff --git a/db_redis.go b/db_redis.go deleted file mode 100644 index 8844565..0000000 --- a/db_redis.go +++ /dev/null @@ -1,722 +0,0 @@ -package main - -import ( - "encoding/hex" - "encoding/json" - "fmt" - "strconv" - "strings" - "sync" - "time" -) - -const ( - RedisKeyDevices = "devices" // 使用map保存所有设备信息(不包含通道信息) - RedisKeyDevicesSort = "devices_sort" // 使用zset有序保存所有设备ID(按照入库时间) - RedisKeyChannels = "channels" // 使用map保存所有通道信息 - RedisKeyDeviceChannels = "%s_channels" // 使用zset保存设备下的所有通道ID - RedisKeyPlatforms = "platforms" // 使用zset有序保存所有级联设备 - RedisUniqueChannelID = "%s_%s" // 通道号的唯一ID, 设备_通道号 - - // RedisKeyStreams 保存推拉流信息, 主要目的是程序崩溃重启后,恢复国标流的invite会话. 如果需要统计所有详细的推拉流信息,需要自行实现. - RedisKeyStreams = "streams" //// 保存所有推流端信息 - RedisKeySinks = "sinks" //// 保存所有拉流端信息 - RedisKeyStreamSinks = "%s_sinks" //// 某路流下所有的拉流端 - - RedisKeyDialogs = "streams" - RedisKeyForwardSinks = "forward_%s" -) - -type RedisDB struct { - utils *RedisUtils - platformsLock sync.Mutex -} - -type ChannelKey string - -func (c ChannelKey) Device() string { - return strings.Split(string(c), "_")[0] -} - -func (c ChannelKey) Channel() string { - return strings.Split(string(c), "_")[1] -} - -func (c ChannelKey) String() string { - return string(c) -} - -// DeviceChannelsKey 返回设备通道列表的主键 -func DeviceChannelsKey(id string) string { - return fmt.Sprintf(RedisKeyDeviceChannels, id) -} - -func ForwardSinksKey(id string) string { - return fmt.Sprintf(RedisKeyForwardSinks, id) -} - -// UniqueChannelKey 使用设备号+通道号作为通道的主键,兼容通道号可能重复的情况 -func UniqueChannelKey(device, channel string) ChannelKey { - return ChannelKey(fmt.Sprintf(RedisUniqueChannelID, device, channel)) -} - -func (r *RedisDB) LoadOnlineDevices() (map[string]*Device, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - keys, err := executor.Keys() - if err != nil { - return nil, err - } - - devices := make(map[string]*Device, len(keys)) - for _, key := range keys { - device, err := r.findDevice(key, executor) - if err != nil || device == nil { - continue - } - - devices[key] = device - } - - return devices, nil -} - -func (r *RedisDB) findDevice(id string, executor Executor) (*Device, error) { - value, err := executor.Key(RedisKeyDevices).HGet(id) - if err != nil { - return nil, err - } else if value == nil { - return nil, nil - } - - device := &Device{} - err = json.Unmarshal(value, device) - if err != nil { - return nil, err - } - - return device, nil -} - -func (r *RedisDB) findChannel(id ChannelKey, executor Executor) (*Channel, error) { - value, err := executor.HGet(id.String()) - if err != nil { - return nil, err - } else if value == nil { - return nil, nil - } - - channel := &Channel{} - err = json.Unmarshal(value, channel) - if err != nil { - return nil, err - } - - return channel, nil -} - -func (r *RedisDB) LoadDevices() (map[string]*Device, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - entries, err := executor.Key(RedisKeyDevices).HGetAll() - - devices := make(map[string]*Device, len(entries)) - for k, v := range entries { - device := &Device{} - if err = json.Unmarshal(v, device); err != nil { - continue - } - - devices[k] = device - } - - return devices, err -} - -func (r *RedisDB) SaveDevice(device *Device) error { - data, err := json.Marshal(device) - if err != nil { - return err - } - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - // 保存设备信息 - } else if err = executor.Key(RedisKeyDevices).HSet(device.ID, string(data)); err != nil { - return err - } - - return r.UpdateDeviceStatus(device.ID, device.Status) -} - -func (r *RedisDB) SaveChannel(deviceId string, channel *Channel) error { - setup := SetupTypePassive - oldChannel, err := r.QueryChannel(deviceId, channel.DeviceID) - if err != nil { - return err - } else if oldChannel != nil { - setup = oldChannel.SetupType - } - - channel.SetupType = setup - data, err := json.Marshal(channel) - if err != nil { - return err - } - - channelKey := UniqueChannelKey(deviceId, channel.DeviceID) - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - // 保存通道信息 - } else if err = executor.Key(RedisKeyChannels).HSet(channelKey.String(), string(data)); err != nil { - return err - // 通道关联到Device - } else if err = executor.Key(DeviceChannelsKey(deviceId)).ZAddWithNotExists(float64(time.Now().UnixMilli()), channelKey); err != nil { - return err - } - - return nil -} - -func (r *RedisDB) UpdateDeviceStatus(deviceId string, status OnlineStatus) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - // 如果在线, 设置有效期key, 添加到设备排序表 - if ON == status { - // 设置有效期key - if err = executor.Key(deviceId).Set(nil); err != nil { - return err - } else if err = executor.SetExpires(Config.AliveExpires); err != nil { - return err - // 排序Device,根据入库时间 - } else if err = executor.Key(RedisKeyDevicesSort).ZAddWithNotExists(float64(time.Now().UnixMilli()), deviceId); err != nil { - return err - } - } else { - // 删除有效key - return executor.Key(deviceId).Del() - } - - return nil -} - -func (r *RedisDB) UpdateChannelStatus(channelId, status string) error { - //TODO implement me - panic("implement me") -} - -func (r *RedisDB) RefreshHeartbeat(deviceId string) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } else if err = executor.Key(deviceId).Set(strconv.FormatInt(time.Now().UnixMilli(), 10)); err != nil { - return err - } - - return executor.SetExpires(Config.AliveExpires) -} - -func (r *RedisDB) QueryDevice(id string) (*Device, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - return r.findDevice(id, executor) -} - -func (r *RedisDB) QueryDevices(page int, size int) ([]*Device, int, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, 0, err - } - - keys, err := executor.Key(RedisKeyDevicesSort).ZRangeWithAsc(page, size) - if err != nil { - return nil, 0, err - } - - var devices []*Device - for _, key := range keys { - device, err := r.findDevice(key, executor) - if err != nil { - continue - } - - devices = append(devices, device) - } - - // 查询总记录数 - total, err := executor.Key(RedisKeyDevicesSort).CountZSet() - if err != nil { - return nil, 0, err - } - - return devices, total, nil -} - -func (r *RedisDB) QueryChannel(deviceId string, channelId string) (*Channel, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - executor.Key(RedisKeyChannels) - return r.findChannel(UniqueChannelKey(deviceId, channelId), executor) -} - -func (r *RedisDB) QueryChannels(deviceId string, page, size int) ([]*Channel, int, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, 0, err - } - - id := fmt.Sprintf(RedisKeyDeviceChannels, deviceId) - keys, err := executor.Key(id).ZRangeWithAsc(page, size) - if err != nil { - return nil, 0, err - } - - executor.Key(RedisKeyChannels) - var channels []*Channel - for _, key := range keys { - channel, err := r.findChannel(ChannelKey(key), executor) - if err != nil { - continue - } else if channel == nil { - continue - } - - channels = append(channels, channel) - } - - // 查询总记录数 - total, err := executor.Key(id).CountZSet() - if err != nil { - return nil, 0, err - } - - return channels, total, nil -} - -func (r *RedisDB) LoadPlatforms() ([]*SIPUAParams, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - var platforms []*SIPUAParams - pairs, err := executor.Key(RedisKeyPlatforms).ZRange() - if err == nil { - for _, pair := range pairs { - platform := &SIPUAParams{} - if err := json.Unmarshal([]byte(pair[0]), platform); err != nil { - continue - } - - platform.CreateTime = pair[1] - platforms = append(platforms, platform) - } - } - - return platforms, err -} - -func (r *RedisDB) findPlatformWithServerAddr(addr string) (*SIPUAParams, error) { - platforms, err := r.LoadPlatforms() - if err != nil { - return nil, err - } - - for _, platform := range platforms { - if platform.ServerAddr == addr { - return platform, nil - } - } - - return nil, err -} - -func (r *RedisDB) QueryPlatform(addr string) (*SIPUAParams, error) { - return r.findPlatformWithServerAddr(addr) -} - -func (r *RedisDB) SavePlatform(platform *SIPUAParams) error { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - platforms, _ := r.LoadPlatforms() - for _, old := range platforms { - if old.ServerAddr == platform.ServerAddr { - return fmt.Errorf("地址冲突") - } - } - - data, err := json.Marshal(platform) - if err != nil { - return err - } - - return executor.Key(RedisKeyPlatforms).ZAddWithNotExists(platform.CreateTime, data) -} - -func (r *RedisDB) DeletePlatform(addr string) error { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - platform, err := r.findPlatformWithServerAddr(addr) - if err != nil { - return err - } else if platform == nil { - return fmt.Errorf("platform with addr %s not find", addr) - } - - // 删除所有通道, 没有事务 - if err = executor.Key(DeviceChannelsKey(addr)).Del(); err != nil { - return err - } - - return executor.Key(RedisKeyPlatforms).ZDelWithScore(platform.CreateTime) -} - -func (r *RedisDB) UpdatePlatform(platform *SIPUAParams) error { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - oldPlatform, _ := r.findPlatformWithServerAddr(platform.SeverID) - if oldPlatform == nil { - return fmt.Errorf("platform with ID %s not find", platform.SeverID) - } - - data, err := json.Marshal(platform) - if err != nil { - return err - } - - return executor.ZAdd(oldPlatform.CreateTime, data) -} - -func (r *RedisDB) UpdatePlatformStatus(serverId string, status OnlineStatus) error { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - oldPlatform, _ := r.findPlatformWithServerAddr(serverId) - if oldPlatform == nil { - return fmt.Errorf("platform with ID %s not find", serverId) - } - - oldPlatform.Status = status - data, err := json.Marshal(oldPlatform) - if err != nil { - return err - } - - return executor.ZAdd(oldPlatform.CreateTime, data) -} - -func (r *RedisDB) BindChannels(addr string, channels [][2]string) ([][2]string, error) { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - platform, err := r.QueryPlatform(addr) - if err != nil { - return nil, err - } else if platform == nil { - return nil, fmt.Errorf("platform with addr %s not find", addr) - } - - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - // 返回成功的设备通道号 - var result [][2]string - for _, v := range channels { - deviceId := v[0] - channelId := v[1] - - channelKey := UniqueChannelKey(deviceId, channelId) - // 检查通道是否存在, 以及通道是否冲突 - channel, err := r.findChannel(channelKey, executor.Key(RedisKeyChannels)) - if err != nil { - Sugar.Errorf("添加通道失败, err: %s device: %s channel: %s", err.Error(), deviceId, channelId) - } else if channel == nil { - Sugar.Errorf("添加通道失败, 通道不存在. device: %s channel: %s", deviceId, channelId) - } else if device, err := executor.Key(DeviceChannelsKey(addr)).HGet(channelId); err != nil || device != nil { - Sugar.Errorf("添加通道失败, 通道冲突. device: %s channel: %s", deviceId, channelId) - } else if err = executor.Key(DeviceChannelsKey(addr)).HSet(channelId, deviceId); err != nil { - Sugar.Errorf("添加通道失败, err: %s device: %s channel: %s", err.Error(), deviceId, channelId) - } else { - result = append(result, v) - } - } - - return result, nil -} - -func (r *RedisDB) UnbindChannels(id string, channels [][2]string) ([][2]string, error) { - r.platformsLock.Lock() - defer r.platformsLock.Unlock() - - platform, err := r.QueryPlatform(id) - if err != nil { - return nil, err - } else if platform == nil { - return nil, fmt.Errorf("platform with ID %s not find", platform.SeverID) - } - - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - // 返回成功的设备通道号 - var result [][2]string - for _, v := range channels { - if err := executor.Key(DeviceChannelsKey(id)).HDel(v[1]); err != nil { - continue - } - - result = append(result, v) - } - - return result, nil -} - -func (r *RedisDB) QueryPlatformChannel(platformId string, channelId string) (string, *Channel, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return "", nil, err - } - - deviceId, err := executor.Key(DeviceChannelsKey(platformId)).HGet(channelId) - if err != nil { - return "", nil, err - } - - channel, err := r.findChannel(UniqueChannelKey(string(deviceId), channelId), executor.Key(RedisKeyChannels)) - if err != nil { - return "", nil, err - } - - return string(deviceId), channel, nil -} - -func (r *RedisDB) QueryPlatformChannels(serverAddr string) ([]*Channel, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - keys, err := executor.Key(DeviceChannelsKey(serverAddr)).HGetAll() - if err != nil { - return nil, err - } - - var channels []*Channel - for channelId, deviceId := range keys { - channel, err := r.findChannel(UniqueChannelKey(string(deviceId), channelId), executor.Key(RedisKeyChannels)) - if err != nil { - continue - } - - channels = append(channels, channel) - } - - return channels, nil -} - -func (r *RedisDB) LoadStreams() (map[string]*Stream, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - all, err := executor.Key(RedisKeyStreams).ZRange() - if err != nil { - return nil, err - } - - streams := make(map[string]*Stream, len(all)) - for _, v := range all { - stream := &Stream{} - if err := json.Unmarshal([]byte(v[0]), stream); err != nil { - Sugar.Errorf("解析stream失败, err: %s value: %s", err.Error(), hex.EncodeToString([]byte(v[0]))) - continue - } - - streams[string(stream.ID)] = stream - } - - return streams, nil -} - -func (r *RedisDB) SaveStream(stream *Stream) error { - data, err := json.Marshal(stream) - if err != nil { - return err - } - - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - // return executor.Key(RedisKeyStreams).ZAddWithNotExists(stream.CreateTime, data) - return executor.Key(RedisKeyStreams).ZAdd(stream.CreateTime, data) -} - -func (r *RedisDB) DeleteStream(time int64) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - return executor.Key(RedisKeyStreams).ZDelWithScore(time) -} - -func (r *RedisDB) QueryForwardSink(stream StreamID, sinkId string) (*Sink, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - data, err := executor.Key(ForwardSinksKey(string(stream))).HGet(sinkId) - if err != nil { - return nil, err - } - - sink := &Sink{} - if err = json.Unmarshal(data, sink); err != nil { - return nil, err - } - - return sink, nil -} - -func (r *RedisDB) QueryForwardSinks(stream StreamID) (map[string]*Sink, error) { - executor, err := r.utils.CreateExecutor() - if err != nil { - return nil, err - } - - entries, err := executor.Key(ForwardSinksKey(string(stream))).HGetAll() - if err != nil { - return nil, err - } - - var sinks map[string]*Sink - if len(entries) > 0 { - sinks = make(map[string]*Sink, len(entries)) - } - - for _, entry := range entries { - sink := &Sink{} - if err = json.Unmarshal(entry, sink); err != nil { - return nil, err - } - - sinks[sink.ID] = sink - } - - return sinks, nil -} - -func (r *RedisDB) SaveForwardSink(stream StreamID, sink *Sink) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - data, err := json.Marshal(sink) - if err != nil { - return err - } - - return executor.Key(ForwardSinksKey(string(stream))).HSet(sink.ID, data) -} - -func (r *RedisDB) DeleteForwardSink(stream StreamID, sinkId string) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - return executor.Key(ForwardSinksKey(string(stream))).HDel(sinkId) -} - -func (r *RedisDB) Del(key string) error { - executor, err := r.utils.CreateExecutor() - if err != nil { - return err - } - - return executor.Key(key).Del() -} - -func (r *RedisDB) DeleteForwardSinks(stream StreamID) error { - return r.Del(ForwardSinksKey(string(stream))) -} - -// OnExpires Redis设备ID到期回调 -func (r *RedisDB) OnExpires(db int, id string) { - Sugar.Infof("设备心跳过期 device: %s", id) - - device := DeviceManager.Find(id) - if device == nil { - Sugar.Errorf("设备不存在 device: %s", id) - return - } - - device.Close() -} - -func NewRedisDB(addr, password string) *RedisDB { - db := &RedisDB{ - utils: NewRedisUtils(addr, password), - } - - for { - err := StartExpiredKeysSubscription(db.utils, 0, db.OnExpires) - if err == nil { - break - } - - Sugar.Errorf("监听redis过期key失败, err: %s", err.Error()) - time.Sleep(3 * time.Second) - } - - return db -} diff --git a/db_sqlite.go b/db_sqlite.go new file mode 100644 index 0000000..61a676c --- /dev/null +++ b/db_sqlite.go @@ -0,0 +1,140 @@ +package main + +import ( + "context" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "time" +) + +const ( + DBNAME = "lkm_gb.db" + //DBNAME = ":memory:" +) + +var ( + db *gorm.DB + TaskQueue = make(chan *SaveTask, 1024) + DeviceDao = &daoDevice{} + ChannelDao = &daoChannel{} + PlatformDao = &daoPlatform{} + StreamDao = &daoStream{} + SinkDao = &daoSink{} +) + +func init() { + db_, err := gorm.Open(sqlite.Open(DBNAME), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + TablePrefix: "lkm_", + }, + }) + + if err != nil { + panic(err) + } + + db = db_ + + tx := db.Exec("PRAGMA journal_mode=WAL;") + if tx.Error != nil { + panic(tx.Error) + } + + // 每次启动释放空间 + tx = db.Exec("VACUUM;") + if tx.Error != nil { + panic(tx.Error) + } + + s, err := db.DB() + s.SetMaxOpenConns(40) + s.SetMaxIdleConns(10) + + // devices + // channels + // platforms + // streams + // sinks + if err = db.AutoMigrate(&Device{}); err != nil { + panic(err) + } else if err = db.AutoMigrate(&Channel{}); err != nil { + panic(err) + } else if err = db.AutoMigrate(&SIPUAParams{}); err != nil { + panic(err) + } else if err = db.AutoMigrate(&Stream{}); err != nil { + panic(err) + } else if err = db.AutoMigrate(&Sink{}); err != nil { + panic(err) + } else if err = db.AutoMigrate(&DBPlatformChannel{}); err != nil { + panic(err) + } + + StartSaveTask() +} + +type SaveTask struct { + cb func(tx *gorm.DB) error + err error + cancel context.CancelFunc +} + +func StartSaveTask() { + go func() { + for { + var tasks []*SaveTask + for len(TaskQueue) > 0 { + select { + case task := <-TaskQueue: + tasks = append(tasks, task) + } + } + + if len(tasks) == 0 { + time.Sleep(50 * time.Millisecond) + continue + } + + err := db.Transaction(func(tx *gorm.DB) error { + for _, task := range tasks { + task.err = task.cb(tx) + } + return nil + }) + + if err != nil { + Sugar.Errorf("DBTransaction error: %s", err) + } + + for _, task := range tasks { + task.cancel() + } + } + }() +} + +func DBTransaction(cb func(tx *gorm.DB) error) error { + ctx, cancel := context.WithCancel(context.Background()) + task := &SaveTask{ + cb: cb, + cancel: cancel, + } + + TaskQueue <- task + <-ctx.Done() + return task.err +} + +// OnExpires Redis设备ID到期回调 +func OnExpires(db int, id string) { + Sugar.Infof("设备心跳过期 device: %s", id) + + device, _ := DeviceDao.QueryDevice(id) + if device == nil { + Sugar.Errorf("设备不存在 device: %s", id) + return + } + + device.Close() +} diff --git a/device.go b/device.go index 9736487..5cca9bd 100644 --- a/device.go +++ b/device.go @@ -5,6 +5,7 @@ import ( "github.com/ghettovoice/gosip/sip" "net" "strconv" + "time" ) const ( @@ -90,22 +91,24 @@ type GBDevice interface { } type Device struct { - ID string `json:"id"` - Name string `json:"name"` - RemoteAddr string `json:"remote_addr"` - Transport string `json:"transport"` - Status OnlineStatus `json:"status"` //在线状态 ON-在线/OFF-离线 - Manufacturer string `json:"manufacturer"` - Model string `json:"model"` - Firmware string `json:"firmware"` - RegisterTime int64 `json:"register_time"` + GBModel + DeviceID string `json:"device_id" gorm:"uniqueIndex"` + Name string `json:"name"` + RemoteAddr string `json:"remote_addr"` + Transport string `json:"transport"` + Status OnlineStatus `json:"status"` //在线状态 ON-在线/OFF-离线 + Manufacturer string `json:"manufacturer"` + Model string `json:"model"` + Firmware string `json:"firmware"` + RegisterTime time.Time `json:"register_time"` + LastHeartbeat time.Time `json:"last_heartbeat"` ChannelsTotal int `json:"total_channels"` // 通道总数 ChannelsOnline int `json:"online_channels"` // 通道在线数量 } func (d *Device) GetID() string { - return d.ID + return d.DeviceID } func (d *Device) Online() bool { @@ -122,14 +125,14 @@ func (d *Device) BuildMessageRequest(to, body string) sip.Request { } func (d *Device) QueryDeviceInfo() { - body := fmt.Sprintf(DeviceInfoFormat, "1", d.ID) - request := d.BuildMessageRequest(d.ID, body) + body := fmt.Sprintf(DeviceInfoFormat, "1", d.DeviceID) + request := d.BuildMessageRequest(d.DeviceID, body) SipUA.SendRequest(request) } func (d *Device) QueryCatalog() { - body := fmt.Sprintf(CatalogFormat, "1", d.ID) - request := d.BuildMessageRequest(d.ID, body) + body := fmt.Sprintf(CatalogFormat, "1", d.DeviceID) + request := d.BuildMessageRequest(d.DeviceID, body) SipUA.SendRequest(request) } @@ -146,7 +149,7 @@ func (d *Device) OnBye(request sip.Request) { func (d *Device) SubscribePosition(channelId string) error { if channelId == "" { - channelId = d.ID + channelId = d.DeviceID } //暂时不考虑级联 @@ -189,8 +192,8 @@ func (d *Device) UpdateChannel(id string, event string) { } func (d *Device) BuildCatalogRequest() (sip.Request, error) { - body := fmt.Sprintf(CatalogFormat, "1", d.ID) - request := d.BuildMessageRequest(d.ID, body) + body := fmt.Sprintf(CatalogFormat, "1", d.DeviceID) + request := d.BuildMessageRequest(d.DeviceID, body) return request, nil } @@ -247,7 +250,7 @@ func (d *Device) BuildInviteRequest(sessionName, channelId, ip string, port uint return nil, err } - var subjectHeader = Subject(channelId + ":" + d.ID + "," + Config.SipID + ":" + ssrc) + var subjectHeader = Subject(channelId + ":" + d.DeviceID + "," + Config.SipID + ":" + ssrc) request.AppendHeader(subjectHeader) return request, err @@ -268,22 +271,20 @@ func (d *Device) BuildDownloadRequest(channelId, ip string, port uint16, startTi func (d *Device) Close() { // 更新在数据库中的状态 d.Status = OFF - if err := DB.SaveDevice(d); err != nil { - Sugar.Errorf("更新设备在线状态失败 err: %s device: %s ", err.Error(), d.ID) - } + _ = DeviceDao.UpdateDeviceStatus(d.DeviceID, OFF) // 释放所有推流 - all := StreamManager.All() - var streams []*Stream - for _, stream := range all { - if d.ID == stream.ID.DeviceID() { - streams = append(streams, stream) - } - } - - for _, stream := range streams { - stream.Close(true, true) - } + //all := StreamManager.All() + //var streams []*Stream + //for _, stream := range all { + // if d.DeviceID == stream.StreamID.DeviceID() { + // streams = append(streams, stream) + // } + //} + // + //for _, stream := range streams { + // stream.Close(true, true) + //} } // CreateDialogRequestFromAnswer 根据invite的应答创建Dialog请求 diff --git a/device_manager.go b/device_manager.go deleted file mode 100644 index 7eb3902..0000000 --- a/device_manager.go +++ /dev/null @@ -1,54 +0,0 @@ -package main - -import ( - "fmt" - "sync" -) - -// DeviceManager 位于内存中的所有设备和通道 -var DeviceManager *deviceManager - -func init() { - DeviceManager = &deviceManager{} -} - -type deviceManager struct { - m sync.Map -} - -func (s *deviceManager) Add(device GBDevice) error { - _, ok := s.m.LoadOrStore(device.GetID(), device) - if ok { - return fmt.Errorf("the device %s has been exist", device.GetID()) - } - - return nil -} - -func (s *deviceManager) Find(id string) GBDevice { - value, ok := s.m.Load(id) - if ok { - return value.(GBDevice) - } - - return nil -} - -func (s *deviceManager) Remove(id string) GBDevice { - value, loaded := s.m.LoadAndDelete(id) - if loaded { - return value.(GBDevice) - } - - return nil -} - -func (s *deviceManager) All() []GBDevice { - var devices []GBDevice - s.m.Range(func(key, value any) bool { - devices = append(devices, value.(GBDevice)) - return true - }) - - return devices -} diff --git a/dialogs.go b/dialogs.go index dfc25e8..44e299c 100644 --- a/dialogs.go +++ b/dialogs.go @@ -1,16 +1,43 @@ package main import ( + "context" "fmt" "github.com/ghettovoice/gosip/sip" "github.com/ghettovoice/gosip/sip/parser" "sync" + "time" ) +var ( + Dialogs = NewDialogManager[*StreamWaiting]() +) + +type StreamWaiting struct { + onPublishCb chan int // 等待推流hook的管道 + cancelFunc func() // 取消等待推流hook的ctx + data interface{} +} + +func (s *StreamWaiting) Receive(seconds int) int { + s.onPublishCb = make(chan int, 0) + timeout, cancelFunc := context.WithTimeout(context.Background(), time.Duration(seconds)*time.Second) + s.cancelFunc = cancelFunc + select { + case code := <-s.onPublishCb: + return code + case <-timeout.Done(): + s.cancelFunc = nil + return -1 + } +} +func (s *StreamWaiting) Put(code int) { + s.onPublishCb <- code +} + type DialogManager[T any] struct { lock sync.RWMutex dialogs map[string]T - callIds map[string]T } func (d *DialogManager[T]) Add(id string, dialog T) (T, bool) { @@ -27,29 +54,12 @@ func (d *DialogManager[T]) Add(id string, dialog T) (T, bool) { return old, true } -func (d *DialogManager[T]) AddWithCallId(id string, dialog T) bool { - d.lock.Lock() - defer d.lock.Unlock() - if _, ok := d.callIds[id]; ok { - return false - } - - d.callIds[id] = dialog - return true -} - func (d *DialogManager[T]) Find(id string) T { d.lock.RLock() defer d.lock.RUnlock() return d.dialogs[id] } -func (d *DialogManager[T]) FindWithCallId(id string) T { - d.lock.RLock() - defer d.lock.RUnlock() - return d.callIds[id] -} - func (d *DialogManager[T]) Remove(id string) T { d.lock.Lock() defer d.lock.Unlock() @@ -58,36 +68,6 @@ func (d *DialogManager[T]) Remove(id string) T { return dialog } -func (d *DialogManager[T]) RemoveWithCallId(id string) T { - d.lock.Lock() - defer d.lock.Unlock() - dialog := d.callIds[id] - delete(d.callIds, id) - return dialog -} - -func (d *DialogManager[T]) All() []T { - d.lock.RLock() - defer d.lock.RUnlock() - var result []T - for _, v := range d.dialogs { - result = append(result, v) - } - return result -} - -func (d *DialogManager[T]) PopAll() []T { - d.lock.Lock() - defer d.lock.Unlock() - var result []T - for _, v := range d.dialogs { - result = append(result, v) - } - - d.dialogs = make(map[string]T) - return result -} - func UnmarshalDialog(dialog string) (sip.Request, error) { packetParser := parser.NewPacketParser(logger) message, err := packetParser.ParseMessage([]byte(dialog)) @@ -103,6 +83,5 @@ func UnmarshalDialog(dialog string) (sip.Request, error) { func NewDialogManager[T any]() *DialogManager[T] { return &DialogManager[T]{ dialogs: make(map[string]T), - callIds: make(map[string]T), } } diff --git a/go.mod b/go.mod index d7b8039..62687f0 100644 --- a/go.mod +++ b/go.mod @@ -9,19 +9,24 @@ require ( github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect go.uber.org/zap v1.27.0 golang.org/x/net v0.21.0 - golang.org/x/text v0.16.0 + golang.org/x/text v0.20.0 ) require ( github.com/BurntSushi/toml v1.4.0 // indirect github.com/discoviking/fsm v0.0.0-20150126104936-f4a273feecca // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/gobwas/ws v1.4.0 // indirect - github.com/lkmio/transport v0.0.0-20250417030743-a4180637cd01 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect github.com/tevino/abool v1.2.0 // indirect go.uber.org/multierr v1.10.0 // indirect @@ -29,13 +34,19 @@ require ( golang.org/x/sys v0.21.0 // indirect golang.org/x/term v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + modernc.org/libc v1.22.5 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/sqlite v1.23.1 // indirect ) require ( + github.com/glebarez/sqlite v1.11.0 github.com/gomodule/redigo v1.9.2 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 github.com/lkmio/avformat v0.0.0 + gorm.io/gorm v1.26.1 ) replace github.com/lkmio/avformat => ../avformat diff --git a/live.go b/live.go index 2d90958..100ca1e 100644 --- a/live.go +++ b/live.go @@ -40,28 +40,34 @@ func (i *InviteType) SessionName2Type(name string) { func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId, startTime, stopTime, setup string, speed int, sync bool) (*Stream, error) { stream := &Stream{ - ID: streamId, - CreateTime: time.Now().UnixMilli(), + StreamID: streamId, + Protocol: "28181", } // 先添加占位置, 防止重复请求 - if oldStream, b := StreamManager.Add(stream); !b { + oldStream, b := StreamDao.SaveStream(stream) + if !b { + if oldStream == nil { + return nil, fmt.Errorf("stream already exists") + } return oldStream, nil } dialog, urls, err := d.Invite(inviteType, streamId, channelId, startTime, stopTime, setup, speed) if err != nil { - StreamManager.Remove(streamId) + _, _ = StreamDao.DeleteStream(streamId) return nil, err } - stream.Dialog = dialog - callID, _ := dialog.CallID() - StreamManager.AddWithCallId(callID.Value(), stream) + stream.SetDialog(dialog) // 等待流媒体服务发送推流通知 wait := func() bool { - ok := http.StatusOK == stream.WaitForPublishEvent(10) + waiting := StreamWaiting{} + _, _ = Dialogs.Add(string(streamId), &waiting) + defer Dialogs.Remove(string(streamId)) + + ok := http.StatusOK == waiting.Receive(10) if !ok { Sugar.Infof("收流超时 发送bye请求...") CloseStream(streamId, true) @@ -78,10 +84,7 @@ func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId stream.urls = urls // 保存到数据库 - if DB != nil { - go DB.SaveStream(stream) - } - + _ = StreamDao.UpdateStream(stream) return stream, nil } diff --git a/live_benchmark_test.go b/live_benchmark_test.go index 6271909..a32f748 100644 --- a/live_benchmark_test.go +++ b/live_benchmark_test.go @@ -94,7 +94,7 @@ func startLiveAll(setup string) { //max := 50 //for _, device := range devices { // for _, channel := range device.Channels { - // go startLive(device.ID, channel.DeviceID, setup) + // go startLive(device.DeviceID, channel.DeviceID, setup) // max-- // if max < 1 { // return diff --git a/main.go b/main.go index 5ca6fe8..9390b0c 100644 --- a/main.go +++ b/main.go @@ -4,13 +4,15 @@ import ( "encoding/json" "go.uber.org/zap/zapcore" "net" + "net/http" + _ "net/http/pprof" "strconv" + "time" ) var ( Config *Config_ SipUA SipServer - DB GB28181DB ) func init() { @@ -36,18 +38,17 @@ func main() { indent, _ := json.MarshalIndent(Config, "", "\t") Sugar.Infof("server config:\r\n%s", indent) - DB = NewRedisDB(Config.Redis.Addr, Config.Redis.Password) + OnlineDeviceManager.Start(time.Duration(Config.AliveExpires)*time.Second/4, time.Duration(Config.AliveExpires)*time.Second, OnExpires) // 从数据库中恢复会话 - var streams []*Stream - var sinks []*Sink - if DB != nil { - // 查询在线设备, 更新设备在线状态 - updateDevicesStatus() + var streams map[string]*Stream + var sinks map[string]*Sink - // 恢复国标推流会话 - streams, sinks = recoverStreams() - } + // 查询在线设备, 更新设备在线状态 + updateDevicesStatus() + + // 恢复国标推流会话 + streams, sinks = recoverStreams() // 启动sip server server, err := StartSipServer(config.SipID, config.ListenIP, config.PublicIP, config.SipPort) @@ -61,11 +62,11 @@ func main() { // 在sip启动后, 关闭无效的流 for _, stream := range streams { - stream.Close(true, false) + stream.Bye() } for _, sink := range sinks { - sink.Close(true, false) + sink.Bye() } // 启动级联设备 @@ -73,5 +74,10 @@ func main() { httpAddr := net.JoinHostPort(config.ListenIP, strconv.Itoa(config.HttpPort)) Sugar.Infof("启动http server. addr: %s", httpAddr) - startApiServer(httpAddr) + go startApiServer(httpAddr) + + err = http.ListenAndServe(":19000", nil) + if err != nil { + println(err) + } } diff --git a/online_devices.go b/online_devices.go new file mode 100644 index 0000000..b0e4114 --- /dev/null +++ b/online_devices.go @@ -0,0 +1,60 @@ +package main + +import ( + "sync" + "time" +) + +var ( + OnlineDeviceManager = NewOnlineDeviceManager() +) + +type onlineDeviceManager struct { + lock sync.RWMutex + devices map[string]time.Time +} + +func (m *onlineDeviceManager) Add(deviceId string, t time.Time) { + m.lock.Lock() + defer m.lock.Unlock() + m.devices[deviceId] = t +} + +func (m *onlineDeviceManager) Remove(deviceId string) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.devices, deviceId) +} + +func (m *onlineDeviceManager) Find(deviceId string) (time.Time, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + t, ok := m.devices[deviceId] + return t, ok +} + +func (m *onlineDeviceManager) Start(interval time.Duration, keepalive time.Duration, OnExpires func(platformId int, deviceId string)) { + // 精度有偏差 + var timer *time.Timer + timer = time.AfterFunc(interval, func() { + now := time.Now() + m.lock.Lock() + defer m.lock.Unlock() + for deviceId, t := range m.devices { + if now.Sub(t) < keepalive { + continue + } + + delete(m.devices, deviceId) + go OnExpires(0, deviceId) + } + + timer.Reset(interval) + }) +} + +func NewOnlineDeviceManager() *onlineDeviceManager { + return &onlineDeviceManager{ + devices: make(map[string]time.Time), + } +} diff --git a/platform.go b/platform.go index da0a648..a995a03 100644 --- a/platform.go +++ b/platform.go @@ -73,14 +73,14 @@ func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response { platform := PlatformManager.Find(source) utils.Assert(platform != nil) - deviceId, channel, err := DB.QueryPlatformChannel(g.ServerAddr, user) + deviceId, channel, err := PlatformDao.QueryPlatformChannel(g.ServerAddr, user) if err != nil { Sugar.Errorf("级联转发失败, 查询数据库失败 err: %s platform: %s channel: %s", err.Error(), g.SeverID, user) return CreateResponseWithStatusCode(request, http.StatusInternalServerError) } // 查找通道对应的设备 - device := DeviceManager.Find(deviceId) + device, _ := DeviceDao.QueryDevice(deviceId) if device == nil { Sugar.Errorf("级联转发失败, 设备不存在 device: %s channel: %s", device, user) return CreateResponseWithStatusCode(request, http.StatusNotFound) @@ -115,12 +115,12 @@ func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response { break } - stream := StreamManager.Find(streamId) + stream, _ := StreamDao.QueryStream(streamId) addr := fmt.Sprintf("%s:%d", parse.Addr, media.Port) if stream == nil { s := channel.SetupType.String() println(s) - stream, err = device.(*Device).StartStream(inviteType, streamId, user, time[0], time[1], channel.SetupType.String(), 0, true) + stream, err = device.StartStream(inviteType, streamId, user, time[0], time[1], channel.SetupType.String(), 0, true) if err != nil { Sugar.Errorf("级联转发失败 err: %s stream: %s", err.Error(), streamId) return CreateResponseWithStatusCode(request, http.StatusBadRequest) @@ -148,13 +148,14 @@ func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response { setToTag(response) - AddForwardSink(streamId, &Sink{ - ID: sinkID, - Stream: streamId, + sink := &Sink{ + SinkID: sinkID, + StreamID: streamId, ServerAddr: g.ServerAddr, - Protocol: "gb_cascaded_forward", - Dialog: g.CreateDialogRequestFromAnswer(response, true)}) + Protocol: "gb_cascaded_forward"} + sink.SetDialog(g.CreateDialogRequestFromAnswer(response, true)) + AddForwardSink(streamId, sink) return response } @@ -175,7 +176,7 @@ func (g *GBPlatform) Stop() { func (g *GBPlatform) Online() { Sugar.Infof("级联设备上线 device: %s", g.SeverID) - if err := DB.UpdatePlatformStatus(g.SeverID, ON); err != nil { + if err := PlatformDao.UpdatePlatformStatus(g.SeverID, ON); err != nil { Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverID) } } @@ -183,7 +184,7 @@ func (g *GBPlatform) Online() { func (g *GBPlatform) Offline() { Sugar.Infof("级联设备离线 device: %s", g.SeverID) - if err := DB.UpdatePlatformStatus(g.SeverID, OFF); err != nil { + if err := PlatformDao.UpdatePlatformStatus(g.SeverID, OFF); err != nil { Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverID) } diff --git a/platform_manager.go b/platform_manager.go index 74db4ad..d8b4a21 100644 --- a/platform_manager.go +++ b/platform_manager.go @@ -67,23 +67,19 @@ func AddPlatform(platform *GBPlatform) error { return fmt.Errorf("平台添加失败, 地址冲突. addr: %s", platform.sipClient.ServerAddr) } - if DB != nil { - err := DB.SavePlatform(&platform.SIPUAParams) - if err != nil { - PlatformManager.Remove(platform.sipClient.ServerAddr) - return fmt.Errorf("平台保存到数据库失败, err: %s", err.Error()) - } + err := PlatformDao.SavePlatform(&platform.SIPUAParams) + if err != nil { + PlatformManager.Remove(platform.sipClient.ServerAddr) + return fmt.Errorf("平台保存到数据库失败, err: %s", err.Error()) } return nil } func RemovePlatform(addr string) (*GBPlatform, error) { - if DB != nil { - err := DB.DeletePlatform(addr) - if err != nil { - return nil, err - } + err := PlatformDao.DeletePlatform(addr) + if err != nil { + return nil, err } platform := PlatformManager.Remove(addr) @@ -113,12 +109,10 @@ func UpdatePlatformStatus(addr string, status OnlineStatus) error { //old := platform.Device.Status platform.Device.Status = status - if DB != nil { - err := DB.UpdatePlatformStatus(addr, status) - // platform.Device.Status = old - if err != nil { - return err - } + err := PlatformDao.UpdatePlatformStatus(addr, status) + // platform.Device.Status = old + if err != nil { + return err } return nil diff --git a/position.go b/position.go index 2a3989e..73c46c5 100644 --- a/position.go +++ b/position.go @@ -27,7 +27,7 @@ type MobilePositionNotify struct { func (d *Device) DoSubscribePosition(channelId string) error { if channelId == "" { - channelId = d.ID + channelId = d.DeviceID } //暂时不考虑级联 @@ -60,5 +60,5 @@ func (d *Device) DoSubscribePosition(channelId string) error { } func (d *Device) OnMobilePositionNotify(notify *MobilePositionNotify) { - Sugar.Infof("收到位置信息 device:%s data:%v", d.ID, notify) + Sugar.Infof("收到位置信息 device:%s data:%v", d.DeviceID, notify) } diff --git a/recover.go b/recover.go index 61b95d7..ee5167c 100644 --- a/recover.go +++ b/recover.go @@ -1,23 +1,25 @@ package main -import "github.com/lkmio/avformat/utils" +import ( + "github.com/lkmio/avformat/utils" + "time" +) // 启动级联设备 func startPlatformDevices() { - platforms, err := DB.LoadPlatforms() + platforms, err := PlatformDao.LoadPlatforms() if err != nil { Sugar.Errorf("查询级联设备失败 err: %s", err.Error()) return } - //streams := StreamManager.All() for _, record := range platforms { platform, err := NewGBPlatform(record, SipUA) // 都入库了不允许失败, 程序有BUG, 及时修复 utils.Assert(err == nil) utils.Assert(PlatformManager.Add(platform)) - if err := DB.UpdatePlatformStatus(record.ServerAddr, OFF); err != nil { + if err := PlatformDao.UpdatePlatformStatus(record.ServerAddr, OFF); err != nil { Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), record.SeverID) } @@ -26,7 +28,7 @@ func startPlatformDevices() { //for _, stream := range streams { // sinks := stream.GetForwardStreamSinks() // for _, sink := range sinks { - // if sink.ID != record.SeverID { + // if sink.DeviceID != record.SeverID { // continue // } // @@ -40,42 +42,28 @@ func startPlatformDevices() { } } -func closeStream(stream *Stream) { - DB.DeleteStream(stream.CreateTime) - // 删除转发sink - DB.DeleteForwardSinks(stream.ID) -} - // 返回需要关闭的推流源和转流Sink -func recoverStreams() ([]*Stream, []*Sink) { +func recoverStreams() (map[string]*Stream, map[string]*Sink) { // 比较数据库和流媒体服务器中的流会话, 以流媒体服务器中的为准, 释放过期的会话 // source id和stream id目前都是同一个id - dbStreams, err := DB.LoadStreams() + dbStreams, err := StreamDao.LoadStreams() if err != nil { Sugar.Errorf("恢复推流失败, 查询数据库发生错误. err: %s", err.Error()) return nil, nil - } else if len(dbStreams) < 1 { - return nil, nil } - var closedStreams []*Stream - var closedSinks []*Sink + dbSinks, _ := SinkDao.LoadForwardSinks() // 查询流媒体服务器中的推流源列表 - sources, err := QuerySourceList() + msSources, err := QuerySourceList() if err != nil { // 流媒体服务器崩了, 存在的所有记录都无效, 全部删除 Sugar.Warnf("恢复推流失败, 查询推流源列表发生错误, 删除数据库中的所有记录. err: %s", err.Error()) - - for _, stream := range dbStreams { - closedStreams = append(closedStreams, stream) - } - return closedStreams, nil } // 查询推流源下所有的转发sink列表 - msStreamSinks := make(map[string]map[string]string, len(sources)) - for _, source := range sources { + msStreamSinks := make(map[string]string, len(msSources)) + for _, source := range msSources { // 跳过非国标流 if "28181" != source.Protocol && "gb_talk" != source.Protocol { continue @@ -88,108 +76,61 @@ func recoverStreams() ([]*Stream, []*Sink) { continue } - stream, ok := dbStreams[source.ID] - if !ok { - Sugar.Warnf("流媒体中的流不存在于数据库中 source: %s", source.ID) - continue - } - - stream.SinkCount = int32(len(sinks)) - forwardSinks := make(map[string]string, len(sinks)) for _, sink := range sinks { if "gb_cascaded_forward" == sink.Protocol || "gb_talk_forward" == sink.Protocol { - forwardSinks[sink.ID] = "" + msStreamSinks[sink.ID] = source.ID } } - - msStreamSinks[source.ID] = forwardSinks } - // 遍历数据库中的流会话, 比较是否存在于流媒体服务器中, 不存在则删除 + for _, source := range msSources { + delete(dbStreams, source.ID) + } + + for key, _ := range msStreamSinks { + if dbSinks != nil { + delete(dbSinks, key) + } + } + + var invalidStreamIds []uint for _, stream := range dbStreams { - // 如果stream不存在于流媒体服务器中, 则删除 - msSinks, ok := msStreamSinks[string(stream.ID)] - if !ok { - Sugar.Infof("删除过期的推流会话 stream: %s", stream.ID) - closedStreams = append(closedStreams, stream) - continue - } - - // 查询stream下的转发sink列表 - dbSinks, err := DB.QueryForwardSinks(stream.ID) - if err != nil { - Sugar.Errorf("查询级联转发sink列表失败 err: %s", err.Error()) - } - - // 遍历数据库中的sink, 如果不存在于流媒体服务器中, 则删除 - for _, sink := range dbSinks { - _, ok := msSinks[sink.ID] - if ok { - // 恢复转发sink - AddForwardSink(sink.Stream, sink) - if sink.Protocol == "gb_talk_forward" { - SinkManager.AddWithSinkStreamId(sink) - } - } else { - Sugar.Infof("删除过期的级联转发会话 stream: %s sink: %s", stream.ID, sink.ID) - closedSinks = append(closedSinks, sink) - } - } - - Sugar.Infof("恢复推流会话 stream: %s", stream.ID) - - StreamManager.Add(stream) - if stream.Dialog != nil { - callId, _ := stream.Dialog.CallID() - StreamManager.AddWithCallId(callId.Value(), stream) - } + invalidStreamIds = append(invalidStreamIds, stream.ID) } - return closedStreams, closedSinks + var invalidSinkIds []uint + for _, sink := range dbSinks { + invalidSinkIds = append(invalidSinkIds, sink.ID) + } + + _ = StreamDao.DeleteStreamsByIds(invalidStreamIds) + _ = SinkDao.DeleteForwardSinksByIds(invalidSinkIds) + return dbStreams, dbSinks } // 更新设备的在线状态 func updateDevicesStatus() { - onlineDevices, err := DB.LoadOnlineDevices() - if err != nil { - panic(err) - } - - devices, err := DB.LoadDevices() + devices, err := DeviceDao.LoadDevices() if err != nil { panic(err) } else if len(devices) > 0 { - + now := time.Now() + var offlineDevices []string for key, device := range devices { - status := OFF - if _, ok := onlineDevices[key]; ok { - status = ON - } - - // 根据通道在线状态,统计通道总数和离线数量 - var total int - var online int - channels, _, err := DB.QueryChannels(key, 1, 0xFFFFFFFF) - if err != nil { - Sugar.Errorf("查询通道列表失败 err: %s device: %s", err.Error(), key) - } else { - total = len(channels) - for _, channel := range channels { - if channel.Online() { - online++ - } - } - } - - device.ChannelsTotal = total - device.ChannelsOnline = online - device.Status = status - if err = DB.SaveDevice(device); err != nil { - Sugar.Errorf("更新设备状态失败 device: %s status: %s", key, status) + if device.Status == OFF { + continue + } else if now.Sub(device.LastHeartbeat) < time.Duration(Config.AliveExpires)*time.Second { + OnlineDeviceManager.Add(key, device.LastHeartbeat) continue } - DeviceManager.Add(device) + offlineDevices = append(offlineDevices, key) + } + + if len(offlineDevices) > 0 { + if err = DeviceDao.UpdateOfflineDevices(offlineDevices); err != nil { + Sugar.Errorf("更新设备状态失败 device: %s", offlineDevices) + } } } } diff --git a/sink.go b/sink.go index c9468c2..bc02c73 100644 --- a/sink.go +++ b/sink.go @@ -8,28 +8,27 @@ import ( // Sink 国标级联转发流 type Sink struct { - ID string `json:"id"` // 流媒体服务器中的sink id - Stream StreamID `json:"stream"` // 推流ID - SinkStream StreamID `json:"sink_stream"` // 广播使用, 每个广播设备的唯一ID - Protocol string `json:"protocol,omitempty"` // 转发流协议, gb_cascaded_forward/gb_talk_forward - Dialog sip.Request `json:"dialog,omitempty"` - ServerAddr string `json:"server_addr,omitempty"` // 级联上级地址 - CreateTime int64 `json:"create_time"` - SetupType SetupType // 转发类型 - - StreamWaiting + GBModel + SinkID string `json:"sink_id"` // 流媒体服务器中的sink id + StreamID StreamID `json:"stream_id"` // 推流ID + SinkStreamID StreamID `json:"sink_stream_id"` // 广播使用, 每个广播设备的唯一ID + Protocol string `json:"protocol,omitempty"` // 转发流协议, gb_cascaded_forward/gb_talk_forward + Dialog *RequestWrapper `json:"dialog,omitempty"` + CallID string `json:"call_id,omitempty"` + ServerAddr string `json:"server_addr,omitempty"` // 级联上级地址 + CreateTime int64 `json:"create_time"` + SetupType SetupType // 转发类型 } // Close 关闭级联会话. 是否向上级发送bye请求, 是否通知流媒体服务器发送删除sink func (s *Sink) Close(bye, ms bool) { // 挂断与上级的sip会话 - if bye && s.Dialog != nil { - byeRequest := CreateRequestFromDialog(s.Dialog, sip.BYE) - go SipUA.SendRequest(byeRequest) + if bye { + s.Bye() } if ms { - go CloseSink(string(s.Stream), s.ID) + go CloseSink(string(s.StreamID), s.SinkID) } } @@ -49,6 +48,13 @@ func (s *Sink) MarshalJSON() ([]byte, error) { return json.Marshal(v) } +func (s *Sink) Bye() { + if s.Dialog != nil && s.Dialog.Request != nil { + byeRequest := CreateRequestFromDialog(s.Dialog.Request, sip.BYE) + go SipUA.SendRequest(byeRequest) + } +} + func (s *Sink) UnmarshalJSON(data []byte) error { type Alias Sink // 定义别名以避免递归调用 v := &struct { @@ -71,9 +77,15 @@ func (s *Sink) UnmarshalJSON(data []byte) error { Sugar.Errorf("json解析dialog失败, err: %s value: %s", err.Error(), v.Dialog) } else { request := message.(sip.Request) - s.Dialog = request + s.SetDialog(request) } } return nil } + +func (s *Sink) SetDialog(dialog sip.Request) { + s.Dialog = &RequestWrapper{dialog} + id, _ := dialog.CallID() + s.CallID = id.Value() +} diff --git a/sink_manager.go b/sink_manager.go index 5f81215..56c5f00 100644 --- a/sink_manager.go +++ b/sink_manager.go @@ -1,180 +1,16 @@ package main -import "sync" - -var ( - SinkManager = NewSinkManager() -) - -type sinkManager struct { - lock sync.RWMutex - streamSinks map[StreamID]map[string]*Sink // 推流id->sinks(sinkId->sink) - callIds map[string]*Sink // callId->sink - sinkStreamIds map[StreamID]*Sink // sinkStreamId->sink, 关联广播sink -} - -func (s *sinkManager) Add(sink *Sink) bool { - s.lock.Lock() - defer s.lock.Unlock() - - streamSinks, ok := s.streamSinks[sink.Stream] - if !ok { - streamSinks = make(map[string]*Sink) - s.streamSinks[sink.Stream] = streamSinks - } - - if sink.Dialog == nil { - return false - } - - callId, _ := sink.Dialog.CallID() - id := callId.Value() - if _, ok := s.callIds[id]; ok { - return false - } else if _, ok := streamSinks[sink.ID]; ok { - return false - } - - s.callIds[id] = sink - s.streamSinks[sink.Stream][sink.ID] = sink - return true -} - -func (s *sinkManager) AddWithSinkStreamId(sink *Sink) bool { - s.lock.Lock() - defer s.lock.Unlock() - if _, ok := s.sinkStreamIds[sink.SinkStream]; ok { - return false - } - s.sinkStreamIds[sink.SinkStream] = sink - return true -} - -func (s *sinkManager) Remove(stream StreamID, sinkID string) *Sink { - s.lock.Lock() - defer s.lock.Unlock() - if _, ok := s.streamSinks[stream]; !ok { - return nil - } - - sink, ok := s.streamSinks[stream][sinkID] - if !ok { - return nil - } - - s.removeSink(sink) - return sink -} - -func (s *sinkManager) RemoveWithCallId(callId string) *Sink { - s.lock.Lock() - defer s.lock.Unlock() - - if sink, ok := s.callIds[callId]; ok { - s.removeSink(sink) - return sink - } - - return nil -} - -func (s *sinkManager) removeSink(sink *Sink) { - delete(s.streamSinks[sink.Stream], sink.ID) - - if sink.Dialog != nil { - callID, _ := sink.Dialog.CallID() - delete(s.callIds, callID.Value()) - } - - if sink.SinkStream != "" { - delete(s.sinkStreamIds, sink.SinkStream) - } -} - -func (s *sinkManager) RemoveWithSinkStreamId(sinkStreamId StreamID) *Sink { - s.lock.Lock() - defer s.lock.Unlock() - if sink, ok := s.sinkStreamIds[sinkStreamId]; ok { - s.removeSink(sink) - return sink - } - - return nil -} - -func (s *sinkManager) Find(stream StreamID, sinkID string) *Sink { - s.lock.RLock() - defer s.lock.RUnlock() - if _, ok := s.streamSinks[stream]; !ok { - return nil - } - - sink, ok := s.streamSinks[stream][sinkID] - if !ok { - return nil - } - - return sink -} - -func (s *sinkManager) FindWithCallId(callId string) *Sink { - s.lock.RLock() - defer s.lock.RUnlock() - if sink, ok := s.callIds[callId]; ok { - return sink - } - - return nil -} - -func (s *sinkManager) FindWithSinkStreamId(sinkStreamId StreamID) *Sink { - s.lock.RLock() - defer s.lock.RUnlock() - if sink, ok := s.sinkStreamIds[sinkStreamId]; ok { - return sink - } - - return nil -} - -func (s *sinkManager) PopSinks(stream StreamID) []*Sink { - s.lock.Lock() - defer s.lock.Unlock() - if _, ok := s.streamSinks[stream]; !ok { - return nil - } - - var sinkList []*Sink - for _, sink := range s.streamSinks[stream] { - sinkList = append(sinkList, sink) - } - - for _, sink := range sinkList { - s.removeSink(sink) - } - - delete(s.streamSinks, stream) - return sinkList -} - func AddForwardSink(StreamID StreamID, sink *Sink) bool { - if !SinkManager.Add(sink) { - Sugar.Errorf("转发Sink添加失败, StreamID: %s SinkID: %s", StreamID, sink.ID) + if err := SinkDao.SaveForwardSink(StreamID, sink); err != nil { + Sugar.Errorf("保存sink到数据库失败, stream: %s sink: %s err: %s", StreamID, sink.SinkID, err.Error()) return false } - if DB != nil { - err := DB.SaveForwardSink(StreamID, sink) - if err != nil { - Sugar.Errorf("转发Sink保存到数据库失败, err: %s", err.Error()) - } - } - return true } func RemoveForwardSink(StreamID StreamID, sinkID string) *Sink { - sink := SinkManager.Remove(StreamID, sinkID) + sink, _ := SinkDao.DeleteForwardSink(StreamID, sinkID) if sink == nil { return nil } @@ -184,7 +20,7 @@ func RemoveForwardSink(StreamID StreamID, sinkID string) *Sink { } func RemoveForwardSinkWithCallId(callId string) *Sink { - sink := SinkManager.RemoveWithCallId(callId) + sink, _ := SinkDao.DeleteForwardSinkByCallID(callId) if sink == nil { return nil } @@ -193,8 +29,8 @@ func RemoveForwardSinkWithCallId(callId string) *Sink { return sink } -func RemoveForwardSinkWithSinkStreamId(sinkStreamId StreamID) *Sink { - sink := SinkManager.RemoveWithSinkStreamId(sinkStreamId) +func RemoveForwardSinkWithSinkStreamID(sinkStreamId StreamID) *Sink { + sink, _ := SinkDao.DeleteForwardSinkBySinkStreamID(sinkStreamId) if sink == nil { return nil } @@ -204,17 +40,10 @@ func RemoveForwardSinkWithSinkStreamId(sinkStreamId StreamID) *Sink { } func releaseSink(sink *Sink) { - if DB != nil { - err := DB.DeleteForwardSink(sink.Stream, sink.ID) - if err != nil { - Sugar.Errorf("删除转发Sink失败, err: %s", err.Error()) - } - } - // 减少拉流计数 - if stream := StreamManager.Find(sink.Stream); stream != nil { - stream.DecreaseSinkCount() - } + //if stream := StreamManager.Find(sink.StreamID); stream != nil { + // stream.DecreaseSinkCount() + //} } func closeSink(sink *Sink, bye, ms bool) { @@ -235,40 +64,10 @@ func closeSink(sink *Sink, bye, ms bool) { } func CloseStreamSinks(StreamID StreamID, bye, ms bool) []*Sink { - sinks := SinkManager.PopSinks(StreamID) - + sinks, _ := SinkDao.DeleteForwardSinksByStreamID(StreamID) for _, sink := range sinks { closeSink(sink, bye, ms) } - // 查询数据库中的残余sink - if DB != nil { - // 恢复级联转发sink - forwardSinks, _ := DB.QueryForwardSinks(StreamID) - for _, sink := range forwardSinks { - closeSink(sink, bye, ms) - } - } - - // 删除整个转发流 - if DB != nil { - err := DB.Del(ForwardSinksKey(string(StreamID))) - if err != nil { - Sugar.Errorf("删除转发Sink失败, err: %s", err.Error()) - } - } - return sinks } - -func FindSink(StreamID StreamID, sinkID string) *Sink { - return SinkManager.Find(StreamID, sinkID) -} - -func NewSinkManager() *sinkManager { - return &sinkManager{ - streamSinks: make(map[StreamID]map[string]*Sink), - callIds: make(map[string]*Sink), - sinkStreamIds: make(map[StreamID]*Sink), - } -} diff --git a/sip_client.go b/sip_client.go index 5f1cd5b..cf8de60 100644 --- a/sip_client.go +++ b/sip_client.go @@ -40,15 +40,19 @@ type SipClient interface { } type SIPUAParams struct { - Username string `json:"username"` // 用户名 - SeverID string `json:"server_id"` // 上级ID, 必选. 作为主键, 不能重复. - ServerAddr string `json:"server_addr"` // 上级地址, 必选 - Transport string `json:"transport"` // 上级通信方式, UDP/TCP - Password string `json:"password"` // 密码 - RegisterExpires int `json:"register_expires"` // 注册有效期 - KeepAliveInterval int `json:"keep_alive_interval"` // 心跳间隔 - CreateTime string `json:"create_time"` // 入库时间 - Status OnlineStatus `json:"status"` // 在线状态 + GBModel + Username string `json:"username"` // 用户名 + SeverID string `json:"server_id"` // 上级ID, 必选. 作为主键, 不能重复. + ServerAddr string `json:"server_addr"` // 上级地址, 必选 + Transport string `json:"transport"` // 上级通信方式, UDP/TCP + Password string `json:"password"` // 密码 + RegisterExpires int `json:"register_expires"` // 注册有效期 + KeepaliveInterval int `json:"keepalive_interval"` // 心跳间隔 + Status OnlineStatus `json:"status"` // 在线状态 +} + +func (g *SIPUAParams) TableName() string { + return "lkm_virtual_device" } type sipClient struct { @@ -249,7 +253,7 @@ func (g *sipClient) Refresh() time.Duration { } // 信令正常, 休眠心跳间隔时长 - return time.Duration(g.KeepAliveInterval) * time.Second + return time.Duration(g.KeepaliveInterval) * time.Second } func (g *sipClient) Start() { diff --git a/sip_handler.go b/sip_handler.go index 2160fff..f76c5a3 100644 --- a/sip_handler.go +++ b/sip_handler.go @@ -13,11 +13,11 @@ type Handler interface { OnKeepAlive(id string) bool - OnCatalog(device GBDevice, response *CatalogResponse) + OnCatalog(device string, response *CatalogResponse) - OnRecord(device GBDevice, response *QueryRecordInfoResponse) + OnRecord(device string, response *QueryRecordInfoResponse) - OnDeviceInfo(device GBDevice, response *DeviceInfoResponse) + OnDeviceInfo(device string, response *DeviceInfoResponse) OnNotifyPosition(notify *MobilePositionNotify) } @@ -26,73 +26,40 @@ type EventHandler struct { } func (e *EventHandler) OnUnregister(id string) { - device := DeviceManager.Find(id) - if device != nil { - device.(*Device).Status = OFF - } - - if DB != nil { - _ = DB.SaveDevice(device.(*Device)) - } + _ = DeviceDao.UpdateDeviceStatus(id, OFF) } func (e *EventHandler) OnRegister(id, transport, addr string) (int, GBDevice, bool) { - var device *Device - old := DeviceManager.Find(id) - - if old != nil { - old.(*Device).ID = id - old.(*Device).Transport = transport - old.(*Device).RemoteAddr = addr - - device = old.(*Device) - } else { - device = &Device{ - ID: id, - Transport: transport, - RemoteAddr: addr, - } - - DeviceManager.Add(device) + now := time.Now() + device := &Device{ + DeviceID: id, + Transport: transport, + RemoteAddr: addr, + Status: ON, + RegisterTime: now, + LastHeartbeat: now, } - device.Status = ON - device.RegisterTime = time.Now().UnixMilli() - if DB != nil { - if err := DB.SaveDevice(device); err != nil { - Sugar.Errorf("保存设备信息到数据库失败 device: %s err: %s", id, err.Error()) - } + if err := DeviceDao.SaveDevice(device); err != nil { + Sugar.Errorf("保存设备信息到数据库失败 device: %s err: %s", id, err.Error()) } - return 3600, device, device.ChannelsTotal < 1 + count, _ := ChannelDao.QueryChanelCount(id) + return 3600, device, count < 1 } -func (e *EventHandler) OnKeepAlive(id string) bool { - device := DeviceManager.Find(id) - if device == nil { - Sugar.Errorf("更新心跳失败, 设备不存在. device: %s", id) +func (e *EventHandler) OnKeepAlive(id string, addr string) bool { + now := time.Now() + if err := DeviceDao.RefreshHeartbeat(id, now, addr); err != nil { + Sugar.Errorf("更新有效期失败. device: %s err: %s", id, err.Error()) return false } - if !device.(*Device).Online() { - Sugar.Errorf("更新心跳失败, 设备离线. device: %s", id) - } - - if DB != nil { - if err := DB.RefreshHeartbeat(id); err != nil { - Sugar.Errorf("更新有效期失败. device: %s err: %s", id, err.Error()) - } - } - + OnlineDeviceManager.Add(id, now) return true } -func (e *EventHandler) OnCatalog(device GBDevice, response *CatalogResponse) { - if DB == nil { - return - } - - id := device.GetID() +func (e *EventHandler) OnCatalog(device string, response *CatalogResponse) { for _, channel := range response.DeviceList.Devices { // 状态转为大写 channel.Status = OnlineStatus(strings.ToUpper(channel.Status.String())) @@ -102,34 +69,13 @@ func (e *EventHandler) OnCatalog(device GBDevice, response *CatalogResponse) { channel.Status = ON } - // 判断之前是否已经存在通道, 如果不存在累加总数 - old, _ := DB.QueryChannel(id, channel.DeviceID) - - if err := DB.SaveChannel(id, channel); err != nil { + if err := ChannelDao.SaveChannel(device, channel); err != nil { Sugar.Infof("保存通道到数据库失败 err: %s", err.Error()) } - - if old == nil { - device.(*Device).ChannelsTotal++ - device.(*Device).ChannelsOnline++ - } else if old.Status != channel.Status { - // 保留处理其他状态 - if ON == channel.Status { - device.(*Device).ChannelsOnline++ - } else if OFF == channel.Status { - device.(*Device).ChannelsOnline-- - } else { - return - } - } - - if err := DB.SaveDevice(device.(*Device)); err != nil { - Sugar.Errorf("更新设备在线数失败 err: %s", err.Error()) - } } } -func (e *EventHandler) OnRecord(device GBDevice, response *QueryRecordInfoResponse) { +func (e *EventHandler) OnRecord(device string, response *QueryRecordInfoResponse) { event := SNManager.FindEvent(response.SN) if event == nil { Sugar.Errorf("处理录像查询响应失败 SN: %d", response.SN) @@ -139,16 +85,14 @@ func (e *EventHandler) OnRecord(device GBDevice, response *QueryRecordInfoRespon event(response) } -func (e *EventHandler) OnDeviceInfo(device GBDevice, response *DeviceInfoResponse) { - device.(*Device).Manufacturer = response.Manufacturer - device.(*Device).Model = response.Model - device.(*Device).Firmware = response.Firmware - device.(*Device).Name = response.DeviceName - - if DB != nil { - if err := DB.SaveDevice(device.(*Device)); err != nil { - Sugar.Errorf("保存设备信息到数据库失败 device: %s err: %s", device.GetID(), err.Error()) - } +func (e *EventHandler) OnDeviceInfo(device string, response *DeviceInfoResponse) { + if err := DeviceDao.UpdateDeviceInfo(device, &Device{ + Manufacturer: response.Manufacturer, + Model: response.Model, + Firmware: response.Firmware, + Name: response.DeviceName, + }); err != nil { + Sugar.Errorf("保存设备信息到数据库失败 device: %s err: %s", device, err.Error()) } } diff --git a/sip_server.go b/sip_server.go index c87c149..6fefdc4 100644 --- a/sip_server.go +++ b/sip_server.go @@ -133,9 +133,9 @@ func (s *sipServer) OnInvite(req sip.Request, tx sip.ServerTransaction, parent b if parent { // 级联设备 device = PlatformManager.Find(req.Source()) - } else if session := BroadcastDialogs.Find(user); session != nil { + } else if session := Dialogs.Find(user); session != nil { // 语音广播设备 - device = DeviceManager.Find(session.SinkStream.DeviceID()) + device, _ = DeviceDao.QueryDevice(session.data.(*Sink).SinkStreamID.DeviceID()) } else { // 根据Subject头域查找设备 headers := req.GetHeaders("Subject") @@ -143,7 +143,7 @@ func (s *sipServer) OnInvite(req sip.Request, tx sip.ServerTransaction, parent b subject := headers[0].(*sip.GenericHeader) split := strings.Split(strings.Split(subject.Value(), ",")[0], ":") if len(split) > 1 { - device = DeviceManager.Find(split[1]) + device, _ = DeviceDao.QueryDevice(split[1]) } } } @@ -169,14 +169,12 @@ func (s *sipServer) OnBye(req sip.Request, tx sip.ServerTransaction, parent bool id, _ := req.CallID() var deviceId string - if stream := StreamManager.RemoveWithCallId(id.Value()); stream != nil { + if stream, _ := StreamDao.DeleteStreamByCallID(id.Value()); stream != nil { // 下级设备挂断, 关闭流 - deviceId = stream.ID.DeviceID() + deviceId = stream.StreamID.DeviceID() stream.Close(false, true) - } else if session := StreamManager.RemoveWithCallId(id.Value()); session != nil { - // 广播挂断 - deviceId = session.ID.DeviceID() - session.Close(false, true) + } else if sink, _ := SinkDao.DeleteForwardSinkByCallID(id.Value()); sink != nil { + sink.Close(false, true) } if parent { @@ -184,7 +182,7 @@ func (s *sipServer) OnBye(req sip.Request, tx sip.ServerTransaction, parent bool if platform := PlatformManager.Find(req.Source()); platform != nil { platform.OnBye(req) } - } else if device := DeviceManager.Find(deviceId); device != nil { + } else if device, _ := DeviceDao.QueryDevice(deviceId); device != nil { device.OnBye(req) } } @@ -199,9 +197,7 @@ func (s *sipServer) OnNotify(req sip.Request, tx sip.ServerTransaction, parent b return } - if device := DeviceManager.Find(mobilePosition.DeviceID); device != nil { - s.handler.OnNotifyPosition(&mobilePosition) - } + s.handler.OnNotifyPosition(&mobilePosition) } func (s *sipServer) OnMessage(req sip.Request, tx sip.ServerTransaction, parent bool) { @@ -233,62 +229,50 @@ func (s *sipServer) OnMessage(req sip.Request, tx sip.ServerTransaction, parent } // 查找设备 - var device GBDevice deviceId := message.(BaseMessageGetter).GetDeviceID() if CmdBroadcast == cmd { // 广播消息 from, _ := req.From() deviceId = from.Address.User().String() } - if parent { - device = PlatformManager.Find(req.Source()) - } else { - device = DeviceManager.Find(deviceId) - } - - if ok = device != nil; !ok { - Sugar.Errorf("处理XML消息失败, 设备离线: %s request: %s", deviceId, req.String()) - return - } switch xmlName { case XmlNameControl: break case XmlNameQuery: // 被上级查询 - var client GBClient - client, ok = device.(GBClient) - if !ok { - Sugar.Errorf("处理XML消息失败, 类型转换失败. request: %s", req.String()) + device := PlatformManager.Find(req.Source()) + if ok = device != nil; !ok { + Sugar.Errorf("处理上级请求消息失败, 找不到级联设备 addr: %s request: %s", req.Source(), req.String()) return } if CmdDeviceInfo == cmd { - client.OnQueryDeviceInfo(message.(*BaseMessage).SN) + device.OnQueryDeviceInfo(message.(*BaseMessage).SN) } else if CmdCatalog == cmd { var channels []*Channel // 查询出所有通道 - if DB != nil { - result, err := DB.QueryPlatformChannels(client.(*GBPlatform).ServerAddr) + if PlatformDao != nil { + result, err := PlatformDao.QueryPlatformChannels(device.ServerAddr) if err != nil { - Sugar.Errorf("查询设备通道列表失败 err: %s device: %s", err.Error(), client.GetID()) + Sugar.Errorf("查询设备通道列表失败 err: %s device: %s", err.Error(), device.GetID()) } channels = result } else { // 从模拟多个国标客户端中查找 - channels = DeviceChannelsManager.FindChannels(client.GetID()) + channels = DeviceChannelsManager.FindChannels(device.GetID()) } - client.OnQueryCatalog(message.(*BaseMessage).SN, channels) + device.OnQueryCatalog(message.(*BaseMessage).SN, channels) } break case XmlNameNotify: if CmdKeepalive == cmd { // 下级设备心跳通知 - ok = s.handler.OnKeepAlive(deviceId) + ok = s.handler.OnKeepAlive(deviceId, req.Source()) } break @@ -296,11 +280,11 @@ func (s *sipServer) OnMessage(req sip.Request, tx sip.ServerTransaction, parent case XmlNameResponse: // 查询下级的应答 if CmdCatalog == cmd { - go s.handler.OnCatalog(device, message.(*CatalogResponse)) + s.handler.OnCatalog(deviceId, message.(*CatalogResponse)) } else if CmdRecordInfo == cmd { - go s.handler.OnRecord(device, message.(*QueryRecordInfoResponse)) + s.handler.OnRecord(deviceId, message.(*QueryRecordInfoResponse)) } else if CmdDeviceInfo == cmd { - go s.handler.OnDeviceInfo(device, message.(*DeviceInfoResponse)) + s.handler.OnDeviceInfo(deviceId, message.(*DeviceInfoResponse)) } break diff --git a/stream.go b/stream.go index a564b57..7244418 100644 --- a/stream.go +++ b/stream.go @@ -1,12 +1,12 @@ package main import ( - "context" + "database/sql/driver" "encoding/json" + "errors" "github.com/ghettovoice/gosip/sip" "github.com/ghettovoice/gosip/sip/parser" "sync/atomic" - "time" ) type SetupType int @@ -34,34 +34,50 @@ func (s SetupType) String() string { panic("invalid setup type") } -type StreamWaiting struct { - onPublishCb chan int // 等待推流hook的管道 - cancelFunc func() // 取消等待推流hook的ctx +// RequestWrapper sql序列化 +type RequestWrapper struct { + sip.Request } -func (s *StreamWaiting) WaitForPublishEvent(seconds int) int { - s.onPublishCb = make(chan int, 0) - timeout, cancelFunc := context.WithTimeout(context.Background(), time.Duration(seconds)*time.Second) - s.cancelFunc = cancelFunc - select { - case code := <-s.onPublishCb: - return code - case <-timeout.Done(): - s.cancelFunc = nil - return -1 +func (r *RequestWrapper) Value() (driver.Value, error) { + if r == nil || r.Request == nil { + return "", nil } + + return r.Request.String(), nil +} + +func (r *RequestWrapper) Scan(value interface{}) error { + if value == nil { + return nil + } + + data, ok := value.(string) + if !ok { + return errors.New("invalid type for RequestWrapper") + } else if data == "" { + return nil + } + + dialog, err := UnmarshalDialog(data) + if err != nil { + return err + } + + *r = RequestWrapper{dialog} + return nil } type Stream struct { - ID StreamID `json:"id"` // 流ID - Protocol string `json:"protocol,omitempty"` // 推流协议, rtmp/28181/1078/gb_talk - Dialog sip.Request `json:"dialog,omitempty"` // 国标流的SipCall会话 - CreateTime int64 `json:"create_time"` // 推流时间 - SinkCount int32 `json:"sink_count"` // 拉流端计数(包含级联转发) - SetupType SetupType + GBModel + StreamID StreamID `json:"stream_id"` // 流ID + Protocol string `json:"protocol,omitempty"` // 推流协议, rtmp/28181/1078/gb_talk + Dialog *RequestWrapper `json:"dialog,omitempty"` // 国标流的SipCall会话 + SinkCount int32 `json:"sink_count"` // 拉流端计数(包含级联转发) + SetupType SetupType + CallID string `json:"call_id"` urls []string // 从流媒体服务器返回的拉流地址 - StreamWaiting } func (s *Stream) MarshalJSON() ([]byte, error) { @@ -102,53 +118,61 @@ func (s *Stream) UnmarshalJSON(data []byte) error { Sugar.Errorf("json解析dialog失败, err: %s value: %s", err.Error(), v.Dialog) } else { request := message.(sip.Request) - s.Dialog = request + s.SetDialog(request) } } return nil } +func (s *Stream) SetDialog(dialog sip.Request) { + s.Dialog = &RequestWrapper{dialog} + id, _ := dialog.CallID() + s.CallID = id.Value() +} + func (s *Stream) GetSinkCount() int32 { return atomic.LoadInt32(&s.SinkCount) } func (s *Stream) IncreaseSinkCount() int32 { value := atomic.AddInt32(&s.SinkCount, 1) - Sugar.Infof("拉流计数: %d stream: %s ", value, s.ID) + //Sugar.Infof("拉流计数: %d stream: %s ", value, s.StreamID) // 启动协程去更新拉流计数, 可能会不一致 - go DB.SaveStream(s) + //go StreamDao.SaveStream(s) return value } func (s *Stream) DecreaseSinkCount() int32 { value := atomic.AddInt32(&s.SinkCount, -1) - Sugar.Infof("拉流计数: %d stream: %s ", value, s.ID) - go DB.SaveStream(s) + //Sugar.Infof("拉流计数: %d stream: %s ", value, s.StreamID) + //go StreamDao.SaveStream(s) return value } func (s *Stream) Close(bye, ms bool) { - if s.cancelFunc != nil { - s.cancelFunc() - } - // 断开与推流通道的sip会话 - if bye && s.Dialog != nil { - go SipUA.SendRequest(s.CreateRequestFromDialog(sip.BYE)) - s.Dialog = nil + if bye { + s.Bye() } if ms { // 告知媒体服务释放source - go CloseSource(string(s.ID)) + go CloseSource(string(s.StreamID)) } // 关闭所转发会话 - CloseStreamSinks(s.ID, bye, ms) + CloseStreamSinks(s.StreamID, bye, ms) // 从数据库中删除流记录 - DB.DeleteStream(s.CreateTime) + _, _ = StreamDao.DeleteStream(s.StreamID) +} + +func (s *Stream) Bye() { + if s.Dialog != nil && s.Dialog.Request != nil { + go SipUA.SendRequest(s.CreateRequestFromDialog(sip.BYE)) + s.Dialog = nil + } } func CreateRequestFromDialog(dialog sip.Request, method sip.RequestMethod) sip.Request { @@ -169,8 +193,8 @@ func (s *Stream) CreateRequestFromDialog(method sip.RequestMethod) sip.Request { } func CloseStream(streamId StreamID, ms bool) { - stream := StreamManager.Remove(streamId) - if stream != nil { - stream.Close(true, ms) + deleteStream, err := StreamDao.DeleteStream(streamId) + if err == nil { + deleteStream.Close(true, ms) } } diff --git a/stream_manager.go b/stream_manager.go deleted file mode 100644 index 55d0b51..0000000 --- a/stream_manager.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "sync" -) - -var StreamManager *streamManager - -func init() { - StreamManager = NewStreamManager() -} - -type streamManager struct { - streams map[StreamID]*Stream - callIds map[string]*Stream // CallID关联Stream, 实际推流通道的会话callid和级联转发的callid都会指向Stream - lock sync.RWMutex -} - -// Add 添加Stream -// 如果Stream已经存在, 返回oldStream与false -func (s *streamManager) Add(stream *Stream) (*Stream, bool) { - s.lock.Lock() - defer s.lock.Unlock() - - old, ok := s.streams[stream.ID] - if ok { - return old, false - } - - s.streams[stream.ID] = stream - return nil, true -} - -func (s *streamManager) AddWithCallId(id string, stream *Stream) bool { - s.lock.Lock() - defer s.lock.Unlock() - - if _, ok := s.callIds[id]; ok { - return false - } - - s.callIds[id] = stream - return true -} - -func (s *streamManager) Find(id StreamID) *Stream { - s.lock.RLock() - defer s.lock.RUnlock() - - if value, ok := s.streams[id]; ok { - return value - } - return nil -} - -func (s *streamManager) FindWithCallId(id string) *Stream { - s.lock.RLock() - defer s.lock.RUnlock() - - if value, ok := s.callIds[id]; ok { - return value - } - return nil -} - -func (s *streamManager) Remove(id StreamID) *Stream { - s.lock.Lock() - defer s.lock.Unlock() - - stream, ok := s.streams[id] - delete(s.streams, id) - if ok && stream.Dialog != nil { - callID, _ := stream.Dialog.CallID() - delete(s.callIds, callID.Value()) - } - - return stream -} - -func (s *streamManager) RemoveWithCallId(id string) *Stream { - s.lock.Lock() - defer s.lock.Unlock() - - stream, ok := s.callIds[id] - if ok { - delete(s.callIds, id) - delete(s.streams, stream.ID) - return stream - } - - return nil -} - -func (s *streamManager) All() []*Stream { - s.lock.Lock() - defer s.lock.Unlock() - var streams []*Stream - - for _, stream := range s.streams { - streams = append(streams, stream) - } - - return streams -} - -func (s *streamManager) PopAll() []*Stream { - s.lock.Lock() - defer s.lock.Unlock() - var streams []*Stream - - for _, stream := range s.streams { - streams = append(streams, stream) - } - - s.streams = make(map[StreamID]*Stream) - s.callIds = make(map[string]*Stream) - return streams -} - -func NewStreamManager() *streamManager { - return &streamManager{ - streams: make(map[StreamID]*Stream, 64), - callIds: make(map[string]*Stream, 64), - } -} diff --git a/xml.go b/xml.go index cc494fd..2dbf5b9 100644 --- a/xml.go +++ b/xml.go @@ -1,31 +1,43 @@ package main -import "encoding/xml" +import ( + "encoding/xml" + "time" +) + +// GBModel 解决与Device和Channel的Model变量名冲突 +type GBModel struct { + //gorm.Model + ID uint `gorm:"primarykey"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"-"` +} type Channel struct { - DeviceID string `xml:"DeviceID"` - Name string `xml:"Name,omitempty"` - Manufacturer string `xml:"Manufacturer,omitempty"` - Model string `xml:"Model,omitempty"` - Owner string `xml:"Owner,omitempty"` - CivilCode string `xml:"CivilCode,omitempty"` - Block string `xml:"Block,omitempty"` - Address string `xml:"Address,omitempty"` - Parental string `xml:"Parental,omitempty"` - ParentID string `xml:"ParentID,omitempty"` - SafetyWay string `xml:"SafetyWay,omitempty"` - RegisterWay string `xml:"RegisterWay,omitempty"` - CertNum string `xml:"CertNum,omitempty"` - Certifiable string `xml:"Certifiable,omitempty"` - ErrCode string `xml:"ErrCode,omitempty"` - EndTime string `xml:"EndTime,omitempty"` - Secrecy string `xml:"Secrecy,omitempty"` - IPAddress string `xml:"IPAddress,omitempty"` - Port string `xml:"Port,omitempty"` - Password string `xml:"Password,omitempty"` - Status OnlineStatus `xml:"Status,omitempty"` - Longitude string `xml:"Longitude,omitempty"` - Latitude string `xml:"Latitude,omitempty"` + GBModel + DeviceID string `json:"device_id" xml:"DeviceID" gorm:"index"` + Name string `json:"name" xml:"Name,omitempty"` + Manufacturer string `json:"manufacturer" xml:"Manufacturer,omitempty"` + Model string `json:"model" xml:"Model,omitempty"` + Owner string `json:"owner" xml:"Owner,omitempty"` + CivilCode string `json:"civil_code" xml:"CivilCode,omitempty"` + Block string `json:"block" xml:"Block,omitempty"` + Address string `json:"address" xml:"Address,omitempty"` + Parental string `json:"parental" xml:"Parental,omitempty"` + ParentID string `json:"parent_id" xml:"ParentID,omitempty" gorm:"index"` + SafetyWay string `json:"safety_way" xml:"SafetyWay,omitempty"` + RegisterWay string `json:"register_way" xml:"RegisterWay,omitempty"` + CertNum string `json:"cert_num" xml:"CertNum,omitempty"` + Certifiable string `json:"certifiable" xml:"Certifiable,omitempty"` + ErrCode string `json:"err_code" xml:"ErrCode,omitempty"` + EndTime string `json:"end_time" xml:"EndTime,omitempty"` + Secrecy string `json:"secrecy" xml:"Secrecy,omitempty"` + IPAddress string `json:"ip_address" xml:"IPAddress,omitempty"` + Port string `json:"port" xml:"Port,omitempty"` + Password string `json:"password" xml:"Password,omitempty"` + Status OnlineStatus `json:"status" xml:"Status,omitempty"` + Longitude string `json:"longitude" xml:"Longitude,omitempty"` + Latitude string `json:"latitude" xml:"Latitude,omitempty"` SetupType SetupType `json:"setup_type,omitempty"` }