diff --git a/api.go b/api.go index b5467c6..a20f610 100644 --- a/api.go +++ b/api.go @@ -19,71 +19,34 @@ import ( "time" ) -var upgrader *websocket.Upgrader +type ApiServer struct { + upgrader *websocket.Upgrader + router *mux.Router +} + +var apiServer *ApiServer func init() { - upgrader = &websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true + apiServer = &ApiServer{ + upgrader: &websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, }, + + router: mux.NewRouter(), } } func startApiServer(addr string) { - r := mux.NewRouter() - /** - http://host:port/xxx.flv - http://host:port/xxx.rtc - http://host:port/xxx.m3u8 - http://host:port/xxx_0.ts - ws://host:port/xxx.flv - */ - r.HandleFunc("/live/{source}", func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - source := vars["source"] - index := strings.LastIndex(source, ".") - if index < 0 || index == len(source)-1 { - log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path) - w.WriteHeader(http.StatusBadRequest) - return - } - - sourceId := source[:index] - format := source[index+1:] - - if "flv" == format { - //判断是否是websocket请求 - ws := true - if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) { - ws = false - } else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) { - ws = false - } else if !("13" == r.Header.Get("Sec-Websocket-Version")) { - ws = false - } - - if ws { - onWSFlv(sourceId, w, r) - } else { - onFLV(sourceId, w, r) - } - - } else if "m3u8" == format { - onHLS(sourceId, w, r) - } else if "ts" == format { - onTS(sourceId, w, r) - } else if "rtc" == format { - onRtc(sourceId, w, r) - } - }) - - r.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) { + apiServer.router.HandleFunc("/live/{source}", apiServer.filterLive) + apiServer.router.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) { http.ServeFile(writer, request, "./rtc.html") }) - http.Handle("/", r) + http.Handle("/", apiServer.router) srv := &http.Server{ - Handler: r, + Handler: apiServer.router, Addr: addr, // Good practice: enforce timeouts for servers you create! WriteTimeout: 30 * time.Second, @@ -97,27 +60,89 @@ func startApiServer(addr string) { } } -func onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) +func (api *ApiServer) generateSinkId(remoteAddr string) stream.SinkId { + tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr) + if err != nil { + panic(err) + } + + return stream.GenerateSinkId(tcpAddr) +} + +func (api *ApiServer) doPlay(sink stream.ISink) utils.HookState { + ok := utils.HookStateOK + sink.Play(func() { + + }, func(state utils.HookState) { + ok = state + }) + + return ok +} + +func (api *ApiServer) filterLive(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + source := vars["source"] + index := strings.LastIndex(source, ".") + if index < 0 || index == len(source)-1 { + log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path) + w.WriteHeader(http.StatusBadRequest) + return + } + + sourceId := source[:index] + format := source[index+1:] + + /** + http://host:port/xxx.flv + http://host:port/xxx.rtc + http://host:port/xxx.m3u8 + http://host:port/xxx_0.ts + ws://host:port/xxx.flv + */ + if "flv" == format { + //判断是否是websocket请求 + ws := true + if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) { + ws = false + } else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) { + ws = false + } else if !("13" == r.Header.Get("Sec-Websocket-Version")) { + ws = false + } + + if ws { + api.onWSFlv(sourceId, w, r) + } else { + api.onFLV(sourceId, w, r) + } + + } else if "m3u8" == format { + api.onHLS(sourceId, w, r) + } else if "ts" == format { + api.onTS(sourceId, w, r) + } else if "rtc" == format { + api.onRtc(sourceId, w, r) + } +} + +func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { + conn, err := api.upgrader.Upgrade(w, r, nil) if err != nil { log.Sugar.Errorf("websocket头检查失败 err:%s", err.Error()) w.WriteHeader(http.StatusBadRequest) return } - tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - sinkId := stream.GenerateSinkId(tcpAddr) - sink := flv.NewFLVSink(sinkId, sourceId, flv.NewWSConn(conn)) - + sink := flv.NewFLVSink(api.generateSinkId(r.RemoteAddr), sourceId, flv.NewWSConn(conn)) log.Sugar.Infof("ws-flv 连接 sink:%s", sink.PrintInfo()) - sink.(*stream.SinkImpl).Play(sink, func() { - - }, func(state utils.HookState) { + state := api.doPlay(sink) + if utils.HookStateOK != state { + log.Sugar.Warnf("ws-flv 播放失败 state:%d sink:%s", state, sink.PrintInfo()) w.WriteHeader(http.StatusForbidden) - - conn.Close() - }) + return + } netConn := conn.NetConn() bytes := make([]byte, 64) @@ -130,7 +155,7 @@ func onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { } } -func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "video/x-flv") w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Transfer-Encoding", "chunked") @@ -148,18 +173,16 @@ func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { return } - tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - sinkId := stream.GenerateSinkId(tcpAddr) - sink := flv.NewFLVSink(sinkId, sourceId, conn) - + sink := flv.NewFLVSink(api.generateSinkId(r.RemoteAddr), sourceId, conn) log.Sugar.Infof("http-flv 连接 sink:%s", sink.PrintInfo()) - sink.(*stream.SinkImpl).Play(sink, func() { - }, func(state utils.HookState) { + state := api.doPlay(sink) + if utils.HookStateOK != state { + log.Sugar.Warnf("http-flv 播放失败 state:%d sink:%s", state, sink.PrintInfo()) + w.WriteHeader(http.StatusForbidden) - - conn.Close() - }) + return + } bytes := make([]byte, 64) for { @@ -170,7 +193,7 @@ func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { } } -func onTS(source string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) onTS(source string, w http.ResponseWriter, r *http.Request) { if !stream.AppConfig.Hls.Enable { log.Sugar.Warnf("处理m3u8请求失败 server未开启hls request:%s", r.URL.Path) http.Error(w, "hls disable", http.StatusInternalServerError) @@ -196,17 +219,16 @@ func onTS(source string, w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, tsPath) } -func onHLS(sourceId string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) onHLS(sourceId string, w http.ResponseWriter, r *http.Request) { if !stream.AppConfig.Hls.Enable { - log.Sugar.Warnf("处理hls请求失败 server未开启hls request:%s", r.URL.Path) + log.Sugar.Warnf("处理hls请求失败 server未开启hls") http.Error(w, "hls disable", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/vnd.apple.mpegurl") //m3u8和ts会一直刷新, 每个请求只hook一次. - tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - sinkId := stream.GenerateSinkId(tcpAddr) + sinkId := api.generateSinkId(r.RemoteAddr) //hook成功后, 如果还没有m3u8文件,等生成m3u8文件 //后续直接返回当前m3u8文件 @@ -221,31 +243,26 @@ func onHLS(sourceId string, w http.ResponseWriter, r *http.Request) { done <- 0 }) - hookState := utils.HookStateOK - sink.Play(sink, func() { - err := stream.SinkManager.Add(sink) + state := api.doPlay(sink) + if utils.HookStateOK != state { + log.Sugar.Warnf("m3u8 请求失败 state:%d sink:%s", state, sink.PrintInfo()) - utils.Assert(err == nil) - }, func(state utils.HookState) { - log.Sugar.Warnf("hook播放事件失败 request:%s", r.URL.Path) - hookState = state w.WriteHeader(http.StatusForbidden) - }) - - if utils.HookStateOK != hookState { return + } else { + err := stream.SinkManager.Add(sink) + utils.Assert(err == nil) } select { case <-done: case <-context.Done(): - log.Sugar.Infof("http m3u8连接断开") break } } } -func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) onRtc(sourceId string, w http.ResponseWriter, r *http.Request) { v := struct { Type string `json:"type"` SDP string `json:"sdp"` @@ -253,19 +270,20 @@ func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { - panic(err) - } + log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr) - if err := json.Unmarshal(data, &v); err != nil { - panic(err) - } + http.Error(w, err.Error(), http.StatusBadRequest) + return + } else if err := json.Unmarshal(data, &v); err != nil { + log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr) - tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - sinkId := stream.GenerateSinkId(tcpAddr) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } group := sync.WaitGroup{} group.Add(1) - sink := rtc.NewSink(sinkId, sourceId, v.SDP, func(sdp string) { + sink := rtc.NewSink(api.generateSinkId(r.RemoteAddr), sourceId, v.SDP, func(sdp string) { response := struct { Type string `json:"type"` SDP string `json:"sdp"` @@ -285,13 +303,15 @@ func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) { group.Done() }) - sink.Play(sink, func() { + log.Sugar.Infof("rtc 请求 sink:%s sdp:%v", sink.PrintInfo(), v.SDP) + + state := api.doPlay(sink) + if utils.HookStateOK != state { + log.Sugar.Warnf("rtc 播放失败 state:%d sink:%s", state, sink.PrintInfo()) - }, func(state utils.HookState) { w.WriteHeader(http.StatusForbidden) - group.Done() - }) + } group.Wait() } diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 5c3d689..005ff06 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -61,7 +61,7 @@ func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) log.Sugar.Infof("rtmp onplay app:%s stream:%s sink:%v conn:%s", app, stream_, sink.Id(), s.conn.RemoteAddr().String()) - sink.(*stream.SinkImpl).Play(sink, func() { + sink.Play(func() { s.handle = sink response <- utils.HookStateOK }, func(state utils.HookState) { diff --git a/rtsp/rtsp_session.go b/rtsp/rtsp_session.go index fe0da0d..8319218 100644 --- a/rtsp/rtsp_session.go +++ b/rtsp/rtsp_session.go @@ -163,7 +163,7 @@ func (s *session) onDescribe(source string, headers textproto.MIMEHeader) error code := utils.HookStateOK s.sink_ = sink_.(*sink) - sink_.(*sink).Play(sink_, func() { + sink_.Play(func() { }, func(state utils.HookState) { code = state diff --git a/stream/hook.go b/stream/hook.go index 62916ee..d551e04 100644 --- a/stream/hook.go +++ b/stream/hook.go @@ -38,9 +38,9 @@ func NewPublishHookEventInfo(stream, remoteAddr string, protocol SourceType) eve } type HookHandler interface { - Play(sink ISink, success func(), failure func(state utils.HookState)) + Play(success func(), failure func(state utils.HookState)) - PlayDone(sink ISink, success func(), failure func(state utils.HookState)) + PlayDone(success func(), failure func(state utils.HookState)) } type HookSession interface { diff --git a/stream/sink.go b/stream/sink.go index 28065e0..54c16f5 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -192,18 +192,18 @@ func (s *SinkImpl) PrintInfo() string { return fmt.Sprintf("%s-%v source:%s", s.ProtocolStr(), s.Id_, s.SourceId_) } -func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.HookState)) { +func (s *SinkImpl) Play(success func(), failure func(state utils.HookState)) { f := func() { - source := SourceManager.Find(sink.SourceId()) + source := SourceManager.Find(s.SourceId()) if source == nil { - log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", sink.ProtocolStr(), sink.Id(), sink.SourceId()) + log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId()) - sink.SetState(SessionStateWait) - AddSinkToWaitingQueue(sink.SourceId(), sink) + s.SetState(SessionStateWait) + AddSinkToWaitingQueue(s.SourceId(), s) } else { - log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", sink.ProtocolStr(), sink.Id(), sink.SourceId()) + log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId()) - source.AddEvent(SourceEventPlay, sink) + source.AddEvent(SourceEventPlay, s) } } @@ -213,23 +213,23 @@ func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.Hoo return } - err := s.Hook(HookEventPlay, NewPlayHookEventInfo(sink.SourceId(), "", sink.Protocol()), func(response *http.Response) { + err := s.Hook(HookEventPlay, NewPlayHookEventInfo(s.SourceId(), "", s.Protocol()), func(response *http.Response) { f() success() }, func(response *http.Response, err error) { - log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), sink.ProtocolStr(), sink.Id(), sink.SourceId()) + log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) failure(utils.HookStateFailure) }) if err != nil { - log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), sink.ProtocolStr(), sink.Id(), sink.SourceId()) + log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId()) failure(utils.HookStateFailure) return } } -func (s *SinkImpl) PlayDone(source ISink, success func(), failure func(state utils.HookState)) { +func (s *SinkImpl) PlayDone(success func(), failure func(state utils.HookState)) { }