diff --git a/plugin/flv/index.go b/plugin/flv/index.go index 72c23ec..b4de010 100644 --- a/plugin/flv/index.go +++ b/plugin/flv/index.go @@ -47,7 +47,12 @@ func (plugin *FLVPlugin) Start() (err error) { } func (plugin *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { - streamPath := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/"), ".flv") + rawPath := strings.TrimPrefix(r.URL.Path, "/") + if plugin.Server != nil && plugin.Server.RedirectIfNeeded(w, r, "http", rawPath) { + plugin.Debug("redirect issued", "protocol", "http", "path", rawPath) + return + } + streamPath := strings.TrimSuffix(rawPath, ".flv") var err error defer func() { if err != nil { diff --git a/plugin/hls/index.go b/plugin/hls/index.go index d89e48c..fffdf09 100644 --- a/plugin/hls/index.go +++ b/plugin/hls/index.go @@ -207,6 +207,11 @@ func (config *HLSPlugin) vod(w http.ResponseWriter, r *http.Request) { } func (config *HLSPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + redirectPath := strings.TrimPrefix(r.URL.Path, "/") + if config.Server.RedirectIfNeeded(w, r, "http", redirectPath) { + config.Debug("redirect issued", "protocol", "http", "path", redirectPath) + return + } fileName := strings.TrimPrefix(r.URL.Path, "/") query := r.URL.Query() waitTimeout, err := time.ParseDuration(query.Get("timeout")) diff --git a/plugin/mp4/index.go b/plugin/mp4/index.go index ba4c3f0..6d99445 100644 --- a/plugin/mp4/index.go +++ b/plugin/mp4/index.go @@ -10,8 +10,7 @@ import ( "m7s.live/v5/pkg/codec" "m7s.live/v5/pkg/util" "m7s.live/v5/plugin/mp4/pb" - mp4 "m7s.live/v5/plugin/mp4/pkg" - pkg "m7s.live/v5/plugin/mp4/pkg" + mp4pkg "m7s.live/v5/plugin/mp4/pkg" "m7s.live/v5/plugin/mp4/pkg/box" ) @@ -37,8 +36,8 @@ var _ = m7s.InstallPlugin[MP4Plugin](m7s.PluginMeta{ DefaultYaml: defaultConfig, ServiceDesc: &pb.Api_ServiceDesc, RegisterGRPCHandler: pb.RegisterApiHandler, - NewPuller: pkg.NewPuller, - NewRecorder: pkg.NewRecorder, + NewPuller: mp4pkg.NewPuller, + NewRecorder: mp4pkg.NewRecorder, NewPullProxy: m7s.NewHTTPPullPorxy, }) @@ -57,7 +56,7 @@ func (p *MP4Plugin) Start() (err error) { if err != nil { return } - err = p.DB.AutoMigrate(&mp4.TagModel{}) + err = p.DB.AutoMigrate(&mp4pkg.TagModel{}) if err != nil { return } @@ -99,7 +98,12 @@ func (p *MP4Plugin) Start() (err error) { } func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { - streamPath := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/"), ".mp4") + redirectPath := strings.TrimPrefix(r.URL.Path, "/") + if p.Server != nil && p.Server.RedirectIfNeeded(w, r, "http", redirectPath) { + p.Debug("redirect issued", "protocol", "http", "path", redirectPath) + return + } + streamPath := strings.TrimSuffix(redirectPath, ".mp4") if r.URL.RawQuery != "" { streamPath += "?" + r.URL.RawQuery } @@ -118,13 +122,13 @@ func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx.ContentType = "video/mp4" ctx.ServeHTTP(w, r) - muxer := pkg.NewMuxer(pkg.FLAG_FRAGMENT) + muxer := mp4pkg.NewMuxer(mp4pkg.FLAG_FRAGMENT) err = muxer.WriteInitSegment(&ctx) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - var audio, video *pkg.Track + var audio, video *mp4pkg.Track var nextFragmentId uint32 if sub.Publisher.HasVideoTrack() && sub.SubVideo { v := sub.Publisher.VideoTrack.AVTrack @@ -181,7 +185,7 @@ func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } ctx.Flush() - m7s.PlayBlock(sub, func(frame *mp4.AudioFrame) (err error) { + m7s.PlayBlock(sub, func(frame *mp4pkg.AudioFrame) (err error) { if audio.Samplelist[0].Buffers != nil { audio.Samplelist[0].Duration = sub.AudioReader.AbsTime - audio.Samplelist[0].Timestamp nextFragmentId++ @@ -195,7 +199,7 @@ func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { audio.Samplelist[0].Timestamp = sub.AudioReader.AbsTime audio.Samplelist[0].Memory = frame.Memory return - }, func(frame *mp4.VideoFrame) (err error) { + }, func(frame *mp4pkg.VideoFrame) (err error) { if video.Samplelist[0].Buffers != nil { video.Samplelist[0].Duration = sub.VideoReader.AbsTime - video.Samplelist[0].Timestamp nextFragmentId++ diff --git a/plugin/rtsp/index.go b/plugin/rtsp/index.go index b1af32d..5f71f78 100644 --- a/plugin/rtsp/index.go +++ b/plugin/rtsp/index.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "strings" - "sync" task "github.com/langhuihui/gotask" "m7s.live/v5/pkg/util" @@ -24,16 +23,10 @@ var _ = m7s.InstallPlugin[RTSPPlugin](m7s.PluginMeta{ type RTSPPlugin struct { m7s.Plugin - UserName string `desc:"用户名"` - Password string `desc:"密码"` - UdpPort util.Range[uint16] `default:"20001-30000" desc:"媒体端口范围"` //媒体端口范围 - udpPorts chan uint16 - advisorOnce sync.Once - redirectAdvisor rtspRedirectAdvisor -} - -type rtspRedirectAdvisor interface { - ShouldRedirectRTSP(streamPath, currentHost string) (string, bool) + UserName string `desc:"用户名"` + Password string `desc:"密码"` + UdpPort util.Range[uint16] `default:"20001-30000" desc:"媒体端口范围"` //媒体端口范围 + udpPorts chan uint16 } func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) task.ITask { @@ -59,18 +52,6 @@ func (p *RTSPPlugin) Start() (err error) { return } -func (p *RTSPPlugin) findRedirectAdvisor() rtspRedirectAdvisor { - p.advisorOnce.Do(func() { - for plugin := range p.Server.Plugins.Range { - if advisor, ok := plugin.GetHandler().(rtspRedirectAdvisor); ok { - p.redirectAdvisor = advisor - break - } - } - }) - return p.redirectAdvisor -} - // 初始化UDP端口池 func (p *RTSPPlugin) initUDPPortPool() { if p.UdpPort.Valid() { diff --git a/plugin/rtsp/server.go b/plugin/rtsp/server.go index d3c1e0c..7c9c3ed 100644 --- a/plugin/rtsp/server.go +++ b/plugin/rtsp/server.go @@ -113,11 +113,23 @@ func (task *RTSPServer) Go() (err error) { if rawQuery != "" { streamPath += "?" + rawQuery } - if advisor := task.conf.findRedirectAdvisor(); advisor != nil { - if location, ok := advisor.ShouldRedirectRTSP(streamPath, task.URL.Host); ok { + if advisor := task.conf.Server.GetRedirectAdvisor(); advisor != nil { + if target, statusCode, ok := advisor.GetRedirectTarget("rtsp", streamPath, task.URL.Host); ok && target != "" { + location := "rtsp://" + target + if streamPath != "" { + if !strings.HasPrefix(streamPath, "/") { + location += "/" + } + location += streamPath + } + + if statusCode == 0 { + statusCode = http.StatusFound + } + res := &util.Response{ - StatusCode: http.StatusFound, - Status: "Found", + StatusCode: statusCode, + Status: http.StatusText(statusCode), Header: textproto.MIMEHeader{ "Location": {location}, }, diff --git a/plugin/webrtc/api.go b/plugin/webrtc/api.go index 68d4457..e3742bd 100644 --- a/plugin/webrtc/api.go +++ b/plugin/webrtc/api.go @@ -12,6 +12,11 @@ import ( // https://datatracker.ietf.org/doc/html/draft-ietf-wish-whip func (conf *WebRTCPlugin) servePush(w http.ResponseWriter, r *http.Request) { + redirectPath := strings.TrimPrefix(r.URL.Path, "/") + if conf.Server.RedirectIfNeeded(w, r, "webrtc", redirectPath) { + conf.Debug("redirect issued", "protocol", "webrtc", "path", redirectPath) + return + } streamPath := r.PathValue("streamPath") rawQuery := r.URL.RawQuery auth := r.Header.Get("Authorization") @@ -71,6 +76,11 @@ func (conf *WebRTCPlugin) servePush(w http.ResponseWriter, r *http.Request) { } func (conf *WebRTCPlugin) servePlay(w http.ResponseWriter, r *http.Request) { + redirectPath := strings.TrimPrefix(r.URL.Path, "/") + if conf.Server.RedirectIfNeeded(w, r, "webrtc", redirectPath) { + conf.Debug("redirect issued", "protocol", "webrtc", "path", redirectPath) + return + } w.Header().Set("Content-Type", "application/sdp") streamPath := r.PathValue("streamPath") rawQuery := r.URL.RawQuery diff --git a/server.go b/server.go index 89763e2..55c018e 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "path/filepath" "runtime/debug" "strings" + "sync" "time" "gopkg.in/yaml.v3" @@ -57,6 +58,9 @@ var ( ) type ( + RedirectAdvisor interface { + GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool) + } ServerConfig struct { FatalDir string `default:"fatal" desc:""` PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔 @@ -113,6 +117,8 @@ type ( PushProxies PushProxyManager Subscribers SubscriberCollection LogHandler MultiLogHandler + redirectAdvisor RedirectAdvisor + redirectOnce sync.Once apiList []string grpcServer *grpc.Server grpcClientConn *grpc.ClientConn @@ -166,6 +172,77 @@ func NewServer(conf any) (s *Server) { return } +func (s *Server) GetRedirectAdvisor() RedirectAdvisor { + if s == nil { + return nil + } + s.redirectOnce.Do(func() { + s.Plugins.Range(func(plugin *Plugin) bool { + if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok { + s.redirectAdvisor = advisor + return false + } + return true + }) + }) + return s.redirectAdvisor +} + +// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate. +func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool { + if s == nil { + return false + } + advisor := s.GetRedirectAdvisor() + if advisor == nil { + return false + } + targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host) + if !ok || targetHost == "" { + return false + } + if statusCode == 0 { + statusCode = http.StatusFound + } + redirectURL := buildRedirectURL(r, targetHost) + http.Redirect(w, r, redirectURL, statusCode) + return true +} + +func buildRedirectURL(r *http.Request, host string) string { + scheme := requestScheme(r) + if isWebSocketRequest(r) { + switch scheme { + case "https": + scheme = "wss" + case "http": + scheme = "ws" + } + } + target := &url.URL{ + Scheme: scheme, + Host: host, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + return target.String() +} + +func requestScheme(r *http.Request) string { + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + if r.TLS != nil { + return "https" + } + return "http" +} + +func isWebSocketRequest(r *http.Request) bool { + connection := strings.ToLower(r.Header.Get("Connection")) + return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket") +} + func Run(ctx context.Context, conf any) (err error) { for err = ErrRestart; errors.Is(err, ErrRestart); err = Servers.AddTask(NewServer(conf), ctx).WaitStopped() { }