fix: alias with args

This commit is contained in:
langhuihui
2025-12-04 15:27:12 +08:00
parent f16fe2996e
commit 5b287c9202

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"gorm.io/gorm"
"m7s.live/v5/pb" "m7s.live/v5/pb"
) )
@@ -15,6 +16,7 @@ type AliasStream struct {
AutoRemove bool AutoRemove bool
StreamPath string StreamPath string
Alias string `gorm:"primarykey"` Alias string `gorm:"primarykey"`
Args url.Values `gorm:"-"`
} }
func (a *AliasStream) GetKey() string { func (a *AliasStream) GetKey() string {
@@ -24,6 +26,7 @@ func (a *AliasStream) GetKey() string {
// StreamAliasDB 用于存储流别名的数据库模型 // StreamAliasDB 用于存储流别名的数据库模型
type StreamAliasDB struct { type StreamAliasDB struct {
AliasStream AliasStream
ArgsString string `gorm:"column:args;type:text"`
CreatedAt time.Time `yaml:"-"` CreatedAt time.Time `yaml:"-"`
UpdatedAt time.Time `yaml:"-"` UpdatedAt time.Time `yaml:"-"`
} }
@@ -32,6 +35,40 @@ func (StreamAliasDB) TableName() string {
return "stream_alias" return "stream_alias"
} }
// BeforeSave 保存前序列化查询参数
func (db *StreamAliasDB) BeforeSave(tx *gorm.DB) error {
if len(db.Args) > 0 {
db.ArgsString = db.Args.Encode()
} else {
db.ArgsString = ""
}
return nil
}
// BeforeCreate 创建前序列化查询参数
func (db *StreamAliasDB) BeforeCreate(tx *gorm.DB) error {
return db.BeforeSave(tx)
}
// BeforeUpdate 更新前序列化查询参数
func (db *StreamAliasDB) BeforeUpdate(tx *gorm.DB) error {
return db.BeforeSave(tx)
}
// AfterFind 查询后反序列化查询参数
func (db *StreamAliasDB) AfterFind(tx *gorm.DB) error {
if db.ArgsString != "" {
var err error
db.Args, err = url.ParseQuery(db.ArgsString)
if err != nil {
db.Args = nil
}
} else {
db.Args = nil
}
return nil
}
func (s *Server) initStreamAlias() { func (s *Server) initStreamAlias() {
if s.DB == nil { if s.DB == nil {
return return
@@ -75,13 +112,15 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque
return return
} }
req.StreamPath = strings.TrimPrefix(u.Path, "/") req.StreamPath = strings.TrimPrefix(u.Path, "/")
queryParams := u.Query()
publisher, canReplace := s.Streams.Get(req.StreamPath) publisher, canReplace := s.Streams.Get(req.StreamPath)
if !canReplace { if !canReplace {
defer s.OnSubscribe(req.StreamPath, u.Query()) defer s.OnSubscribe(req.StreamPath, queryParams)
} }
if aliasInfo, ok := s.AliasStreams.Get(req.Alias); ok { //modify alias if aliasInfo, ok := s.AliasStreams.Get(req.Alias); ok { //modify alias
oldStreamPath := aliasInfo.StreamPath oldStreamPath := aliasInfo.StreamPath
aliasInfo.AutoRemove = req.AutoRemove aliasInfo.AutoRemove = req.AutoRemove
aliasInfo.Args = queryParams
if aliasInfo.StreamPath != req.StreamPath { if aliasInfo.StreamPath != req.StreamPath {
aliasInfo.StreamPath = req.StreamPath aliasInfo.StreamPath = req.StreamPath
if canReplace { if canReplace {
@@ -96,9 +135,10 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque
} }
// 更新数据库中的别名 // 更新数据库中的别名
if s.DB != nil { if s.DB != nil {
s.DB.Where("alias = ?", req.Alias).Assign(aliasInfo).FirstOrCreate(&StreamAliasDB{ dbAlias := &StreamAliasDB{
AliasStream: *aliasInfo, AliasStream: *aliasInfo,
}) }
s.DB.Where("alias = ?", req.Alias).Save(dbAlias)
} }
s.Info("modify alias", "alias", req.Alias, "oldStreamPath", oldStreamPath, "streamPath", req.StreamPath, "replace", ok && canReplace) s.Info("modify alias", "alias", req.Alias, "oldStreamPath", oldStreamPath, "streamPath", req.StreamPath, "replace", ok && canReplace)
} else { // create alias } else { // create alias
@@ -106,6 +146,7 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque
AutoRemove: req.AutoRemove, AutoRemove: req.AutoRemove,
StreamPath: req.StreamPath, StreamPath: req.StreamPath,
Alias: req.Alias, Alias: req.Alias,
Args: queryParams,
} }
var pubId uint32 var pubId uint32
s.AliasStreams.Add(&aliasInfo) s.AliasStreams.Add(&aliasInfo)
@@ -125,9 +166,10 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque
} }
// 保存到数据库 // 保存到数据库
if s.DB != nil { if s.DB != nil {
s.DB.Create(&StreamAliasDB{ dbAlias := &StreamAliasDB{
AliasStream: aliasInfo, AliasStream: aliasInfo,
}) }
s.DB.Create(dbAlias)
} }
s.Info("add alias", "alias", req.Alias, "streamPath", req.StreamPath, "replace", ok && canReplace, "pub", pubId) s.Info("add alias", "alias", req.Alias, "streamPath", req.StreamPath, "replace", ok && canReplace, "pub", pubId)
} }
@@ -143,15 +185,20 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque
if publisher, hasTarget := s.Streams.Get(req.Alias); hasTarget { // restore stream if publisher, hasTarget := s.Streams.Get(req.Alias); hasTarget { // restore stream
aliasStream.TransferSubscribers(publisher) aliasStream.TransferSubscribers(publisher)
} else { } else {
var args url.Values // 优先使用别名保存的查询参数
args := aliasStream.Args
if len(args) == 0 {
// 如果没有保存的查询参数,则从订阅者中获取
for sub := range aliasStream.Publisher.SubscriberRange { for sub := range aliasStream.Publisher.SubscriberRange {
if sub.StreamPath == req.Alias { if sub.StreamPath == req.Alias {
aliasStream.Publisher.RemoveSubscriber(sub) aliasStream.Publisher.RemoveSubscriber(sub)
s.Waiting.Wait(sub) s.Waiting.Wait(sub)
args = sub.Args args = sub.Args
break
} }
} }
if args != nil { }
if len(args) > 0 {
s.OnSubscribe(req.Alias, args) s.OnSubscribe(req.Alias, args)
} }
} }
@@ -218,7 +265,21 @@ func (s *Subscriber) processAliasOnStart() (hasInvited bool, done bool) {
done = true done = true
return return
} else { } else {
server.OnSubscribe(alias.StreamPath, s.Args) // 合并参数:先使用别名保存的参数,然后用订阅者传入的参数覆盖同名参数
args := make(url.Values)
// 先复制别名保存的参数
if alias.Args != nil {
for k, v := range alias.Args {
args[k] = append([]string(nil), v...)
}
}
// 用订阅者传入的参数覆盖同名参数
if s.Args != nil {
for k, v := range s.Args {
args[k] = append([]string(nil), v...)
}
}
server.OnSubscribe(alias.StreamPath, args)
hasInvited = true hasInvited = true
} }
} else { } else {
@@ -227,12 +288,14 @@ func (s *Subscriber) processAliasOnStart() (hasInvited bool, done bool) {
as := AliasStream{ as := AliasStream{
StreamPath: streamPath, StreamPath: streamPath,
Alias: s.StreamPath, Alias: s.StreamPath,
Args: s.Args,
} }
server.AliasStreams.Set(&as) server.AliasStreams.Set(&as)
if server.DB != nil { if server.DB != nil {
server.DB.Where("alias = ?", s.StreamPath).Assign(as).FirstOrCreate(&StreamAliasDB{ dbAlias := &StreamAliasDB{
AliasStream: as, AliasStream: as,
}) }
server.DB.Where("alias = ?", s.StreamPath).Save(dbAlias)
} }
if publisher, ok := server.Streams.Get(streamPath); ok { if publisher, ok := server.Streams.Get(streamPath); ok {
publisher.AddSubscriber(s) publisher.AddSubscriber(s)