opt: idr ring lock-free

This commit is contained in:
langhuihui
2025-11-19 11:04:43 +08:00
parent ac430ffd0d
commit 1f06faf0f0
6 changed files with 459 additions and 362 deletions

View File

@@ -2,18 +2,83 @@ package pkg
import (
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/langhuihui/gotask"
task "github.com/langhuihui/gotask"
"m7s.live/v5/pkg/util"
)
type IDRNode struct {
Value *util.Ring[AVFrame]
next atomic.Pointer[IDRNode]
prev atomic.Pointer[IDRNode]
}
func (n *IDRNode) Next() *IDRNode {
return n.next.Load()
}
func (n *IDRNode) Prev() *IDRNode {
return n.prev.Load()
}
type IDRList struct {
head atomic.Pointer[IDRNode]
tail atomic.Pointer[IDRNode]
len atomic.Int32
}
func (l *IDRList) Init() {
l.head.Store(nil)
l.tail.Store(nil)
l.len.Store(0)
}
func (l *IDRList) Len() int {
return int(l.len.Load())
}
func (l *IDRList) Front() *IDRNode {
return l.head.Load()
}
func (l *IDRList) Back() *IDRNode {
return l.tail.Load()
}
func (l *IDRList) PushBack(v *util.Ring[AVFrame]) {
node := &IDRNode{Value: v}
l.len.Add(1)
tail := l.tail.Load()
node.prev.Store(tail)
if tail != nil {
tail.next.Store(node)
}
l.tail.Store(node)
if l.head.Load() == nil {
l.head.Store(node)
}
}
func (l *IDRList) Remove(node *IDRNode) {
head := l.head.Load()
if head != node {
return
}
next := head.next.Load()
l.head.Store(next)
if next != nil {
next.prev.Store(nil)
} else {
l.tail.Store(nil)
}
l.len.Add(-1)
}
type RingWriter struct {
*util.Ring[AVFrame]
sync.RWMutex
IDRingList util.List[*util.Ring[AVFrame]] // 关键帧链表
IDRingList IDRList // 关键帧链表
BufferRange util.Range[time.Duration]
SizeRange util.Range[int]
pool *util.Ring[AVFrame]
@@ -98,8 +163,6 @@ func (rb *RingWriter) Dispose() {
}
func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] {
rb.RLock()
defer rb.RUnlock()
if latest := rb.IDRingList.Back(); latest != nil {
return latest.Value
}
@@ -107,8 +170,6 @@ func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] {
}
func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] {
rb.RLock()
defer rb.RUnlock()
if latest := rb.IDRingList.Front(); latest != nil {
return latest.Value
}
@@ -116,8 +177,6 @@ func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] {
}
func (rb *RingWriter) GetHistoryIDR(bufTime time.Duration) *util.Ring[AVFrame] {
rb.RLock()
defer rb.RUnlock()
for item := rb.IDRingList.Back(); item != nil; item = item.Prev() {
if rb.LastValue.Timestamp-item.Value.Value.Timestamp >= bufTime {
return item.Value
@@ -135,9 +194,7 @@ func (rb *RingWriter) CurrentBufferTime() time.Duration {
}
func (rb *RingWriter) PushIDR() {
rb.Lock()
rb.IDRingList.PushBack(rb.Ring)
rb.Unlock()
}
func (rb *RingWriter) Step() (normal bool) {
@@ -159,9 +216,7 @@ func (rb *RingWriter) Step() (normal bool) {
} else if next == oldIDR.Value {
if nextOld := oldIDR.Next(); nextOld != nil && rb.durationFrom(nextOld.Value) > rb.BufferRange[0] {
rb.SLogger.Log(nil, task.TraceLevel, "remove old idr")
rb.Lock()
rb.IDRingList.Remove(oldIDR)
rb.Unlock()
} else {
rb.glow(5, "not enough buffer")
next = rb.Next()

View File

@@ -450,7 +450,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end
}
// 遍历IDRingList获取关键帧
videoTrack.RLock()
if videoTrack.IDRingList.Len() > 0 {
// 从头开始遍历所有关键帧
for idrElem := videoTrack.IDRingList.Front(); idrElem != nil; idrElem = idrElem.Next() {
@@ -467,7 +466,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end
}
}
}
videoTrack.RUnlock()
// 如果没有找到关键帧但有GOP信息则使用估算的GOP间隔生成时间点
if len(snapTimes) == 0 && gopDuration > 0 {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"sync"
"time"
@@ -112,6 +111,7 @@ type Publisher struct {
OnGetPosition func() time.Time
PullProxyConfig *PullProxyConfig
dropAfterTs time.Duration
bufferTimeCounts map[time.Duration]int
}
type PublishParam struct {
@@ -134,6 +134,7 @@ func (p *Publisher) Start() (err error) {
if p.MaxCount > 0 && s.Streams.Length >= p.MaxCount {
return ErrPublishMaxCount
}
p.bufferTimeCounts = make(map[time.Duration]int)
p.Info("publish")
p.processPullProxyOnStart()
p.audioReady = util.NewPromiseWithTimeout(p, p.PublishTimeout)
@@ -200,14 +201,25 @@ func (p *Publisher) Go() error {
func (p *Publisher) RemoveSubscriber(subscriber *Subscriber) {
p.Subscribers.Remove(subscriber)
p.bufferTimeCounts[subscriber.BufferTime]--
if p.bufferTimeCounts[subscriber.BufferTime] == 0 {
delete(p.bufferTimeCounts, subscriber.BufferTime)
}
p.Info("subscriber -1", "count", p.Subscribers.Length)
if p.Plugin == nil {
return
}
if subscriber.BufferTime == p.BufferTime && p.Subscribers.Length > 0 {
p.BufferTime = slices.MaxFunc(p.Subscribers.Items, func(a, b *Subscriber) int {
return int(a.BufferTime - b.BufferTime)
}).BufferTime
if p.Subscribers.Length > 0 {
var maxBuf time.Duration
for k := range p.bufferTimeCounts {
if k > maxBuf {
maxBuf = k
}
}
if defaultBuf := p.Plugin.GetCommonConf().Publish.BufferTime; maxBuf < defaultBuf {
maxBuf = defaultBuf
}
p.BufferTime = maxBuf
} else {
p.BufferTime = p.Plugin.GetCommonConf().Publish.BufferTime
}
@@ -235,6 +247,7 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) {
}
subscriber.waitStartTime = time.Time{}
if p.Subscribers.AddUnique(subscriber) {
p.bufferTimeCounts[subscriber.BufferTime]++
p.Info("subscriber +1", "count", p.Subscribers.Length)
if subscriber.BufferTime > p.BufferTime {
p.BufferTime = subscriber.BufferTime

341
server.go
View File

@@ -5,7 +5,6 @@ import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
@@ -32,13 +31,10 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"gorm.io/gorm"
"m7s.live/v5/pb"
. "m7s.live/v5/pkg"
"m7s.live/v5/pkg/auth"
"m7s.live/v5/pkg/db"
"m7s.live/v5/pkg/format"
"m7s.live/v5/pkg/util"
)
@@ -58,9 +54,6 @@ var (
)
type (
RedirectAdvisor interface {
GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool)
}
ServerConfig struct {
FatalDir string `default:"fatal" desc:""`
PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔
@@ -134,21 +127,9 @@ type (
task.TickTask
s *Server
}
GRPCServer struct {
task.Task
s *Server
tcpTask *config.ListenTCPWork
}
RawConfig = map[string]map[string]any
)
// context key type & keys
type ctxKey int
const (
ctxKeyClaims ctxKey = iota
)
func (w *WaitStream) GetKey() string {
return w.StreamPath
}
@@ -172,77 +153,6 @@ func NewServer(conf any) (s *Server) {
return
}
func (s *Server) GetRedirectAdvisor() RedirectAdvisor {
if s == nil {
return nil
}
s.redirectOnce.Do(func() {
s.Plugins.Range(func(plugin *Plugin) bool {
if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok {
s.redirectAdvisor = advisor
return false
}
return true
})
})
return s.redirectAdvisor
}
// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate.
func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool {
if s == nil {
return false
}
advisor := s.GetRedirectAdvisor()
if advisor == nil {
return false
}
targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host)
if !ok || targetHost == "" {
return false
}
if statusCode == 0 {
statusCode = http.StatusFound
}
redirectURL := buildRedirectURL(r, targetHost)
http.Redirect(w, r, redirectURL, statusCode)
return true
}
func buildRedirectURL(r *http.Request, host string) string {
scheme := requestScheme(r)
if isWebSocketRequest(r) {
switch scheme {
case "https":
scheme = "wss"
case "http":
scheme = "ws"
}
}
target := &url.URL{
Scheme: scheme,
Host: host,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
return target.String()
}
func requestScheme(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
if r.TLS != nil {
return "https"
}
return "http"
}
func isWebSocketRequest(r *http.Request) bool {
connection := strings.ToLower(r.Header.Get("Connection"))
return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
func Run(ctx context.Context, conf any) (err error) {
for err = ErrRestart; errors.Is(err, ErrRestart); err = Servers.AddTask(NewServer(conf), ctx).WaitStopped() {
}
@@ -664,14 +574,6 @@ func (c *CheckSubWaitTimeout) Tick(any) {
c.s.Waiting.checkTimeout()
}
func (gRPC *GRPCServer) Dispose() {
gRPC.s.Stop(gRPC.StopReason())
}
func (gRPC *GRPCServer) Go() (err error) {
return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener)
}
func (s *Server) CallOnStreamTask(callback func()) {
s.Streams.Call(callback)
}
@@ -723,246 +625,3 @@ func (s *Server) OnSubscribe(streamPath string, args url.Values) {
}
}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check for location-based forwarding first
if s.Location != nil {
for pattern, target := range s.Location {
if pattern.MatchString(r.URL.Path) {
// Rewrite the URL path and handle locally
r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target)
// Forward to local handler
s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r)
return
}
}
}
// 检查 admin.zip 是否需要重新加载
now := time.Now()
if now.Sub(s.Admin.lastCheckTime) > checkInterval {
if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime {
s.Info("admin.zip changed, reloading...")
s.loadAdminZip()
}
s.Admin.lastCheckTime = now
}
if s.Admin.zipReader != nil {
// Handle root path redirect to HomePage
if r.URL.Path == "/" {
http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound)
return
}
// For .map files, set correct content-type before serving
if strings.HasSuffix(r.URL.Path, ".map") {
filePath := strings.TrimPrefix(r.URL.Path, "/admin/")
file, err := s.Admin.zipReader.Open(filePath)
if err != nil {
http.NotFound(w, r)
return
}
defer file.Close()
w.Header().Set("Content-Type", "application/json")
io.Copy(w, file)
return
}
http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin"))
return
}
if r.URL.Path == "/favicon.ico" {
http.ServeFile(w, r, "favicon.ico")
return
}
_, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime)
for plugin := range s.Plugins.Range {
_, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version)
}
for _, api := range s.apiList {
_, _ = fmt.Fprintf(w, "%s\n", api)
}
}
// ValidateToken implements auth.TokenValidator
func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) {
if !s.ServerConfig.Admin.EnableLogin {
return &auth.JWTClaims{Username: "anonymous"}, nil
}
return auth.ValidateJWT(tokenString)
}
// Login implements the Login RPC method
func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) {
res = &pb.LoginResponse{}
if !s.ServerConfig.Admin.EnableLogin {
res.Data = &pb.LoginSuccess{
Token: "monibuca",
UserInfo: &pb.UserInfo{
Username: "anonymous",
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
if s.DB == nil {
err = ErrNoDB
return
}
var user db.User
if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil {
return
}
if !user.CheckPassword(req.Password) {
err = ErrInvalidCredentials
return
}
// Generate JWT token
var tokenString string
tokenString, err = auth.GenerateToken(user.Username)
if err != nil {
return
}
// Update last login time
s.DB.Model(&user).Update("last_login", time.Now())
res.Data = &pb.LoginSuccess{
Token: tokenString,
UserInfo: &pb.UserInfo{
Username: user.Username,
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
// Logout implements the Logout RPC method
func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) {
// In a more complex system, you might want to maintain a blacklist of logged-out tokens
// For now, we'll just return success as JWT tokens are stateless
res = &pb.LogoutResponse{Code: 0, Message: "success"}
return
}
// GetUserInfo implements the GetUserInfo RPC method
func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) {
if !s.ServerConfig.Admin.EnableLogin {
res = &pb.UserInfoResponse{
Code: 0,
Message: "success",
Data: &pb.UserInfo{
Username: "anonymous",
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
res = &pb.UserInfoResponse{}
claims, err := s.ValidateToken(req.Token)
if err != nil {
err = ErrInvalidCredentials
return
}
var user db.User
if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil {
return
}
// Token is valid for 24 hours from now
expiresAt := time.Now().Add(24 * time.Hour).Unix()
return &pb.UserInfoResponse{
Code: 0,
Message: "success",
Data: &pb.UserInfo{
Username: user.Username,
ExpiresAt: expiresAt,
},
}, nil
}
// AuthInterceptor creates a new unary interceptor for authentication
func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !s.ServerConfig.Admin.EnableLogin {
return handler(ctx, req)
}
// Skip auth for login endpoint
if info.FullMethod == "/pb.Auth/Login" {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("missing metadata")
}
authHeader := md.Get("authorization")
if len(authHeader) == 0 {
return nil, errors.New("missing authorization header")
}
tokenString := strings.TrimPrefix(authHeader[0], "Bearer ")
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, errors.New("invalid token")
}
// Check if token needs refresh
shouldRefresh, err := auth.ShouldRefreshToken(tokenString)
if err == nil && shouldRefresh {
newToken, err := auth.RefreshToken(tokenString)
if err == nil {
// Add new token to response headers
header := metadata.New(map[string]string{
"new-token": newToken,
})
grpc.SetHeader(ctx, header)
}
}
// Add claims to context
newCtx := context.WithValue(ctx, ctxKeyClaims, claims)
return handler(newCtx, req)
}
}
func (s *Server) annexB(w http.ResponseWriter, r *http.Request) {
streamPath := r.PathValue("streamPath")
if r.URL.RawQuery != "" {
streamPath += "?" + r.URL.RawQuery
}
var conf = s.config.Subscribe
conf.SubType = SubscribeTypeServer
conf.SubAudio = false
suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var ctx util.HTTP_WS_Writer
ctx.Conn, err = suber.CheckWebSocket(w, r)
if err != nil {
return
}
ctx.WriteTimeout = s.GetCommonConf().WriteTimeout
ctx.ContentType = "application/octet-stream"
ctx.ServeHTTP(w, r)
PlayBlock(suber, func(frame *format.RawAudio) (err error) {
return nil
}, func(frame *format.AnnexB) (err error) {
_, err = frame.WriteTo(&ctx)
if err != nil {
return
}
return ctx.Flush()
})
}

186
server_grpc.go Normal file
View File

@@ -0,0 +1,186 @@
package m7s
import (
"context"
"errors"
"strings"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
task "github.com/langhuihui/gotask"
"m7s.live/v5/pb"
. "m7s.live/v5/pkg"
"m7s.live/v5/pkg/auth"
"m7s.live/v5/pkg/config"
"m7s.live/v5/pkg/db"
)
// context key type & keys
type ctxKey int
const (
ctxKeyClaims ctxKey = iota
)
type GRPCServer struct {
task.Task
s *Server
tcpTask *config.ListenTCPWork
}
func (gRPC *GRPCServer) Dispose() {
gRPC.s.Stop(gRPC.StopReason())
}
func (gRPC *GRPCServer) Go() (err error) {
return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener)
}
// ValidateToken implements auth.TokenValidator
func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) {
if !s.ServerConfig.Admin.EnableLogin {
return &auth.JWTClaims{Username: "anonymous"}, nil
}
return auth.ValidateJWT(tokenString)
}
// Login implements the Login RPC method
func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) {
res = &pb.LoginResponse{}
if !s.ServerConfig.Admin.EnableLogin {
res.Data = &pb.LoginSuccess{
Token: "monibuca",
UserInfo: &pb.UserInfo{
Username: "anonymous",
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
if s.DB == nil {
err = ErrNoDB
return
}
var user db.User
if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil {
return
}
if !user.CheckPassword(req.Password) {
err = ErrInvalidCredentials
return
}
// Generate JWT token
var tokenString string
tokenString, err = auth.GenerateToken(user.Username)
if err != nil {
return
}
// Update last login time
s.DB.Model(&user).Update("last_login", time.Now())
res.Data = &pb.LoginSuccess{
Token: tokenString,
UserInfo: &pb.UserInfo{
Username: user.Username,
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
// Logout implements the Logout RPC method
func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) {
// In a more complex system, you might want to maintain a blacklist of logged-out tokens
// For now, we'll just return success as JWT tokens are stateless
res = &pb.LogoutResponse{Code: 0, Message: "success"}
return
}
// GetUserInfo implements the GetUserInfo RPC method
func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) {
if !s.ServerConfig.Admin.EnableLogin {
res = &pb.UserInfoResponse{
Code: 0,
Message: "success",
Data: &pb.UserInfo{
Username: "anonymous",
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
},
}
return
}
res = &pb.UserInfoResponse{}
claims, err := s.ValidateToken(req.Token)
if err != nil {
err = ErrInvalidCredentials
return
}
var user db.User
if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil {
return
}
// Token is valid for 24 hours from now
expiresAt := time.Now().Add(24 * time.Hour).Unix()
return &pb.UserInfoResponse{
Code: 0,
Message: "success",
Data: &pb.UserInfo{
Username: user.Username,
ExpiresAt: expiresAt,
},
}, nil
}
// AuthInterceptor creates a new unary interceptor for authentication
func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !s.ServerConfig.Admin.EnableLogin {
return handler(ctx, req)
}
// Skip auth for login endpoint
if info.FullMethod == "/pb.Auth/Login" {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("missing metadata")
}
authHeader := md.Get("authorization")
if len(authHeader) == 0 {
return nil, errors.New("missing authorization header")
}
tokenString := strings.TrimPrefix(authHeader[0], "Bearer ")
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, errors.New("invalid token")
}
// Check if token needs refresh
shouldRefresh, err := auth.ShouldRefreshToken(tokenString)
if err == nil && shouldRefresh {
newToken, err := auth.RefreshToken(tokenString)
if err == nil {
// Add new token to response headers
header := metadata.New(map[string]string{
"new-token": newToken,
})
grpc.SetHeader(ctx, header)
}
}
// Add claims to context
newCtx := context.WithValue(ctx, ctxKeyClaims, claims)
return handler(newCtx, req)
}
}

186
server_http.go Normal file
View File

@@ -0,0 +1,186 @@
package m7s
import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"m7s.live/v5/pkg/format"
"m7s.live/v5/pkg/util"
)
type RedirectAdvisor interface {
GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool)
}
func (s *Server) GetRedirectAdvisor() RedirectAdvisor {
if s == nil {
return nil
}
s.redirectOnce.Do(func() {
s.Plugins.Range(func(plugin *Plugin) bool {
if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok {
s.redirectAdvisor = advisor
return false
}
return true
})
})
return s.redirectAdvisor
}
// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate.
func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool {
if s == nil {
return false
}
advisor := s.GetRedirectAdvisor()
if advisor == nil {
return false
}
targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host)
if !ok || targetHost == "" {
return false
}
if statusCode == 0 {
statusCode = http.StatusFound
}
redirectURL := buildRedirectURL(r, targetHost)
http.Redirect(w, r, redirectURL, statusCode)
return true
}
func buildRedirectURL(r *http.Request, host string) string {
scheme := requestScheme(r)
if isWebSocketRequest(r) {
switch scheme {
case "https":
scheme = "wss"
case "http":
scheme = "ws"
}
}
target := &url.URL{
Scheme: scheme,
Host: host,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
return target.String()
}
func requestScheme(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
if r.TLS != nil {
return "https"
}
return "http"
}
func isWebSocketRequest(r *http.Request) bool {
connection := strings.ToLower(r.Header.Get("Connection"))
return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check for location-based forwarding first
if s.Location != nil {
for pattern, target := range s.Location {
if pattern.MatchString(r.URL.Path) {
// Rewrite the URL path and handle locally
r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target)
// Forward to local handler
s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r)
return
}
}
}
// 检查 admin.zip 是否需要重新加载
now := time.Now()
if now.Sub(s.Admin.lastCheckTime) > checkInterval {
if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime {
s.Info("admin.zip changed, reloading...")
s.loadAdminZip()
}
s.Admin.lastCheckTime = now
}
if s.Admin.zipReader != nil {
// Handle root path redirect to HomePage
if r.URL.Path == "/" {
http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound)
return
}
// For .map files, set correct content-type before serving
if strings.HasSuffix(r.URL.Path, ".map") {
filePath := strings.TrimPrefix(r.URL.Path, "/admin/")
file, err := s.Admin.zipReader.Open(filePath)
if err != nil {
http.NotFound(w, r)
return
}
defer file.Close()
w.Header().Set("Content-Type", "application/json")
io.Copy(w, file)
return
}
http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin"))
return
}
if r.URL.Path == "/favicon.ico" {
http.ServeFile(w, r, "favicon.ico")
return
}
_, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime)
for plugin := range s.Plugins.Range {
_, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version)
}
for _, api := range s.apiList {
_, _ = fmt.Fprintf(w, "%s\n", api)
}
}
func (s *Server) annexB(w http.ResponseWriter, r *http.Request) {
streamPath := r.PathValue("streamPath")
if r.URL.RawQuery != "" {
streamPath += "?" + r.URL.RawQuery
}
var conf = s.config.Subscribe
conf.SubType = SubscribeTypeServer
conf.SubAudio = false
suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var ctx util.HTTP_WS_Writer
ctx.Conn, err = suber.CheckWebSocket(w, r)
if err != nil {
return
}
ctx.WriteTimeout = s.GetCommonConf().WriteTimeout
ctx.ContentType = "application/octet-stream"
ctx.ServeHTTP(w, r)
PlayBlock(suber, func(frame *format.RawAudio) (err error) {
return nil
}, func(frame *format.AnnexB) (err error) {
_, err = frame.WriteTo(&ctx)
if err != nil {
return
}
return ctx.Flush()
})
}