From 9b663cb7b5fc0485af9435d240e6323e206c760a Mon Sep 17 00:00:00 2001 From: ydajiang Date: Wed, 24 Sep 2025 18:22:02 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=80=82=E9=85=8Dlivegbs=E4=B8=80?= =?UTF-8?q?=E5=AF=B9=E4=B8=80=E5=AF=B9=E8=AE=B2=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.go | 264 +++++++++++++------------------------------ common/http_proxy.go | 21 ++++ dao/channel.go | 9 ++ dao/device.go | 13 +++ dao/sink.go | 47 +++++--- go.mod | 1 + main.go | 3 + recover.go | 4 +- stack/broadcast.go | 103 +++++++++++++++-- stack/platform.go | 4 +- stack/sink.go | 9 +- stack/sip_server.go | 12 +- stack/stream.go | 2 +- 13 files changed, 267 insertions(+), 225 deletions(-) diff --git a/api.go b/api.go index 62a5074..0e13d73 100644 --- a/api.go +++ b/api.go @@ -21,7 +21,6 @@ import ( "os" "strconv" "strings" - "sync" "time" ) @@ -261,7 +260,7 @@ func startApiServer(addr string) { apiServer.router.HandleFunc("/api/v1/broadcast/invite", common.WithJsonResponse(apiServer.OnBroadcast, &BroadcastParams{Setup: &common.DefaultSetupType})) // 发起语音广播 apiServer.router.HandleFunc("/api/v1/broadcast/hangup", common.WithJsonResponse(apiServer.OnHangup, &BroadcastParams{})) // 挂断广播会话 - apiServer.router.HandleFunc("/api/v1/control/ws-talk/{device}/{channel}", withVerify(apiServer.OnTalk)) // 语音对讲 + apiServer.router.HandleFunc("/api/v1/control/ws-talk/{device}/{channel}", withVerify(apiServer.OnTalk)) // 一对一语音对讲 apiServer.router.HandleFunc("/api/v1/jt/device/add", common.WithJsonResponse(apiServer.OnVirtualDeviceAdd, &dao.JTDeviceModel{})) apiServer.router.HandleFunc("/api/v1/jt/device/edit", common.WithJsonResponse(apiServer.OnVirtualDeviceEdit, &dao.JTDeviceModel{})) @@ -379,7 +378,7 @@ func (api *ApiServer) OnPlay(params *PlayDoneParams, w http.ResponseWriter, r *h } else if stream, _ := dao.Stream.QueryStream(params.Stream); stream == nil { w.WriteHeader(http.StatusNotFound) } else { - _ = dao.Sink.SaveForwardSink(&dao.SinkModel{ + _ = dao.Sink.CreateSink(&dao.SinkModel{ SinkID: params.Sink, StreamID: params.Stream, Protocol: params.Protocol, @@ -387,9 +386,13 @@ func (api *ApiServer) OnPlay(params *PlayDoneParams, w http.ResponseWriter, r *h }) } return + } else if stack.TransStreamGBTalk == params.Protocol { + // 对讲/广播 + w.WriteHeader(http.StatusOK) + return } - // 对讲/级联, 在此处请求流 + // 级联, 在此处请求流 inviteParams := &InviteParams{ DeviceID: deviceId, ChannelID: channelId, @@ -417,7 +420,7 @@ func (api *ApiServer) OnPlay(params *PlayDoneParams, w http.ResponseWriter, r *h } else if http.StatusOK == code { _ = stream.ID - _ = dao.Sink.SaveForwardSink(&dao.SinkModel{ + _ = dao.Sink.CreateSink(&dao.SinkModel{ SinkID: params.Sink, StreamID: params.Stream, Protocol: params.Protocol, @@ -432,7 +435,7 @@ func (api *ApiServer) OnPlay(params *PlayDoneParams, w http.ResponseWriter, r *h func (api *ApiServer) OnPlayDone(params *PlayDoneParams, _ http.ResponseWriter, _ *http.Request) { log.Sugar.Debugf("播放结束事件. protocol: %s stream: %s", params.Protocol, params.Stream) - sink, _ := dao.Sink.DeleteForwardSink(params.Sink) + sink, _ := dao.Sink.DeleteSink(params.Sink) if sink == nil { return } @@ -561,38 +564,38 @@ func (api *ApiServer) DoStreamStart(v *InviteParams, w http.ResponseWriter, r *h var urls map[string]string urls = make(map[string]string, 10) - for _, url := range stream.Urls { + for _, streamUrl := range stream.Urls { var streamName string - if strings.HasPrefix(url, "ws") { + if strings.HasPrefix(streamUrl, "ws") { streamName = "WS_FLV" - } else if strings.HasSuffix(url, ".flv") { + } else if strings.HasSuffix(streamUrl, ".flv") { streamName = "FLV" - } else if strings.HasSuffix(url, ".m3u8") { + } else if strings.HasSuffix(streamUrl, ".m3u8") { streamName = "HLS" - } else if strings.HasSuffix(url, ".rtc") { + } else if strings.HasSuffix(streamUrl, ".rtc") { streamName = "WEBRTC" - } else if strings.HasPrefix(url, "rtmp") { + } else if strings.HasPrefix(streamUrl, "rtmp") { streamName = "RTMP" - } else if strings.HasPrefix(url, "rtsp") { + } else if strings.HasPrefix(streamUrl, "rtsp") { streamName = "RTSP" } // 加上登录的token, 播放授权 - url += "?stream_token=" + v.Token + streamUrl += "?stream_token=" + v.Token // 兼容livegbs前端播放webrtc if streamName == "WEBRTC" { - if strings.HasPrefix(url, "http") { - url = strings.Replace(url, "http", "webrtc", 1) - } else if strings.HasPrefix(url, "https") { - url = strings.Replace(url, "https", "webrtcs", 1) + if strings.HasPrefix(streamUrl, "http") { + streamUrl = strings.Replace(streamUrl, "http", "webrtc", 1) + } else if strings.HasPrefix(streamUrl, "https") { + streamUrl = strings.Replace(streamUrl, "https", "webrtcs", 1) } - url += "&wf=livegbs" + streamUrl += "&wf=livegbs" } - urls[streamName] = url + urls[streamName] = streamUrl } response := LiveGBSStream{ @@ -674,7 +677,7 @@ func (api *ApiServer) DoInvite(inviteType common.InviteType, params *InviteParam if speed < 1 { speed = 4 } - d := stack.Device{device} + d := &stack.Device{DeviceModel: device} stream, err := d.StartStream(inviteType, params.streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync) if err != nil { return http.StatusInternalServerError, nil, err @@ -844,7 +847,7 @@ func (api *ApiServer) OnRecordList(v *QueryRecordParams, _ http.ResponseWriter, return nil, fmt.Errorf("设备离线") } - device := &stack.Device{model} + device := &stack.Device{DeviceModel: model} sn := stack.GetSN() err := device.QueryRecord(v.ChannelID, v.StartTime, v.EndTime, sn, "all") if err != nil { @@ -927,7 +930,7 @@ func (api *ApiServer) OnSubscribePosition(v *DeviceChannelID, _ http.ResponseWri return nil, fmt.Errorf("设备离线") } - device := &stack.Device{model} + device := &stack.Device{DeviceModel: model} if err := device.SubscribePosition(v.ChannelID); err != nil { log.Sugar.Errorf("订阅位置失败 err: %s", err.Error()) return nil, err @@ -945,7 +948,7 @@ func (api *ApiServer) OnSeekPlayback(v *SeekParams, _ http.ResponseWriter, _ *ht return nil, fmt.Errorf("stream不存在") } - stream := &stack.Stream{model} + stream := &stack.Stream{StreamModel: model} seekRequest := stream.CreateRequestFromDialog(sip.INFO) seq, _ := seekRequest.CSeq() body := fmt.Sprintf(stack.SeekBodyFormat, seq.SeqNo, v.Seconds) @@ -966,7 +969,7 @@ func (api *ApiServer) OnPTZControl(v *QueryRecordParams, _ http.ResponseWriter, return nil, fmt.Errorf("设备离线") } - device := &stack.Device{model} + device := &stack.Device{DeviceModel: model} device.ControlPTZ(v.Command, v.ChannelID) return "OK", nil @@ -976,8 +979,8 @@ func (api *ApiServer) OnHangup(v *BroadcastParams, _ http.ResponseWriter, _ *htt log.Sugar.Debugf("广播挂断 %v", *v) id := common.GenerateStreamID(common.InviteTypeBroadcast, v.DeviceID, v.ChannelID, "", "") - if sink, _ := dao.Sink.DeleteForwardSinkBySinkStreamID(id); sink != nil { - (&stack.Sink{sink}).Close(true, true) + if sink, _ := dao.Sink.DeleteSinkBySinkStreamID(id); sink != nil { + (&stack.Sink{SinkModel: sink}).Close(true, true) } return nil, nil @@ -986,171 +989,62 @@ func (api *ApiServer) OnHangup(v *BroadcastParams, _ http.ResponseWriter, _ *htt func (api *ApiServer) OnBroadcast(v *BroadcastParams, _ http.ResponseWriter, r *http.Request) (interface{}, error) { log.Sugar.Debugf("广播邀请 %v", *v) - var sinkStreamId common.StreamID - var InviteSourceId string - var ok bool - // 响应错误消息 - defer func() { - if !ok { - if InviteSourceId != "" { - stack.EarlyDialogs.Remove(InviteSourceId) - } - - if sinkStreamId != "" { - _, _ = dao.Sink.DeleteForwardSinkBySinkStreamID(sinkStreamId) - } - } - }() - model, _ := dao.Device.QueryDevice(v.DeviceID) if model == nil || !model.Online() { - log.Sugar.Errorf("广播失败, 设备离线, DeviceID: %s", v.DeviceID) return nil, fmt.Errorf("设备离线") } // 主讲人id - stream, _ := dao.Stream.QueryStream(v.StreamId) - if stream == nil { - log.Sugar.Errorf("广播失败, 找不到主讲人, stream: %s", v.StreamId) - return nil, fmt.Errorf("找不到主讲人") - } + //stream, _ := dao.Stream.QueryStream(v.StreamId) + //if stream == nil { + // return nil, fmt.Errorf("找不到主讲人") + //} - // 生成下级设备Invite请求携带的user - // server用于区分是哪个设备的广播 - - InviteSourceId = string(v.StreamId) + utils.RandStringBytes(10) - // 每个设备的广播唯一ID - sinkStreamId = common.GenerateStreamID(common.InviteTypeBroadcast, v.DeviceID, v.ChannelID, "", "") - - setupType := common.SetupTypePassive - if v.Setup != nil && *v.Setup >= common.SetupTypeUDP && *v.Setup <= common.SetupTypeActive { - setupType = *v.Setup - } - - sink := &dao.SinkModel{ - StreamID: v.StreamId, - SinkStreamID: sinkStreamId, - Protocol: stack.SourceTypeGBTalk, - CreateTime: time.Now().Unix(), - SetupType: setupType, - } - - streamWaiting := &stack.StreamWaiting{Data: sink} - if err := dao.Sink.SaveForwardSink(sink); err != nil { - log.Sugar.Errorf("广播失败, 设备正在广播中. stream: %s", sinkStreamId) - return nil, fmt.Errorf("设备正在广播中") - } else if _, ok = stack.EarlyDialogs.Add(InviteSourceId, streamWaiting); !ok { - log.Sugar.Errorf("广播失败, id冲突. id: %s", InviteSourceId) - return nil, fmt.Errorf("id冲突") - } - - ok = false - cancel := r.Context() - device := stack.Device{model} - transaction := device.Broadcast(InviteSourceId, v.ChannelID) - responses := transaction.Responses() - select { - // 等待message broadcast的应答 - case response := <-responses: - if response == nil { - log.Sugar.Errorf("广播失败, 信令超时. stream: %s", sinkStreamId) - return nil, fmt.Errorf("信令超时") - } - - if response.StatusCode() != http.StatusOK { - log.Sugar.Errorf("广播失败, 错误响应, status code: %d", response.StatusCode()) - return nil, fmt.Errorf("错误响应 code: %d", response.StatusCode()) - } - - // 等待下级设备的Invite请求 - code := streamWaiting.Receive(10) - if code == -1 { - log.Sugar.Errorf("广播失败, 等待invite超时. stream: %s", sinkStreamId) - return nil, fmt.Errorf("等待invite超时") - } else if http.StatusOK != code { - log.Sugar.Errorf("广播失败, 下级设备invite失败. stream: %s", sinkStreamId) - return nil, fmt.Errorf("错误应答 code: %d", code) - } else { - //ok = AddForwardSink(v.StreamId, sink) - ok = true - } - break - case <-cancel.Done(): - // http请求取消 - log.Sugar.Warnf("广播失败, http请求取消. session: %s", sinkStreamId) - break - } - - return nil, nil + device := &stack.Device{DeviceModel: model} + _, err := device.StartBroadcast(v.StreamId, v.DeviceID, v.ChannelID, r.Context()) + return nil, err } func (api *ApiServer) OnTalk(w http.ResponseWriter, r *http.Request) { - //vars := mux.Vars(r) - //device := vars["device"] - //channel := vars["channel"] - format := r.URL.Query().Get("format") + vars := mux.Vars(r) + deviceId := vars["device"] + channelId := vars["channel"] - // 升级HTTP连接到WebSocket - conn, err := api.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Sugar.Errorf("WebSocket升级失败: %v", err) - return - } - defer conn.Close() - - parse, err := url.Parse(common.Config.MediaServer) - if err != nil { + _, online := stack.OnlineDeviceManager.Find(deviceId) + if !online { + w.WriteHeader(http.StatusBadRequest) + _ = common.HttpResponseJson(w, "设备离线") return } - // 目标WebSocket服务地址 - targetURL := fmt.Sprintf("ws://%s%s?format=%s", parse.Host, r.URL.Path, format) - - // 连接到目标WebSocket服务 - targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, nil) + model, err := dao.Device.QueryDevice(deviceId) if err != nil { - log.Sugar.Errorf("连接目标WebSocket失败: %v", err) + w.WriteHeader(http.StatusBadRequest) + _ = common.HttpResponseJson(w, "设备不存在") return } - defer targetConn.Close() - group := sync.WaitGroup{} - group.Add(2) + // 目前只实现livegbs的一对一的对讲, stream id就是通道的广播id + streamid := common.GenerateStreamID(common.InviteTypeBroadcast, deviceId, channelId, "", "") + device := &stack.Device{DeviceModel: model} + ctx, _ := context.WithTimeout(context.Background(), time.Second*10) + sinkModel, err := device.StartBroadcast(streamid, deviceId, channelId, ctx) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + _ = common.HttpResponseJson(w, "广播失败") + return + } - // 启动两个goroutine双向转发数据 - // 从客户端转发到目标服务 - go func() { - defer group.Done() - for { - messageType, p, err := conn.ReadMessage() - if err != nil { - log.Sugar.Debugf("读取客户端消息错误: %v", err) - return - } - if err := targetConn.WriteMessage(messageType, p); err != nil { - log.Sugar.Debugf("写入目标服务消息错误: %v", err) - return - } - } - }() + err = common.WSForwardTo(r.URL.Path, w, r) + if err != nil { + log.Sugar.Errorf("广播失败 err: %s", err.Error()) + } - // 从目标服务转发到客户端 - go func() { - defer group.Done() - for { - messageType, p, err := targetConn.ReadMessage() - if err != nil { - log.Sugar.Debugf("读取目标服务消息错误: %v", err) - return - } - if err := conn.WriteMessage(messageType, p); err != nil { - log.Sugar.Debugf("写入客户端消息错误: %v", err) - return - } - } - }() + log.Sugar.Infof("广播结束 device: %s/%s", deviceId, channelId) - group.Wait() + // 对讲结束, 关闭sink + sink := &stack.Sink{SinkModel: sinkModel} + sink.Close(true, true) } func (api *ApiServer) OnStarted(_ http.ResponseWriter, _ *http.Request) { @@ -1158,12 +1052,12 @@ func (api *ApiServer) OnStarted(_ http.ResponseWriter, _ *http.Request) { streams, _ := dao.Stream.DeleteStreams() for _, stream := range streams { - (&stack.Stream{stream}).Close(true, false) + (&stack.Stream{StreamModel: stream}).Close(true, false) } - sinks, _ := dao.Sink.DeleteForwardSinks() + sinks, _ := dao.Sink.DeleteSinks() for _, sink := range sinks { - (&stack.Sink{sink}).Close(true, false) + (&stack.Sink{SinkModel: sink}).Close(true, false) } } @@ -1372,7 +1266,7 @@ func (api *ApiServer) OnCatalogQuery(params *QueryDeviceChannel, _ http.Response return nil, fmt.Errorf("not found device") } - list, err := (&stack.Device{deviceModel}).QueryCatalog(15) + list, err := (&stack.Device{DeviceModel: deviceModel}).QueryCatalog(15) if err != nil { return nil, err } @@ -1437,7 +1331,7 @@ func (api *ApiServer) OnSessionList(q *QueryDeviceChannel, _ http.ResponseWriter var n int n, err = resp.Body.Read(bytes) - resp.Body.Close() + _ = resp.Body.Close() if n < 1 { break } @@ -1456,7 +1350,7 @@ func (api *ApiServer) OnSessionList(q *QueryDeviceChannel, _ http.ResponseWriter return &response, nil } -func (api *ApiServer) OnSessionStop(params *StreamIDParams, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnSessionStop(params *StreamIDParams, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { err := stack.MSCloseSource(string(params.StreamID)) if err != nil { return nil, err @@ -1465,7 +1359,7 @@ func (api *ApiServer) OnSessionStop(params *StreamIDParams, w http.ResponseWrite return "OK", nil } -func (api *ApiServer) OnDeviceTree(q *QueryDeviceChannel, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnDeviceTree(q *QueryDeviceChannel, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { var response []*LiveGBSDeviceTree // 查询所有设备 @@ -1506,7 +1400,7 @@ func (api *ApiServer) OnDeviceTree(q *QueryDeviceChannel, w http.ResponseWriter, return &response, nil } -func (api *ApiServer) OnDeviceRemove(q *DeleteDevice, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnDeviceRemove(q *DeleteDevice, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { var err error if q.IP != "" { // 删除IP下的所有设备 @@ -1534,7 +1428,7 @@ func (api *ApiServer) OnDeviceRemove(q *DeleteDevice, w http.ResponseWriter, req return "OK", nil } -func (api *ApiServer) OnEnableSet(params *SetEnable, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnEnableSet(params *SetEnable, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { model, err := dao.Platform.QueryPlatformByID(params.ID) if err != nil { return nil, err @@ -1637,7 +1531,7 @@ func (api *ApiServer) OnPlatformChannelList(q *QueryCascadeChannelList, w http.R return &response, nil } -func (api *ApiServer) OnShareAllChannel(q *SetEnable, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnShareAllChannel(q *SetEnable, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { var err error if q.ShareAllChannel { // 删除所有已经绑定的通道, 设置级联所有通道为true @@ -1656,7 +1550,7 @@ func (api *ApiServer) OnShareAllChannel(q *SetEnable, w http.ResponseWriter, req return "OK", nil } -func (api *ApiServer) OnCustomChannelSet(q *CustomChannel, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnCustomChannelSet(q *CustomChannel, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { if len(q.CustomID) != 20 { return nil, fmt.Errorf("20位国标ID") } @@ -1668,7 +1562,7 @@ func (api *ApiServer) OnCustomChannelSet(q *CustomChannel, w http.ResponseWriter return "OK", nil } -func (api *ApiServer) OnCatalogPush(q *SetEnable, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnCatalogPush(_ *SetEnable, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { return "OK", nil } @@ -1680,7 +1574,7 @@ func (api *ApiServer) OnRecordStop(writer http.ResponseWriter, request *http.Req common.HttpForwardTo("/api/v1/record/stop", writer, request) } -func (api *ApiServer) OnPlaybackControl(params *StreamIDParams, w http.ResponseWriter, req *http.Request) (interface{}, error) { +func (api *ApiServer) OnPlaybackControl(params *StreamIDParams, _ http.ResponseWriter, _ *http.Request) (interface{}, error) { if "scale" != params.Command || params.Scale <= 0 || params.Scale > 4 { return nil, errors.New("scale error") } @@ -1698,7 +1592,7 @@ func (api *ApiServer) OnPlaybackControl(params *StreamIDParams, w http.ResponseW return nil, err } - s := stack.Device{device} + s := &stack.Device{DeviceModel: device} s.ScalePlayback(stream.Dialog, params.Scale) err = stack.MSSpeedSet(string(params.StreamID), params.Scale) if err != nil { diff --git a/common/http_proxy.go b/common/http_proxy.go index e63ff57..e2c88f8 100644 --- a/common/http_proxy.go +++ b/common/http_proxy.go @@ -3,10 +3,12 @@ package common import ( "bytes" "fmt" + "github.com/pretty66/websocketproxy" "io" "net/http" "net/http/httputil" "net/url" + "strings" ) func HttpForwardTo(path string, w http.ResponseWriter, r *http.Request) { @@ -52,3 +54,22 @@ func HttpForwardTo(path string, w http.ResponseWriter, r *http.Request) { proxy.ServeHTTP(w, r) } + +func WSForwardTo(path string, w http.ResponseWriter, r *http.Request) error { + hostport := Config.MediaServer + if strings.HasPrefix(Config.MediaServer, "https") { + hostport = "wss" + Config.MediaServer[5:] + } else if strings.HasPrefix(Config.MediaServer, "http") { + hostport = "ws" + Config.MediaServer[4:] + } + + wp, err := websocketproxy.NewProxy(fmt.Sprintf("%s%s", hostport, path), func(r *http.Request) error { + return nil + }) + + if err == nil { + wp.Proxy(w, r) + } + + return err +} diff --git a/dao/channel.go b/dao/channel.go index f9b6598..29b7bd5 100644 --- a/dao/channel.go +++ b/dao/channel.go @@ -309,3 +309,12 @@ func (d *daoChannel) QueryOnlineSubChannelCount(rootId string, groupId string, h func (d *daoChannel) UpdateCustomID(rootId, channelId string, customID string) error { return db.Model(&ChannelModel{}).Where("root_id =? and device_id =?", rootId, channelId).Update("custom_id", customID).Error } + +func (d *daoChannel) QueryChannelsByParentID(rootId string, parentId string) ([]*ChannelModel, error) { + var channels []*ChannelModel + tx := db.Where("root_id =? and parent_id =?", rootId, parentId).Find(&channels) + if tx.Error != nil { + return nil, tx.Error + } + return channels, nil +} diff --git a/dao/device.go b/dao/device.go index d1e51a6..884d260 100644 --- a/dao/device.go +++ b/dao/device.go @@ -144,6 +144,19 @@ func (d *daoDevice) QueryDevice(id string) (*DeviceModel, error) { return &device, nil } +// QueryDeviceByAddr 根据地址查询设备 +func (d *daoDevice) QueryDeviceByAddr(addr string) (*DeviceModel, error) { + host, p, _ := net.SplitHostPort(addr) + port, _ := strconv.Atoi(p) + var device DeviceModel + tx := db.Where("remote_ip = ? and remote_port = ?", host, port).Take(&device) + if tx.Error != nil { + return nil, tx.Error + } + + return &device, nil +} + func (d *daoDevice) QueryDevices(page int, size int, status string, keyword string, order string) ([]*DeviceModel, int, error) { var cond = make(map[string]interface{}) if status != "" { diff --git a/dao/sink.go b/dao/sink.go index 27bd1c0..d0a1a28 100644 --- a/dao/sink.go +++ b/dao/sink.go @@ -8,10 +8,10 @@ import ( // SinkModel 级联/对讲/网关转发流Sink type SinkModel struct { GBModel - SinkID string `json:"sink_id"` // 流媒体服务器中的sink id - StreamID common.StreamID `json:"stream_id"` // 所属的推流ID - SinkStreamID common.StreamID `json:"sink_stream_id"` // 广播使用, 每个广播设备的唯一ID - Protocol int `json:"protocol,omitempty"` // 拉流协议, @See stack.TransStreamRtmp + SinkID string `json:"sink_id"` // 流媒体服务器中的sink id + StreamID common.StreamID `json:"stream_id"` // 拉取流的id, 目前和source id一致 + SinkStreamID common.StreamID `json:"sink_stream_id" gorm:"unique"` // 广播使用, 每个广播设备的唯一ID + Protocol int `json:"protocol,omitempty"` // 拉流协议, @See stack.TransStreamRtmp Dialog *common.RequestWrapper `json:"dialog,omitempty"` CallID string `json:"call_id,omitempty"` ServerAddr string `json:"server_addr,omitempty"` // 级联上级地址 @@ -27,7 +27,7 @@ func (d *SinkModel) TableName() string { type daoSink struct { } -func (d *daoSink) LoadForwardSinks() (map[string]*SinkModel, error) { +func (d *daoSink) LoadSinks() (map[string]*SinkModel, error) { var sinks []*SinkModel tx := db.Find(&sinks) if tx.Error != nil { @@ -41,13 +41,13 @@ func (d *daoSink) LoadForwardSinks() (map[string]*SinkModel, error) { return sinkMap, nil } -func (d *daoSink) QueryForwardSink(stream common.StreamID, sinkId string) (*SinkModel, error) { +func (d *daoSink) QuerySink(stream common.StreamID, sinkId string) (*SinkModel, error) { var sink SinkModel db.Where("stream_id =? and sink_id =?", stream, sinkId).Take(&sink) return &sink, db.Error } -func (d *daoSink) QueryForwardSinks(stream common.StreamID) (map[string]*SinkModel, error) { +func (d *daoSink) QuerySinks(stream common.StreamID) (map[string]*SinkModel, error) { var sinks []*SinkModel tx := db.Where("stream_id =?", stream).Find(&sinks) if tx.Error != nil { @@ -61,13 +61,19 @@ func (d *daoSink) QueryForwardSinks(stream common.StreamID) (map[string]*SinkMod return sinkMap, nil } -func (d *daoSink) SaveForwardSink(sink *SinkModel) error { +func (d *daoSink) CreateSink(sink *SinkModel) error { return DBTransaction(func(tx *gorm.DB) error { return tx.Create(sink).Error }) } -func (d *daoSink) DeleteForwardSink(sinkId string) (*SinkModel, error) { +func (d *daoSink) SaveSink(sink *SinkModel) error { + return DBTransaction(func(tx *gorm.DB) error { + return tx.Save(sink).Error + }) +} + +func (d *daoSink) DeleteSink(sinkId string) (*SinkModel, error) { var sink SinkModel tx := db.Where("sink_id =?", sinkId).Take(&sink) if tx.Error != nil { @@ -79,7 +85,7 @@ func (d *daoSink) DeleteForwardSink(sinkId string) (*SinkModel, error) { }) } -func (d *daoSink) DeleteForwardSinksByStreamID(stream common.StreamID) ([]*SinkModel, error) { +func (d *daoSink) DeleteSinksByStreamID(stream common.StreamID) ([]*SinkModel, error) { var sinks []*SinkModel tx := db.Where("stream_id =?", stream).Find(&sinks) if tx.Error != nil { @@ -91,7 +97,7 @@ func (d *daoSink) DeleteForwardSinksByStreamID(stream common.StreamID) ([]*SinkM }) } -func (d *daoSink) QueryForwardSinkByCallID(callID string) (*SinkModel, error) { +func (d *daoSink) QuerySinkByCallID(callID string) (*SinkModel, error) { var sinks SinkModel tx := db.Where("call_id =?", callID).Find(&sinks) if tx.Error != nil { @@ -101,7 +107,7 @@ func (d *daoSink) QueryForwardSinkByCallID(callID string) (*SinkModel, error) { return &sinks, nil } -func (d *daoSink) DeleteForwardSinkByCallID(callID string) (*SinkModel, error) { +func (d *daoSink) DeleteSinkByCallID(callID string) (*SinkModel, error) { var sink SinkModel tx := db.Where("call_id =?", callID).First(&sink) if tx.Error != nil { @@ -113,7 +119,7 @@ func (d *daoSink) DeleteForwardSinkByCallID(callID string) (*SinkModel, error) { }) } -func (d *daoSink) DeleteForwardSinkBySinkStreamID(sinkStreamId common.StreamID) (*SinkModel, error) { +func (d *daoSink) DeleteSinkBySinkStreamID(sinkStreamId common.StreamID) (*SinkModel, error) { var sink SinkModel tx := db.Where("sink_stream_id =?", sinkStreamId).First(&sink) if tx.Error != nil { @@ -125,7 +131,7 @@ func (d *daoSink) DeleteForwardSinkBySinkStreamID(sinkStreamId common.StreamID) }) } -func (d *daoSink) DeleteForwardSinks() ([]*SinkModel, error) { +func (d *daoSink) DeleteSinks() ([]*SinkModel, error) { var sinks []*SinkModel tx := db.Find(&sinks) if tx.Error != nil { @@ -137,13 +143,13 @@ func (d *daoSink) DeleteForwardSinks() ([]*SinkModel, error) { }) } -func (d *daoSink) DeleteForwardSinksByIds(ids []uint) error { +func (d *daoSink) DeleteSinksByIds(ids []uint) error { return DBTransaction(func(tx *gorm.DB) error { return tx.Where("id in?", ids).Unscoped().Delete(&SinkModel{}).Error }) } -func (d *daoSink) DeleteForwardSinksByServerAddr(addr string) ([]*SinkModel, error) { +func (d *daoSink) DeleteSinksByServerAddr(addr string) ([]*SinkModel, error) { var sinks []*SinkModel tx := db.Where("server_addr =?", addr).Find(&sinks) if tx.Error != nil { @@ -190,3 +196,12 @@ func (d *daoSink) QueryStreamIds(protocols []int, page, size int) ([]string, int return streamIds, int(total), nil } + +func (d *daoSink) QuerySinkBySinkStreamID(sinkStreamId common.StreamID) (*SinkModel, error) { + var sink SinkModel + tx := db.Where("sink_stream_id =?", sinkStreamId).First(&sink) + if tx.Error != nil { + return nil, tx.Error + } + return &sink, nil +} diff --git a/go.mod b/go.mod index 9c41070..86dc6be 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 github.com/lkmio/avformat v0.0.1 + github.com/pretty66/websocketproxy v0.0.0-20220507015215-930b3a686308 github.com/shirou/gopsutil/v3 v3.24.5 gorm.io/gorm v1.26.1 ) diff --git a/main.go b/main.go index 5aa168a..e0d6ea2 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,9 @@ import ( "gb-cms/hook" "gb-cms/log" "gb-cms/stack" + "github.com/pretty66/websocketproxy" "github.com/shirou/gopsutil/v3/host" + "go.uber.org/zap" "go.uber.org/zap/zapcore" "net" "net/http" @@ -38,6 +40,7 @@ func init() { } log.InitLogger(zapcore.Level(logConfig.Level), logConfig.Name, logConfig.MaxSize, logConfig.MaxBackup, logConfig.MaxAge, logConfig.Compress) + websocketproxy.SetLogger(zap.NewStdLog(log.Sugar.Desugar())) } func main() { diff --git a/recover.go b/recover.go index d0a6ef9..f3730f6 100644 --- a/recover.go +++ b/recover.go @@ -66,7 +66,7 @@ func recoverStreams() (map[string]*dao.StreamModel, map[string]*dao.SinkModel) { return nil, nil } - dbSinks, _ := dao.Sink.LoadForwardSinks() + dbSinks, _ := dao.Sink.LoadSinks() // 查询流媒体服务器中的推流源列表 msSources, err := stack.MSQuerySourceList() @@ -118,7 +118,7 @@ func recoverStreams() (map[string]*dao.StreamModel, map[string]*dao.SinkModel) { } _ = dao.Stream.DeleteStreamsByIds(invalidStreamIds) - _ = dao.Sink.DeleteForwardSinksByIds(invalidSinkIds) + _ = dao.Sink.DeleteSinksByIds(invalidSinkIds) return dbStreams, dbSinks } diff --git a/stack/broadcast.go b/stack/broadcast.go index 95c13d7..6c35332 100644 --- a/stack/broadcast.go +++ b/stack/broadcast.go @@ -1,11 +1,15 @@ package stack import ( + "context" "fmt" "gb-cms/common" + "gb-cms/dao" "gb-cms/log" "github.com/ghettovoice/gosip/sip" + "github.com/lkmio/avformat/utils" "net/http" + "time" ) const ( @@ -18,24 +22,100 @@ const ( "\r\n" ) -func (d *Device) DoBroadcast(sourceId, channelId string) error { - body := fmt.Sprintf(BroadcastFormat, 1, sourceId, channelId) - request := d.BuildMessageRequest(channelId, body) +func (d *Device) StartBroadcast(streamId common.StreamID, deviceId, channelId string, timeoutCtx context.Context) (*dao.SinkModel, error) { + // 生成sinkstreamid, 该通道的唯一广播id + sinkStreamId := common.GenerateStreamID(common.InviteTypeBroadcast, deviceId, channelId, "", "") + // 生成source id, 关联会话. 下发broadcast message告知设备, 设备的invite请求行将携带 + inviteSourceId := utils.RandStringBytes(20) - common.SipStack.SendRequest(request) - return nil + var ok bool + defer func() { + EarlyDialogs.Remove(inviteSourceId) + EarlyDialogs.Remove(d.DeviceID) + if !ok { + _, _ = dao.Sink.DeleteSinkBySinkStreamID(sinkStreamId) + } + }() + + sink := &dao.SinkModel{ + SinkStreamID: sinkStreamId, + StreamID: streamId, + Protocol: SourceTypeGBTalk, + CreateTime: time.Now().Unix(), + SetupType: common.SetupTypePassive, + } + + // 保存sink, 保存失败认为该设备正在广播 + if err := dao.Sink.CreateSink(sink); err != nil { + return nil, err + } + + // 查找音频输出通道 + var audioChannelId = channelId + if subChannels, err := dao.Channel.QueryChannelsByParentID(deviceId, channelId); err == nil { + for _, channel := range subChannels { + if 137 != channel.TypeCode { + continue + } + + audioChannelId = channel.DeviceID + break + } + } + + // 关联会话 + streamWaiting := &StreamWaiting{Data: sink} + if _, ok = EarlyDialogs.Add(inviteSourceId, streamWaiting); !ok { + return nil, fmt.Errorf("id冲突") + } else if _, ok = EarlyDialogs.Add(d.DeviceID, streamWaiting); !ok { + // 使用设备ID关联下会话, 兼容不标准的下级设备. 如果下级设备都不标准,意味着同时只能对一个通道发起对讲. + } + + // 信令交互 + transaction := d.Broadcast(inviteSourceId, audioChannelId) + responses := transaction.Responses() + select { + // 等待message broadcast的应答 + case response := <-responses: + if response == nil { + return nil, fmt.Errorf("信令超时") + } + + if response.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("错误响应 code: %d", response.StatusCode()) + } + + // 等待下级设备的Invite请求 + code := streamWaiting.Receive(10) + if code == -1 { + return nil, fmt.Errorf("等待invite超时") + } else if http.StatusOK != code { + return nil, fmt.Errorf("错误应答 code: %d", code) + } else { + ok = true + return sink, nil + } + case <-timeoutCtx.Done(): + // 外部调用超时 + streamWaiting.Put(-1) + break + } + + return nil, fmt.Errorf("广播失败") } -// OnInvite 语音广播 +// OnInvite 收到设备的语音广播offer func (d *Device) OnInvite(request sip.Request, user string) sip.Response { - // 会话是否存在 + // 查找会话, 先用source id查找, 找不到再根据设备id查找 streamWaiting := EarlyDialogs.Find(user) - if streamWaiting == nil { - return CreateResponseWithStatusCode(request, http.StatusBadRequest) + if streamWaiting != nil { + if streamWaiting = EarlyDialogs.Find(d.DeviceID); streamWaiting == nil { + return CreateResponseWithStatusCode(request, http.StatusBadRequest) + } } // 解析offer - sink := streamWaiting.Data.(*Sink) + sink := streamWaiting.Data.(*dao.SinkModel) body := request.Body() offer, err := ParseGBSDP(body) if err != nil { @@ -53,7 +133,8 @@ func (d *Device) OnInvite(request sip.Request, user string) sip.Response { offer.AnswerSetup = sink.SetupType } - response, err := AddForwardSink(TransStreamGBTalk, request, user, sink, sink.StreamID, offer, common.InviteTypeBroadcast, "8 PCMA/8000") + // 添加sink到流媒体服务器 + response, err := AddForwardSink(TransStreamGBTalk, request, user, &Sink{sink}, sink.StreamID, offer, common.InviteTypeBroadcast, "8 PCMA/8000") if err != nil { log.Sugar.Errorf("广播失败, 流媒体创建answer发生err: %s sink: %s ", err.Error(), sink.SinkID) streamWaiting.Put(http.StatusInternalServerError) diff --git a/stack/platform.go b/stack/platform.go index 58bf087..9fea91d 100644 --- a/stack/platform.go +++ b/stack/platform.go @@ -39,7 +39,7 @@ func (g *Platform) OnQueryCatalog(sn int, channels []*dao.ChannelModel) { // CloseStream 关闭级联会话 func (g *Platform) CloseStream(callId string, bye, ms bool) { - sink, _ := dao.Sink.DeleteForwardSinkByCallID(callId) + sink, _ := dao.Sink.DeleteSinkByCallID(callId) if sink != nil { (&Sink{sink}).Close(bye, ms) } @@ -47,7 +47,7 @@ func (g *Platform) CloseStream(callId string, bye, ms bool) { // CloseStreams 关闭所有级联会话 func (g *Platform) CloseStreams(bye, ms bool) { - sinks, _ := dao.Sink.DeleteForwardSinksByServerAddr(g.ServerAddr) + sinks, _ := dao.Sink.DeleteSinksByServerAddr(g.ServerAddr) for _, sink := range sinks { (&Sink{sink}).Close(bye, ms) } diff --git a/stack/sink.go b/stack/sink.go index 4dce8f3..faf88f2 100644 --- a/stack/sink.go +++ b/stack/sink.go @@ -26,6 +26,13 @@ func (s *Sink) Close(bye, ms bool) { if ms { go MSCloseSink(string(s.StreamID), s.SinkID) } + + // 目前只有一对一对讲, 断开就删除整个websocket对讲流 + if s.Protocol == TransStreamGBTalk { + _, _ = dao.Stream.DeleteStream(s.StreamID) + // 删除流媒体source + _ = MSCloseSource(string(s.StreamID)) + } } func (s *Sink) MarshalJSON() ([]byte, error) { @@ -121,7 +128,7 @@ func AddForwardSink(forwardType int, request sip.Request, user string, sink *Sin sink.SetDialog(CreateDialogRequestFromAnswer(response, true, request.Source())) - if err = dao.Sink.SaveForwardSink(sink.SinkModel); err != nil { + if err = dao.Sink.SaveSink(sink.SinkModel); err != nil { log.Sugar.Errorf("保存sink到数据库失败, stream: %s sink: %s err: %s", streamId, sink.SinkID, err.Error()) } diff --git a/stack/sip_server.go b/stack/sip_server.go index 9323c3e..d6bc92b 100644 --- a/stack/sip_server.go +++ b/stack/sip_server.go @@ -131,12 +131,9 @@ func (s *sipServer) OnInvite(wrapper *SipRequestSource) { device = JTDeviceManager.Find(channels[0].RootID) } } else { - if session := EarlyDialogs.Find(user); session != nil { + if model, _ := dao.Device.QueryDeviceByAddr(wrapper.req.Source()); model != nil { // 语音广播设备 - model, _ := dao.Device.QueryDevice(session.Data.(*Sink).SinkStreamID.DeviceID()) - if model != nil { - device = &Device{model} - } + device = &Device{model} } else { // 根据Subject头域查找设备 headers := wrapper.req.GetHeaders("Subject") @@ -179,7 +176,7 @@ func (s *sipServer) OnBye(wrapper *SipRequestSource) { // 下级设备挂断, 关闭流 deviceId = stream.StreamID.DeviceID() (&Stream{stream}).Close(false, true) - } else if sink, _ := dao.Sink.DeleteForwardSinkByCallID(id.Value()); sink != nil { + } else if sink, _ := dao.Sink.DeleteSinkByCallID(id.Value()); sink != nil { (&Sink{sink}).Close(false, true) } @@ -360,9 +357,9 @@ func (s *sipServer) ListenAddr() string { // 过滤SIP消息、超找消息来源 func filterRequest(f func(wrapper *SipRequestSource)) gosip.RequestHandler { return func(req sip.Request, tx sip.ServerTransaction) { - userAgent := req.GetHeaders("User-Agent") // 过滤黑名单 + userAgent := req.GetHeaders("User-Agent") if model, _ := dao.Blacklist.QueryIP(req.Source()); model != nil { SendResponseWithStatusCode(req, tx, http.StatusForbidden) log2.Sugar.Errorf("处理%s请求失败, IP被黑名单过滤: %s request: %s ", req.Method(), req.Source(), req.String()) @@ -375,6 +372,7 @@ func filterRequest(f func(wrapper *SipRequestSource)) gosip.RequestHandler { } } + // 查找请求来源: 下级设备/级联上级/1078转GB28181的上级 source := req.Source() // 是否是级联上级下发的请求 platform := PlatformManager.Find(source) diff --git a/stack/stream.go b/stack/stream.go index 89eaeca..c1fca29 100644 --- a/stack/stream.go +++ b/stack/stream.go @@ -116,7 +116,7 @@ func CloseStreamByCallID(callId string) { // CloseStreamSinks 关闭某个流的所有sink func CloseStreamSinks(StreamID common.StreamID, bye, ms bool) []*dao.SinkModel { - sinks, _ := dao.Sink.DeleteForwardSinksByStreamID(StreamID) + sinks, _ := dao.Sink.DeleteSinksByStreamID(StreamID) for _, sink := range sinks { (&Sink{sink}).Close(bye, ms) }