diff --git a/api.go b/api.go index 352e7f1..141542b 100644 --- a/api.go +++ b/api.go @@ -40,8 +40,34 @@ func init() { } } +func withCheckParams(f func(sourceId string, w http.ResponseWriter, req *http.Request), suffix string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, req *http.Request) { + source, err := stream.Path2SourceId(req.URL.Path, suffix) + if err != nil { + httpResponse(w, http.StatusBadRequest, err.Error()) + return + } + + f(source, w, req) + } +} + func startApiServer(addr string) { - apiServer.router.HandleFunc("/live/{source}", apiServer.filterLive) + /** + http://host:port/xxx.flv + http://host:port/xxx.rtc + http://host:port/xxx.m3u8 + http://host:port/xxx_0.ts + ws://host:port/xxx.flv + */ + apiServer.router.HandleFunc("/{source}.flv", withCheckParams(apiServer.onFlv, ".flv")) + apiServer.router.HandleFunc("/{source}/{stream}.flv", withCheckParams(apiServer.onFlv, ".flv")) + apiServer.router.HandleFunc("/{source}.m3u8", withCheckParams(apiServer.onHLS, ".m3u8")) + apiServer.router.HandleFunc("/{source}/{stream}.m3u8", withCheckParams(apiServer.onHLS, ".m3u8")) + apiServer.router.HandleFunc("/{source}.ts", withCheckParams(apiServer.onTS, ".ts")) + apiServer.router.HandleFunc("/{source}/{stream}.ts", withCheckParams(apiServer.onTS, ".ts")) + apiServer.router.HandleFunc("/{source}.rtc", withCheckParams(apiServer.onRtc, ".rtc")) + apiServer.router.HandleFunc("/{source}/{stream}.rtc", withCheckParams(apiServer.onRtc, ".rtc")) apiServer.router.HandleFunc("/v1/gb28181/source/create", apiServer.createGBSource) //TCP主动,设置连接地址 @@ -218,6 +244,15 @@ func (api *ApiServer) generateSinkId(remoteAddr string) stream.SinkId { return stream.GenerateSinkId(tcpAddr) } +func (api *ApiServer) generateSourceId(remoteAddr string) stream.SinkId { + tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr) + if err != nil { + panic(err) + } + + return stream.GenerateSinkId(tcpAddr) +} + func (api *ApiServer) doPlay(sink stream.ISink) utils.HookState { ok := utils.HookStateOK stream.HookPlaying(sink, func() { @@ -229,52 +264,22 @@ func (api *ApiServer) doPlay(sink stream.ISink) utils.HookState { return ok } -func (api *ApiServer) filterLive(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - source := vars["source"] - index := strings.LastIndex(source, ".") - if index < 0 || index == len(source)-1 { - log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path) - w.WriteHeader(http.StatusBadRequest) - return +func (api *ApiServer) onFlv(sourceId string, w http.ResponseWriter, r *http.Request) { + ws := true + if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) { + ws = false + } else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) { + ws = false + } else if !("13" == r.Header.Get("Sec-Websocket-Version")) { + ws = false } - sourceId := source[:index] - format := source[index+1:] - - /** - http://host:port/xxx.flv - http://host:port/xxx.rtc - http://host:port/xxx.m3u8 - http://host:port/xxx_0.ts - ws://host:port/xxx.flv - */ - if "flv" == format { - //判断是否是websocket请求 - ws := true - if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) { - ws = false - } else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) { - ws = false - } else if !("13" == r.Header.Get("Sec-Websocket-Version")) { - ws = false - } - - if ws { - api.onWSFlv(sourceId, w, r) - } else { - api.onFLV(sourceId, w, r) - } - - } else if "m3u8" == format { - api.onHLS(sourceId, w, r) - } else if "ts" == format { - api.onTS(sourceId, w, r) - } else if "rtc" == format { - api.onRtc(sourceId, w, r) + if ws { + apiServer.onWSFlv(sourceId, w, r) + } else { + apiServer.onHttpFLV(sourceId, w, r) } } - func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { conn, err := api.upgrader.Upgrade(w, r, nil) if err != nil { @@ -304,7 +309,7 @@ func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Re } } -func (api *ApiServer) onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { +func (api *ApiServer) onHttpFLV(sourceId string, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "video/x-flv") w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Transfer-Encoding", "chunked") diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 842dd0a..9450708 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -34,10 +34,20 @@ type sessionImpl struct { conn net.Conn } +func (s *sessionImpl) generateSourceId(app, stream_ string) string { + if len(app) == 0 { + return stream_ + } else if len(stream_) == 0 { + return app + } else { + return app + "/" + stream_ + } +} + func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookState) { log.Sugar.Infof("rtmp onpublish app:%s stream:%s conn:%s", app, stream_, s.conn.RemoteAddr().String()) - sourceId := app + "_" + stream_ + sourceId := s.generateSourceId(app, stream_) source := NewPublisher(sourceId, s.stack, s.conn) //设置推流的音视频回调 s.stack.SetOnPublishHandler(source) @@ -57,7 +67,7 @@ func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookSta } func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) { - sourceId := app + "_" + stream_ + sourceId := s.generateSourceId(app, stream_) //拉流事件Sink统一处理 sink := NewSink(stream.GenerateSinkId(s.conn.RemoteAddr()), sourceId, s.conn) diff --git a/rtsp/rtsp_handler.go b/rtsp/rtsp_handler.go index ff0e77e..bfc68b7 100644 --- a/rtsp/rtsp_handler.go +++ b/rtsp/rtsp_handler.go @@ -70,9 +70,14 @@ func (h handler) Process(session *session, method string, url_ *url.URL, headers return fmt.Errorf("please establish a session first") } - var err error - split := strings.Split(url_.Path, "/") - source := split[len(split)-1] + source := strings.TrimSpace(url_.Path) + if strings.HasPrefix(source, "/") { + source = source[1:] + } + + if len(strings.TrimSpace(source)) == 0 { + return fmt.Errorf("the request source cannot be empty") + } //反射调用各个处理函数 results := m.Call([]reflect.Value{ @@ -80,7 +85,7 @@ func (h handler) Process(session *session, method string, url_ *url.URL, headers reflect.ValueOf(Request{session, source, method, url_, headers}), }) - err, _ = results[2].Interface().(error) + err, _ := results[2].Interface().(error) if err != nil { return err } diff --git a/rtsp/rtsp_session.go b/rtsp/rtsp_session.go index 8f2fb14..0da2246 100644 --- a/rtsp/rtsp_session.go +++ b/rtsp/rtsp_session.go @@ -97,7 +97,7 @@ func parseMessage(data []byte) (string, *url.URL, textproto.MIMEHeader, error) { line, err := tp.ReadLine() split := strings.Split(line, " ") if len(split) < 3 { - panic(fmt.Errorf("unknow response line of response:%s", line)) + panic(fmt.Errorf("wrong request line %s", line)) } method := strings.ToUpper(split[0]) @@ -109,6 +109,15 @@ func parseMessage(data []byte) (string, *url.URL, textproto.MIMEHeader, error) { return "", nil, nil, err } + path := strings.TrimSpace(url_.Path) + if strings.HasPrefix(path, "/") { + path = path[1:] + } + + if len(strings.TrimSpace(path)) == 0 { + return "", nil, nil, fmt.Errorf("the request source cannot be empty") + } + header, err := tp.ReadMIMEHeader() if err != nil { return "", nil, nil, err diff --git a/stream/source.go b/stream/source.go index 7adee40..4670cf1 100644 --- a/stream/source.go +++ b/stream/source.go @@ -52,34 +52,6 @@ const ( SessionStateClose = SessionState(7) //关闭状态 ) -func (s SourceType) ToString() string { - if SourceTypeRtmp == s { - return "rtmp" - } else if SourceType28181 == s { - return "28181" - } else if SourceType1078 == s { - return "jt1078" - } - - panic(fmt.Sprintf("unknown source type %d", s)) -} - -func (p Protocol) ToString() string { - if ProtocolRtmp == p { - return "rtmp" - } else if ProtocolFlv == p { - return "flv" - } else if ProtocolRtsp == p { - return "rtsp" - } else if ProtocolHls == p { - return "hls" - } else if ProtocolRtc == p { - return "rtc" - } - - panic(fmt.Sprintf("unknown stream protocol %d", p)) -} - // ISource 父类Source负责, 除解析流以外的所有事情 type ISource interface { // Id Source的唯一ID/** @@ -140,8 +112,6 @@ type ISource interface { Init(input func(data []byte) error) } -var TranscoderFactory func(src utils.AVStream, dst utils.AVStream) transcode.ITranscoder - type SourceImpl struct { hookSessionImpl diff --git a/stream/source_utils.go b/stream/source_utils.go new file mode 100644 index 0000000..cf531b1 --- /dev/null +++ b/stream/source_utils.go @@ -0,0 +1,53 @@ +package stream + +import ( + "fmt" + "strings" +) + +func (s SourceType) ToString() string { + if SourceTypeRtmp == s { + return "rtmp" + } else if SourceType28181 == s { + return "28181" + } else if SourceType1078 == s { + return "jt1078" + } + + panic(fmt.Sprintf("unknown source type %d", s)) +} + +func (p Protocol) ToString() string { + if ProtocolRtmp == p { + return "rtmp" + } else if ProtocolFlv == p { + return "flv" + } else if ProtocolRtsp == p { + return "rtsp" + } else if ProtocolHls == p { + return "hls" + } else if ProtocolRtc == p { + return "rtc" + } + + panic(fmt.Sprintf("unknown stream protocol %d", p)) +} + +func Path2SourceId(path string, suffix string) (string, error) { + source := strings.TrimSpace(path) + if strings.HasPrefix(source, "/") { + source = source[1:] + } + + if len(suffix) > 0 && strings.HasSuffix(source, suffix) { + source = source[:len(source)-len(suffix)] + } + + source = strings.TrimSpace(source) + + if len(strings.TrimSpace(source)) == 0 { + return "", fmt.Errorf("the request source cannot be empty") + } + + return source, nil +}