From 1f06faf0f00d1db95d988d552e64ebb7b190450b Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Wed, 19 Nov 2025 11:04:43 +0800 Subject: [PATCH] opt: idr ring lock-free --- pkg/ring-writer.go | 83 +++++++++-- plugin/snap/api.go | 2 - publisher.go | 23 ++- server.go | 341 --------------------------------------------- server_grpc.go | 186 +++++++++++++++++++++++++ server_http.go | 186 +++++++++++++++++++++++++ 6 files changed, 459 insertions(+), 362 deletions(-) create mode 100644 server_grpc.go create mode 100644 server_http.go diff --git a/pkg/ring-writer.go b/pkg/ring-writer.go index 4a78fd4..5b53494 100644 --- a/pkg/ring-writer.go +++ b/pkg/ring-writer.go @@ -2,18 +2,83 @@ package pkg import ( "log/slog" - "sync" "sync/atomic" "time" - "github.com/langhuihui/gotask" + task "github.com/langhuihui/gotask" "m7s.live/v5/pkg/util" ) +type IDRNode struct { + Value *util.Ring[AVFrame] + next atomic.Pointer[IDRNode] + prev atomic.Pointer[IDRNode] +} + +func (n *IDRNode) Next() *IDRNode { + return n.next.Load() +} + +func (n *IDRNode) Prev() *IDRNode { + return n.prev.Load() +} + +type IDRList struct { + head atomic.Pointer[IDRNode] + tail atomic.Pointer[IDRNode] + len atomic.Int32 +} + +func (l *IDRList) Init() { + l.head.Store(nil) + l.tail.Store(nil) + l.len.Store(0) +} + +func (l *IDRList) Len() int { + return int(l.len.Load()) +} + +func (l *IDRList) Front() *IDRNode { + return l.head.Load() +} + +func (l *IDRList) Back() *IDRNode { + return l.tail.Load() +} + +func (l *IDRList) PushBack(v *util.Ring[AVFrame]) { + node := &IDRNode{Value: v} + l.len.Add(1) + tail := l.tail.Load() + node.prev.Store(tail) + if tail != nil { + tail.next.Store(node) + } + l.tail.Store(node) + if l.head.Load() == nil { + l.head.Store(node) + } +} + +func (l *IDRList) Remove(node *IDRNode) { + head := l.head.Load() + if head != node { + return + } + next := head.next.Load() + l.head.Store(next) + if next != nil { + next.prev.Store(nil) + } else { + l.tail.Store(nil) + } + l.len.Add(-1) +} + type RingWriter struct { *util.Ring[AVFrame] - sync.RWMutex - IDRingList util.List[*util.Ring[AVFrame]] // 关键帧链表 + IDRingList IDRList // 关键帧链表 BufferRange util.Range[time.Duration] SizeRange util.Range[int] pool *util.Ring[AVFrame] @@ -98,8 +163,6 @@ func (rb *RingWriter) Dispose() { } func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] { - rb.RLock() - defer rb.RUnlock() if latest := rb.IDRingList.Back(); latest != nil { return latest.Value } @@ -107,8 +170,6 @@ func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] { } func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] { - rb.RLock() - defer rb.RUnlock() if latest := rb.IDRingList.Front(); latest != nil { return latest.Value } @@ -116,8 +177,6 @@ func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] { } func (rb *RingWriter) GetHistoryIDR(bufTime time.Duration) *util.Ring[AVFrame] { - rb.RLock() - defer rb.RUnlock() for item := rb.IDRingList.Back(); item != nil; item = item.Prev() { if rb.LastValue.Timestamp-item.Value.Value.Timestamp >= bufTime { return item.Value @@ -135,9 +194,7 @@ func (rb *RingWriter) CurrentBufferTime() time.Duration { } func (rb *RingWriter) PushIDR() { - rb.Lock() rb.IDRingList.PushBack(rb.Ring) - rb.Unlock() } func (rb *RingWriter) Step() (normal bool) { @@ -159,9 +216,7 @@ func (rb *RingWriter) Step() (normal bool) { } else if next == oldIDR.Value { if nextOld := oldIDR.Next(); nextOld != nil && rb.durationFrom(nextOld.Value) > rb.BufferRange[0] { rb.SLogger.Log(nil, task.TraceLevel, "remove old idr") - rb.Lock() rb.IDRingList.Remove(oldIDR) - rb.Unlock() } else { rb.glow(5, "not enough buffer") next = rb.Next() diff --git a/plugin/snap/api.go b/plugin/snap/api.go index 60ee110..95ce180 100755 --- a/plugin/snap/api.go +++ b/plugin/snap/api.go @@ -450,7 +450,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end } // 遍历IDRingList获取关键帧 - videoTrack.RLock() if videoTrack.IDRingList.Len() > 0 { // 从头开始遍历所有关键帧 for idrElem := videoTrack.IDRingList.Front(); idrElem != nil; idrElem = idrElem.Next() { @@ -467,7 +466,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end } } } - videoTrack.RUnlock() // 如果没有找到关键帧,但有GOP信息,则使用估算的GOP间隔生成时间点 if len(snapTimes) == 0 && gopDuration > 0 { diff --git a/publisher.go b/publisher.go index 8ac1de9..eb574d3 100644 --- a/publisher.go +++ b/publisher.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "reflect" - "slices" "sync" "time" @@ -112,6 +111,7 @@ type Publisher struct { OnGetPosition func() time.Time PullProxyConfig *PullProxyConfig dropAfterTs time.Duration + bufferTimeCounts map[time.Duration]int } type PublishParam struct { @@ -134,6 +134,7 @@ func (p *Publisher) Start() (err error) { if p.MaxCount > 0 && s.Streams.Length >= p.MaxCount { return ErrPublishMaxCount } + p.bufferTimeCounts = make(map[time.Duration]int) p.Info("publish") p.processPullProxyOnStart() p.audioReady = util.NewPromiseWithTimeout(p, p.PublishTimeout) @@ -200,14 +201,25 @@ func (p *Publisher) Go() error { func (p *Publisher) RemoveSubscriber(subscriber *Subscriber) { p.Subscribers.Remove(subscriber) + p.bufferTimeCounts[subscriber.BufferTime]-- + if p.bufferTimeCounts[subscriber.BufferTime] == 0 { + delete(p.bufferTimeCounts, subscriber.BufferTime) + } p.Info("subscriber -1", "count", p.Subscribers.Length) if p.Plugin == nil { return } - if subscriber.BufferTime == p.BufferTime && p.Subscribers.Length > 0 { - p.BufferTime = slices.MaxFunc(p.Subscribers.Items, func(a, b *Subscriber) int { - return int(a.BufferTime - b.BufferTime) - }).BufferTime + if p.Subscribers.Length > 0 { + var maxBuf time.Duration + for k := range p.bufferTimeCounts { + if k > maxBuf { + maxBuf = k + } + } + if defaultBuf := p.Plugin.GetCommonConf().Publish.BufferTime; maxBuf < defaultBuf { + maxBuf = defaultBuf + } + p.BufferTime = maxBuf } else { p.BufferTime = p.Plugin.GetCommonConf().Publish.BufferTime } @@ -235,6 +247,7 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) { } subscriber.waitStartTime = time.Time{} if p.Subscribers.AddUnique(subscriber) { + p.bufferTimeCounts[subscriber.BufferTime]++ p.Info("subscriber +1", "count", p.Subscribers.Length) if subscriber.BufferTime > p.BufferTime { p.BufferTime = subscriber.BufferTime diff --git a/server.go b/server.go index 90be64d..64caf09 100644 --- a/server.go +++ b/server.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "log/slog" "net/http" "net/url" @@ -32,13 +31,10 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/metadata" "gorm.io/gorm" "m7s.live/v5/pb" . "m7s.live/v5/pkg" - "m7s.live/v5/pkg/auth" "m7s.live/v5/pkg/db" - "m7s.live/v5/pkg/format" "m7s.live/v5/pkg/util" ) @@ -58,9 +54,6 @@ var ( ) type ( - RedirectAdvisor interface { - GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool) - } ServerConfig struct { FatalDir string `default:"fatal" desc:""` PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔 @@ -134,21 +127,9 @@ type ( task.TickTask s *Server } - GRPCServer struct { - task.Task - s *Server - tcpTask *config.ListenTCPWork - } RawConfig = map[string]map[string]any ) -// context key type & keys -type ctxKey int - -const ( - ctxKeyClaims ctxKey = iota -) - func (w *WaitStream) GetKey() string { return w.StreamPath } @@ -172,77 +153,6 @@ func NewServer(conf any) (s *Server) { return } -func (s *Server) GetRedirectAdvisor() RedirectAdvisor { - if s == nil { - return nil - } - s.redirectOnce.Do(func() { - s.Plugins.Range(func(plugin *Plugin) bool { - if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok { - s.redirectAdvisor = advisor - return false - } - return true - }) - }) - return s.redirectAdvisor -} - -// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate. -func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool { - if s == nil { - return false - } - advisor := s.GetRedirectAdvisor() - if advisor == nil { - return false - } - targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host) - if !ok || targetHost == "" { - return false - } - if statusCode == 0 { - statusCode = http.StatusFound - } - redirectURL := buildRedirectURL(r, targetHost) - http.Redirect(w, r, redirectURL, statusCode) - return true -} - -func buildRedirectURL(r *http.Request, host string) string { - scheme := requestScheme(r) - if isWebSocketRequest(r) { - switch scheme { - case "https": - scheme = "wss" - case "http": - scheme = "ws" - } - } - target := &url.URL{ - Scheme: scheme, - Host: host, - Path: r.URL.Path, - RawQuery: r.URL.RawQuery, - } - return target.String() -} - -func requestScheme(r *http.Request) string { - if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { - return proto - } - if r.TLS != nil { - return "https" - } - return "http" -} - -func isWebSocketRequest(r *http.Request) bool { - connection := strings.ToLower(r.Header.Get("Connection")) - return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket") -} - func Run(ctx context.Context, conf any) (err error) { for err = ErrRestart; errors.Is(err, ErrRestart); err = Servers.AddTask(NewServer(conf), ctx).WaitStopped() { } @@ -664,14 +574,6 @@ func (c *CheckSubWaitTimeout) Tick(any) { c.s.Waiting.checkTimeout() } -func (gRPC *GRPCServer) Dispose() { - gRPC.s.Stop(gRPC.StopReason()) -} - -func (gRPC *GRPCServer) Go() (err error) { - return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener) -} - func (s *Server) CallOnStreamTask(callback func()) { s.Streams.Call(callback) } @@ -723,246 +625,3 @@ func (s *Server) OnSubscribe(streamPath string, args url.Values) { } } } - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Check for location-based forwarding first - if s.Location != nil { - for pattern, target := range s.Location { - if pattern.MatchString(r.URL.Path) { - // Rewrite the URL path and handle locally - r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target) - // Forward to local handler - s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r) - return - } - } - } - - // 检查 admin.zip 是否需要重新加载 - now := time.Now() - if now.Sub(s.Admin.lastCheckTime) > checkInterval { - if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime { - s.Info("admin.zip changed, reloading...") - s.loadAdminZip() - } - s.Admin.lastCheckTime = now - } - - if s.Admin.zipReader != nil { - // Handle root path redirect to HomePage - if r.URL.Path == "/" { - http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound) - return - } - - // For .map files, set correct content-type before serving - if strings.HasSuffix(r.URL.Path, ".map") { - filePath := strings.TrimPrefix(r.URL.Path, "/admin/") - file, err := s.Admin.zipReader.Open(filePath) - if err != nil { - http.NotFound(w, r) - return - } - defer file.Close() - w.Header().Set("Content-Type", "application/json") - io.Copy(w, file) - return - } - - http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin")) - return - } - if r.URL.Path == "/favicon.ico" { - http.ServeFile(w, r, "favicon.ico") - return - } - _, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime) - for plugin := range s.Plugins.Range { - _, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version) - } - for _, api := range s.apiList { - _, _ = fmt.Fprintf(w, "%s\n", api) - } -} - -// ValidateToken implements auth.TokenValidator -func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) { - if !s.ServerConfig.Admin.EnableLogin { - return &auth.JWTClaims{Username: "anonymous"}, nil - } - return auth.ValidateJWT(tokenString) -} - -// Login implements the Login RPC method -func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) { - res = &pb.LoginResponse{} - if !s.ServerConfig.Admin.EnableLogin { - res.Data = &pb.LoginSuccess{ - Token: "monibuca", - UserInfo: &pb.UserInfo{ - Username: "anonymous", - ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), - }, - } - return - } - if s.DB == nil { - err = ErrNoDB - return - } - var user db.User - if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil { - return - } - - if !user.CheckPassword(req.Password) { - err = ErrInvalidCredentials - return - } - - // Generate JWT token - var tokenString string - tokenString, err = auth.GenerateToken(user.Username) - if err != nil { - return - } - - // Update last login time - s.DB.Model(&user).Update("last_login", time.Now()) - res.Data = &pb.LoginSuccess{ - Token: tokenString, - UserInfo: &pb.UserInfo{ - Username: user.Username, - ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), - }, - } - return -} - -// Logout implements the Logout RPC method -func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) { - // In a more complex system, you might want to maintain a blacklist of logged-out tokens - // For now, we'll just return success as JWT tokens are stateless - res = &pb.LogoutResponse{Code: 0, Message: "success"} - return -} - -// GetUserInfo implements the GetUserInfo RPC method -func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) { - if !s.ServerConfig.Admin.EnableLogin { - res = &pb.UserInfoResponse{ - Code: 0, - Message: "success", - Data: &pb.UserInfo{ - Username: "anonymous", - ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), - }, - } - return - } - res = &pb.UserInfoResponse{} - claims, err := s.ValidateToken(req.Token) - if err != nil { - err = ErrInvalidCredentials - return - } - - var user db.User - if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil { - return - } - - // Token is valid for 24 hours from now - expiresAt := time.Now().Add(24 * time.Hour).Unix() - - return &pb.UserInfoResponse{ - Code: 0, - Message: "success", - Data: &pb.UserInfo{ - Username: user.Username, - ExpiresAt: expiresAt, - }, - }, nil -} - -// AuthInterceptor creates a new unary interceptor for authentication -func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if !s.ServerConfig.Admin.EnableLogin { - return handler(ctx, req) - } - - // Skip auth for login endpoint - if info.FullMethod == "/pb.Auth/Login" { - return handler(ctx, req) - } - - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, errors.New("missing metadata") - } - - authHeader := md.Get("authorization") - if len(authHeader) == 0 { - return nil, errors.New("missing authorization header") - } - - tokenString := strings.TrimPrefix(authHeader[0], "Bearer ") - claims, err := s.ValidateToken(tokenString) - if err != nil { - return nil, errors.New("invalid token") - } - - // Check if token needs refresh - shouldRefresh, err := auth.ShouldRefreshToken(tokenString) - if err == nil && shouldRefresh { - newToken, err := auth.RefreshToken(tokenString) - if err == nil { - // Add new token to response headers - header := metadata.New(map[string]string{ - "new-token": newToken, - }) - grpc.SetHeader(ctx, header) - } - } - - // Add claims to context - newCtx := context.WithValue(ctx, ctxKeyClaims, claims) - return handler(newCtx, req) - } -} - -func (s *Server) annexB(w http.ResponseWriter, r *http.Request) { - streamPath := r.PathValue("streamPath") - - if r.URL.RawQuery != "" { - streamPath += "?" + r.URL.RawQuery - } - var conf = s.config.Subscribe - conf.SubType = SubscribeTypeServer - conf.SubAudio = false - suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - var ctx util.HTTP_WS_Writer - ctx.Conn, err = suber.CheckWebSocket(w, r) - if err != nil { - return - } - ctx.WriteTimeout = s.GetCommonConf().WriteTimeout - ctx.ContentType = "application/octet-stream" - ctx.ServeHTTP(w, r) - - PlayBlock(suber, func(frame *format.RawAudio) (err error) { - return nil - }, func(frame *format.AnnexB) (err error) { - _, err = frame.WriteTo(&ctx) - if err != nil { - return - } - return ctx.Flush() - }) -} diff --git a/server_grpc.go b/server_grpc.go new file mode 100644 index 0000000..89bbf86 --- /dev/null +++ b/server_grpc.go @@ -0,0 +1,186 @@ +package m7s + +import ( + "context" + "errors" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + task "github.com/langhuihui/gotask" + "m7s.live/v5/pb" + . "m7s.live/v5/pkg" + "m7s.live/v5/pkg/auth" + "m7s.live/v5/pkg/config" + "m7s.live/v5/pkg/db" +) + +// context key type & keys +type ctxKey int + +const ( + ctxKeyClaims ctxKey = iota +) + +type GRPCServer struct { + task.Task + s *Server + tcpTask *config.ListenTCPWork +} + +func (gRPC *GRPCServer) Dispose() { + gRPC.s.Stop(gRPC.StopReason()) +} + +func (gRPC *GRPCServer) Go() (err error) { + return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener) +} + +// ValidateToken implements auth.TokenValidator +func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) { + if !s.ServerConfig.Admin.EnableLogin { + return &auth.JWTClaims{Username: "anonymous"}, nil + } + return auth.ValidateJWT(tokenString) +} + +// Login implements the Login RPC method +func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) { + res = &pb.LoginResponse{} + if !s.ServerConfig.Admin.EnableLogin { + res.Data = &pb.LoginSuccess{ + Token: "monibuca", + UserInfo: &pb.UserInfo{ + Username: "anonymous", + ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), + }, + } + return + } + if s.DB == nil { + err = ErrNoDB + return + } + var user db.User + if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil { + return + } + + if !user.CheckPassword(req.Password) { + err = ErrInvalidCredentials + return + } + + // Generate JWT token + var tokenString string + tokenString, err = auth.GenerateToken(user.Username) + if err != nil { + return + } + + // Update last login time + s.DB.Model(&user).Update("last_login", time.Now()) + res.Data = &pb.LoginSuccess{ + Token: tokenString, + UserInfo: &pb.UserInfo{ + Username: user.Username, + ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), + }, + } + return +} + +// Logout implements the Logout RPC method +func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) { + // In a more complex system, you might want to maintain a blacklist of logged-out tokens + // For now, we'll just return success as JWT tokens are stateless + res = &pb.LogoutResponse{Code: 0, Message: "success"} + return +} + +// GetUserInfo implements the GetUserInfo RPC method +func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) { + if !s.ServerConfig.Admin.EnableLogin { + res = &pb.UserInfoResponse{ + Code: 0, + Message: "success", + Data: &pb.UserInfo{ + Username: "anonymous", + ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), + }, + } + return + } + res = &pb.UserInfoResponse{} + claims, err := s.ValidateToken(req.Token) + if err != nil { + err = ErrInvalidCredentials + return + } + + var user db.User + if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil { + return + } + + // Token is valid for 24 hours from now + expiresAt := time.Now().Add(24 * time.Hour).Unix() + + return &pb.UserInfoResponse{ + Code: 0, + Message: "success", + Data: &pb.UserInfo{ + Username: user.Username, + ExpiresAt: expiresAt, + }, + }, nil +} + +// AuthInterceptor creates a new unary interceptor for authentication +func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if !s.ServerConfig.Admin.EnableLogin { + return handler(ctx, req) + } + + // Skip auth for login endpoint + if info.FullMethod == "/pb.Auth/Login" { + return handler(ctx, req) + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata") + } + + authHeader := md.Get("authorization") + if len(authHeader) == 0 { + return nil, errors.New("missing authorization header") + } + + tokenString := strings.TrimPrefix(authHeader[0], "Bearer ") + claims, err := s.ValidateToken(tokenString) + if err != nil { + return nil, errors.New("invalid token") + } + + // Check if token needs refresh + shouldRefresh, err := auth.ShouldRefreshToken(tokenString) + if err == nil && shouldRefresh { + newToken, err := auth.RefreshToken(tokenString) + if err == nil { + // Add new token to response headers + header := metadata.New(map[string]string{ + "new-token": newToken, + }) + grpc.SetHeader(ctx, header) + } + } + + // Add claims to context + newCtx := context.WithValue(ctx, ctxKeyClaims, claims) + return handler(newCtx, req) + } +} + diff --git a/server_http.go b/server_http.go new file mode 100644 index 0000000..ad64278 --- /dev/null +++ b/server_http.go @@ -0,0 +1,186 @@ +package m7s + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + "m7s.live/v5/pkg/format" + "m7s.live/v5/pkg/util" +) + +type RedirectAdvisor interface { + GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool) +} + +func (s *Server) GetRedirectAdvisor() RedirectAdvisor { + if s == nil { + return nil + } + s.redirectOnce.Do(func() { + s.Plugins.Range(func(plugin *Plugin) bool { + if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok { + s.redirectAdvisor = advisor + return false + } + return true + }) + }) + return s.redirectAdvisor +} + +// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate. +func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool { + if s == nil { + return false + } + advisor := s.GetRedirectAdvisor() + if advisor == nil { + return false + } + targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host) + if !ok || targetHost == "" { + return false + } + if statusCode == 0 { + statusCode = http.StatusFound + } + redirectURL := buildRedirectURL(r, targetHost) + http.Redirect(w, r, redirectURL, statusCode) + return true +} + +func buildRedirectURL(r *http.Request, host string) string { + scheme := requestScheme(r) + if isWebSocketRequest(r) { + switch scheme { + case "https": + scheme = "wss" + case "http": + scheme = "ws" + } + } + target := &url.URL{ + Scheme: scheme, + Host: host, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + return target.String() +} + +func requestScheme(r *http.Request) string { + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + if r.TLS != nil { + return "https" + } + return "http" +} + +func isWebSocketRequest(r *http.Request) bool { + connection := strings.ToLower(r.Header.Get("Connection")) + return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket") +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check for location-based forwarding first + if s.Location != nil { + for pattern, target := range s.Location { + if pattern.MatchString(r.URL.Path) { + // Rewrite the URL path and handle locally + r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target) + // Forward to local handler + s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r) + return + } + } + } + + // 检查 admin.zip 是否需要重新加载 + now := time.Now() + if now.Sub(s.Admin.lastCheckTime) > checkInterval { + if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime { + s.Info("admin.zip changed, reloading...") + s.loadAdminZip() + } + s.Admin.lastCheckTime = now + } + + if s.Admin.zipReader != nil { + // Handle root path redirect to HomePage + if r.URL.Path == "/" { + http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound) + return + } + + // For .map files, set correct content-type before serving + if strings.HasSuffix(r.URL.Path, ".map") { + filePath := strings.TrimPrefix(r.URL.Path, "/admin/") + file, err := s.Admin.zipReader.Open(filePath) + if err != nil { + http.NotFound(w, r) + return + } + defer file.Close() + w.Header().Set("Content-Type", "application/json") + io.Copy(w, file) + return + } + + http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin")) + return + } + if r.URL.Path == "/favicon.ico" { + http.ServeFile(w, r, "favicon.ico") + return + } + _, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime) + for plugin := range s.Plugins.Range { + _, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version) + } + for _, api := range s.apiList { + _, _ = fmt.Fprintf(w, "%s\n", api) + } +} + +func (s *Server) annexB(w http.ResponseWriter, r *http.Request) { + streamPath := r.PathValue("streamPath") + + if r.URL.RawQuery != "" { + streamPath += "?" + r.URL.RawQuery + } + var conf = s.config.Subscribe + conf.SubType = SubscribeTypeServer + conf.SubAudio = false + suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var ctx util.HTTP_WS_Writer + ctx.Conn, err = suber.CheckWebSocket(w, r) + if err != nil { + return + } + ctx.WriteTimeout = s.GetCommonConf().WriteTimeout + ctx.ContentType = "application/octet-stream" + ctx.ServeHTTP(w, r) + + PlayBlock(suber, func(frame *format.RawAudio) (err error) { + return nil + }, func(frame *format.AnnexB) (err error) { + _, err = frame.WriteTo(&ctx) + if err != nil { + return + } + return ctx.Flush() + }) +} +