diff --git a/api.go b/api.go index e53801f..fd4f976 100644 --- a/api.go +++ b/api.go @@ -28,7 +28,7 @@ type InviteParams struct { EndTime string `json:"end_time"` Setup string `json:"setup"` Speed string `json:"speed"` - streamId string + streamId StreamID } var apiServer *ApiServer @@ -45,16 +45,16 @@ func init() { } } -func withHookParams(f func(streamId, protocol string, w http.ResponseWriter, req *http.Request)) func(http.ResponseWriter, *http.Request) { +func withHookParams(f func(streamId StreamID, protocol string, w http.ResponseWriter, req *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, req *http.Request) { if "" != req.URL.RawQuery { Sugar.Infof("on request %s?%s", req.URL.Path, req.URL.RawQuery) } v := struct { - Stream string `json:"stream"` //Stream id - Protocol string `json:"protocol"` //推拉流协议 - RemoteAddr string `json:"remote_addr"` //peer地址 + Stream StreamID `json:"stream"` //Stream id + Protocol string `json:"protocol"` //推拉流协议 + RemoteAddr string `json:"remote_addr"` //peer地址 }{} err := HttpDecodeJSONBody(w, req, &v) @@ -77,20 +77,26 @@ func startApiServer(addr string) { apiServer.router.HandleFunc("/api/v1/hook/on_record", withHookParams(apiServer.OnReceiveTimeout)) apiServer.router.HandleFunc("/api/v1/hook/on_started", apiServer.OnStarted) - //统一处理live/playback/download请求 + // 统一处理live/playback/download请求 apiServer.router.HandleFunc("/api/v1/{action}/start", apiServer.OnInvite) - apiServer.router.HandleFunc("/api/v1/stream/close", apiServer.OnCloseStream) //释放流(实时/回放/下载), 以拉流计数为准, 如果没有客户端拉流, 不等流媒体服务通知空闲超时,立即释放流,否则(还有客户端拉流)不会释放。 + apiServer.router.HandleFunc("/api/v1/stream/close", apiServer.OnCloseStream) // 释放流(实时/回放/下载), 实际以拉流计数为准, 如果没有客户端拉流, 不等流媒体服务通知空闲超时,立即释放流,否则(还有客户端拉流)不会释放。 - apiServer.router.HandleFunc("/api/v1/device/list", apiServer.OnDeviceList) //查询在线设备 - apiServer.router.HandleFunc("/api/v1/record/list", apiServer.OnRecordList) //查询录像列表 - apiServer.router.HandleFunc("/api/v1/position/sub", apiServer.OnSubscribePosition) //订阅移动位置 - apiServer.router.HandleFunc("/api/v1/playback/seek", apiServer.OnSeekPlayback) //回放seek - apiServer.router.HandleFunc("/api/v1/ptz/control", apiServer.OnPTZControl) //云台控制 + apiServer.router.HandleFunc("/api/v1/device/list", apiServer.OnDeviceList) // 查询在线设备 + apiServer.router.HandleFunc("/api/v1/record/list", apiServer.OnRecordList) // 查询录像列表 + apiServer.router.HandleFunc("/api/v1/position/sub", apiServer.OnSubscribePosition) // 订阅移动位置 + apiServer.router.HandleFunc("/api/v1/playback/seek", apiServer.OnSeekPlayback) // 回放seek + apiServer.router.HandleFunc("/api/v1/ptz/control", apiServer.OnPTZControl) // 云台控制 - apiServer.router.HandleFunc("/ws/v1/talk", apiServer.OnWSTalk) //语音广播/对讲, 音频传输链路 - apiServer.router.HandleFunc("/api/v1/broadcast/invite", apiServer.OnBroadcast) //语音广播 - apiServer.router.HandleFunc("/api/v1/broadcast/hangup", apiServer.OnHangup) //挂断广播会话 - apiServer.router.HandleFunc("/api/v1/talk", apiServer.OnTalk) //语音对讲 + apiServer.router.HandleFunc("/api/v1/platform/add", apiServer.OnPlatformAdd) // 添加上级平台 + apiServer.router.HandleFunc("/api/v1/platform/remove", apiServer.OnPlatformRemove) // 删除上级平台 + apiServer.router.HandleFunc("/api/v1/platform/list", apiServer.OnPlatformList) // 上级平台列表 + apiServer.router.HandleFunc("/api/v1/platform/channel/bind", apiServer.OnPlatformChannelBind) // 级联绑定通道 + apiServer.router.HandleFunc("/api/v1/platform/channel/unbind", apiServer.OnPlatformChannelUnbind) // 级联取消绑定通道 + + apiServer.router.HandleFunc("/ws/v1/talk", apiServer.OnWSTalk) // 语音广播/对讲, 主讲音频传输链路 + apiServer.router.HandleFunc("/api/v1/broadcast/invite", apiServer.OnBroadcast) // 发起语音广播 + apiServer.router.HandleFunc("/api/v1/broadcast/hangup", apiServer.OnHangup) // 挂断广播会话 + apiServer.router.HandleFunc("/api/v1/talk", apiServer.OnTalk) // 语音对讲 apiServer.router.HandleFunc("/broadcast.html", func(writer http.ResponseWriter, request *http.Request) { http.ServeFile(writer, request, "./broadcast.html") }) @@ -115,197 +121,28 @@ func startApiServer(addr string) { } } -func generateStreamId(inviteType InviteType, deviceId, channelId string, startTime, endTime string) string { - utils.Assert(channelId != "") - - var streamId []string - if deviceId != "" { - streamId = append(streamId, deviceId) - } - - streamId = append(streamId, channelId) - if InviteTypePlayback == inviteType { - return strings.Join(streamId, "/") + ".playback" + "." + startTime + "." + endTime - } else if InviteTypeDownload == inviteType { - return strings.Join(streamId, "/") + ".download" + "." + startTime + "." + endTime - } - - return strings.Join(streamId, "/") -} - -func (api *ApiServer) OnInvite(w http.ResponseWriter, r *http.Request) { - v := InviteParams{} - err := HttpDecodeJSONBody(w, r, &v) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - vars := mux.Vars(r) - action := strings.ToLower(vars["action"]) - if "playback" == action { - apiServer.DoInvite(InviteTypePlayback, v, true, w, r) - } else if "download" == action { - apiServer.DoInvite(InviteTypeDownload, v, true, w, r) - } else if "live" == action { - apiServer.DoInvite(InviteTypeLive, v, true, w, r) - } else { - w.WriteHeader(http.StatusNotFound) - } -} - -func (api *ApiServer) OnLiveStart(device *DBDevice, params InviteParams, streamId string, w http.ResponseWriter, r *http.Request) (sip.Request, bool, error) { - dialog, b := device.Live(streamId, params.ChannelID, params.Setup) - return dialog, b, nil -} - -func (api *ApiServer) OnPlaybackStart(device *DBDevice, params InviteParams, streamId string, w http.ResponseWriter, r *http.Request) (sip.Request, bool, error) { - startTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.StartTime, time.Local) - if err != nil { - Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.StartTime) - return nil, false, err - } - - endTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.EndTime, time.Local) - if err != nil { - Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.EndTime) - return nil, false, err - } - - startTimeSeconds := strconv.FormatInt(startTime.Unix(), 10) - endTimeSeconds := strconv.FormatInt(endTime.Unix(), 10) - dialog, b := device.Playback(streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup) - return dialog, b, nil -} - -func (api *ApiServer) OnDownloadStart(device *DBDevice, params InviteParams, streamId string, w http.ResponseWriter, r *http.Request) (sip.Request, bool, error) { - startTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.StartTime, time.Local) - if err != nil { - Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.StartTime) - return nil, false, err - } - - endTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.EndTime, time.Local) - if err != nil { - Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.EndTime) - return nil, false, err - } - - startTimeSeconds := strconv.FormatInt(startTime.Unix(), 10) - endTimeSeconds := strconv.FormatInt(endTime.Unix(), 10) - speed, _ := strconv.Atoi(params.Speed) - speed = int(math.Min(4, float64(speed))) - - dialog, b := device.Download(streamId, params.ChannelID, params.StartTime, startTimeSeconds, endTimeSeconds, speed) - return dialog, b, nil -} - -// DoInvite 处理Invite请求 -// @params sync 是否异步等待流媒体的publish事件(确认收到流), 目前请求流分两种方式,流媒体hook和http接口, hook方式同步等待确认收到流再应答, http接口直接应答成功。 -func (api *ApiServer) DoInvite(inviteType InviteType, params InviteParams, sync bool, w http.ResponseWriter, r *http.Request) (*Stream, bool) { - device := DeviceManager.Find(params.DeviceID) - if device == nil { - Sugar.Warnf("设备离线 id:%s", params.DeviceID) - return nil, false - } - - streamId := params.streamId - if streamId == "" { - streamId = generateStreamId(inviteType, device.Id, params.ChannelID, params.StartTime, params.EndTime) - } - stream := &Stream{ - Id: streamId, - Protocol: "28181", - StreamType: inviteType, - } - - var inviteOK bool - var publishOK bool - defer func() { - if !inviteOK { - StreamManager.Remove(streamId) - w.WriteHeader(http.StatusInternalServerError) - } else if !publishOK { - CloseStream(streamId) - w.WriteHeader(http.StatusInternalServerError) - } else { - response := map[string]string{ - "stream_id": streamId, - } - - httpResponseOK(w, response) - } - }() - - //如果添加Stream失败, 说明Stream已经存在 - if stream, ok := StreamManager.Add(stream); !ok { - Sugar.Infof("stream %s 已经存在", streamId) - inviteOK = true - publishOK = true - return stream, true - } - - var dialog sip.Request - var err error - if InviteTypePlayback == inviteType { - dialog, inviteOK, err = api.OnPlaybackStart(device, params, streamId, w, r) - } else if InviteTypeDownload == inviteType { - dialog, inviteOK, err = api.OnDownloadStart(device, params, streamId, w, r) - } else { - dialog, inviteOK, err = api.OnLiveStart(device, params, streamId, w, r) - } - - if !inviteOK || err != nil { - StreamManager.Remove(streamId) - w.WriteHeader(http.StatusInternalServerError) - return nil, false - } - - stream.DialogRequest = dialog - StreamManager.AddWithCallId(stream) - - //启动收流超时计时器 - wait := func() bool { - ok := stream.WaitForPublishEvent(10) - if !ok { - Sugar.Infof("收流超时 发送bye请求...") - CloseStream(streamId) - } - return ok - } - - if sync { - publishOK = true - go wait() - } else { - publishOK = wait() - } - - return stream, publishOK -} - -func (api *ApiServer) OnPlay(streamId, protocol string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) OnPlay(streamId StreamID, protocol string, w http.ResponseWriter, r *http.Request) { Sugar.Infof("play. protocol:%s stream id:%s", protocol, streamId) - //[注意]: windows上使用cmd/power shell推拉流如果要携带多个参数, 请用双引号将与号引起来("&") - //session_id是为了同一个录像文件, 允许同时点播多个.当然如果实时流支持多路预览, 也是可以的. + // [注意]: windows上使用cmd/power shell推拉流如果要携带多个参数, 请用双引号将与号引起来("&") + // session_id是为了同一个录像文件, 允许同时点播多个.当然如果实时流支持多路预览, 也是可以的. //ffplay -i rtmp://127.0.0.1/34020000001320000001/34020000001310000001 //ffplay -i http://127.0.0.1:8080/34020000001320000001/34020000001310000001.flv?setup=passive //ffplay -i http://127.0.0.1:8080/34020000001320000001/34020000001310000001.m3u8?setup=passive //ffplay -i rtsp://test:123456@127.0.0.1/34020000001320000001/34020000001310000001?setup=passive - //回放示例 + // 回放示例 //ffplay -i rtmp://127.0.0.1/34020000001320000001/34020000001310000001.session_id_0?setup=passive"&"stream_type=playback"&"start_time=2024-06-18T15:20:56"&"end_time=2024-06-18T15:25:56 //ffplay -i rtmp://127.0.0.1/34020000001320000001/34020000001310000001.session_id_0?setup=passive&stream_type=playback&start_time=2024-06-18T15:20:56&end_time=2024-06-18T15:25:56 - //跳过非国标拉流 - split := strings.Split(streamId, "/") + // 跳过非国标拉流 + split := strings.Split(string(streamId), "/") if len(split) != 2 || len(split[0]) != 20 || len(split[1]) < 20 { w.WriteHeader(http.StatusOK) return } - //已经存在,累加计数 + // 已经存在,累加计数 if stream := StreamManager.Find(streamId); stream != nil { stream.IncreaseSinkCount() w.WriteHeader(http.StatusOK) @@ -345,9 +182,81 @@ func (api *ApiServer) OnPlay(streamId, protocol string, w http.ResponseWriter, r } } +func (api *ApiServer) OnInvite(w http.ResponseWriter, r *http.Request) { + v := InviteParams{} + if err := HttpDecodeJSONBody(w, r, &v); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + vars := mux.Vars(r) + action := strings.ToLower(vars["action"]) + if "playback" == action { + apiServer.DoInvite(InviteTypePlayback, v, true, w, r) + } else if "download" == action { + apiServer.DoInvite(InviteTypeDownload, v, true, w, r) + } else if "live" == action { + apiServer.DoInvite(InviteTypeLive, v, true, w, r) + } else { + w.WriteHeader(http.StatusNotFound) + } +} + +// DoInvite 处理Invite请求 +// @params sync 是否异步等待流媒体的publish事件(确认收到流), 目前请求流分两种方式,流媒体hook和http接口, hook方式同步等待确认收到流再应答, http接口直接应答成功。 +func (api *ApiServer) DoInvite(inviteType InviteType, params InviteParams, sync bool, w http.ResponseWriter, r *http.Request) (*Stream, bool) { + device := DeviceManager.Find(params.DeviceID) + if device == nil { + Sugar.Warnf("设备离线 id:%s", params.DeviceID) + w.WriteHeader(http.StatusNotFound) + return nil, false + } + + // 解析时间范围参数 + var startTimeSeconds string + var endTimeSeconds string + if InviteTypeLive != inviteType { + startTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.StartTime, time.Local) + if err != nil { + Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.StartTime) + w.WriteHeader(http.StatusBadRequest) + return nil, false + } + + endTime, err := time.ParseInLocation("2006-01-02t15:04:05", params.EndTime, time.Local) + if err != nil { + Sugar.Errorf("解析开始时间失败 err:%s start_time:%s", err.Error(), params.EndTime) + w.WriteHeader(http.StatusBadRequest) + return nil, false + } + + startTimeSeconds = strconv.FormatInt(startTime.Unix(), 10) + endTimeSeconds = strconv.FormatInt(endTime.Unix(), 10) + } + + streamId := params.streamId + if streamId == "" { + streamId = GenerateStreamId(inviteType, device.GetID(), params.ChannelID, params.StartTime, params.EndTime) + } + + // 解析回放或下载速度参数 + speed, _ := strconv.Atoi(params.Speed) + speed = int(math.Min(4, float64(speed))) + stream, ok := device.(*Device).StartStream(inviteType, streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return nil, false + } + + // 返回stream id + response := map[string]string{"stream_id": string(streamId)} + httpResponseOK(w, response) + return stream, true +} + func (api *ApiServer) OnCloseStream(w http.ResponseWriter, r *http.Request) { v := struct { - StreamID string `json:"stream_id"` + StreamID StreamID `json:"stream_id"` }{} err := HttpDecodeJSONBody(w, r, &v) @@ -369,22 +278,28 @@ func (api *ApiServer) OnCloseStream(w http.ResponseWriter, r *http.Request) { CloseStream(v.StreamID) } -func CloseStream(streamId string) { - stream, _ := StreamManager.Remove(streamId) +func CloseStream(streamId StreamID) { + stream := StreamManager.Remove(streamId) if stream != nil { stream.Close(true) } } -func (api *ApiServer) OnPlayDone(streamId, protocol string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) OnPlayDone(streamId StreamID, protocol string, w http.ResponseWriter, r *http.Request) { Sugar.Infof("play done. protocol:%s stream id:%s", protocol, streamId) if stream := StreamManager.Find(streamId); stream != nil { stream.DecreaseSinkCount() } + + // 与上级级联断开连接 + if protocol == "gb_stream_forward" { + + } + w.WriteHeader(http.StatusOK) } -func (api *ApiServer) OnPublish(streamId, protocol string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) OnPublish(streamId StreamID, protocol string, w http.ResponseWriter, r *http.Request) { Sugar.Infof("publish. protocol:%s stream id:%s", protocol, streamId) w.WriteHeader(http.StatusOK) @@ -394,14 +309,13 @@ func (api *ApiServer) OnPublish(streamId, protocol string, w http.ResponseWriter } } -func (api *ApiServer) OnPublishDone(streamId, protocol string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) OnPublishDone(streamId StreamID, protocol string, w http.ResponseWriter, r *http.Request) { Sugar.Infof("publish done. protocol:%s stream id:%s", protocol, streamId) - w.WriteHeader(http.StatusOK) CloseStream(streamId) } -func (api *ApiServer) OnIdleTimeout(streamId string, protocol string, w http.ResponseWriter, req *http.Request) { +func (api *ApiServer) OnIdleTimeout(streamId StreamID, protocol string, w http.ResponseWriter, req *http.Request) { Sugar.Infof("publish timeout. protocol:%s stream id:%s", protocol, streamId) if protocol != "rtmp" { @@ -412,7 +326,7 @@ func (api *ApiServer) OnIdleTimeout(streamId string, protocol string, w http.Res } } -func (api *ApiServer) OnReceiveTimeout(streamId string, protocol string, w http.ResponseWriter, req *http.Request) { +func (api *ApiServer) OnReceiveTimeout(streamId StreamID, protocol string, w http.ResponseWriter, req *http.Request) { Sugar.Infof("receive timeout. protocol:%s stream id:%s", protocol, streamId) if protocol != "rtmp" { @@ -456,7 +370,7 @@ func (api *ApiServer) OnRecordList(w http.ResponseWriter, r *http.Request) { } sn := GetSN() - err = device.DoQueryRecordList(v.ChannelId, v.StartTime, v.EndTime, sn, v.Type_) + err = device.QueryRecord(v.ChannelId, v.StartTime, v.EndTime, sn, v.Type_) if err != nil { httpResponseOK(w, fmt.Sprintf("发送查询录像记录失败 err:%s", err.Error())) return @@ -504,7 +418,7 @@ func (api *ApiServer) OnSubscribePosition(w http.ResponseWriter, r *http.Request return } - if err := device.DoSubscribePosition(v.ChannelID); err != nil { + if err := device.SubscribePosition(v.ChannelID); err != nil { } @@ -513,8 +427,8 @@ func (api *ApiServer) OnSubscribePosition(w http.ResponseWriter, r *http.Request func (api *ApiServer) OnSeekPlayback(w http.ResponseWriter, r *http.Request) { v := struct { - StreamId string `json:"stream_id"` - Seconds int `json:"seconds"` + StreamId StreamID `json:"stream_id"` + Seconds int `json:"seconds"` }{} if err := HttpDecodeJSONBody(w, r, &v); err != nil { @@ -539,6 +453,7 @@ func (api *ApiServer) OnSeekPlayback(w http.ResponseWriter, r *http.Request) { } func (api *ApiServer) OnPTZControl(w http.ResponseWriter, r *http.Request) { + } func (api *ApiServer) OnWSTalk(w http.ResponseWriter, r *http.Request) { @@ -658,7 +573,7 @@ func (api *ApiServer) OnBroadcast(w http.ResponseWriter, r *http.Request) { } if BroadcastManager.AddSession(v.RoomID, session) { - device.DoBroadcast(sourceId, v.ChannelID) + device.Broadcast(sourceId, v.ChannelID) httpResponseOK(w, nil) } else { w.WriteHeader(http.StatusForbidden) @@ -690,3 +605,97 @@ func (api *ApiServer) OnStarted(w http.ResponseWriter, req *http.Request) { stream.Close(true) } } + +func (api *ApiServer) OnPlatformAdd(w http.ResponseWriter, r *http.Request) { + v := GBPlatformRecord{} + if err := HttpDecodeJSONBody(w, r, &v); err != nil { + httpResponse2(w, err) + return + } + + if PlatformManager.ExistPlatform(v.SeverID) || PlatformManager.ExistPlatformWithServerAddr(v.ServerAddr) { + return + } + + platform, err := NewGBPlatform(&v, SipUA) + if err != nil { + return + } else if !PlatformManager.AddPlatform(platform) { + return + } + + platform.Start() +} + +func (api *ApiServer) OnPlatformRemove(w http.ResponseWriter, r *http.Request) { + v := GBPlatformRecord{} + if err := HttpDecodeJSONBody(w, r, &v); err != nil { + httpResponse2(w, err) + return + } + + platform := PlatformManager.RemovePlatform(v.SeverID) + if platform != nil { + platform.Stop() + } +} + +func (api *ApiServer) OnPlatformList(w http.ResponseWriter, r *http.Request) { + platforms := PlatformManager.Platforms() + httpResponseOK(w, platforms) +} + +func (api *ApiServer) OnPlatformChannelBind(w http.ResponseWriter, r *http.Request) { + v := struct { + ServerID string `json:"server_id"` + Channels [][2]string `json:"channels"` //二维数组, 索引0-设备ID/索引1-通道ID + }{} + + if err := HttpDecodeJSONBody(w, r, &v); err != nil { + httpResponse2(w, err) + return + } + + platform := PlatformManager.FindPlatform(v.ServerID) + if platform == nil { + return + } + + var channels []*Channel + for _, pair := range v.Channels { + device := DeviceManager.Find(pair[0]) + if device == nil { + continue + } + + channel := device.FindChannel(pair[1]) + if channel == nil { + continue + } + + channels = append(channels, channel) + } + + platform.AddChannels(channels) +} + +func (api *ApiServer) OnPlatformChannelUnbind(w http.ResponseWriter, r *http.Request) { + v := struct { + ServerID string `json:"server_id"` + Channels [][2]string `json:"channels"` //二维数组, 索引0-设备ID/索引1-通道ID + }{} + + if err := HttpDecodeJSONBody(w, r, &v); err != nil { + httpResponse2(w, err) + return + } + + platform := PlatformManager.FindPlatform(v.ServerID) + if platform == nil { + return + } + + for _, pair := range v.Channels { + platform.RemoveChannel(pair[1]) + } +} diff --git a/broadcast.go b/broadcast.go index fbc05b6..ea8ba60 100644 --- a/broadcast.go +++ b/broadcast.go @@ -12,7 +12,7 @@ import ( ) const ( - BroadcastFormat = "\r\n" + + BroadcastFormat = "\r\n" + "\r\n" + "Broadcast\r\n" + "%d\r\n" + @@ -30,34 +30,37 @@ const ( "a=rtpmap:8 PCMA/8000\r\n" ) -func (d *DBDevice) DoBroadcast(sourceId, channelId string) error { +func (d *Device) DoBroadcast(sourceId, channelId string) error { body := fmt.Sprintf(BroadcastFormat, 1, sourceId, channelId) - request, err := d.BuildMessageRequest(channelId, body) - if err != nil { - return err - } + request := d.BuildMessageRequest(channelId, body) SipUA.SendRequest(request) return nil } -func (d *DBDevice) OnInviteBroadcast(request sip.Request, session *BroadcastSession) (int, string) { +func (d *Device) OnInvite(request sip.Request, user string) sip.Response { + session := FindBroadcastSessionWithSourceID(user) + if session == nil { + return CreateResponseWithStatusCode(request, http.StatusBadRequest) + } + body := request.Body() if body == "" { - return http.StatusBadRequest, "" + return CreateResponseWithStatusCode(request, http.StatusBadRequest) } sdp, err := sdp.Parse(body) if err != nil { Sugar.Infof("解析sdp失败 err:%s sdp:%s", err.Error(), body) - return http.StatusBadRequest, "" + return CreateResponseWithStatusCode(request, http.StatusBadRequest) } if sdp.Audio == nil { Sugar.Infof("处理sdp失败 缺少audio字段 sdp:%s", body) - return http.StatusBadRequest, "" + return CreateResponseWithStatusCode(request, http.StatusBadRequest) } + var answerSDP string isTcp := strings.Contains(sdp.Audio.Proto, "TCP") if !isTcp && BroadcastTypeUDP == session.Type { var client *transport.UDPClient @@ -70,25 +73,41 @@ func (d *DBDevice) OnInviteBroadcast(request sip.Request, session *BroadcastSess if err == nil { Sugar.Errorf("创建UDP广播端口失败 err:%s", err.Error()) - return http.StatusInternalServerError, "" + return CreateResponseWithStatusCode(request, http.StatusInternalServerError) } session.RemoteIP = sdp.Addr session.RemotePort = int(sdp.Audio.Port) session.Transport = client session.Transport.SetHandler(session) - return http.StatusOK, fmt.Sprintf(AnswerFormat, Config.SipId, Config.PublicIP, Config.PublicIP, client.ListenPort(), "RTP/AVP") + answerSDP = fmt.Sprintf(AnswerFormat, Config.SipId, Config.PublicIP, Config.PublicIP, client.ListenPort(), "RTP/AVP") } else { server, err := TransportManager.NewTCPServer(Config.ListenIP) if err != nil { Sugar.Errorf("创建TCP广播端口失败 err:%s", err.Error()) - return http.StatusInternalServerError, "" + return CreateResponseWithStatusCode(request, http.StatusInternalServerError) } go server.Accept() session.Transport = server session.Transport.SetHandler(session) - return http.StatusOK, fmt.Sprintf(AnswerFormat, Config.SipId, Config.PublicIP, Config.PublicIP, server.ListenPort(), "TCP/RTP/AVP") + answerSDP = fmt.Sprintf(AnswerFormat, Config.SipId, Config.PublicIP, Config.PublicIP, server.ListenPort(), "TCP/RTP/AVP") } + response := CreateResponseWithStatusCode(request, http.StatusOK) + + setToTag(response) + + session.Successful = true + session.ByeRequest = d.CreateDialogRequestFromAnswer(response, true) + + id, _ := request.CallID() + BroadcastManager.AddSessionWithCallId(id.Value(), session) + + response.SetBody(answerSDP, true) + response.AppendHeader(&SDPMessageType) + response.AppendHeader(GlobalContactAddress.AsContactHeader()) + + session.Answer <- 0 + return response } diff --git a/broadcast_manager.go b/broadcast_manager.go index d64e2eb..65486a5 100644 --- a/broadcast_manager.go +++ b/broadcast_manager.go @@ -18,11 +18,21 @@ func init() { type broadcastManager struct { rooms map[string]*BroadcastRoom //主讲人关联房间 - sessions map[string]*BroadcastSession //sessionId关联全部广播会话 - callIds map[string]*BroadcastSession //callId关联全部广播会话 + sessions map[string]*BroadcastSession //sessionId关联广播会话 + callIds map[string]*BroadcastSession //callId关联广播会话 lock sync.RWMutex } +func FindBroadcastSessionWithSourceID(user string) *BroadcastSession { + roomId := user[:10] + room := BroadcastManager.FindRoom(roomId) + if room != nil { + return room.Find(user) + } + + return nil +} + func (b *broadcastManager) CreateRoom(id string) *BroadcastRoom { b.lock.Lock() defer b.lock.Unlock() diff --git a/broadcast_room.go b/broadcast_room.go index abe0d70..b7be623 100644 --- a/broadcast_room.go +++ b/broadcast_room.go @@ -31,19 +31,19 @@ func (r *BroadcastRoom) Remove(sourceId string) { delete(r.members, sourceId) } -func (r *BroadcastRoom) Exist(sessionId string) bool { +func (r *BroadcastRoom) Exist(sourceId string) bool { r.lock.RLock() defer r.lock.RUnlock() - _, ok := r.members[sessionId] + _, ok := r.members[sourceId] return ok } -func (r *BroadcastRoom) Find(sessionId string) *BroadcastSession { +func (r *BroadcastRoom) Find(sourceId string) *BroadcastSession { r.lock.RLock() defer r.lock.RUnlock() - session, _ := r.members[sessionId] + session, _ := r.members[sourceId] return session } diff --git a/catalog.go b/catalog.go deleted file mode 100644 index 9254e83..0000000 --- a/catalog.go +++ /dev/null @@ -1,12 +0,0 @@ -package main - -func (d *DBDevice) OnCatalog(response *QueryCatalogResponse) { - if d.Channels == nil { - d.Channels = make(map[string]Channel, 5) - } - - for index := range response.DeviceList.Devices { - device := response.DeviceList.Devices[index] - d.Channels[device.DeviceID] = device - } -} diff --git a/client.go b/client.go new file mode 100644 index 0000000..49d238c --- /dev/null +++ b/client.go @@ -0,0 +1,153 @@ +package main + +import ( + "encoding/xml" + "gb-cms/sdp" + "github.com/ghettovoice/gosip/sip" + "strconv" + "strings" +) + +type GBClient interface { + SipClient + + GBDevice + + SetDeviceInfo(name, manufacturer, model, firmware string) + + // OnQueryCatalog 被查询目录 + OnQueryCatalog(sn int) + + // OnQueryDeviceInfo 被查询设备信息 + OnQueryDeviceInfo(sn int) + + OnSubscribeCatalog(sn int) + + // AddChannels 重写添加通道函数, 增加SIP通知通道变化 + AddChannels(channels []*Channel) +} + +type Client struct { + *sipClient + Device + deviceInfo *DeviceInfoResponse +} + +func (g *Client) OnQueryCatalog(sn int) { + channels := g.GetChannels() + if len(channels) == 0 { + return + } + + response := CatalogResponse{} + response.SN = sn + response.CmdType = CmdCatalog + response.DeviceID = g.sipClient.Username + response.SumNum = len(channels) + + for i, _ := range channels { + channel := *channels[i] + + response.DeviceList.Devices = nil + response.DeviceList.Num = 1 // 一次发一个通道 + response.DeviceList.Devices = append(response.DeviceList.Devices, &channel) + response.DeviceList.Devices[0].ParentID = g.sipClient.Username + + g.SendMessage(&response) + } +} + +func (g *Client) SendMessage(msg interface{}) { + marshal, err := xml.MarshalIndent(msg, "", " ") + if err != nil { + panic(err) + } + + request, err := BuildMessageRequest(g.sipClient.Username, g.sipClient.ListenAddr, g.sipClient.SeverId, g.sipClient.Domain, g.sipClient.Transport, string(marshal)) + if err != nil { + panic(err) + } + + g.sipClient.ua.SendRequest(request) +} + +func (g *Client) OnQueryDeviceInfo(sn int) { + g.deviceInfo.SN = sn + g.SendMessage(&g.deviceInfo) +} + +func (g *Client) OnInvite(request sip.Request, user string) sip.Response { + return nil +} + +func (g *Client) SetDeviceInfo(name, manufacturer, model, firmware string) { + g.deviceInfo.DeviceName = name + g.deviceInfo.Manufacturer = manufacturer + g.deviceInfo.Model = model + g.deviceInfo.Firmware = firmware +} + +func (g *Client) OnSubscribeCatalog(sn int) { + +} + +func ParseGBSDP(body string) (offer *sdp.SDP, ssrc string, speed int, media *sdp.Media, offerSetup, answerSetup string, err error) { + offer, err = sdp.Parse(body) + if err != nil { + return nil, "", 0, nil, "", "", err + } + + // 解析设置下载速度 + var setup string + for _, attr := range offer.Attrs { + if "downloadspeed" == attr[0] { + speed, err = strconv.Atoi(attr[1]) + if err != nil { + return nil, "", 0, nil, "", "", err + } + } else if "setup" == attr[0] { + setup = attr[1] + } + } + + // 解析ssrc + for _, attr := range offer.Other { + if "y" == attr[0] { + ssrc = attr[1] + } + } + + if offer.Video != nil { + media = offer.Video + } else if offer.Audio != nil { + media = offer.Audio + } + + tcp := strings.HasPrefix(media.Proto, "TCP") + if "passive" == setup && tcp { + offerSetup = "passive" + answerSetup = "active" + } else if "active" == setup && tcp { + offerSetup = "active" + answerSetup = "passive" + } + + return +} + +func NewGBClient(username, serverId, serverAddr, transport, password string, registerExpires, keepalive int, ua SipServer) GBClient { + sip := &sipClient{ + Username: username, + Domain: serverAddr, + Transport: transport, + Password: password, + RegisterExpires: registerExpires, + KeeAliveInterval: keepalive, + SeverId: serverId, + ListenAddr: ua.ListenAddr(), + ua: ua, + } + + client := &Client{sip, Device{ID: username, Channels: map[string]*Channel{}}, &DeviceInfoResponse{BaseResponse: BaseResponse{BaseMessage: BaseMessage{DeviceID: username, CmdType: CmdDeviceInfo}, Result: "OK"}}} + return client +} diff --git a/client_benchmark_test.go b/client_benchmark_test.go new file mode 100644 index 0000000..868628c --- /dev/null +++ b/client_benchmark_test.go @@ -0,0 +1,335 @@ +package main + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "github.com/ghettovoice/gosip/sip" + "github.com/lkmio/avformat/librtp" + "github.com/lkmio/avformat/transport" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +var ( + rtpPackets [][]byte + locks map[uint32]*sync.RWMutex +) + +type MediaStream struct { + ssrc uint32 + tcp bool + conn net.Conn + transport transport.ITransport + cancel context.CancelFunc + dialog sip.Request + ctx context.Context + + closedCB func(sendBye bool) +} + +func (m *MediaStream) write() { + var index int + length := len(rtpPackets) + for m.ctx.Err() == nil && index < length { + time.Sleep(time.Millisecond * 30) + + //一次发某个时间范围内的所有rtp包 + ts := binary.BigEndian.Uint32(rtpPackets[index][2+4:]) + mutex := locks[ts] + { + mutex.Lock() + + for ; m.ctx.Err() == nil && index < length; index++ { + bytes := rtpPackets[index] + nextTS := binary.BigEndian.Uint32(bytes[2+4:]) + if nextTS != ts { + break + } + + librtp.ModifySSRC(bytes[2:], m.ssrc) + + if m.tcp { + m.conn.Write(bytes) + } else { + m.transport.(*transport.UDPClient).Write(bytes[2:]) + } + } + + mutex.Unlock() + } + } + + println("推流结束") + m.Close(true) +} + +func (m *MediaStream) Start() { + m.ctx, m.cancel = context.WithCancel(context.Background()) + go m.write() +} + +func (m *MediaStream) Close(sendBye bool) { + m.cancel() + + if m.closedCB != nil { + m.closedCB(sendBye) + } +} + +func (m *MediaStream) OnConnected(conn net.Conn) []byte { + m.conn = conn + fmt.Printf("tcp连接:%s", conn.RemoteAddr()) + return nil +} + +func (m *MediaStream) OnPacket(conn net.Conn, data []byte) []byte { + return nil +} + +func (m *MediaStream) OnDisConnected(conn net.Conn, err error) { + fmt.Printf("tcp断开连接:%s", conn.RemoteAddr()) + m.Close(true) +} + +type VirtualDevice struct { + *Client + streams map[string]*MediaStream + lock sync.Locker +} + +func CreateTransport(ip string, port int, setup string, handler transport.Handler) (transport.ITransport, bool, error) { + if "passive" == setup { + tcpClient := &transport.TCPClient{} + tcpClient.SetHandler(handler) + + err := tcpClient.Connect(nil, &net.TCPAddr{IP: net.ParseIP(ip), Port: port}) + return tcpClient, true, err + } else if "active" == setup { + tcpServer := &transport.TCPServer{} + tcpServer.SetHandler(handler) + err := tcpServer.Bind(nil) + + return tcpServer, true, err + } else { + udp := &transport.UDPClient{} + err := udp.Connect(nil, &net.UDPAddr{IP: net.ParseIP(ip), Port: port}) + return udp, false, err + } +} + +func (v VirtualDevice) OnInvite(request sip.Request, user string) sip.Response { + if len(rtpPackets) < 1 { + return CreateResponseWithStatusCode(request, http.StatusInternalServerError) + } + + offer, ssrc, speed, media, offerSetup, answerSetup, err := ParseGBSDP(request.Body()) + if err != nil { + return CreateResponseWithStatusCode(request, http.StatusBadRequest) + } + + stream := &MediaStream{} + socket, tcp, err := CreateTransport(offer.Addr, int(media.Port), offerSetup, stream) + if err != nil { + return CreateResponseWithStatusCode(request, http.StatusBadRequest) + } + + time := strings.Split(offer.Time, " ") + if len(time) < 2 { + return CreateResponseWithStatusCode(request, http.StatusBadRequest) + } + + var ip string + var port sip.Port + var contactAddr string + if v.sipClient.NatAddr != "" { + contactAddr = v.sipClient.NatAddr + } else { + contactAddr = v.sipClient.ListenAddr + } + + host, p, _ := net.SplitHostPort(contactAddr) + ip = host + atoi, _ := strconv.Atoi(p) + port = sip.Port(atoi) + + contactAddress := &sip.Address{ + Uri: &sip.SipUri{ + FUser: sip.String{Str: user}, + FHost: ip, + FPort: &port, + }, + } + + answer := BuildSDP(user, offer.Session, ip, uint16(socket.ListenPort()), time[0], time[1], answerSetup, speed, ssrc) + response := CreateResponseWithStatusCode(request, http.StatusOK) + response.RemoveHeader("Contact") + response.AppendHeader(contactAddress.AsContactHeader()) + response.AppendHeader(&SDPMessageType) + response.SetBody(answer, true) + setToTag(response) + + i, _ := strconv.Atoi(ssrc) + stream.ssrc = uint32(i) + stream.tcp = tcp + stream.dialog = CreateDialogRequestFromAnswer(response, true, v.sipClient.Domain) + callId, _ := response.CallID() + + { + v.lock.Lock() + defer v.lock.Unlock() + v.streams[callId.Value()] = stream + } + + // 设置网络断开回调 + stream.closedCB = func(sendBye bool) { + if stream.dialog != nil { + id, _ := stream.dialog.CallID() + StreamManager.RemoveWithCallId(id.Value()) + + { + v.lock.Lock() + delete(v.streams, id.Value()) + v.lock.Unlock() + } + + if sendBye { + bye := CreateRequestFromDialog(stream.dialog, sip.BYE) + v.sipClient.ua.SendRequest(bye) + } + + stream.dialog = nil + } + + if stream.transport != nil { + stream.transport.Close() + stream.transport = nil + } + } + + stream.transport = socket + stream.Start() + + // 绑定到StreamManager, bye请求才会找到设备回调 + streamId := GenerateStreamId(InviteTypeLive, v.sipClient.Username, user, "", "") + s := Stream{ID: streamId, DialogRequest: stream.dialog} + StreamManager.Add(&s) + + callID, _ := request.CallID() + StreamManager.AddWithCallId(callID.Value(), &s) + return response +} + +func (v VirtualDevice) OnBye(request sip.Request) { + id, _ := request.CallID() + stream, ok := v.streams[id.Value()] + if !ok { + return + } + + { + // 此作用域内defer不会生效 + v.lock.Lock() + delete(v.streams, id.Value()) + v.lock.Unlock() + } + + stream.Close(false) +} + +type ClientConfig struct { + DeviceIDPrefix string `json:"device_id_prefix"` + ChannelIDPrefix string `json:"channel_id_prefix"` + ServerID string `json:"server_id"` + Domain string `json:"domain"` + Password string `json:"password"` + ListenAddr string `json:"listenAddr"` + Count int `json:"count"` + RawFilePath string `json:"rtp_over_tcp_raw_file_path"` // rtp over tcp源文件 +} + +func TestGBClient(t *testing.T) { + configData, err := os.ReadFile("./client_benchmark_test_config.json") + if err != nil { + panic(err) + } + + clientConfig := &ClientConfig{} + if err = json.Unmarshal(configData, clientConfig); err != nil { + panic(err) + } + + rtpData, err := os.ReadFile(clientConfig.RawFilePath) + if err != nil { + println("读取rtp源文件 不能推流") + } else { + // 分割rtp包 + offset := 2 + length := len(rtpData) + locks = make(map[uint32]*sync.RWMutex, 128) + for rtpSize := 0; offset < length; offset += rtpSize + 2 { + rtpSize = int(binary.BigEndian.Uint16(rtpData[offset-2:])) + if length-offset < rtpSize { + break + } + + bytes := rtpData[offset : offset+rtpSize] + ts := binary.BigEndian.Uint32(bytes[4:]) + // 每个相同时间戳共用一把互斥锁, 只允许同时一路流发送该时间戳内的rtp包, 保护ssrc被不同的流修改 + if _, ok := locks[ts]; !ok { + locks[ts] = &sync.RWMutex{} + } + + rtpPackets = append(rtpPackets, rtpData[offset-2:offset+rtpSize]) + } + } + + // 初始化UA配置, 防止SipServer使用时空指针 + Config = &Config_{} + + listenIP, listenPort, err := net.SplitHostPort(clientConfig.ListenAddr) + if err != nil { + panic(err) + } + + atoi, err := strconv.Atoi(listenPort) + if err != nil { + panic(err) + } + + server, err := StartSipServer("", listenIP, listenIP, atoi) + if err != nil { + panic(err) + } + + for i := 0; i < clientConfig.Count; i++ { + deviceId := clientConfig.DeviceIDPrefix + fmt.Sprintf("%07d", i+1) + channelId := clientConfig.ChannelIDPrefix + fmt.Sprintf("%07d", i+1) + client := NewGBClient(deviceId, clientConfig.ServerID, clientConfig.Domain, "UDP", clientConfig.Password, 500, 40, server) + + device := VirtualDevice{client.(*Client), map[string]*MediaStream{}, &sync.Mutex{}} + device.SetDeviceInfo(fmt.Sprintf("测试设备%d", i+1), "lkmio", "lkmio_gb", "dev-0.0.1") + + var channels []*Channel + channels = append(channels, &Channel{ + DeviceID: channelId, + Name: "1", + ParentID: deviceId, + }) + + DeviceManager.Add(device) + device.AddChannels(channels) + device.Start() + } + + for { + time.Sleep(time.Second * 3) + } +} diff --git a/client_benchmark_test_config.json b/client_benchmark_test_config.json new file mode 100644 index 0000000..66c80d5 --- /dev/null +++ b/client_benchmark_test_config.json @@ -0,0 +1,10 @@ +{ + "device_id_prefix": "3402000000132", + "channel_id_prefix": "3402000000131", + "server_id": "34020000002000000001", + "domain": "192.168.2.148:5060", + "password": "12345678", + "listenAddr": "192.168.2.148:15062", + "count": 1, + "rtp_over_tcp_raw_file_path": "./rtp.raw" +} \ No newline at end of file diff --git a/db_local.go b/db_local.go index 8ec8401..e7d17b3 100644 --- a/db_local.go +++ b/db_local.go @@ -4,25 +4,26 @@ package main type LocalDB struct { } -func (m LocalDB) LoadDevices() []*DBDevice { +func (m LocalDB) LoadDevices() []*Device { return nil } -func (m LocalDB) RegisterDevice(device *DBDevice) (error, bool) { +func (m LocalDB) RegisterDevice(device *Device) (error, bool) { //持久化... device.Status = "ON" - d := DeviceManager.Find(device.Id) - if d != nil { - d.Status = "ON" - d.RemoteAddr = device.RemoteAddr - d.Name = device.Name - d.Transport = device.Transport + oldDevice := DeviceManager.Find(device.ID) + if oldDevice != nil { + oldDevice.(*Device).Status = "ON" + oldDevice.(*Device).RemoteAddr = device.RemoteAddr + oldDevice.(*Device).Name = device.Name + oldDevice.(*Device).Transport = device.Transport + device = oldDevice.(*Device) } else if err := DeviceManager.Add(device); err != nil { return err, false } - return nil, d == nil || len(d.Channels) == 0 + return nil, oldDevice == nil || len(device.Channels) == 0 } func (m LocalDB) UnRegisterDevice(id string) { @@ -31,9 +32,25 @@ func (m LocalDB) UnRegisterDevice(id string) { return } - device.Status = "OFF" + device.(*Device).Status = "OFF" } -func (m LocalDB) KeepAliveDevice(device *DBDevice) { +func (m LocalDB) KeepAliveDevice(device *Device) { } + +func (m LocalDB) AddPlatform(record GBPlatformRecord) error { + //if ExistPlatform(record.SeverID) { + // return + //} + + return nil +} + +func (m LocalDB) LoadPlatforms() []GBPlatformRecord { + //if ExistPlatform(record.SeverID) { + // return + //} + + return nil +} diff --git a/device.go b/device.go index 7ca211b..d72389f 100644 --- a/device.go +++ b/device.go @@ -1,11 +1,11 @@ package main import ( - "encoding/xml" "fmt" "github.com/ghettovoice/gosip/sip" "net" "strconv" + "sync" ) const ( @@ -27,68 +27,231 @@ var ( SDPMessageType sip.ContentType = "application/sdp" ) -type DBDevice struct { - Id string `json:"id"` - Name string `json:"name"` - RemoteAddr string `json:"remote_addr"` - Transport string `json:"transport"` - Status string `xml:"Status,omitempty"` //在线状态 ON-在线/OFF-离线 - Channels map[string]Channel `json:"channels"` +type GBDevice interface { + GetID() string + + QueryCatalog() + + QueryRecord(channelId, startTime, endTime string, sn int, type_ string) error + + //Invite(channel string, setup string) + + OnCatalog(response *CatalogResponse) + + OnRecord(response *QueryRecordInfoResponse) + + OnDeviceInfo(response *DeviceInfoResponse) + + // OnInvite 语音广播 + OnInvite(request sip.Request, user string) sip.Response + + // OnBye 设备侧主动挂断 + OnBye(request sip.Request) + + OnNotifyPosition(notify *MobilePositionNotify) + + // + //OnNotifyCatalog() + // + //OnNotifyAlarm() + + SubscribePosition(channelId string) error + + //SubscribeCatalog() + // + //SubscribeAlarm() + + Broadcast(sourceId, channelId string) error + + OnKeepalive() + + // AddChannels 批量添加通道 + AddChannels(channels []*Channel) + + // GetChannels 获取所有通道 + GetChannels() []*Channel + + // FindChannel 根据通道ID查找通道 + FindChannel(id string) *Channel + + // RemoveChannel 根据通道ID删除通道 + RemoveChannel(id string) *Channel + + // UpdateChannel 订阅目录,通道发生改变 + // 附录P.4.2.2 + // @Params event ON-上线/OFF-离线/VLOST-视频丢失/DEFECT-故障/ADD-增加/DEL-删除/UPDATE-更新 + UpdateChannel(id string, event string) } -type Channel struct { - DeviceID string `xml:"DeviceID"` - Name string `xml:"Name,omitempty"` - Manufacturer string `xml:"Manufacturer,omitempty"` - Model string `xml:"Model,omitempty"` - Owner string `xml:"Owner,omitempty"` - CivilCode string `xml:"CivilCode,omitempty"` - Block string `xml:"Block,omitempty"` - Address string `xml:"Address,omitempty"` - Parental int `xml:"Parental,omitempty"` - ParentID string `xml:"ParentID,omitempty"` - SafetyWay int `xml:"SafetyWay,omitempty"` - RegisterWay int `xml:"RegisterWay,omitempty"` - CertNum string `xml:"CertNum,omitempty"` - Certifiable int `xml:"Certifiable,omitempty"` - ErrCode int `xml:"ErrCode,omitempty"` - EndTime string `xml:"EndTime,omitempty"` - Secrecy string `xml:"Secrecy,omitempty"` - IPAddress string `xml:"IPAddress,omitempty"` - Port int `xml:"Port,omitempty"` - Password string `xml:"Password,omitempty"` - Status string `xml:"Status,omitempty"` - Longitude string `xml:"Longitude,omitempty"` - Latitude string `xml:"Latitude,omitempty"` +type Device struct { + ID string `json:"id"` + Name string `json:"name"` + RemoteAddr string `json:"remote_addr"` + Transport string `json:"transport"` + Status string `json:"Status,omitempty"` //在线状态 ON-在线/OFF-离线 + Channels map[string]*Channel `json:"channels"` + lock sync.RWMutex } -type DeviceList struct { - Num int `xml:"Num,attr"` - Devices []Channel `xml:"Item"` +func (d *Device) GetID() string { + return d.ID } -type QueryCatalogResponse struct { - XMLName xml.Name `xml:"Response"` - CmdType string `xml:"CmdType"` - SN int `xml:"SN"` - DeviceID string `xml:"DeviceID"` - SumNum int `xml:"SumNum"` - DeviceList DeviceList `xml:"DeviceList"` +func (d *Device) BuildMessageRequest(to, body string) sip.Request { + request, err := BuildMessageRequest(Config.SipId, net.JoinHostPort(GlobalContactAddress.Uri.Host(), GlobalContactAddress.Uri.Port().String()), to, d.RemoteAddr, d.Transport, body) + if err != nil { + panic(err) + } + return request } -func (d *DBDevice) BuildCatalogRequest() (sip.Request, error) { - body := fmt.Sprintf(CatalogFormat, "1", d.Id) - return d.BuildMessageRequest(d.Id, body) +func (d *Device) QueryCatalog() { + body := fmt.Sprintf(CatalogFormat, "1", d.ID) + request := d.BuildMessageRequest(d.ID, body) + SipUA.SendRequest(request) } -func (d *DBDevice) BuildMessageRequest(to, body string) (sip.Request, error) { - builder := d.NewRequestBuilder(sip.MESSAGE, Config.SipId, Config.SipContactAddr, to) +func (d *Device) QueryRecord(channelId, startTime, endTime string, sn int, type_ string) error { + body := fmt.Sprintf(QueryRecordFormat, sn, channelId, startTime, endTime, type_) + request := d.BuildMessageRequest(channelId, body) + SipUA.SendRequest(request) + return nil +} + +func (d *Device) OnBye(request sip.Request) { + +} + +func (d *Device) OnCatalog(response *CatalogResponse) { + for _, device := range response.DeviceList.Devices { + device.ParentID = d.ID + } + + d.AddChannels(response.DeviceList.Devices) +} + +func (d *Device) OnRecord(response *QueryRecordInfoResponse) { + event := SNManager.FindEvent(response.SN) + if event == nil { + Sugar.Errorf("处理录像查询响应失败 SN:%d", response.SN) + return + } + + event(response) +} + +func (d *Device) OnDeviceInfo(response *DeviceInfoResponse) { + +} + +func (d *Device) OnNotifyPosition(notify *MobilePositionNotify) { + +} + +func (d *Device) SubscribePosition(channelId string) error { + if channelId == "" { + channelId = d.ID + } + + //暂时不考虑级联 + builder := d.NewRequestBuilder(sip.SUBSCRIBE, Config.SipId, Config.SipContactAddr, channelId) + body := fmt.Sprintf(MobilePositionMessageFormat, "1", channelId, Config.MobilePositionInterval) + + expiresHeader := sip.Expires(Config.MobilePositionExpires) + builder.SetExpires(&expiresHeader) builder.SetContentType(&XmlMessageType) + builder.SetContact(GlobalContactAddress) builder.SetBody(body) - return builder.Build() + + request, err := builder.Build() + if err != nil { + return err + } + + event := Event("Catalog;id=2") + request.AppendHeader(&event) + response, err := SipUA.SendRequestWithTimeout(5, request) + if err != nil { + return err + } + + if response.StatusCode() != 200 { + return fmt.Errorf("err code %d", response.StatusCode()) + } + + return nil } -func (d *DBDevice) NewSIPRequestBuilderWithTransport() *sip.RequestBuilder { +func (d *Device) Broadcast(sourceId, channelId string) error { + body := fmt.Sprintf(BroadcastFormat, 1, sourceId, channelId) + request := d.BuildMessageRequest(channelId, body) + SipUA.SendRequest(request) + return nil +} + +func (d *Device) OnKeepalive() { + +} + +func (d *Device) AddChannels(channels []*Channel) { + d.lock.Lock() + defer d.lock.Unlock() + + if d.Channels == nil { + d.Channels = make(map[string]*Channel, 5) + } + + for i, _ := range channels { + d.Channels[channels[i].DeviceID] = channels[i] + } +} + +func (d *Device) GetChannels() []*Channel { + d.lock.RLock() + defer d.lock.RUnlock() + + var channels []*Channel + for _, channel := range d.Channels { + channels = append(channels, channel) + } + + return channels +} + +func (d *Device) RemoveChannel(id string) *Channel { + d.lock.Lock() + defer d.lock.Unlock() + + if channel, ok := d.Channels[id]; ok { + delete(d.Channels, id) + return channel + } + + return nil +} + +func (d *Device) FindChannel(id string) *Channel { + d.lock.RLock() + defer d.lock.RUnlock() + + if channel, ok := d.Channels[id]; ok { + return channel + } + return nil +} + +func (d *Device) UpdateChannel(id string, event string) { + d.lock.RLock() + defer d.lock.RUnlock() +} + +func (d *Device) BuildCatalogRequest() (sip.Request, error) { + body := fmt.Sprintf(CatalogFormat, "1", d.ID) + request := d.BuildMessageRequest(d.ID, body) + return request, nil +} + +func (d *Device) NewSIPRequestBuilderWithTransport() *sip.RequestBuilder { builder := sip.NewRequestBuilder() hop := sip.ViaHop{ Transport: d.Transport, @@ -98,7 +261,7 @@ func (d *DBDevice) NewSIPRequestBuilderWithTransport() *sip.RequestBuilder { return builder } -func (d *DBDevice) NewRequestBuilder(method sip.RequestMethod, from, realm, to string) *sip.RequestBuilder { +func (d *Device) NewRequestBuilder(method sip.RequestMethod, fromUser, realm, toUser string) *sip.RequestBuilder { builder := d.NewSIPRequestBuilderWithTransport() builder.SetMethod(method) @@ -107,7 +270,7 @@ func (d *DBDevice) NewRequestBuilder(method sip.RequestMethod, from, realm, to s sipPort := sip.Port(port) requestUri := &sip.SipUri{ - FUser: sip.String{Str: to}, + FUser: sip.String{Str: toUser}, FHost: host, FPort: &sipPort, } @@ -116,7 +279,7 @@ func (d *DBDevice) NewRequestBuilder(method sip.RequestMethod, from, realm, to s fromAddress := &sip.Address{ Uri: &sip.SipUri{ - FUser: sip.String{Str: from}, + FUser: sip.String{Str: fromUser}, FHost: realm, }, } @@ -130,80 +293,46 @@ func (d *DBDevice) NewRequestBuilder(method sip.RequestMethod, from, realm, to s return builder } -func (d *DBDevice) BuildSDP(userName, sessionName, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) string { - format := "v=0\r\n" + - "o=%s 0 0 IN IP4 %s\r\n" + - "s=%s\r\n" + - "c=IN IP4 %s\r\n" + - "t=%s %s\r\n" + - "m=video %d %s 96\r\n" + - "a=%s\r\n" + - "a=rtpmap:96 PS/90000\r\n" - - tcpFormat := "a=setup:%s\r\n" + - "a=connection:new\r\n" - - var tcp bool - var mediaProtocol string - if "active" == setup || "passive" == setup { - mediaProtocol = "TCP/RTP/AVP" - tcp = true - } else { - mediaProtocol = "RTP/AVP" - } - - sdp := fmt.Sprintf(format, userName, ip, sessionName, ip, startTime, stopTime, port, mediaProtocol, "recvonly") - if tcp { - sdp += fmt.Sprintf(tcpFormat, setup) - } - - if speed > 0 { - sdp += fmt.Sprintf("a=downloadspeed:%d\r\n", speed) - } - - sdp += fmt.Sprintf("y=%s\r\n", ssrc) - return sdp -} - -func (d *DBDevice) BuildInviteRequest(sessionName, channelId, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) (sip.Request, error) { +func (d *Device) BuildInviteRequest(sessionName, channelId, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) (sip.Request, error) { builder := d.NewRequestBuilder(sip.INVITE, Config.SipId, Config.SipContactAddr, channelId) - sdp := d.BuildSDP(Config.SipId, sessionName, ip, port, startTime, stopTime, setup, speed, ssrc) + sdp := BuildSDP(Config.SipId, sessionName, ip, port, startTime, stopTime, setup, speed, ssrc) builder.SetContentType(&SDPMessageType) - builder.SetContact(globalContactAddress) + builder.SetContact(GlobalContactAddress) builder.SetBody(sdp) request, err := builder.Build() if err != nil { return nil, err } - var subjectHeader = Subject(channelId + ":" + channelId + "," + Config.SipId + ":" + ssrc) + var subjectHeader = Subject(channelId + ":" + d.ID + "," + Config.SipId + ":" + ssrc) request.AppendHeader(subjectHeader) return request, err } -func (d *DBDevice) BuildLiveRequest(channelId, ip string, port uint16, setup string, ssrc string) (sip.Request, error) { +func (d *Device) BuildLiveRequest(channelId, ip string, port uint16, setup string, ssrc string) (sip.Request, error) { return d.BuildInviteRequest("Play", channelId, ip, port, "0", "0", setup, 0, ssrc) } -func (d *DBDevice) BuildPlaybackRequest(channelId, ip string, port uint16, startTime, stopTime, setup string, ssrc string) (sip.Request, error) { +func (d *Device) BuildPlaybackRequest(channelId, ip string, port uint16, startTime, stopTime, setup string, ssrc string) (sip.Request, error) { return d.BuildInviteRequest("Playback", channelId, ip, port, startTime, stopTime, setup, 0, ssrc) } -func (d *DBDevice) BuildDownloadRequest(channelId, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) (sip.Request, error) { +func (d *Device) BuildDownloadRequest(channelId, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) (sip.Request, error) { return d.BuildInviteRequest("Download", channelId, ip, port, startTime, stopTime, setup, speed, ssrc) } // CreateDialogRequestFromAnswer 根据invite的应答创建Dialog请求 // 应答的to头域需携带tag -func (d *DBDevice) CreateDialogRequestFromAnswer(message sip.Response, uas bool) sip.Request { + +func CreateDialogRequestFromAnswer(message sip.Response, uas bool, remoteAddr string) sip.Request { from, _ := message.From() to, _ := message.To() id, _ := message.CallID() requestLine := &sip.SipUri{} requestLine.SetUser(from.Address.User()) - host, port, _ := net.SplitHostPort(d.RemoteAddr) + host, port, _ := net.SplitHostPort(remoteAddr) portInt, _ := strconv.Atoi(port) sipPort := sip.Port(portInt) requestLine.SetHost(host) @@ -211,7 +340,7 @@ func (d *DBDevice) CreateDialogRequestFromAnswer(message sip.Response, uas bool) seq, _ := message.CSeq() - builder := d.NewSIPRequestBuilderWithTransport() + builder := NewSIPRequestBuilderWithTransport(message.Transport()) if uas { builder.SetFrom(sip.NewAddressFromToHeader(to)) builder.SetTo(sip.NewAddressFromFromHeader(from)) @@ -231,3 +360,7 @@ func (d *DBDevice) CreateDialogRequestFromAnswer(message sip.Response, uas bool) return request } + +func (d *Device) CreateDialogRequestFromAnswer(message sip.Response, uas bool) sip.Request { + return CreateDialogRequestFromAnswer(message, uas, d.RemoteAddr) +} diff --git a/device_db.go b/device_db.go index 099cd6c..cc181a7 100644 --- a/device_db.go +++ b/device_db.go @@ -1,11 +1,23 @@ package main type DeviceDB interface { - LoadDevices() []*DBDevice + LoadDevices() []*Device - RegisterDevice(device *DBDevice) (error, bool) + RegisterDevice(device *Device) (error, bool) UnRegisterDevice(id string) - KeepAliveDevice(device *DBDevice) + KeepAliveDevice(device *Device) + + LoadPlatforms() []GBPlatformRecord + + AddPlatform(record GBPlatformRecord) error + + //RemovePlatform(record GBPlatformRecord) (GBPlatformRecord, bool) + // + //PlatformList() []GBPlatformRecord + // + //BindPlatformChannel() bool + // + //UnbindPlatformChannel() bool } diff --git a/device_manager.go b/device_manager.go index df648d7..c796e13 100644 --- a/device_manager.go +++ b/device_manager.go @@ -15,37 +15,37 @@ type deviceManager struct { m sync.Map } -func (s *deviceManager) Add(device *DBDevice) error { - _, ok := s.m.LoadOrStore(device.Id, device) +func (s *deviceManager) Add(device GBDevice) error { + _, ok := s.m.LoadOrStore(device.GetID(), device) if ok { - return fmt.Errorf("the device %s has been exist", device.Id) + return fmt.Errorf("the device %s has been exist", device.GetID()) } return nil } -func (s *deviceManager) Find(id string) *DBDevice { +func (s *deviceManager) Find(id string) GBDevice { value, ok := s.m.Load(id) if ok { - return value.(*DBDevice) + return value.(GBDevice) } return nil } -func (s *deviceManager) Remove(id string) (*DBDevice, error) { +func (s *deviceManager) Remove(id string) (GBDevice, error) { value, loaded := s.m.LoadAndDelete(id) if loaded { - return value.(*DBDevice), nil + return value.(GBDevice), nil } return nil, fmt.Errorf("device with id %s was not find", id) } -func (s *deviceManager) AllDevices() []DBDevice { - var devices []DBDevice +func (s *deviceManager) AllDevices() []GBDevice { + var devices []GBDevice s.m.Range(func(key, value any) bool { - devices = append(devices, *value.(*DBDevice)) + devices = append(devices, value.(GBDevice)) return true }) diff --git a/live.go b/live.go index 8f360b8..08acc9e 100644 --- a/live.go +++ b/live.go @@ -20,13 +20,67 @@ const ( InviteTypeDownload = InviteType(2) ) -func (d *DBDevice) Invite(inviteType InviteType, streamId, channelId, ip string, port uint16, startTime, stopTime, setup string, speed int) (sip.Request, bool) { +func (i *InviteType) SessionName2Type(name string) { + switch name { + case "download": + *i = InviteTypeDownload + break + case "playback": + *i = InviteTypePlayback + break + //case "play": + default: + *i = InviteTypeLive + break + } +} + +func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId, startTime, stopTime, setup string, speed int, sync bool) (*Stream, bool) { + stream := &Stream{ + ID: streamId, + forwardSinks: map[string]*Sink{}, + } + + // 先添加占位置, 防止重复请求 + if oldStream, b := StreamManager.Add(stream); !b { + return oldStream, true + } + + if dialog, ok := d.Invite(inviteType, streamId, channelId, startTime, stopTime, setup, speed); ok { + stream.DialogRequest = dialog + callID, _ := dialog.CallID() + StreamManager.AddWithCallId(callID.Value(), stream) + } else { + StreamManager.Remove(streamId) + return nil, false + } + + //开启收流超时 + wait := func() bool { + ok := stream.WaitForPublishEvent(10) + if !ok { + Sugar.Infof("收流超时 发送bye请求...") + CloseStream(streamId) + } + return ok + } + + if sync { + go wait() + } else if !sync && !wait() { + return nil, false + } + + return stream, true +} + +func (d *Device) Invite(inviteType InviteType, streamId StreamID, channelId, startTime, stopTime, setup string, speed int) (sip.Request, bool) { var ok bool var ssrc string defer func() { if !ok { - go CloseGBSource(streamId) + go CloseGBSource(string(streamId)) } }() @@ -37,7 +91,7 @@ func (d *DBDevice) Invite(inviteType InviteType, streamId, channelId, ip string, } ssrcValue, _ := strconv.Atoi(ssrc) - ip, port, err := CreateGBSource(streamId, setup, uint32(ssrcValue)) + ip, port, err := CreateGBSource(string(streamId), setup, uint32(ssrcValue)) if err != nil { Sugar.Errorf("创建GBSource失败 err:%s", err.Error()) return nil, false @@ -67,7 +121,7 @@ func (d *DBDevice) Invite(inviteType InviteType, streamId, channelId, ip string, } else if res.StatusCode() == 200 { answer = res.Body() ackRequest := sip.NewAckRequest("", inviteRequest, res, "", nil) - ackRequest.AppendHeader(globalContactAddress.AsContactHeader()) + ackRequest.AppendHeader(GlobalContactAddress.AsContactHeader()) //手动替换ack请求目标地址, answer的contact可能不对. recipient := ackRequest.Recipient() remoteIP, remotePortStr, _ := net.SplitHostPort(d.RemoteAddr) @@ -98,17 +152,14 @@ func (d *DBDevice) Invite(inviteType InviteType, streamId, channelId, ip string, if "active" == setup { parse, err := sdp.Parse(answer) - if err != nil { - ok = false - Sugar.Errorf("解析应答sdp失败 err:%s sdp:%s", err.Error(), answer) + ok = err == nil && parse.Video != nil && parse.Video.Port != 0 + if !ok { + Sugar.Errorf("解析应答sdp失败 err:%v sdp:%s", err, answer) return nil, false - } else if parse.Video == nil || parse.Video.Port == 0 { - ok = false - Sugar.Errorf("answer中没有视频连接地址 sdp:%s", answer) } addr := fmt.Sprintf("%s:%d", parse.Addr, parse.Video.Port) - if err = ConnectGBSource(streamId, addr); err != nil { + if err = ConnectGBSource(string(streamId), addr); err != nil { ok = false Sugar.Errorf("设置GB28181连接地址失败 err:%s addr:%s", err.Error(), addr) } @@ -117,15 +168,15 @@ func (d *DBDevice) Invite(inviteType InviteType, streamId, channelId, ip string, return dialogRequest, ok } -func (d *DBDevice) Live(streamId, channelId, setup string) (sip.Request, bool) { - return d.Invite(InviteTypeLive, streamId, channelId, "", 0, "", "", setup, 0) +func (d *Device) Live(streamId StreamID, channelId, setup string) (sip.Request, bool) { + return d.Invite(InviteTypeLive, streamId, channelId, "", "", setup, 0) } -func (d *DBDevice) Playback(streamId, channelId, startTime, stopTime, setup string) (sip.Request, bool) { - return d.Invite(InviteTypePlayback, streamId, channelId, "", 0, startTime, stopTime, setup, 0) +func (d *Device) Playback(streamId StreamID, channelId, startTime, stopTime, setup string) (sip.Request, bool) { + return d.Invite(InviteTypePlayback, streamId, channelId, startTime, stopTime, setup, 0) } -func (d *DBDevice) Download(streamId, channelId, startTime, stopTime, setup string, speed int) (sip.Request, bool) { - return d.Invite(InviteTypePlayback, streamId, channelId, "", 0, startTime, stopTime, setup, speed) +func (d *Device) Download(streamId StreamID, channelId, startTime, stopTime, setup string, speed int) (sip.Request, bool) { + return d.Invite(InviteTypePlayback, streamId, channelId, startTime, stopTime, setup, speed) } diff --git a/live_benchmark_test.go b/live_benchmark_test.go new file mode 100644 index 0000000..90ea45a --- /dev/null +++ b/live_benchmark_test.go @@ -0,0 +1,127 @@ +package main + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "testing" + "time" +) + +func request(url string, body []byte) (*http.Response, error) { + client := &http.Client{} + request, err := http.NewRequest("post", url, bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + return client.Do(request) +} + +func queryAllDevices() []Device { + response, err := request("http://localhost:9000/api/v1/device/list", nil) + if err != nil { + panic(err) + } + + all, err := io.ReadAll(response.Body) + if err != nil { + panic(err) + } + + v := struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data []Device `json:"data,omitempty"` + }{} + + err = json.Unmarshal(all, &v) + if err != nil { + panic(err) + } + + return v.Data +} + +func startLive(deviceId, channelId, setup string) (bool, string) { + params := map[string]string{ + "device_id": deviceId, + "channel_id": channelId, + "setup": setup, + } + + requestBody, err := json.Marshal(params) + if err != nil { + panic(err) + } + + response, err := request("http://localhost:9000/api/v1/live/start", requestBody) + if err != nil { + panic(err) + } + + if response.StatusCode != 200 { + return false, "" + } + + all, err := io.ReadAll(response.Body) + if len(all) == 0 { + return true, "" + } + + v := struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data map[string]string `json:"data,omitempty"` + }{} + + err = json.Unmarshal(all, &v) + if err != nil { + panic(err) + } + + return true, v.Data["stream_id"] +} + +func startLiveAll(setup string) { + devices := queryAllDevices() + if len(devices) == 0 { + return + } + + max := 50 + for _, device := range devices { + for _, channel := range device.Channels { + go startLive(device.ID, channel.DeviceID, setup) + max-- + if max < 1 { + return + } + } + } +} + +func TestLiveAll(t *testing.T) { + index := 0 + + for { + index++ + var setup string + + if index%1 == 0 { + setup = "udp" + } else if index%2 == 0 { + setup = "passive" + } else if index%3 == 0 { + setup = "active" + } else if index%4 == 0 { + //关闭所有流,再请求 + } + + go startLiveAll(setup) + + time.Sleep(60 * time.Second) + } +} diff --git a/main.go b/main.go index e8e95ee..3f4a28c 100644 --- a/main.go +++ b/main.go @@ -46,7 +46,7 @@ func main() { DeviceManager.Add(device) } - server, err := StartSipServer(config) + server, err := StartSipServer(config.SipId, config.ListenIP, config.PublicIP, config.SipPort) if err != nil { panic(err) } diff --git a/media_server.go b/media_server.go index 0c805d2..8de9045 100644 --- a/media_server.go +++ b/media_server.go @@ -85,3 +85,45 @@ func CloseGBSource(id string) error { _, err := Send("api/v1/gb28181/source/close", v) return err } + +func AddForwardStreamSink(id, serverAddr, setup string, ssrc uint32) (ip string, port uint16, sinkId string, err error) { + v := struct { + Source string `json:"source"` + Addr string `json:"addr"` + Setup string `json:"setup"` + SSRC uint32 `json:"ssrc"` + }{ + Source: id, + Addr: serverAddr, + Setup: setup, + SSRC: ssrc, + } + + response, err := Send("api/v1/gb28181/forward", v) + if err != nil { + return "", 0, "", err + } + + r := struct { + ID string `json:"id"` //sink id + IP string `json:"ip"` + Port uint16 `json:"port"` + }{} + + if err = DecodeJSONBody(response.Body, &r); err != nil { + return "", 0, "", err + } + + return r.IP, r.Port, r.ID, nil +} + +func CloseSink(sourceId string, sinkId string) { + v := struct { + SourceID string `json:"source"` + SinkID string `json:"sink"` // sink id + }{ + sourceId, sinkId, + } + + _, _ = Send("api/v1/sink/close", v) +} diff --git a/message_factory.go b/message_factory.go new file mode 100644 index 0000000..f5e9ba9 --- /dev/null +++ b/message_factory.go @@ -0,0 +1,108 @@ +package main + +import ( + "fmt" + "github.com/ghettovoice/gosip/sip" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" + "net" + "strconv" + "strings" +) + +const ( + XmlHeaderGBK = `` + "\r\n" +) + +func BuildSDP(userName, sessionName, ip string, port uint16, startTime, stopTime, setup string, speed int, ssrc string) string { + format := "v=0\r\n" + + "o=%s 0 0 IN IP4 %s\r\n" + + "s=%s\r\n" + + "c=IN IP4 %s\r\n" + + "t=%s %s\r\n" + + "m=video %d %s 96\r\n" + + "a=%s\r\n" + + "a=rtpmap:96 PS/90000\r\n" + + tcpFormat := "a=setup:%s\r\n" + + "a=connection:new\r\n" + + var tcp bool + var mediaProtocol string + if "active" == setup || "passive" == setup { + mediaProtocol = "TCP/RTP/AVP" + tcp = true + } else { + mediaProtocol = "RTP/AVP" + } + + sdp := fmt.Sprintf(format, userName, ip, sessionName, ip, startTime, stopTime, port, mediaProtocol, "recvonly") + if tcp { + sdp += fmt.Sprintf(tcpFormat, setup) + } + + if speed > 0 { + sdp += fmt.Sprintf("a=downloadspeed:%d\r\n", speed) + } + + sdp += fmt.Sprintf("y=%s\r\n", ssrc) + return sdp +} + +func NewSIPRequestBuilderWithTransport(transport string) *sip.RequestBuilder { + builder := sip.NewRequestBuilder() + hop := sip.ViaHop{ + Transport: transport, + } + + builder.AddVia(&hop) + return builder +} + +func BuildMessageRequest(from, fromRealm, to, toAddr, transport, body string) (sip.Request, error) { + if !strings.HasPrefix(body, "平台 + addrMap map[string]interface{} //上级地址->平台 + lock sync.RWMutex +} + +func (p *platformManager) AddPlatform(platform *GBPlatform) bool { + p.lock.Lock() + defer p.lock.Unlock() + + // 以上级平台ID作为主键 + if _, ok := p.addrMap[platform.sipClient.SeverId]; ok { + return false + } + + p.platforms[platform.sipClient.SeverId] = platform + p.addrMap[platform.sipClient.Domain] = platform + return true +} + +func (p *platformManager) ExistPlatform(id string) bool { + p.lock.RLock() + defer p.lock.RUnlock() + _, ok := p.platforms[id] + return ok +} + +func (p *platformManager) ExistPlatformWithServerAddr(addr string) bool { + p.lock.RLock() + defer p.lock.RUnlock() + _, ok := p.addrMap[addr] + return ok +} + +func (p *platformManager) FindPlatform(id string) *GBPlatform { + p.lock.RLock() + defer p.lock.RUnlock() + if platform, ok := p.platforms[id]; ok { + return platform.(*GBPlatform) + } + return nil +} + +func (p *platformManager) RemovePlatform(id string) *GBPlatform { + p.lock.Lock() + defer p.lock.Unlock() + + platform, ok := p.platforms[id] + if !ok { + return nil + } + + delete(p.platforms, id) + delete(p.addrMap, platform.(*GBPlatform).sipClient.Domain) + return platform.(*GBPlatform) +} + +func (p *platformManager) FindPlatformWithServerAddr(addr string) *GBPlatform { + p.lock.RLock() + defer p.lock.RUnlock() + if platform, ok := p.addrMap[addr]; ok { + return platform.(*GBPlatform) + } + return nil +} + +func (p *platformManager) Platforms() []*GBPlatform { + p.lock.RLock() + defer p.lock.RUnlock() + + var platforms []*GBPlatform + for _, platform := range p.platforms { + platforms = append(platforms, platform.(*GBPlatform)) + } + return platforms +} diff --git a/position.go b/position.go index 8abda23..197d6d8 100644 --- a/position.go +++ b/position.go @@ -25,9 +25,9 @@ type MobilePositionNotify struct { Altitude string `xml:"Altitude"` } -func (d *DBDevice) DoSubscribePosition(channelId string) error { +func (d *Device) DoSubscribePosition(channelId string) error { if channelId == "" { - channelId = d.Id + channelId = d.ID } //暂时不考虑级联 @@ -37,7 +37,7 @@ func (d *DBDevice) DoSubscribePosition(channelId string) error { expiresHeader := sip.Expires(Config.MobilePositionExpires) builder.SetExpires(&expiresHeader) builder.SetContentType(&XmlMessageType) - builder.SetContact(globalContactAddress) + builder.SetContact(GlobalContactAddress) builder.SetBody(body) request, err := builder.Build() @@ -59,6 +59,6 @@ func (d *DBDevice) DoSubscribePosition(channelId string) error { return nil } -func (d *DBDevice) OnMobilePositionNotify(notify *MobilePositionNotify) { - Sugar.Infof("收到位置信息 device:%s data:%v", d.Id, notify) +func (d *Device) OnMobilePositionNotify(notify *MobilePositionNotify) { + Sugar.Infof("收到位置信息 device:%s data:%v", d.ID, notify) } diff --git a/reasons.go b/reasons.go new file mode 100644 index 0000000..791672c --- /dev/null +++ b/reasons.go @@ -0,0 +1,68 @@ +package main + +var reasons map[int]string + +func init() { + reasons = map[int]string{ + 100: "Trying", + 180: "Ringing", + 181: "Call Is Being Forwarded", + 182: "Queued", + 183: "Session Progress", + 200: "OK", + 202: "Accepted", + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Moved Temporarily", + 305: "Use Proxy", + 380: "Alternative Service", + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 410: "Gone", + 413: "Request Entity Too Large", + 414: "Request-URI Too Long", + 415: "Unsupported Media Type", + 416: "Unsupported URI Scheme", + 420: "Bad Extension", + 421: "Extension Required", + 423: "Interval Too Brief", + 480: "Temporarily Unavailable", + 481: "Call transaction Does Not Exist", + 482: "Loop Detected", + 483: "Too Many Hops", + 484: "Address Incomplete", + 485: "Ambiguous", + 486: "Busy Here", + 487: "Request Terminated", + 488: "Not Acceptable Here", + 489: "Bad Event", + 491: "Request Pending", + 493: "Undecipherable", + 500: "Server Internal Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Server Tim", + 505: "Version Not Supported", + 513: "message Too Large", + 600: "Busy Everywhere", + 603: "Decline", + 604: "Does Not Exist Anywhere", + 606: "SESSION NOT ACCEPTABLE", + } +} + +func StatusCode2Reason(code int) string { + if s, ok := reasons[code]; ok { + return s + } + + return "unknown reason" +} diff --git a/sdp/sdp_test.go b/sdp/sdp_test.go index 8202c95..eac5818 100644 --- a/sdp/sdp_test.go +++ b/sdp/sdp_test.go @@ -383,7 +383,7 @@ func TestParse(t *testing.T) { "o=34020099991320000015 2950 2950 IN IP4 192.168.1.64\r\n" + "s=Play\r\n" + "c=IN IP4 192.168.1.64\r\n" + - "t=0 0\r\n" + + "t=2 3\r\n" + "m=audio 15066 RTP/AVP 8 96\r\n" + "a=recvonly\r\n" + "a=rtpmap:8 PCMA/8000\r\n" + @@ -410,7 +410,7 @@ func TestParse(t *testing.T) { sdp.Origin.User) } if test.sdp.Origin.ID != sdp.Origin.ID { - t.Error(test.name, "Origin.ID doesn't match") + t.Error(test.name, "Origin.Username doesn't match") } if test.sdp.Origin.Version != sdp.Origin.Version { t.Error(test.name, "Origin.Version doesn't match") diff --git a/sdp/util.go b/sdp/util.go index c0118a1..08961bc 100644 --- a/sdp/util.go +++ b/sdp/util.go @@ -49,7 +49,7 @@ func GenerateTag() string { return hex.EncodeToString(randomBytes(6)) } -// Generate a SIP 2.0 Via branch ID. This is probably not suitable for use by +// Generate a SIP 2.0 Via branch Username. This is probably not suitable for use by // stateless proxies. func GenerateBranch() string { return "z9hG4bK-" + GenerateTag() @@ -67,7 +67,7 @@ func GenerateCallID() string { return uuid4 } -// Generate a random ID for an SDP. +// Generate a random Username for an SDP. func GenerateOriginID() string { return strconv.FormatUint(uint64(rand.Uint32()), 10) } diff --git a/sip_client.go b/sip_client.go new file mode 100644 index 0000000..83195ca --- /dev/null +++ b/sip_client.go @@ -0,0 +1,294 @@ +package main + +import ( + "context" + "fmt" + "github.com/ghettovoice/gosip/sip" + "github.com/lkmio/avformat/utils" + "math" + "net" + "strconv" + "time" +) + +const ( + KeepAliveBody = "\r\n" + + "\r\n" + + "Keepalive\r\n" + + "%d\r\n" + + "%s\r\n" + + "OK\r\n" + + "\r\n" +) + +var ( + UnregisterExpiresHeader = sip.Expires(0) +) + +type SipClient interface { + doRegister(request sip.Request) bool + + doUnregister() + + doKeepalive() bool + + Start() + + Stop() + + SetOnRegisterHandler(online, offline func()) +} + +type sipClient struct { + Username string + Domain string //注册域 + Transport string + Password string //密码 + RegisterExpires int //注册有效期 + KeeAliveInterval int //心跳间隔 + SeverId string //上级ID + + ListenAddr string //UA的监听地址 + NatAddr string //Nat地址 + + ua SipServer + existed bool + ctx context.Context + cancel context.CancelFunc + keepaliveFailedCount int + registerOK bool + registerOKTime time.Time //注册成功时间 + registerOKRequest sip.Request + + onlineCB func() + offlineCB func() +} + +func (g *sipClient) doRegister(request sip.Request) bool { + hop, _ := request.ViaHop() + empty := sip.String{} + hop.Params.Add("rport", &empty) + hop.Params.Add("received", &empty) + + for i := 0; i < 2; i++ { + //发起注册, 第一次未携带授权头, 第二次携带授权头 + clientTransaction := g.ua.SendRequest(request) + + //等待响应 + responses := clientTransaction.Responses() + var response sip.Response + select { + case response = <-responses: + break + case <-g.ctx.Done(): + break + } + + if response == nil { + break + } else if response.StatusCode() == 200 { + g.registerOKRequest = request.Clone().(sip.Request) + viaHop, _ := response.ViaHop() + rport, _ := viaHop.Params.Get("rport") + received, _ := viaHop.Params.Get("received") + if rport != nil && received != nil { + g.NatAddr = net.JoinHostPort(received.String(), rport.String()) + } + return true + } else if response.StatusCode() == 401 || response.StatusCode() == 407 { + authorizer := sip.DefaultAuthorizer{Password: sip.String{Str: g.Password}, User: sip.String{Str: g.Username}} + if err := authorizer.AuthorizeRequest(request, response); err != nil { + break + } + } else { + break + } + } + + return false +} + +func (g *sipClient) startNewRegister() bool { + builder := NewRequestBuilder(sip.REGISTER, g.Username, g.ListenAddr, g.SeverId, g.Domain, g.Transport) + expires := sip.Expires(g.RegisterExpires) + builder.SetExpires(&expires) + + host, p, _ := net.SplitHostPort(g.ListenAddr) + port, _ := strconv.Atoi(p) + sipPort := sip.Port(port) + builder.SetTo(&sip.Address{ + Uri: &sip.SipUri{ + FUser: sip.String{Str: g.Username}, + FHost: host, + FPort: &sipPort, + }, + }) + + request, err := builder.Build() + if err != nil { + panic(err) + } + + if ok := g.doRegister(request); ok { + g.registerOKRequest = request + return ok + } + + return false +} + +func CopySipRequest(old sip.Request) sip.Request { + //累加cseq number + cseq, _ := old.CSeq() + cseq.SeqNo++ + + request := old.Clone().(sip.Request) + //清空事务标记 + hop, _ := request.ViaHop() + hop.Params.Remove("branch") + return request +} + +func (g *sipClient) refreshRegister() bool { + request := CopySipRequest(g.registerOKRequest) + return g.doRegister(request) +} + +func (g *sipClient) doUnregister() { + request := CopySipRequest(g.registerOKRequest) + request.RemoveHeader("Expires") + request.AppendHeader(&UnregisterExpiresHeader) + g.ua.SendRequest(request) + + if g.offlineCB != nil { + go g.offlineCB() + } +} + +func (g *sipClient) doKeepalive() bool { + body := fmt.Sprintf(KeepAliveBody, time.Now().UnixMilli()/1000, g.Username) + request, err := BuildMessageRequest(g.Username, g.ListenAddr, g.SeverId, g.Domain, g.Transport, body) + if err != nil { + panic(err) + } + + transaction := g.ua.SendRequest(request) + responses := transaction.Responses() + + var response sip.Response + select { + case response = <-responses: + break + case <-g.ctx.Done(): + break + } + + return response != nil && response.StatusCode() == 200 +} + +// IsExpires 是否临近注册有效期 +func (g *sipClient) IsExpires() (bool, int) { + if !g.registerOK { + return false, 0 + } + + dis := g.RegisterExpires - int(time.Now().Sub(g.registerOKTime).Seconds()) + return dis <= 10, dis - 10 +} + +// Refresh 处理Client的生命周期任务, 发起注册, 发送心跳,断开重连等, 并返回下次刷新任务时间 +func (g *sipClient) Refresh() time.Duration { + expires, _ := g.IsExpires() + + if !g.registerOK || expires { + if expires { + g.registerOK = g.refreshRegister() + } else { + g.registerOK = g.startNewRegister() + } + + if g.registerOK { + g.registerOKTime = time.Now() + if g.onlineCB != nil { + go g.onlineCB() + } + } + } + + // 注册失败后, 等待10秒钟再发起注册 + if !g.registerOK { + return 10 * time.Second + } + + // 发送心跳 + if !g.doKeepalive() { + g.keepaliveFailedCount++ + } else { + g.keepaliveFailedCount = 0 + } + + // 心跳失败超过三次, 重新发起注册 + if g.keepaliveFailedCount > 2 { + g.keepaliveFailedCount = 0 + g.registerOK = false + g.registerOKRequest = nil + g.NatAddr = "" + + if g.offlineCB != nil { + go g.offlineCB() + } + + // 立马发起注册 + return 0 + } + + // 信令正常, 休眠心跳间隔时长 + return time.Duration(g.KeeAliveInterval) * time.Second +} + +func (g *sipClient) Start() { + utils.Assert(!g.existed) + g.ctx, g.cancel = context.WithCancel(context.Background()) + + go func() { + for !g.existed { + duration := g.Refresh() + expires, dis := g.IsExpires() + if duration < time.Second || expires { + continue + } else if g.registerOK { + duration = time.Duration(int(math.Min(duration.Seconds(), float64(dis)))) * time.Second + } + + if g.existed { + return + } + + select { + case <-time.After(duration): + break + case <-g.ctx.Done(): + break + } + } + }() +} + +func (g *sipClient) Stop() { + utils.Assert(!g.existed) + + g.existed = true + g.cancel() + g.registerOK = false + g.onlineCB = nil + g.offlineCB = nil + + if g.registerOK { + g.doUnregister() + } +} + +func (g *sipClient) SetOnRegisterHandler(online, offline func()) { + g.onlineCB = online + g.offlineCB = offline +} diff --git a/sip_server.go b/sip_server.go index dc3bb67..fc8ec5c 100644 --- a/sip_server.go +++ b/sip_server.go @@ -2,12 +2,15 @@ package main import ( "context" + "fmt" "github.com/ghettovoice/gosip" "github.com/ghettovoice/gosip/log" "github.com/ghettovoice/gosip/sip" "github.com/ghettovoice/gosip/util" + "github.com/lkmio/avformat/utils" "net" "net/http" + "reflect" "strconv" "strings" "time" @@ -15,12 +18,24 @@ import ( var ( logger log.Logger - globalContactAddress *sip.Address + GlobalContactAddress *sip.Address ) const ( CmdTagStart = "" CmdTagEnd = "" + + XmlNameControl = "Control" + XmlNameQuery = "Query" //主动查询消息 + XmlNameNotify = "Notify" //订阅产生的通知消息 + XmlNameResponse = "Response" //响应消息 + + CmdDeviceInfo = "DeviceInfo" + CmdDeviceStatus = "DeviceStatus" + CmdCatalog = "Catalog" + CmdRecordInfo = "RecordInfo" + CmdMobilePosition = "MobilePosition" + CmdKeepalive = "Keepalive" ) func init() { @@ -28,44 +43,37 @@ func init() { } type SipServer interface { - OnRegister(req sip.Request, tx sip.ServerTransaction) - - OnInvite(req sip.Request, tx sip.ServerTransaction) - - OnAck(req sip.Request, tx sip.ServerTransaction) - - OnBye(req sip.Request, tx sip.ServerTransaction) - - OnNotify(req sip.Request, tx sip.ServerTransaction) - SendRequestWithContext(ctx context.Context, request sip.Request, options ...gosip.RequestWithContextOption) - SendRequest(request sip.Request) + SendRequest(request sip.Request) sip.ClientTransaction SendRequestWithTimeout(seconds int, request sip.Request, options ...gosip.RequestWithContextOption) (sip.Response, error) Send(msg sip.Message) error + + ListenAddr() string } type sipServer struct { - sip gosip.Server - config *Config_ + sip gosip.Server + listenAddr string + xmlReflectTypes map[string]reflect.Type } func (s *sipServer) Send(msg sip.Message) error { return s.sip.Send(msg) } -func setToTag(response sip.Message, toTag string) { +func setToTag(response sip.Message) { toHeader := response.GetHeaders("To") to := toHeader[0].(*sip.ToHeader) - to.Params = sip.NewParams().Add("tag", sip.String{Str: toTag}) + to.Params = sip.NewParams().Add("tag", sip.String{Str: util.RandString(10)}) } -func (s *sipServer) OnRegister(req sip.Request, tx sip.ServerTransaction) { - var device *DBDevice +func (s *sipServer) OnRegister(req sip.Request, tx sip.ServerTransaction, parent bool) { + var device *Device var query bool _ = req.GetHeaders("Authorization") fromHeader := req.GetHeaders("From")[0].(*sip.FromHeader) @@ -82,8 +90,8 @@ func (s *sipServer) OnRegister(req sip.Request, tx sip.ServerTransaction) { //sip.NewResponseFromRequest("", req, 401, "Unauthorized", "") - device = &DBDevice{ - Id: fromHeader.Address.User().String(), + device = &Device{ + ID: fromHeader.Address.User().String(), Transport: req.Transport(), RemoteAddr: req.Source(), } @@ -92,96 +100,85 @@ func (s *sipServer) OnRegister(req sip.Request, tx sip.ServerTransaction) { query = err != nil || b } - sendResponse(tx, response) + SendResponse(tx, response) if device != nil && query { - catalog, err := device.BuildCatalogRequest() - if err != nil { - panic(err) - } - - s.SendRequest(catalog) + device.QueryCatalog() } } -func (s *sipServer) OnInvite(req sip.Request, tx sip.ServerTransaction) { - sendResponse(tx, sip.NewResponseFromRequest("", req, 100, "Trying", "")) - - var response sip.Response - var session *BroadcastSession +// OnInvite 上级预览/下级广播 +func (s *sipServer) OnInvite(req sip.Request, tx sip.ServerTransaction, parent bool) { + SendResponse(tx, sip.NewResponseFromRequest("", req, 100, "Trying", "")) user := req.Recipient().User().String() - exist := false - defer func() { - if !exist { - response = sip.NewResponseFromRequest("", req, 404, http.StatusText(404), "") - } - - sendResponse(tx, response) - if session != nil { - session.Answer <- 0 - } - }() if len(user) != 20 { + SendResponseWithStatusCode(req, tx, http.StatusNotFound) return } - roomId := user[:10] - room := BroadcastManager.FindRoom(roomId) - if room == nil { - return + // 查找对应的设备 + var device GBDevice + if parent { + // 级联设备 + device = PlatformManager.FindPlatformWithServerAddr(req.Source()) + } else if session := FindBroadcastSessionWithSourceID(user); session != nil { + // 语音广播设备 + device = DeviceManager.Find(session.DeviceID) + } else { + // 根据Subject头域查找设备 + headers := req.GetHeaders("Subject") + if len(headers) > 0 { + subject := headers[0].(*sip.GenericHeader) + split := strings.Split(strings.Split(subject.Value(), ",")[0], ":") + if len(split) > 1 { + device = DeviceManager.Find(split[1]) + } + } } - session = room.Find(user) - if session == nil { - return - } - - device := DeviceManager.Find(session.DeviceID) if device == nil { - return - } - - exist = true - code, sdp := device.OnInviteBroadcast(req, session) - response = sip.NewResponseFromRequest("", req, sip.StatusCode(code), http.StatusText(code), "") - - if code >= 200 && code < 300 { - toTag := util.RandString(10) - setToTag(response, toTag) - - session.Successful = true - session.ByeRequest = device.CreateDialogRequestFromAnswer(response, true) - - id, _ := req.CallID() - BroadcastManager.AddSessionWithCallId(id.Value(), session) - - response.SetBody(sdp, true) - response.AppendHeader(&SDPMessageType) - response.AppendHeader(globalContactAddress.AsContactHeader()) + SendResponseWithStatusCode(req, tx, http.StatusNotFound) + } else { + response := device.OnInvite(req, user) + SendResponse(tx, response) } } -func (s *sipServer) OnAck(req sip.Request, tx sip.ServerTransaction) { +func (s *sipServer) OnAck(req sip.Request, tx sip.ServerTransaction, parent bool) { } -func (s *sipServer) OnBye(req sip.Request, tx sip.ServerTransaction) { +func (s *sipServer) OnBye(req sip.Request, tx sip.ServerTransaction, parent bool) { response := sip.NewResponseFromRequest("", req, 200, "OK", "") - sendResponse(tx, response) + SendResponse(tx, response) id, _ := req.CallID() + var deviceId string - if stream, err := StreamManager.RemoveWithCallId(id.Value()); err == nil { + if stream := StreamManager.RemoveWithCallId(id.Value()); stream != nil { + // 下级设备挂断, 关闭流 + deviceId = stream.ID.DeviceID() stream.Close(false) } else if session := BroadcastManager.RemoveWithCallId(id.Value()); session != nil { + // 广播挂断 + deviceId = session.DeviceID session.Close(false) } + + if parent { + // 上级设备挂断 + if platform := PlatformManager.FindPlatformWithServerAddr(req.Source()); platform != nil { + platform.OnBye(req) + } + } else if device := DeviceManager.Find(deviceId); device != nil { + device.OnBye(req) + } } -func (s *sipServer) OnNotify(req sip.Request, tx sip.ServerTransaction) { +func (s *sipServer) OnNotify(req sip.Request, tx sip.ServerTransaction, parent bool) { response := sip.NewResponseFromRequest("", req, 200, "OK", "") - sendResponse(tx, response) + SendResponse(tx, response) mobilePosition := MobilePositionNotify{} if err := DecodeXML([]byte(req.Body()), &mobilePosition); err != nil { @@ -190,13 +187,94 @@ func (s *sipServer) OnNotify(req sip.Request, tx sip.ServerTransaction) { } if device := DeviceManager.Find(mobilePosition.DeviceID); device != nil { - device.OnMobilePositionNotify(&mobilePosition) + device.OnNotifyPosition(&mobilePosition) } } -func sendResponse(tx sip.ServerTransaction, response sip.Response) bool { - sendError := tx.Respond(response) +func (s *sipServer) OnMessage(req sip.Request, tx sip.ServerTransaction, parent bool) { + var online bool + defer func() { + var response sip.Response + if online { + response = CreateResponseWithStatusCode(req, http.StatusOK) + } else { + response = CreateResponseWithStatusCode(req, http.StatusForbidden) + } + + SendResponse(tx, response) + }() + + body := req.Body() + xmlName := GetRootElementName(body) + cmd := GetCmdType(body) + src, ok := s.xmlReflectTypes[xmlName+"."+cmd] + if !ok { + return + } + + message := reflect.New(src).Interface() + if err := DecodeXML([]byte(body), message); err != nil { + Sugar.Errorf("解析xml异常 >>> %s %s", err.Error(), body) + return + } + + // 查找设备 + var device GBDevice + deviceId := message.(BaseMessageGetter).GetDeviceID() + if parent { + device = PlatformManager.FindPlatformWithServerAddr(req.Source()) + } else { + device = DeviceManager.Find(deviceId) + } + + if online = device != nil; !online { + Sugar.Errorf("处理Msg失败 设备离线: %s Msg: %s", deviceId, body) + return + } + + switch xmlName { + case XmlNameControl: + break + case XmlNameQuery: + client, ok := device.(GBClient) + if !ok { + online = false + return + } + + if CmdDeviceInfo == cmd { + client.OnQueryDeviceInfo(message.(*BaseMessage).SN) + } else if CmdCatalog == cmd { + client.OnQueryCatalog(message.(*BaseMessage).SN) + } + break + case XmlNameNotify: + if CmdKeepalive == cmd { + device.OnKeepalive() + } + break + case XmlNameResponse: + if CmdCatalog == cmd { + device.OnCatalog(message.(*CatalogResponse)) + } else if CmdRecordInfo == cmd { + device.OnRecord(message.(*QueryRecordInfoResponse)) + } + break + } +} + +func CreateResponseWithStatusCode(request sip.Request, code int) sip.Response { + return sip.NewResponseFromRequest("", request, sip.StatusCode(code), StatusCode2Reason(code), "") +} + +func SendResponseWithStatusCode(request sip.Request, tx sip.ServerTransaction, code int) { + SendResponse(tx, CreateResponseWithStatusCode(request, code)) +} + +func SendResponse(tx sip.ServerTransaction, response sip.Response) bool { Sugar.Infof("send response >>> %s", response.String()) + sendError := tx.Respond(response) + if sendError != nil { Sugar.Infof("send response error %s %s", sendError.Error(), response.String()) } @@ -205,127 +283,105 @@ func sendResponse(tx sip.ServerTransaction, response sip.Response) bool { } func (s *sipServer) SendRequestWithContext(ctx context.Context, request sip.Request, options ...gosip.RequestWithContextOption) { - Sugar.Infof("send reqeust:%s", request.String()) + Sugar.Infof("send reqeust: %s", request.String()) s.sip.RequestWithContext(ctx, request, options...) } func (s *sipServer) SendRequestWithTimeout(seconds int, request sip.Request, options ...gosip.RequestWithContextOption) (sip.Response, error) { - Sugar.Infof("send reqeust:%s", request.String()) + Sugar.Infof("send reqeust: %s", request.String()) reqCtx, _ := context.WithTimeout(context.Background(), time.Duration(seconds)*time.Second) return s.sip.RequestWithContext(reqCtx, request, options...) } -func (s *sipServer) SendRequest(request sip.Request) { - Sugar.Infof("send reqeust:%s", request.String()) - - clientTransaction, err := s.sip.Request(request) +func (s *sipServer) SendRequest(request sip.Request) sip.ClientTransaction { + Sugar.Infof("send reqeust: %s", request.String()) + transaction, err := s.sip.Request(request) if err != nil { panic(err) } - defer func() { - if clientTransaction != nil { - err = clientTransaction.Cancel() - } - }() + return transaction } -func StartSipServer(config *Config_) (SipServer, error) { - server := gosip.NewServer(gosip.ServerConfig{ - Host: config.PublicIP, +func (s *sipServer) ListenAddr() string { + return s.listenAddr +} + +// 过滤SIP消息、超找消息来源 +func filterRequest(f func(req sip.Request, tx sip.ServerTransaction, parent bool)) gosip.RequestHandler { + return func(req sip.Request, tx sip.ServerTransaction) { + Sugar.Infof("process request: %s", req.String()) + + source := req.Source() + platform := PlatformManager.FindPlatformWithServerAddr(source) + switch req.Method() { + case sip.SUBSCRIBE, sip.INFO: + if platform == nil { + // SUBSCRIBE/INFO只能上级发起 + SendResponseWithStatusCode(req, tx, http.StatusBadRequest) + return + } + break + case sip.NOTIFY, sip.REGISTER: + if platform != nil { + // NOTIFY和REGISTER只能下级发起 + SendResponseWithStatusCode(req, tx, http.StatusBadRequest) + return + } + break + } + + f(req, tx, platform != nil) + } +} + +func StartSipServer(id, listenIP, publicIP string, listenPort int) (SipServer, error) { + ua := gosip.NewServer(gosip.ServerConfig{ + Host: publicIP, }, nil, nil, logger) - addr := net.JoinHostPort(config.ListenIP, strconv.Itoa(config.SipPort)) - if err := server.Listen("udp", addr); err != nil { + addr := net.JoinHostPort(listenIP, strconv.Itoa(listenPort)) + if err := ua.Listen("udp", addr); err != nil { return nil, err - } else if err := server.Listen("tcp", addr); err != nil { + } else if err := ua.Listen("tcp", addr); err != nil { return nil, err } - s := &sipServer{sip: server} - server.OnRequest(sip.REGISTER, s.OnRegister) - server.OnRequest(sip.INVITE, s.OnInvite) - server.OnRequest(sip.BYE, s.OnBye) - server.OnRequest(sip.ACK, s.OnAck) - server.OnRequest(sip.NOTIFY, s.OnNotify) - server.OnRequest(sip.MESSAGE, func(req sip.Request, tx sip.ServerTransaction) { - online := true - defer func() { - var response sip.Response - if online { - response = sip.NewResponseFromRequest("", req, 200, "OK", "") - } else { - response = sip.NewResponseFromRequest("", req, 403, "OK", "") - } + server := &sipServer{sip: ua, xmlReflectTypes: map[string]reflect.Type{ + fmt.Sprintf("%s.%s", XmlNameQuery, CmdCatalog): reflect.TypeOf(BaseMessage{}), + fmt.Sprintf("%s.%s", XmlNameQuery, CmdDeviceInfo): reflect.TypeOf(BaseMessage{}), + fmt.Sprintf("%s.%s", XmlNameQuery, CmdDeviceStatus): reflect.TypeOf(BaseMessage{}), + fmt.Sprintf("%s.%s", XmlNameResponse, CmdCatalog): reflect.TypeOf(CatalogResponse{}), + fmt.Sprintf("%s.%s", XmlNameResponse, CmdDeviceInfo): reflect.TypeOf(DeviceInfoResponse{}), + fmt.Sprintf("%s.%s", XmlNameResponse, CmdDeviceStatus): reflect.TypeOf(DeviceStatusResponse{}), + fmt.Sprintf("%s.%s", XmlNameResponse, CmdRecordInfo): reflect.TypeOf(QueryRecordInfoResponse{}), + fmt.Sprintf("%s.%s", XmlNameNotify, CmdKeepalive): reflect.TypeOf(BaseMessage{}), + fmt.Sprintf("%s.%s", XmlNameNotify, CmdMobilePosition): reflect.TypeOf(BaseMessage{}), + }} - sendResponse(tx, response) - }() + utils.Assert(ua.OnRequest(sip.REGISTER, filterRequest(server.OnRegister)) == nil) + utils.Assert(ua.OnRequest(sip.INVITE, filterRequest(server.OnInvite)) == nil) + utils.Assert(ua.OnRequest(sip.BYE, filterRequest(server.OnBye)) == nil) + utils.Assert(ua.OnRequest(sip.ACK, filterRequest(server.OnAck)) == nil) + utils.Assert(ua.OnRequest(sip.NOTIFY, filterRequest(server.OnNotify)) == nil) + utils.Assert(ua.OnRequest(sip.MESSAGE, filterRequest(server.OnMessage)) == nil) - body := req.Body() - startIndex := strings.Index(body, CmdTagStart) - endIndex := strings.Index(body, CmdTagEnd) - if startIndex <= 0 || endIndex <= 0 || endIndex+len(CmdTagStart) <= startIndex { - Sugar.Warnf("未知消息 %s", body) - return - } + utils.Assert(ua.OnRequest(sip.INFO, filterRequest(func(req sip.Request, tx sip.ServerTransaction, parent bool) { + })) == nil) + utils.Assert(ua.OnRequest(sip.CANCEL, filterRequest(func(req sip.Request, tx sip.ServerTransaction, parent bool) { + })) == nil) + utils.Assert(ua.OnRequest(sip.SUBSCRIBE, filterRequest(func(req sip.Request, tx sip.ServerTransaction, parent bool) { + })) == nil) - cmd := strings.ToLower(body[startIndex+len(CmdTagStart) : endIndex]) - var message interface{} - if "keepalive" == cmd { - return - } else if "catalog" == cmd { - message = &QueryCatalogResponse{} - } else if "recordinfo" == cmd { - message = &QueryRecordInfoResponse{} - } else if "mediastatus" == cmd { - return - } - - if err := DecodeXML([]byte(body), message); err != nil { - Sugar.Errorf("解析xml异常 >>> %s %s", err.Error(), body) - return - } - - switch cmd { - case "catalog": - { - if device := DeviceManager.Find(message.(*QueryCatalogResponse).DeviceID); device != nil { - device.OnCatalog(message.(*QueryCatalogResponse)) - } - } - break - - case "recordinfo": - { - if device := DeviceManager.Find(message.(*QueryRecordInfoResponse).DeviceID); device != nil { - device.OnRecord(message.(*QueryRecordInfoResponse)) - } - } - break - - case "keepalive": - { - device := DeviceManager.Find(message.(*QueryCatalogResponse).DeviceID) - if device != nil { - DB.KeepAliveDevice(device) - } - - online = device != nil - } - break - } - }) - - s.config = config - port := sip.Port(Config.SipPort) - - globalContactAddress = &sip.Address{ + server.listenAddr = addr + port := sip.Port(listenPort) + GlobalContactAddress = &sip.Address{ Uri: &sip.SipUri{ - FUser: sip.String{Str: config.SipId}, - FHost: config.PublicIP, + FUser: sip.String{Str: id}, + FHost: publicIP, FPort: &port, }, } - return s, nil + return server, nil } diff --git a/stream.go b/stream.go index e7d3fdd..bdb0926 100644 --- a/stream.go +++ b/stream.go @@ -3,19 +3,59 @@ package main import ( "context" "github.com/ghettovoice/gosip/sip" + "sync" "sync/atomic" "time" ) -type Stream struct { - Id string //推流ID - Protocol string //推流协议 - DialogRequest sip.Request - StreamType InviteType +// Sink 级联转发 +type Sink struct { + id string + deviceID string + dialog sip.Request +} - sinkCount int32 +// Stream 国标推流源 +type Stream struct { + ID StreamID // 推流ID + DialogRequest sip.Request + + sinkCount int32 // 拉流数量+级联转发数量 publishEvent chan byte cancelFunc func() + + forwardSinks map[string]*Sink // 级联转发Sink, Key为与上级的CallID + lock sync.RWMutex +} + +func (s *Stream) AddForwardSink(id string, sink *Sink) { + s.lock.Lock() + defer s.lock.Unlock() + s.forwardSinks[id] = sink +} + +func (s *Stream) RemoveForwardSink(id string) *Sink { + s.lock.Lock() + defer s.lock.Unlock() + + sink, ok := s.forwardSinks[id] + if ok { + delete(s.forwardSinks, id) + } + + return sink +} + +func (s *Stream) AllForwardSink() []*Sink { + s.lock.Lock() + defer s.lock.Unlock() + + var sinks []*Sink + for _, sink := range s.forwardSinks { + sinks = append(sinks, sink) + } + + return sinks } func (s *Stream) WaitForPublishEvent(seconds int) bool { @@ -49,23 +89,36 @@ func (s *Stream) Close(sendBye bool) { s.cancelFunc() } + // 断开与下级的会话 if sendBye && s.DialogRequest != nil { SipUA.SendRequest(s.CreateRequestFromDialog(sip.BYE)) s.DialogRequest = nil } - go CloseGBSource(s.Id) + go CloseGBSource(string(s.ID)) + + // 关闭所有级联会话 + sinks := s.AllForwardSink() + for _, sink := range sinks { + platform := PlatformManager.FindPlatform(sink.deviceID) + id, _ := sink.dialog.CallID() + platform.CloseStream(id.Value(), true, true) + } } -func (s *Stream) CreateRequestFromDialog(method sip.RequestMethod) sip.Request { +func CreateRequestFromDialog(dialog sip.Request, method sip.RequestMethod) sip.Request { { - seq, _ := s.DialogRequest.CSeq() + seq, _ := dialog.CSeq() seq.SeqNo++ seq.MethodName = method } - request := s.DialogRequest.Clone().(sip.Request) + request := dialog.Clone().(sip.Request) request.SetMethod(method) - + request.SetSource("") return request } + +func (s *Stream) CreateRequestFromDialog(method sip.RequestMethod) sip.Request { + return CreateRequestFromDialog(s.DialogRequest, method) +} diff --git a/stream_id.go b/stream_id.go new file mode 100644 index 0000000..e9ba910 --- /dev/null +++ b/stream_id.go @@ -0,0 +1,34 @@ +package main + +import ( + "github.com/lkmio/avformat/utils" + "strings" +) + +type StreamID string + +func (s StreamID) DeviceID() string { + return strings.Split(string(s), "/")[0] +} + +func (s StreamID) ChannelID() string { + return strings.Split(strings.Split(string(s), "/")[1], ".")[0] +} + +func GenerateStreamId(inviteType InviteType, deviceId, channelId string, startTime, endTime string) StreamID { + utils.Assert(channelId != "") + + var streamId []string + if deviceId != "" { + streamId = append(streamId, deviceId) + } + + streamId = append(streamId, channelId) + if InviteTypePlayback == inviteType { + return StreamID(strings.Join(streamId, "/") + ".playback" + "." + startTime + "." + endTime) + } else if InviteTypeDownload == inviteType { + return StreamID(strings.Join(streamId, "/") + ".download" + "." + startTime + "." + endTime) + } + + return StreamID(strings.Join(streamId, "/")) +} diff --git a/stream_manager.go b/stream_manager.go index e3ea56f..1de4820 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -1,22 +1,18 @@ package main import ( - "fmt" "sync" ) var StreamManager *streamManager func init() { - StreamManager = &streamManager{ - streams: make(map[string]*Stream, 64), - callIds: make(map[string]*Stream, 64), - } + StreamManager = NewStreamManager() } type streamManager struct { - streams map[string]*Stream - callIds map[string]*Stream + streams map[StreamID]*Stream + callIds map[string]*Stream // 本SipUA的CallID与Stream的关系 lock sync.RWMutex } @@ -26,29 +22,28 @@ func (s *streamManager) Add(stream *Stream) (*Stream, bool) { s.lock.Lock() defer s.lock.Unlock() - old, ok := s.streams[stream.Id] + old, ok := s.streams[stream.ID] if ok { return old, false } - s.streams[stream.Id] = stream + s.streams[stream.ID] = stream return nil, true } -func (s *streamManager) AddWithCallId(stream *Stream) error { +func (s *streamManager) AddWithCallId(id string, stream *Stream) bool { s.lock.Lock() defer s.lock.Unlock() - id, _ := stream.DialogRequest.CallID() - if _, ok := s.callIds[id.Value()]; ok { - return fmt.Errorf("the stream %s has been exist", id.Value()) + if _, ok := s.callIds[id]; ok { + return false } - s.callIds[id.Value()] = stream - return nil + s.callIds[id] = stream + return true } -func (s *streamManager) Find(id string) *Stream { +func (s *streamManager) Find(id StreamID) *Stream { s.lock.RLock() defer s.lock.RUnlock() @@ -68,7 +63,7 @@ func (s *streamManager) FindWithCallId(id string) *Stream { return nil } -func (s *streamManager) Remove(id string) (*Stream, error) { +func (s *streamManager) Remove(id StreamID) *Stream { s.lock.Lock() defer s.lock.Unlock() @@ -77,24 +72,24 @@ func (s *streamManager) Remove(id string) (*Stream, error) { if ok && stream.DialogRequest != nil { callID, _ := stream.DialogRequest.CallID() delete(s.callIds, callID.Value()) - return stream, nil + return stream } - return nil, fmt.Errorf("stream with id %s was not find", id) + return nil } -func (s *streamManager) RemoveWithCallId(id string) (*Stream, error) { +func (s *streamManager) RemoveWithCallId(id string) *Stream { s.lock.Lock() defer s.lock.Unlock() stream, ok := s.callIds[id] if ok { delete(s.callIds, id) - delete(s.streams, stream.Id) - return stream, nil + delete(s.streams, stream.ID) + return stream } - return nil, fmt.Errorf("stream with id %s was not find", id) + return nil } func (s *streamManager) PopAll() []*Stream { @@ -106,7 +101,14 @@ func (s *streamManager) PopAll() []*Stream { streams = append(streams, stream) } - s.streams = make(map[string]*Stream) + s.streams = make(map[StreamID]*Stream) s.callIds = make(map[string]*Stream) return streams } + +func NewStreamManager() *streamManager { + return &streamManager{ + streams: make(map[StreamID]*Stream, 64), + callIds: make(map[string]*Stream, 64), + } +} diff --git a/subscribe.go b/subscribe.go new file mode 100644 index 0000000..d1624cf --- /dev/null +++ b/subscribe.go @@ -0,0 +1,13 @@ +package main + +import "github.com/ghettovoice/gosip/sip" + +type GBSubscribe struct { + PositionDialog sip.Request + CatalogDialog sip.Request + AlarmDialog sip.Request +} + +func RefreshSubscribe(expires int) { + +} diff --git a/util.go b/util.go index 308d790..9f702de 100644 --- a/util.go +++ b/util.go @@ -49,7 +49,7 @@ func GenerateTag() string { return hex.EncodeToString(randomBytes(6)) } -// Generate a SIP 2.0 Via branch ID. This is probably not suitable for use by +// Generate a SIP 2.0 Via branch Username. This is probably not suitable for use by // stateless proxies. func GenerateBranch() string { return "z9hG4bK-" + GenerateTag() @@ -67,7 +67,7 @@ func GenerateCallID() string { return uuid4 } -// Generate a random ID for an SDP. +// Generate a random Username for an SDP. func GenerateOriginID() string { return strconv.FormatUint(uint64(rand.Uint32()), 10) } diff --git a/xml.go b/xml.go index ba131a0..6f0621e 100644 --- a/xml.go +++ b/xml.go @@ -1,40 +1,94 @@ package main -import ( - "bytes" - "encoding/xml" - "golang.org/x/net/html/charset" - "golang.org/x/text/encoding/simplifiedchinese" - "golang.org/x/text/transform" - "io" -) +import "encoding/xml" -func GbkToUtf8(s []byte) ([]byte, error) { - reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder()) - return io.ReadAll(reader) +type Channel struct { + DeviceID string `xml:"DeviceID"` + Name string `xml:"Name,omitempty"` + Manufacturer string `xml:"Manufacturer,omitempty"` + Model string `xml:"Model,omitempty"` + Owner string `xml:"Owner,omitempty"` + CivilCode string `xml:"CivilCode,omitempty"` + Block string `xml:"Block,omitempty"` + Address string `xml:"Address,omitempty"` + Parental string `xml:"Parental,omitempty"` + ParentID string `xml:"ParentID,omitempty"` + SafetyWay string `xml:"SafetyWay,omitempty"` + RegisterWay string `xml:"RegisterWay,omitempty"` + CertNum string `xml:"CertNum,omitempty"` + Certifiable string `xml:"Certifiable,omitempty"` + ErrCode string `xml:"ErrCode,omitempty"` + EndTime string `xml:"EndTime,omitempty"` + Secrecy string `xml:"Secrecy,omitempty"` + IPAddress string `xml:"IPAddress,omitempty"` + Port string `xml:"Port,omitempty"` + Password string `xml:"Password,omitempty"` + Status string `xml:"Status,omitempty"` + Longitude string `xml:"Longitude,omitempty"` + Latitude string `xml:"Latitude,omitempty"` } -func DoDecodeXML(data []byte, message interface{}) error { - decoder := xml.NewDecoder(bytes.NewReader(data)) - decoder.CharsetReader = func(c string, i io.Reader) (io.Reader, error) { - return charset.NewReaderLabel(c, i) - } - - return decoder.Decode(message) +type BaseMessageGetter interface { + GetDeviceID() string + GetCmdType() string + GetSN() int } -func DecodeXML(data []byte, message interface{}) error { - //uft8Data := []byte(strings.Replace(string(data), "GB2312", "UTF-8", 1)) - uft8Data := data - err := DoDecodeXML(uft8Data, message) - if err != nil { - uft8Data, err = GbkToUtf8(uft8Data) - if err != nil { - return err - } - - err = DoDecodeXML(uft8Data, message) - } - - return err +type BaseMessage struct { + CmdType string `xml:"CmdType"` + SN int `xml:"SN"` + DeviceID string `xml:"DeviceID"` +} + +func (b BaseMessage) GetDeviceID() string { + return b.DeviceID +} + +func (b BaseMessage) GetCmdType() string { + return b.CmdType +} + +func (b BaseMessage) GetSN() int { + return b.SN +} + +type DeviceList struct { + Num int `xml:"Num,attr"` + Devices []*Channel `xml:"Item"` +} + +type ExtendedInfo struct { + Info string `xml:"Info,omitempty"` +} + +type BaseResponse struct { + XMLName xml.Name `xml:"Response"` + BaseMessage + Result string `xml:"Result,omitempty"` + ExtendedInfo +} + +type CatalogResponse struct { + BaseResponse + SumNum int `xml:"SumNum"` + DeviceList DeviceList `xml:"DeviceList"` +} + +type DeviceInfoResponse struct { + BaseResponse + DeviceName string `xml:"DeviceName,omitempty"` + Manufacturer string `xml:"Manufacturer,omitempty"` + Model string `xml:"Model,omitempty"` + Firmware string `xml:"Firmware,omitempty"` + Channel string `xml:"Channel,omitempty"` //通道数 +} + +type DeviceStatusResponse struct { + BaseResponse + Online string `xml:"Online"` //ONLINE/OFFLINE + Status string `xml:"Status"` //OK/ERROR + Reason string `xml:"Reason"` //OK/ERROR + Encode string `xml:"Encode"` //ON/OFF + Record string `xml:"Record"` //ON/OFF + DeviceTime string `xml:"DeviceTime"` } diff --git a/record.go b/xml_record.go similarity index 82% rename from record.go rename to xml_record.go index 67fc3d8..15ef228 100644 --- a/record.go +++ b/xml_record.go @@ -60,23 +60,9 @@ type RecordInfo struct { ShutdownTime string `xml:"ShutdownTime" json:"shutdownTime"` } -func (d *DBDevice) DoQueryRecordList(channelId, startTime, endTime string, sn int, type_ string) error { +func (d *Device) DoQueryRecordList(channelId, startTime, endTime string, sn int, type_ string) error { body := fmt.Sprintf(QueryRecordFormat, sn, channelId, startTime, endTime, type_) - msg, err := d.BuildMessageRequest(channelId, body) - if err != nil { - return err - } - - SipUA.SendRequest(msg) + request := d.BuildMessageRequest(channelId, body) + SipUA.SendRequest(request) return nil } - -func (d *DBDevice) OnRecord(response *QueryRecordInfoResponse) { - event := SNManager.FindEvent(response.SN) - if event == nil { - Sugar.Errorf("处理录像查询响应失败 SN:%d", response.SN) - return - } - - event(response) -} diff --git a/xml_test.go b/xml_test.go new file mode 100644 index 0000000..4d0a1a6 --- /dev/null +++ b/xml_test.go @@ -0,0 +1,17 @@ +package main + +import ( + "encoding/hex" + "testing" +) + +func TestDecodeXML(t *testing.T) { + //str := "3c3f786d6c2076657273696f6e3d22312e30223f3e0d0a3c51756572793e0d0a3c436d64547970653e446576696365496e666f3c2f436d64547970653e0d0a3c534e3e323c2f534e3e0d0a3c44657669636549443e33343032303030303030313332303030303030313c2f44657669636549443e0d0a3c2f51756572793e0d0a" + str := "3c3f786d6c2076657273696f6e3d22312e302220656e636f64696e673d22474232333132223f3e0d0a3c526573706f6e73653e0d0a3c436d64547970653e436174616c6f673c2f436d64547970653e0d0a3c534e3e313c2f534e3e0d0a3c44657669636549443e33343032303030303030313332303030303030313c2f44657669636549443e0d0a3c53756d4e756d3e313c2f53756d4e756d3e0d0a3c4465766963654c697374204e756d3d2231223e0d0a3c4974656d3e0d0a3c44657669636549443e33343032303030303030313331303030303030313c2f44657669636549443e0d0a3c4e616d653e47423238313831436c69656e743c2f4e616d653e0d0a3c4d616e7566616374757265723e48616958696e3c2f4d616e7566616374757265723e0d0a3c4d6f64656c3e474232383138315f416e64726f69643c2f4d6f64656c3e0d0a3c4f776e65723e4f776e65723c2f4f776e65723e0d0a3c416464726573733e416464726573733c2f416464726573733e0d0a3c506172656e74616c3e303c2f506172656e74616c3e0d0a3c506172656e7449443e33343032303030303030313332303030303030313c2f506172656e7449443e0d0a3c5361666574795761793e303c2f5361666574795761793e0d0a3c52656769737465725761793e313c2f52656769737465725761793e0d0a3c536563726563793e303c2f536563726563793e0d0a3c5374617475733e4f4e3c2f5374617475733e0d0a3c2f4974656d3e0d0a3c2f4465766963654c6973743e0d0a3c2f526573706f6e73653e0d0a" + data, err := hex.DecodeString(str) + + response := CatalogResponse{} + if err = DecodeXML(data, &response); err != nil { + panic(err) + } +} diff --git a/xml_util.go b/xml_util.go new file mode 100644 index 0000000..bd21485 --- /dev/null +++ b/xml_util.go @@ -0,0 +1,69 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/xml" + "golang.org/x/net/html/charset" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" + "io" + "strings" +) + +func GbkToUtf8(s []byte) ([]byte, error) { + reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder()) + return io.ReadAll(reader) +} + +func DoDecodeXML(data []byte, message interface{}) error { + decoder := xml.NewDecoder(bytes.NewReader(data)) + decoder.CharsetReader = func(c string, i io.Reader) (io.Reader, error) { + return charset.NewReaderLabel(c, i) + } + + return decoder.Decode(message) +} + +func DecodeXML(data []byte, message interface{}) error { + //uft8Data := []byte(strings.Replace(string(data), "GB2312", "UTF-8", 1)) + uft8Data := data + err := DoDecodeXML(uft8Data, message) + if err != nil { + uft8Data, err = GbkToUtf8(uft8Data) + if err != nil { + return err + } + + err = DoDecodeXML(uft8Data, message) + } + + return err +} + +func GetRootElementName(data string) string { + reader := strings.NewReader(data) + scanner := bufio.NewScanner(reader) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() && scanner.Scan() { + line := scanner.Text() + if len(line) == 0 { + continue + } + + return line[1 : len(line)-1] + } + + return "" +} + +func GetCmdType(data string) string { + startIndex := strings.Index(data, CmdTagStart) + endIndex := strings.Index(data, CmdTagEnd) + if startIndex <= 0 || endIndex <= 0 || endIndex+len(CmdTagStart) <= startIndex { + return "" + } + + return data[startIndex+len(CmdTagStart) : endIndex] +}