完善http接口封装

This commit is contained in:
yangjiechina
2024-04-10 10:34:11 +08:00
parent 8dc824494e
commit 598311f21f
5 changed files with 140 additions and 120 deletions

226
api.go
View File

@@ -19,71 +19,34 @@ import (
"time" "time"
) )
var upgrader *websocket.Upgrader type ApiServer struct {
upgrader *websocket.Upgrader
router *mux.Router
}
var apiServer *ApiServer
func init() { func init() {
upgrader = &websocket.Upgrader{ apiServer = &ApiServer{
upgrader: &websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
}, },
},
router: mux.NewRouter(),
} }
} }
func startApiServer(addr string) { func startApiServer(addr string) {
r := mux.NewRouter() apiServer.router.HandleFunc("/live/{source}", apiServer.filterLive)
/** apiServer.router.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) {
http://host:port/xxx.flv
http://host:port/xxx.rtc
http://host:port/xxx.m3u8
http://host:port/xxx_0.ts
ws://host:port/xxx.flv
*/
r.HandleFunc("/live/{source}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
source := vars["source"]
index := strings.LastIndex(source, ".")
if index < 0 || index == len(source)-1 {
log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
return
}
sourceId := source[:index]
format := source[index+1:]
if "flv" == format {
//判断是否是websocket请求
ws := true
if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) {
ws = false
} else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) {
ws = false
} else if !("13" == r.Header.Get("Sec-Websocket-Version")) {
ws = false
}
if ws {
onWSFlv(sourceId, w, r)
} else {
onFLV(sourceId, w, r)
}
} else if "m3u8" == format {
onHLS(sourceId, w, r)
} else if "ts" == format {
onTS(sourceId, w, r)
} else if "rtc" == format {
onRtc(sourceId, w, r)
}
})
r.HandleFunc("/rtc.html", func(writer http.ResponseWriter, request *http.Request) {
http.ServeFile(writer, request, "./rtc.html") http.ServeFile(writer, request, "./rtc.html")
}) })
http.Handle("/", r) http.Handle("/", apiServer.router)
srv := &http.Server{ srv := &http.Server{
Handler: r, Handler: apiServer.router,
Addr: addr, Addr: addr,
// Good practice: enforce timeouts for servers you create! // Good practice: enforce timeouts for servers you create!
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
@@ -97,27 +60,89 @@ func startApiServer(addr string) {
} }
} }
func onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) generateSinkId(remoteAddr string) stream.SinkId {
conn, err := upgrader.Upgrade(w, r, nil) tcpAddr, err := net.ResolveTCPAddr("tcp", remoteAddr)
if err != nil {
panic(err)
}
return stream.GenerateSinkId(tcpAddr)
}
func (api *ApiServer) doPlay(sink stream.ISink) utils.HookState {
ok := utils.HookStateOK
sink.Play(func() {
}, func(state utils.HookState) {
ok = state
})
return ok
}
func (api *ApiServer) filterLive(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
source := vars["source"]
index := strings.LastIndex(source, ".")
if index < 0 || index == len(source)-1 {
log.Sugar.Errorf("bad request:%s. stream format must be passed at the end of the URL", r.URL.Path)
w.WriteHeader(http.StatusBadRequest)
return
}
sourceId := source[:index]
format := source[index+1:]
/**
http://host:port/xxx.flv
http://host:port/xxx.rtc
http://host:port/xxx.m3u8
http://host:port/xxx_0.ts
ws://host:port/xxx.flv
*/
if "flv" == format {
//判断是否是websocket请求
ws := true
if !("upgrade" == strings.ToLower(r.Header.Get("Connection"))) {
ws = false
} else if !("websocket" == strings.ToLower(r.Header.Get("Upgrade"))) {
ws = false
} else if !("13" == r.Header.Get("Sec-Websocket-Version")) {
ws = false
}
if ws {
api.onWSFlv(sourceId, w, r)
} else {
api.onFLV(sourceId, w, r)
}
} else if "m3u8" == format {
api.onHLS(sourceId, w, r)
} else if "ts" == format {
api.onTS(sourceId, w, r)
} else if "rtc" == format {
api.onRtc(sourceId, w, r)
}
}
func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) {
conn, err := api.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Sugar.Errorf("websocket头检查失败 err:%s", err.Error()) log.Sugar.Errorf("websocket头检查失败 err:%s", err.Error())
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) sink := flv.NewFLVSink(api.generateSinkId(r.RemoteAddr), sourceId, flv.NewWSConn(conn))
sinkId := stream.GenerateSinkId(tcpAddr)
sink := flv.NewFLVSink(sinkId, sourceId, flv.NewWSConn(conn))
log.Sugar.Infof("ws-flv 连接 sink:%s", sink.PrintInfo()) log.Sugar.Infof("ws-flv 连接 sink:%s", sink.PrintInfo())
sink.(*stream.SinkImpl).Play(sink, func() { state := api.doPlay(sink)
if utils.HookStateOK != state {
}, func(state utils.HookState) { log.Sugar.Warnf("ws-flv 播放失败 state:%d sink:%s", state, sink.PrintInfo())
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return
conn.Close() }
})
netConn := conn.NetConn() netConn := conn.NetConn()
bytes := make([]byte, 64) bytes := make([]byte, 64)
@@ -130,7 +155,7 @@ func onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) {
} }
} }
func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) onFLV(sourceId string, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/x-flv") w.Header().Set("Content-Type", "video/x-flv")
w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Connection", "Keep-Alive")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
@@ -148,18 +173,16 @@ func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) {
return return
} }
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) sink := flv.NewFLVSink(api.generateSinkId(r.RemoteAddr), sourceId, conn)
sinkId := stream.GenerateSinkId(tcpAddr)
sink := flv.NewFLVSink(sinkId, sourceId, conn)
log.Sugar.Infof("http-flv 连接 sink:%s", sink.PrintInfo()) log.Sugar.Infof("http-flv 连接 sink:%s", sink.PrintInfo())
sink.(*stream.SinkImpl).Play(sink, func() {
}, func(state utils.HookState) { state := api.doPlay(sink)
if utils.HookStateOK != state {
log.Sugar.Warnf("http-flv 播放失败 state:%d sink:%s", state, sink.PrintInfo())
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return
conn.Close() }
})
bytes := make([]byte, 64) bytes := make([]byte, 64)
for { for {
@@ -170,7 +193,7 @@ func onFLV(sourceId string, w http.ResponseWriter, r *http.Request) {
} }
} }
func onTS(source string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) onTS(source string, w http.ResponseWriter, r *http.Request) {
if !stream.AppConfig.Hls.Enable { if !stream.AppConfig.Hls.Enable {
log.Sugar.Warnf("处理m3u8请求失败 server未开启hls request:%s", r.URL.Path) log.Sugar.Warnf("处理m3u8请求失败 server未开启hls request:%s", r.URL.Path)
http.Error(w, "hls disable", http.StatusInternalServerError) http.Error(w, "hls disable", http.StatusInternalServerError)
@@ -196,17 +219,16 @@ func onTS(source string, w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, tsPath) http.ServeFile(w, r, tsPath)
} }
func onHLS(sourceId string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) onHLS(sourceId string, w http.ResponseWriter, r *http.Request) {
if !stream.AppConfig.Hls.Enable { if !stream.AppConfig.Hls.Enable {
log.Sugar.Warnf("处理hls请求失败 server未开启hls request:%s", r.URL.Path) log.Sugar.Warnf("处理hls请求失败 server未开启hls")
http.Error(w, "hls disable", http.StatusInternalServerError) http.Error(w, "hls disable", http.StatusInternalServerError)
return return
} }
w.Header().Set("Content-Type", "application/vnd.apple.mpegurl") w.Header().Set("Content-Type", "application/vnd.apple.mpegurl")
//m3u8和ts会一直刷新, 每个请求只hook一次. //m3u8和ts会一直刷新, 每个请求只hook一次.
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) sinkId := api.generateSinkId(r.RemoteAddr)
sinkId := stream.GenerateSinkId(tcpAddr)
//hook成功后, 如果还没有m3u8文件等生成m3u8文件 //hook成功后, 如果还没有m3u8文件等生成m3u8文件
//后续直接返回当前m3u8文件 //后续直接返回当前m3u8文件
@@ -221,31 +243,26 @@ func onHLS(sourceId string, w http.ResponseWriter, r *http.Request) {
done <- 0 done <- 0
}) })
hookState := utils.HookStateOK state := api.doPlay(sink)
sink.Play(sink, func() { if utils.HookStateOK != state {
err := stream.SinkManager.Add(sink) log.Sugar.Warnf("m3u8 请求失败 state:%d sink:%s", state, sink.PrintInfo())
utils.Assert(err == nil)
}, func(state utils.HookState) {
log.Sugar.Warnf("hook播放事件失败 request:%s", r.URL.Path)
hookState = state
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
})
if utils.HookStateOK != hookState {
return return
} else {
err := stream.SinkManager.Add(sink)
utils.Assert(err == nil)
} }
select { select {
case <-done: case <-done:
case <-context.Done(): case <-context.Done():
log.Sugar.Infof("http m3u8连接断开")
break break
} }
} }
} }
func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) onRtc(sourceId string, w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Type string `json:"type"` Type string `json:"type"`
SDP string `json:"sdp"` SDP string `json:"sdp"`
@@ -253,19 +270,20 @@ func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
panic(err) log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr)
}
if err := json.Unmarshal(data, &v); err != nil { http.Error(w, err.Error(), http.StatusBadRequest)
panic(err) return
} } else if err := json.Unmarshal(data, &v); err != nil {
log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr)
tcpAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) http.Error(w, err.Error(), http.StatusBadRequest)
sinkId := stream.GenerateSinkId(tcpAddr) return
}
group := sync.WaitGroup{} group := sync.WaitGroup{}
group.Add(1) group.Add(1)
sink := rtc.NewSink(sinkId, sourceId, v.SDP, func(sdp string) { sink := rtc.NewSink(api.generateSinkId(r.RemoteAddr), sourceId, v.SDP, func(sdp string) {
response := struct { response := struct {
Type string `json:"type"` Type string `json:"type"`
SDP string `json:"sdp"` SDP string `json:"sdp"`
@@ -285,13 +303,15 @@ func onRtc(sourceId string, w http.ResponseWriter, r *http.Request) {
group.Done() group.Done()
}) })
sink.Play(sink, func() { log.Sugar.Infof("rtc 请求 sink:%s sdp:%v", sink.PrintInfo(), v.SDP)
state := api.doPlay(sink)
if utils.HookStateOK != state {
log.Sugar.Warnf("rtc 播放失败 state:%d sink:%s", state, sink.PrintInfo())
}, func(state utils.HookState) {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
group.Done() group.Done()
}) }
group.Wait() group.Wait()
} }

View File

@@ -61,7 +61,7 @@ func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState)
log.Sugar.Infof("rtmp onplay app:%s stream:%s sink:%v conn:%s", app, stream_, sink.Id(), s.conn.RemoteAddr().String()) log.Sugar.Infof("rtmp onplay app:%s stream:%s sink:%v conn:%s", app, stream_, sink.Id(), s.conn.RemoteAddr().String())
sink.(*stream.SinkImpl).Play(sink, func() { sink.Play(func() {
s.handle = sink s.handle = sink
response <- utils.HookStateOK response <- utils.HookStateOK
}, func(state utils.HookState) { }, func(state utils.HookState) {

View File

@@ -163,7 +163,7 @@ func (s *session) onDescribe(source string, headers textproto.MIMEHeader) error
code := utils.HookStateOK code := utils.HookStateOK
s.sink_ = sink_.(*sink) s.sink_ = sink_.(*sink)
sink_.(*sink).Play(sink_, func() { sink_.Play(func() {
}, func(state utils.HookState) { }, func(state utils.HookState) {
code = state code = state

View File

@@ -38,9 +38,9 @@ func NewPublishHookEventInfo(stream, remoteAddr string, protocol SourceType) eve
} }
type HookHandler interface { type HookHandler interface {
Play(sink ISink, success func(), failure func(state utils.HookState)) Play(success func(), failure func(state utils.HookState))
PlayDone(sink ISink, success func(), failure func(state utils.HookState)) PlayDone(success func(), failure func(state utils.HookState))
} }
type HookSession interface { type HookSession interface {

View File

@@ -192,18 +192,18 @@ func (s *SinkImpl) PrintInfo() string {
return fmt.Sprintf("%s-%v source:%s", s.ProtocolStr(), s.Id_, s.SourceId_) return fmt.Sprintf("%s-%v source:%s", s.ProtocolStr(), s.Id_, s.SourceId_)
} }
func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.HookState)) { func (s *SinkImpl) Play(success func(), failure func(state utils.HookState)) {
f := func() { f := func() {
source := SourceManager.Find(sink.SourceId()) source := SourceManager.Find(s.SourceId())
if source == nil { if source == nil {
log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", sink.ProtocolStr(), sink.Id(), sink.SourceId()) log.Sugar.Infof("添加sink到等待队列 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId())
sink.SetState(SessionStateWait) s.SetState(SessionStateWait)
AddSinkToWaitingQueue(sink.SourceId(), sink) AddSinkToWaitingQueue(s.SourceId(), s)
} else { } else {
log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", sink.ProtocolStr(), sink.Id(), sink.SourceId()) log.Sugar.Debugf("发送播放事件 sink:%s-%v source:%s", s.ProtocolStr(), s.Id(), s.SourceId())
source.AddEvent(SourceEventPlay, sink) source.AddEvent(SourceEventPlay, s)
} }
} }
@@ -213,23 +213,23 @@ func (s *SinkImpl) Play(sink ISink, success func(), failure func(state utils.Hoo
return return
} }
err := s.Hook(HookEventPlay, NewPlayHookEventInfo(sink.SourceId(), "", sink.Protocol()), func(response *http.Response) { err := s.Hook(HookEventPlay, NewPlayHookEventInfo(s.SourceId(), "", s.Protocol()), func(response *http.Response) {
f() f()
success() success()
}, func(response *http.Response, err error) { }, func(response *http.Response, err error) {
log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), sink.ProtocolStr(), sink.Id(), sink.SourceId()) log.Sugar.Errorf("Hook播放事件响应失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId())
failure(utils.HookStateFailure) failure(utils.HookStateFailure)
}) })
if err != nil { if err != nil {
log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), sink.ProtocolStr(), sink.Id(), sink.SourceId()) log.Sugar.Errorf("Hook播放事件发送失败 err:%s sink:%s-%v source:%s", err.Error(), s.ProtocolStr(), s.Id(), s.SourceId())
failure(utils.HookStateFailure) failure(utils.HookStateFailure)
return return
} }
} }
func (s *SinkImpl) PlayDone(source ISink, success func(), failure func(state utils.HookState)) { func (s *SinkImpl) PlayDone(success func(), failure func(state utils.HookState)) {
} }