mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-12-24 13:48:04 +08:00
opt: idr ring lock-free
This commit is contained in:
@@ -2,18 +2,83 @@ package pkg
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/langhuihui/gotask"
|
||||
task "github.com/langhuihui/gotask"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
type IDRNode struct {
|
||||
Value *util.Ring[AVFrame]
|
||||
next atomic.Pointer[IDRNode]
|
||||
prev atomic.Pointer[IDRNode]
|
||||
}
|
||||
|
||||
func (n *IDRNode) Next() *IDRNode {
|
||||
return n.next.Load()
|
||||
}
|
||||
|
||||
func (n *IDRNode) Prev() *IDRNode {
|
||||
return n.prev.Load()
|
||||
}
|
||||
|
||||
type IDRList struct {
|
||||
head atomic.Pointer[IDRNode]
|
||||
tail atomic.Pointer[IDRNode]
|
||||
len atomic.Int32
|
||||
}
|
||||
|
||||
func (l *IDRList) Init() {
|
||||
l.head.Store(nil)
|
||||
l.tail.Store(nil)
|
||||
l.len.Store(0)
|
||||
}
|
||||
|
||||
func (l *IDRList) Len() int {
|
||||
return int(l.len.Load())
|
||||
}
|
||||
|
||||
func (l *IDRList) Front() *IDRNode {
|
||||
return l.head.Load()
|
||||
}
|
||||
|
||||
func (l *IDRList) Back() *IDRNode {
|
||||
return l.tail.Load()
|
||||
}
|
||||
|
||||
func (l *IDRList) PushBack(v *util.Ring[AVFrame]) {
|
||||
node := &IDRNode{Value: v}
|
||||
l.len.Add(1)
|
||||
tail := l.tail.Load()
|
||||
node.prev.Store(tail)
|
||||
if tail != nil {
|
||||
tail.next.Store(node)
|
||||
}
|
||||
l.tail.Store(node)
|
||||
if l.head.Load() == nil {
|
||||
l.head.Store(node)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *IDRList) Remove(node *IDRNode) {
|
||||
head := l.head.Load()
|
||||
if head != node {
|
||||
return
|
||||
}
|
||||
next := head.next.Load()
|
||||
l.head.Store(next)
|
||||
if next != nil {
|
||||
next.prev.Store(nil)
|
||||
} else {
|
||||
l.tail.Store(nil)
|
||||
}
|
||||
l.len.Add(-1)
|
||||
}
|
||||
|
||||
type RingWriter struct {
|
||||
*util.Ring[AVFrame]
|
||||
sync.RWMutex
|
||||
IDRingList util.List[*util.Ring[AVFrame]] // 关键帧链表
|
||||
IDRingList IDRList // 关键帧链表
|
||||
BufferRange util.Range[time.Duration]
|
||||
SizeRange util.Range[int]
|
||||
pool *util.Ring[AVFrame]
|
||||
@@ -98,8 +163,6 @@ func (rb *RingWriter) Dispose() {
|
||||
}
|
||||
|
||||
func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] {
|
||||
rb.RLock()
|
||||
defer rb.RUnlock()
|
||||
if latest := rb.IDRingList.Back(); latest != nil {
|
||||
return latest.Value
|
||||
}
|
||||
@@ -107,8 +170,6 @@ func (rb *RingWriter) GetIDR() *util.Ring[AVFrame] {
|
||||
}
|
||||
|
||||
func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] {
|
||||
rb.RLock()
|
||||
defer rb.RUnlock()
|
||||
if latest := rb.IDRingList.Front(); latest != nil {
|
||||
return latest.Value
|
||||
}
|
||||
@@ -116,8 +177,6 @@ func (rb *RingWriter) GetOldestIDR() *util.Ring[AVFrame] {
|
||||
}
|
||||
|
||||
func (rb *RingWriter) GetHistoryIDR(bufTime time.Duration) *util.Ring[AVFrame] {
|
||||
rb.RLock()
|
||||
defer rb.RUnlock()
|
||||
for item := rb.IDRingList.Back(); item != nil; item = item.Prev() {
|
||||
if rb.LastValue.Timestamp-item.Value.Value.Timestamp >= bufTime {
|
||||
return item.Value
|
||||
@@ -135,9 +194,7 @@ func (rb *RingWriter) CurrentBufferTime() time.Duration {
|
||||
}
|
||||
|
||||
func (rb *RingWriter) PushIDR() {
|
||||
rb.Lock()
|
||||
rb.IDRingList.PushBack(rb.Ring)
|
||||
rb.Unlock()
|
||||
}
|
||||
|
||||
func (rb *RingWriter) Step() (normal bool) {
|
||||
@@ -159,9 +216,7 @@ func (rb *RingWriter) Step() (normal bool) {
|
||||
} else if next == oldIDR.Value {
|
||||
if nextOld := oldIDR.Next(); nextOld != nil && rb.durationFrom(nextOld.Value) > rb.BufferRange[0] {
|
||||
rb.SLogger.Log(nil, task.TraceLevel, "remove old idr")
|
||||
rb.Lock()
|
||||
rb.IDRingList.Remove(oldIDR)
|
||||
rb.Unlock()
|
||||
} else {
|
||||
rb.glow(5, "not enough buffer")
|
||||
next = rb.Next()
|
||||
|
||||
@@ -450,7 +450,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end
|
||||
}
|
||||
|
||||
// 遍历IDRingList获取关键帧
|
||||
videoTrack.RLock()
|
||||
if videoTrack.IDRingList.Len() > 0 {
|
||||
// 从头开始遍历所有关键帧
|
||||
for idrElem := videoTrack.IDRingList.Front(); idrElem != nil; idrElem = idrElem.Next() {
|
||||
@@ -467,7 +466,6 @@ func (p *SnapPlugin) calculateSnapTimes(publisher *m7s.Publisher, startTime, end
|
||||
}
|
||||
}
|
||||
}
|
||||
videoTrack.RUnlock()
|
||||
|
||||
// 如果没有找到关键帧,但有GOP信息,则使用估算的GOP间隔生成时间点
|
||||
if len(snapTimes) == 0 && gopDuration > 0 {
|
||||
|
||||
23
publisher.go
23
publisher.go
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -112,6 +111,7 @@ type Publisher struct {
|
||||
OnGetPosition func() time.Time
|
||||
PullProxyConfig *PullProxyConfig
|
||||
dropAfterTs time.Duration
|
||||
bufferTimeCounts map[time.Duration]int
|
||||
}
|
||||
|
||||
type PublishParam struct {
|
||||
@@ -134,6 +134,7 @@ func (p *Publisher) Start() (err error) {
|
||||
if p.MaxCount > 0 && s.Streams.Length >= p.MaxCount {
|
||||
return ErrPublishMaxCount
|
||||
}
|
||||
p.bufferTimeCounts = make(map[time.Duration]int)
|
||||
p.Info("publish")
|
||||
p.processPullProxyOnStart()
|
||||
p.audioReady = util.NewPromiseWithTimeout(p, p.PublishTimeout)
|
||||
@@ -200,14 +201,25 @@ func (p *Publisher) Go() error {
|
||||
|
||||
func (p *Publisher) RemoveSubscriber(subscriber *Subscriber) {
|
||||
p.Subscribers.Remove(subscriber)
|
||||
p.bufferTimeCounts[subscriber.BufferTime]--
|
||||
if p.bufferTimeCounts[subscriber.BufferTime] == 0 {
|
||||
delete(p.bufferTimeCounts, subscriber.BufferTime)
|
||||
}
|
||||
p.Info("subscriber -1", "count", p.Subscribers.Length)
|
||||
if p.Plugin == nil {
|
||||
return
|
||||
}
|
||||
if subscriber.BufferTime == p.BufferTime && p.Subscribers.Length > 0 {
|
||||
p.BufferTime = slices.MaxFunc(p.Subscribers.Items, func(a, b *Subscriber) int {
|
||||
return int(a.BufferTime - b.BufferTime)
|
||||
}).BufferTime
|
||||
if p.Subscribers.Length > 0 {
|
||||
var maxBuf time.Duration
|
||||
for k := range p.bufferTimeCounts {
|
||||
if k > maxBuf {
|
||||
maxBuf = k
|
||||
}
|
||||
}
|
||||
if defaultBuf := p.Plugin.GetCommonConf().Publish.BufferTime; maxBuf < defaultBuf {
|
||||
maxBuf = defaultBuf
|
||||
}
|
||||
p.BufferTime = maxBuf
|
||||
} else {
|
||||
p.BufferTime = p.Plugin.GetCommonConf().Publish.BufferTime
|
||||
}
|
||||
@@ -235,6 +247,7 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) {
|
||||
}
|
||||
subscriber.waitStartTime = time.Time{}
|
||||
if p.Subscribers.AddUnique(subscriber) {
|
||||
p.bufferTimeCounts[subscriber.BufferTime]++
|
||||
p.Info("subscriber +1", "count", p.Subscribers.Length)
|
||||
if subscriber.BufferTime > p.BufferTime {
|
||||
p.BufferTime = subscriber.BufferTime
|
||||
|
||||
341
server.go
341
server.go
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -32,13 +31,10 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"gorm.io/gorm"
|
||||
"m7s.live/v5/pb"
|
||||
. "m7s.live/v5/pkg"
|
||||
"m7s.live/v5/pkg/auth"
|
||||
"m7s.live/v5/pkg/db"
|
||||
"m7s.live/v5/pkg/format"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
@@ -58,9 +54,6 @@ var (
|
||||
)
|
||||
|
||||
type (
|
||||
RedirectAdvisor interface {
|
||||
GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool)
|
||||
}
|
||||
ServerConfig struct {
|
||||
FatalDir string `default:"fatal" desc:""`
|
||||
PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔
|
||||
@@ -134,21 +127,9 @@ type (
|
||||
task.TickTask
|
||||
s *Server
|
||||
}
|
||||
GRPCServer struct {
|
||||
task.Task
|
||||
s *Server
|
||||
tcpTask *config.ListenTCPWork
|
||||
}
|
||||
RawConfig = map[string]map[string]any
|
||||
)
|
||||
|
||||
// context key type & keys
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
ctxKeyClaims ctxKey = iota
|
||||
)
|
||||
|
||||
func (w *WaitStream) GetKey() string {
|
||||
return w.StreamPath
|
||||
}
|
||||
@@ -172,77 +153,6 @@ func NewServer(conf any) (s *Server) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) GetRedirectAdvisor() RedirectAdvisor {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.redirectOnce.Do(func() {
|
||||
s.Plugins.Range(func(plugin *Plugin) bool {
|
||||
if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok {
|
||||
s.redirectAdvisor = advisor
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
})
|
||||
return s.redirectAdvisor
|
||||
}
|
||||
|
||||
// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate.
|
||||
func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
advisor := s.GetRedirectAdvisor()
|
||||
if advisor == nil {
|
||||
return false
|
||||
}
|
||||
targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host)
|
||||
if !ok || targetHost == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusFound
|
||||
}
|
||||
redirectURL := buildRedirectURL(r, targetHost)
|
||||
http.Redirect(w, r, redirectURL, statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
func buildRedirectURL(r *http.Request, host string) string {
|
||||
scheme := requestScheme(r)
|
||||
if isWebSocketRequest(r) {
|
||||
switch scheme {
|
||||
case "https":
|
||||
scheme = "wss"
|
||||
case "http":
|
||||
scheme = "ws"
|
||||
}
|
||||
}
|
||||
target := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
return target.String()
|
||||
}
|
||||
|
||||
func requestScheme(r *http.Request) string {
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
return proto
|
||||
}
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func isWebSocketRequest(r *http.Request) bool {
|
||||
connection := strings.ToLower(r.Header.Get("Connection"))
|
||||
return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, conf any) (err error) {
|
||||
for err = ErrRestart; errors.Is(err, ErrRestart); err = Servers.AddTask(NewServer(conf), ctx).WaitStopped() {
|
||||
}
|
||||
@@ -664,14 +574,6 @@ func (c *CheckSubWaitTimeout) Tick(any) {
|
||||
c.s.Waiting.checkTimeout()
|
||||
}
|
||||
|
||||
func (gRPC *GRPCServer) Dispose() {
|
||||
gRPC.s.Stop(gRPC.StopReason())
|
||||
}
|
||||
|
||||
func (gRPC *GRPCServer) Go() (err error) {
|
||||
return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener)
|
||||
}
|
||||
|
||||
func (s *Server) CallOnStreamTask(callback func()) {
|
||||
s.Streams.Call(callback)
|
||||
}
|
||||
@@ -723,246 +625,3 @@ func (s *Server) OnSubscribe(streamPath string, args url.Values) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for location-based forwarding first
|
||||
if s.Location != nil {
|
||||
for pattern, target := range s.Location {
|
||||
if pattern.MatchString(r.URL.Path) {
|
||||
// Rewrite the URL path and handle locally
|
||||
r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target)
|
||||
// Forward to local handler
|
||||
s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 admin.zip 是否需要重新加载
|
||||
now := time.Now()
|
||||
if now.Sub(s.Admin.lastCheckTime) > checkInterval {
|
||||
if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime {
|
||||
s.Info("admin.zip changed, reloading...")
|
||||
s.loadAdminZip()
|
||||
}
|
||||
s.Admin.lastCheckTime = now
|
||||
}
|
||||
|
||||
if s.Admin.zipReader != nil {
|
||||
// Handle root path redirect to HomePage
|
||||
if r.URL.Path == "/" {
|
||||
http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// For .map files, set correct content-type before serving
|
||||
if strings.HasSuffix(r.URL.Path, ".map") {
|
||||
filePath := strings.TrimPrefix(r.URL.Path, "/admin/")
|
||||
file, err := s.Admin.zipReader.Open(filePath)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
io.Copy(w, file)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin"))
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/favicon.ico" {
|
||||
http.ServeFile(w, r, "favicon.ico")
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime)
|
||||
for plugin := range s.Plugins.Range {
|
||||
_, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version)
|
||||
}
|
||||
for _, api := range s.apiList {
|
||||
_, _ = fmt.Fprintf(w, "%s\n", api)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken implements auth.TokenValidator
|
||||
func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
return &auth.JWTClaims{Username: "anonymous"}, nil
|
||||
}
|
||||
return auth.ValidateJWT(tokenString)
|
||||
}
|
||||
|
||||
// Login implements the Login RPC method
|
||||
func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) {
|
||||
res = &pb.LoginResponse{}
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
res.Data = &pb.LoginSuccess{
|
||||
Token: "monibuca",
|
||||
UserInfo: &pb.UserInfo{
|
||||
Username: "anonymous",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
if s.DB == nil {
|
||||
err = ErrNoDB
|
||||
return
|
||||
}
|
||||
var user db.User
|
||||
if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !user.CheckPassword(req.Password) {
|
||||
err = ErrInvalidCredentials
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
var tokenString string
|
||||
tokenString, err = auth.GenerateToken(user.Username)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update last login time
|
||||
s.DB.Model(&user).Update("last_login", time.Now())
|
||||
res.Data = &pb.LoginSuccess{
|
||||
Token: tokenString,
|
||||
UserInfo: &pb.UserInfo{
|
||||
Username: user.Username,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Logout implements the Logout RPC method
|
||||
func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) {
|
||||
// In a more complex system, you might want to maintain a blacklist of logged-out tokens
|
||||
// For now, we'll just return success as JWT tokens are stateless
|
||||
res = &pb.LogoutResponse{Code: 0, Message: "success"}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUserInfo implements the GetUserInfo RPC method
|
||||
func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
res = &pb.UserInfoResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: &pb.UserInfo{
|
||||
Username: "anonymous",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
res = &pb.UserInfoResponse{}
|
||||
claims, err := s.ValidateToken(req.Token)
|
||||
if err != nil {
|
||||
err = ErrInvalidCredentials
|
||||
return
|
||||
}
|
||||
|
||||
var user db.User
|
||||
if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Token is valid for 24 hours from now
|
||||
expiresAt := time.Now().Add(24 * time.Hour).Unix()
|
||||
|
||||
return &pb.UserInfoResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: &pb.UserInfo{
|
||||
Username: user.Username,
|
||||
ExpiresAt: expiresAt,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthInterceptor creates a new unary interceptor for authentication
|
||||
func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// Skip auth for login endpoint
|
||||
if info.FullMethod == "/pb.Auth/Login" {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("missing metadata")
|
||||
}
|
||||
|
||||
authHeader := md.Get("authorization")
|
||||
if len(authHeader) == 0 {
|
||||
return nil, errors.New("missing authorization header")
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader[0], "Bearer ")
|
||||
claims, err := s.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
shouldRefresh, err := auth.ShouldRefreshToken(tokenString)
|
||||
if err == nil && shouldRefresh {
|
||||
newToken, err := auth.RefreshToken(tokenString)
|
||||
if err == nil {
|
||||
// Add new token to response headers
|
||||
header := metadata.New(map[string]string{
|
||||
"new-token": newToken,
|
||||
})
|
||||
grpc.SetHeader(ctx, header)
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims to context
|
||||
newCtx := context.WithValue(ctx, ctxKeyClaims, claims)
|
||||
return handler(newCtx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) annexB(w http.ResponseWriter, r *http.Request) {
|
||||
streamPath := r.PathValue("streamPath")
|
||||
|
||||
if r.URL.RawQuery != "" {
|
||||
streamPath += "?" + r.URL.RawQuery
|
||||
}
|
||||
var conf = s.config.Subscribe
|
||||
conf.SubType = SubscribeTypeServer
|
||||
conf.SubAudio = false
|
||||
suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var ctx util.HTTP_WS_Writer
|
||||
ctx.Conn, err = suber.CheckWebSocket(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ctx.WriteTimeout = s.GetCommonConf().WriteTimeout
|
||||
ctx.ContentType = "application/octet-stream"
|
||||
ctx.ServeHTTP(w, r)
|
||||
|
||||
PlayBlock(suber, func(frame *format.RawAudio) (err error) {
|
||||
return nil
|
||||
}, func(frame *format.AnnexB) (err error) {
|
||||
_, err = frame.WriteTo(&ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return ctx.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
186
server_grpc.go
Normal file
186
server_grpc.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package m7s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
task "github.com/langhuihui/gotask"
|
||||
"m7s.live/v5/pb"
|
||||
. "m7s.live/v5/pkg"
|
||||
"m7s.live/v5/pkg/auth"
|
||||
"m7s.live/v5/pkg/config"
|
||||
"m7s.live/v5/pkg/db"
|
||||
)
|
||||
|
||||
// context key type & keys
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
ctxKeyClaims ctxKey = iota
|
||||
)
|
||||
|
||||
type GRPCServer struct {
|
||||
task.Task
|
||||
s *Server
|
||||
tcpTask *config.ListenTCPWork
|
||||
}
|
||||
|
||||
func (gRPC *GRPCServer) Dispose() {
|
||||
gRPC.s.Stop(gRPC.StopReason())
|
||||
}
|
||||
|
||||
func (gRPC *GRPCServer) Go() (err error) {
|
||||
return gRPC.s.grpcServer.Serve(gRPC.tcpTask.Listener)
|
||||
}
|
||||
|
||||
// ValidateToken implements auth.TokenValidator
|
||||
func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
return &auth.JWTClaims{Username: "anonymous"}, nil
|
||||
}
|
||||
return auth.ValidateJWT(tokenString)
|
||||
}
|
||||
|
||||
// Login implements the Login RPC method
|
||||
func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) {
|
||||
res = &pb.LoginResponse{}
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
res.Data = &pb.LoginSuccess{
|
||||
Token: "monibuca",
|
||||
UserInfo: &pb.UserInfo{
|
||||
Username: "anonymous",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
if s.DB == nil {
|
||||
err = ErrNoDB
|
||||
return
|
||||
}
|
||||
var user db.User
|
||||
if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !user.CheckPassword(req.Password) {
|
||||
err = ErrInvalidCredentials
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
var tokenString string
|
||||
tokenString, err = auth.GenerateToken(user.Username)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update last login time
|
||||
s.DB.Model(&user).Update("last_login", time.Now())
|
||||
res.Data = &pb.LoginSuccess{
|
||||
Token: tokenString,
|
||||
UserInfo: &pb.UserInfo{
|
||||
Username: user.Username,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Logout implements the Logout RPC method
|
||||
func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) {
|
||||
// In a more complex system, you might want to maintain a blacklist of logged-out tokens
|
||||
// For now, we'll just return success as JWT tokens are stateless
|
||||
res = &pb.LogoutResponse{Code: 0, Message: "success"}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUserInfo implements the GetUserInfo RPC method
|
||||
func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
res = &pb.UserInfoResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: &pb.UserInfo{
|
||||
Username: "anonymous",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
res = &pb.UserInfoResponse{}
|
||||
claims, err := s.ValidateToken(req.Token)
|
||||
if err != nil {
|
||||
err = ErrInvalidCredentials
|
||||
return
|
||||
}
|
||||
|
||||
var user db.User
|
||||
if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Token is valid for 24 hours from now
|
||||
expiresAt := time.Now().Add(24 * time.Hour).Unix()
|
||||
|
||||
return &pb.UserInfoResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: &pb.UserInfo{
|
||||
Username: user.Username,
|
||||
ExpiresAt: expiresAt,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthInterceptor creates a new unary interceptor for authentication
|
||||
func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if !s.ServerConfig.Admin.EnableLogin {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// Skip auth for login endpoint
|
||||
if info.FullMethod == "/pb.Auth/Login" {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("missing metadata")
|
||||
}
|
||||
|
||||
authHeader := md.Get("authorization")
|
||||
if len(authHeader) == 0 {
|
||||
return nil, errors.New("missing authorization header")
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader[0], "Bearer ")
|
||||
claims, err := s.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
shouldRefresh, err := auth.ShouldRefreshToken(tokenString)
|
||||
if err == nil && shouldRefresh {
|
||||
newToken, err := auth.RefreshToken(tokenString)
|
||||
if err == nil {
|
||||
// Add new token to response headers
|
||||
header := metadata.New(map[string]string{
|
||||
"new-token": newToken,
|
||||
})
|
||||
grpc.SetHeader(ctx, header)
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims to context
|
||||
newCtx := context.WithValue(ctx, ctxKeyClaims, claims)
|
||||
return handler(newCtx, req)
|
||||
}
|
||||
}
|
||||
|
||||
186
server_http.go
Normal file
186
server_http.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package m7s
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"m7s.live/v5/pkg/format"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
type RedirectAdvisor interface {
|
||||
GetRedirectTarget(protocol, streamPath, currentHost string) (targetHost string, statusCode int, ok bool)
|
||||
}
|
||||
|
||||
func (s *Server) GetRedirectAdvisor() RedirectAdvisor {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.redirectOnce.Do(func() {
|
||||
s.Plugins.Range(func(plugin *Plugin) bool {
|
||||
if advisor, ok := plugin.GetHandler().(RedirectAdvisor); ok {
|
||||
s.redirectAdvisor = advisor
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
})
|
||||
return s.redirectAdvisor
|
||||
}
|
||||
|
||||
// RedirectIfNeeded evaluates redirect advice for HTTP-based protocols and issues redirects when appropriate.
|
||||
func (s *Server) RedirectIfNeeded(w http.ResponseWriter, r *http.Request, protocol, redirectPath string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
advisor := s.GetRedirectAdvisor()
|
||||
if advisor == nil {
|
||||
return false
|
||||
}
|
||||
targetHost, statusCode, ok := advisor.GetRedirectTarget(protocol, redirectPath, r.Host)
|
||||
if !ok || targetHost == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusFound
|
||||
}
|
||||
redirectURL := buildRedirectURL(r, targetHost)
|
||||
http.Redirect(w, r, redirectURL, statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
func buildRedirectURL(r *http.Request, host string) string {
|
||||
scheme := requestScheme(r)
|
||||
if isWebSocketRequest(r) {
|
||||
switch scheme {
|
||||
case "https":
|
||||
scheme = "wss"
|
||||
case "http":
|
||||
scheme = "ws"
|
||||
}
|
||||
}
|
||||
target := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
return target.String()
|
||||
}
|
||||
|
||||
func requestScheme(r *http.Request) string {
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
return proto
|
||||
}
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func isWebSocketRequest(r *http.Request) bool {
|
||||
connection := strings.ToLower(r.Header.Get("Connection"))
|
||||
return strings.Contains(connection, "upgrade") && strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for location-based forwarding first
|
||||
if s.Location != nil {
|
||||
for pattern, target := range s.Location {
|
||||
if pattern.MatchString(r.URL.Path) {
|
||||
// Rewrite the URL path and handle locally
|
||||
r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target)
|
||||
// Forward to local handler
|
||||
s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 admin.zip 是否需要重新加载
|
||||
now := time.Now()
|
||||
if now.Sub(s.Admin.lastCheckTime) > checkInterval {
|
||||
if info, err := os.Stat(s.Admin.FilePath); err == nil && info.ModTime() != s.Admin.zipLastModTime {
|
||||
s.Info("admin.zip changed, reloading...")
|
||||
s.loadAdminZip()
|
||||
}
|
||||
s.Admin.lastCheckTime = now
|
||||
}
|
||||
|
||||
if s.Admin.zipReader != nil {
|
||||
// Handle root path redirect to HomePage
|
||||
if r.URL.Path == "/" {
|
||||
http.Redirect(w, r, "/admin/#/"+s.Admin.HomePage, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// For .map files, set correct content-type before serving
|
||||
if strings.HasSuffix(r.URL.Path, ".map") {
|
||||
filePath := strings.TrimPrefix(r.URL.Path, "/admin/")
|
||||
file, err := s.Admin.zipReader.Open(filePath)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
io.Copy(w, file)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeFileFS(w, r, s.Admin.zipReader, strings.TrimPrefix(r.URL.Path, "/admin"))
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/favicon.ico" {
|
||||
http.ServeFile(w, r, "favicon.ico")
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "visit:%s\nMonibuca Engine %s StartTime:%s\n", r.URL.Path, Version, s.StartTime)
|
||||
for plugin := range s.Plugins.Range {
|
||||
_, _ = fmt.Fprintf(w, "Plugin %s Version:%s\n", plugin.Meta.Name, plugin.Meta.Version)
|
||||
}
|
||||
for _, api := range s.apiList {
|
||||
_, _ = fmt.Fprintf(w, "%s\n", api)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) annexB(w http.ResponseWriter, r *http.Request) {
|
||||
streamPath := r.PathValue("streamPath")
|
||||
|
||||
if r.URL.RawQuery != "" {
|
||||
streamPath += "?" + r.URL.RawQuery
|
||||
}
|
||||
var conf = s.config.Subscribe
|
||||
conf.SubType = SubscribeTypeServer
|
||||
conf.SubAudio = false
|
||||
suber, err := s.SubscribeWithConfig(r.Context(), streamPath, conf)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var ctx util.HTTP_WS_Writer
|
||||
ctx.Conn, err = suber.CheckWebSocket(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ctx.WriteTimeout = s.GetCommonConf().WriteTimeout
|
||||
ctx.ContentType = "application/octet-stream"
|
||||
ctx.ServeHTTP(w, r)
|
||||
|
||||
PlayBlock(suber, func(frame *format.RawAudio) (err error) {
|
||||
return nil
|
||||
}, func(frame *format.AnnexB) (err error) {
|
||||
_, err = frame.WriteTo(&ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return ctx.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user