refactor: 优化流订阅处理逻辑

This commit is contained in:
ydajiang
2025-05-14 19:50:12 +08:00
parent 24fc44f9c7
commit 7486fc1491
6 changed files with 58 additions and 59 deletions

75
api.go
View File

@@ -163,27 +163,25 @@ func (api *ApiServer) onFlv(sourceId string, w http.ResponseWriter, r *http.Requ
func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) onWSFlv(sourceId string, w http.ResponseWriter, r *http.Request) {
conn, err := api.upgrader.Upgrade(w, r, nil) conn, err := api.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Sugar.Errorf("websocket头检查失败 err:%s", err.Error()) log.Sugar.Errorf("ws拉流失败 source: %s err: %s", sourceId, err.Error())
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
sink := flv.NewFLVSink(api.generateSinkID(r.RemoteAddr), sourceId, flv.NewWSConn(conn)) sink := flv.NewFLVSink(api.generateSinkID(r.RemoteAddr), sourceId, flv.NewWSConn(conn))
sink.SetUrlValues(r.URL.Query()) ok := stream.SubscribeStream(sink, r.URL.Query())
log.Sugar.Infof("ws-flv 连接 sink:%s", sink.String()) if utils.HookStateOK != ok {
log.Sugar.Warnf("ws-flv 拉流失败 source: %s sink: %s", sourceId, sink.String())
_, state := stream.PreparePlaySink(sink) _ = conn.Close()
if utils.HookStateOK != state { } else {
log.Sugar.Warnf("ws-flv 播放失败 sink:%s", sink.String()) log.Sugar.Infof("ws-flv 拉流成功 source: %s sink: %s", sourceId, sink.String())
w.WriteHeader(http.StatusForbidden)
return
} }
netConn := conn.NetConn() netConn := conn.NetConn()
bytes := make([]byte, 64) bytes := make([]byte, 64)
for { for {
if _, err := netConn.Read(bytes); err != nil { if _, err := netConn.Read(bytes); err != nil {
log.Sugar.Infof("ws-flv 断开连接 sink:%s", sink.String()) log.Sugar.Infof("ws-flv 断开连接 source: %s sink:%s", sourceId, sink.String())
sink.Close() sink.Close()
break break
} }
@@ -195,28 +193,28 @@ func (api *ApiServer) onHttpFLV(sourceId string, w http.ResponseWriter, r *http.
w.Header().Set("Connection", "Keep-Alive") w.Header().Set("Connection", "Keep-Alive")
w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Transfer-Encoding", "chunked")
hj, ok := w.(http.Hijacker) var conn net.Conn
if !ok { if hj, ok := w.(http.Hijacker); !ok {
log.Sugar.Errorf("http-flv 拉流失败 不支持hijacking. source: %s remote: %s", sourceId, r.RemoteAddr)
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
return return
} } else {
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusOK) var err error
conn, _, err := hj.Hijack() if conn, _, err = hj.Hijack(); err != nil {
if err != nil { log.Sugar.Errorf("http-flv 拉流失败 source: %s remote: %s err: %s", sourceId, r.RemoteAddr, err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
}
} }
sink := flv.NewFLVSink(api.generateSinkID(r.RemoteAddr), sourceId, conn) sink := flv.NewFLVSink(api.generateSinkID(r.RemoteAddr), sourceId, conn)
sink.SetUrlValues(r.URL.Query()) ok := stream.SubscribeStream(sink, r.URL.Query())
log.Sugar.Infof("http-flv 连接 sink:%s", sink.String()) if utils.HookStateOK != ok {
log.Sugar.Warnf("http-flv 拉流失败 source: %s sink: %s", sourceId, sink.String())
_, state := stream.PreparePlaySink(sink) sink.Close()
if utils.HookStateOK != state { } else {
log.Sugar.Warnf("http-flv 播放失败 sink:%s", sink.String()) log.Sugar.Infof("http-flv 拉流成功 source: %s sink: %s", sourceId, sink.String())
w.WriteHeader(http.StatusForbidden)
return return
} }
@@ -294,10 +292,9 @@ func (api *ApiServer) onHLS(source string, w http.ResponseWriter, r *http.Reques
m3u8Pipe <- m3u8 m3u8Pipe <- m3u8
}, sid) }, sid)
sink.SetUrlValues(r.URL.Query()) ok := stream.SubscribeStream(sink, r.URL.Query())
if _, state := stream.PreparePlaySink(sink); utils.HookStateOK != state { if utils.HookStateOK != ok {
log.Sugar.Warnf("m3u8拉流失败 sink: %s", sink.String()) log.Sugar.Warnf("m3u8拉流失败 source: %s sink: %s", source, sink.String())
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return return
} }
@@ -332,13 +329,11 @@ func (api *ApiServer) onRtc(sourceId string, w http.ResponseWriter, r *http.Requ
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr) log.Sugar.Errorf("rtc拉流失败 err: %s remote: %s", err.Error(), r.RemoteAddr)
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} else if err := json.Unmarshal(data, &v); err != nil { } else if err := json.Unmarshal(data, &v); err != nil {
log.Sugar.Errorf("rtc请求错误 err:%s remote:%s", err.Error(), r.RemoteAddr) log.Sugar.Errorf("rtc拉流失败 err: %s remote: %s", err.Error(), r.RemoteAddr)
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
@@ -365,13 +360,11 @@ func (api *ApiServer) onRtc(sourceId string, w http.ResponseWriter, r *http.Requ
group.Done() group.Done()
}) })
sink.SetUrlValues(r.URL.Query()) log.Sugar.Infof("rtc拉流请求 source: %s sink: %s sdp:%v", sourceId, sink.String(), v.SDP)
log.Sugar.Infof("rtc 请求 sink:%s sdp:%v", sink.String(), v.SDP)
_, state := stream.PreparePlaySink(sink)
if utils.HookStateOK != state {
log.Sugar.Warnf("rtc 播放失败 sink:%s", sink.String())
ok := stream.SubscribeStream(sink, r.URL.Query())
if utils.HookStateOK != ok {
log.Sugar.Warnf("rtc拉流失败 source: %s sink: %s", sourceId, sink.String())
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
group.Done() group.Done()
} }

View File

@@ -207,8 +207,8 @@ func (api *ApiServer) OnGBAnswerCreate(v *GBOffer, w http.ResponseWriter, r *htt
} }
log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s", sink.GetID(), port, setup.TransportType()) log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s", sink.GetID(), port, setup.TransportType())
_, state := stream.PreparePlaySink(sink) ok := stream.SubscribeStream(sink, r.URL.Query())
if utils.HookStateOK != state { if utils.HookStateOK != ok {
err = fmt.Errorf("failed to prepare play sink") err = fmt.Errorf("failed to prepare play sink")
return return
} }

View File

@@ -56,19 +56,18 @@ func (s *Session) OnPlay(app, stream_ string) utils.HookState {
streamName, values := stream.ParseUrl(stream_) streamName, values := stream.ParseUrl(stream_)
sourceId := s.generateSourceID(app, streamName) sourceId := s.generateSourceID(app, streamName)
sink := NewSink(stream.NetAddr2SinkId(s.conn.RemoteAddr()), sourceId, s.conn, s.stack) sinkId := stream.NetAddr2SinkId(s.conn.RemoteAddr())
sink.SetUrlValues(values) log.Sugar.Infof("rtmp onplay app: %s stream: %s sink: %v conn: %s", app, stream_, sinkId, s.conn.RemoteAddr().String())
log.Sugar.Infof("rtmp onplay app: %s stream: %s sink: %v conn: %s", app, stream_, sink.GetID(), s.conn.RemoteAddr().String()) sink := NewSink(sinkId, sourceId, s.conn, s.stack)
ok := stream.SubscribeStream(sink, values)
_, state := stream.PreparePlaySink(sink) if utils.HookStateOK != ok {
if utils.HookStateOK != state {
log.Sugar.Errorf("rtmp拉流失败 source: %s sink: %s", sourceId, sink.GetID()) log.Sugar.Errorf("rtmp拉流失败 source: %s sink: %s", sourceId, sink.GetID())
} else { } else {
s.handle = sink s.handle = sink
} }
return state return ok
} }
func (s *Session) Input(data []byte) error { func (s *Session) Input(data []byte) error {

View File

@@ -132,10 +132,9 @@ func (h handler) OnDescribe(request Request) (*http.Response, []byte, error) {
request.session.response(response, []byte(sdp)) request.session.response(response, []byte(sdp))
}) })
sink.SetUrlValues(request.url.Query()) ok := stream.SubscribeStreamWithRead(sink, request.url.Query(), false)
_, code := stream.PreparePlaySinkWithReady(sink, false) if utils.HookStateOK != ok {
if utils.HookStateOK != code { return nil, nil, fmt.Errorf("hook failed. code: %d", ok)
return nil, nil, fmt.Errorf("hook failed. code: %d", code)
} }
request.session.sink = sink.(*Sink) request.session.sink = sink.(*Sink)

View File

@@ -7,10 +7,6 @@ import (
) )
func PreparePlaySink(sink Sink) (*http.Response, utils.HookState) { func PreparePlaySink(sink Sink) (*http.Response, utils.HookState) {
return PreparePlaySinkWithReady(sink, true)
}
func PreparePlaySinkWithReady(sink Sink, ok bool) (*http.Response, utils.HookState) {
var response *http.Response var response *http.Response
if AppConfig.Hooks.IsEnableOnPlay() { if AppConfig.Hooks.IsEnableOnPlay() {
@@ -24,7 +20,6 @@ func PreparePlaySinkWithReady(sink Sink, ok bool) (*http.Response, utils.HookSta
response = hook response = hook
} }
sink.SetReady(ok)
source := SourceManager.Find(sink.GetSourceID()) source := SourceManager.Find(sink.GetSourceID())
if source == nil { if source == nil {
log.Sugar.Infof("添加%s sink到等待队列 id: %v source: %s", sink.GetProtocol().String(), sink.GetID(), sink.GetSourceID()) log.Sugar.Infof("添加%s sink到等待队列 id: %v source: %s", sink.GetProtocol().String(), sink.GetID(), sink.GetSourceID())

View File

@@ -3,7 +3,9 @@ package stream
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/lkmio/avformat/utils"
"net" "net"
"net/url"
"strconv" "strconv"
) )
@@ -63,3 +65,14 @@ func ExecuteSyncEventOnTransStreamPublisher(sourceId string, event func()) bool
return false return false
} }
func SubscribeStream(sink Sink, values url.Values) utils.HookState {
return SubscribeStreamWithRead(sink, values, true)
}
func SubscribeStreamWithRead(sink Sink, values url.Values, ready bool) utils.HookState {
sink.SetReady(ready)
sink.SetUrlValues(values)
_, state := PreparePlaySink(sink)
return state
}