From 5b287c92026688f1d84f09afdf476e28347867f2 Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Thu, 4 Dec 2025 15:27:12 +0800 Subject: [PATCH] fix: alias with args --- alias.go | 99 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 18 deletions(-) diff --git a/alias.go b/alias.go index 93201ed..03cd547 100644 --- a/alias.go +++ b/alias.go @@ -7,6 +7,7 @@ import ( "time" "google.golang.org/protobuf/types/known/emptypb" + "gorm.io/gorm" "m7s.live/v5/pb" ) @@ -14,7 +15,8 @@ type AliasStream struct { *Publisher `gorm:"-:all"` AutoRemove bool StreamPath string - Alias string `gorm:"primarykey"` + Alias string `gorm:"primarykey"` + Args url.Values `gorm:"-"` } func (a *AliasStream) GetKey() string { @@ -24,14 +26,49 @@ func (a *AliasStream) GetKey() string { // StreamAliasDB 用于存储流别名的数据库模型 type StreamAliasDB struct { AliasStream - CreatedAt time.Time `yaml:"-"` - UpdatedAt time.Time `yaml:"-"` + ArgsString string `gorm:"column:args;type:text"` + CreatedAt time.Time `yaml:"-"` + UpdatedAt time.Time `yaml:"-"` } func (StreamAliasDB) TableName() string { return "stream_alias" } +// BeforeSave 保存前序列化查询参数 +func (db *StreamAliasDB) BeforeSave(tx *gorm.DB) error { + if len(db.Args) > 0 { + db.ArgsString = db.Args.Encode() + } else { + db.ArgsString = "" + } + return nil +} + +// BeforeCreate 创建前序列化查询参数 +func (db *StreamAliasDB) BeforeCreate(tx *gorm.DB) error { + return db.BeforeSave(tx) +} + +// BeforeUpdate 更新前序列化查询参数 +func (db *StreamAliasDB) BeforeUpdate(tx *gorm.DB) error { + return db.BeforeSave(tx) +} + +// AfterFind 查询后反序列化查询参数 +func (db *StreamAliasDB) AfterFind(tx *gorm.DB) error { + if db.ArgsString != "" { + var err error + db.Args, err = url.ParseQuery(db.ArgsString) + if err != nil { + db.Args = nil + } + } else { + db.Args = nil + } + return nil +} + func (s *Server) initStreamAlias() { if s.DB == nil { return @@ -75,13 +112,15 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque return } req.StreamPath = strings.TrimPrefix(u.Path, "/") + queryParams := u.Query() publisher, canReplace := s.Streams.Get(req.StreamPath) if !canReplace { - defer s.OnSubscribe(req.StreamPath, u.Query()) + defer s.OnSubscribe(req.StreamPath, queryParams) } if aliasInfo, ok := s.AliasStreams.Get(req.Alias); ok { //modify alias oldStreamPath := aliasInfo.StreamPath aliasInfo.AutoRemove = req.AutoRemove + aliasInfo.Args = queryParams if aliasInfo.StreamPath != req.StreamPath { aliasInfo.StreamPath = req.StreamPath if canReplace { @@ -96,9 +135,10 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque } // 更新数据库中的别名 if s.DB != nil { - s.DB.Where("alias = ?", req.Alias).Assign(aliasInfo).FirstOrCreate(&StreamAliasDB{ + dbAlias := &StreamAliasDB{ AliasStream: *aliasInfo, - }) + } + s.DB.Where("alias = ?", req.Alias).Save(dbAlias) } s.Info("modify alias", "alias", req.Alias, "oldStreamPath", oldStreamPath, "streamPath", req.StreamPath, "replace", ok && canReplace) } else { // create alias @@ -106,6 +146,7 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque AutoRemove: req.AutoRemove, StreamPath: req.StreamPath, Alias: req.Alias, + Args: queryParams, } var pubId uint32 s.AliasStreams.Add(&aliasInfo) @@ -125,9 +166,10 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque } // 保存到数据库 if s.DB != nil { - s.DB.Create(&StreamAliasDB{ + dbAlias := &StreamAliasDB{ AliasStream: aliasInfo, - }) + } + s.DB.Create(dbAlias) } s.Info("add alias", "alias", req.Alias, "streamPath", req.StreamPath, "replace", ok && canReplace, "pub", pubId) } @@ -143,15 +185,20 @@ func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasReque if publisher, hasTarget := s.Streams.Get(req.Alias); hasTarget { // restore stream aliasStream.TransferSubscribers(publisher) } else { - var args url.Values - for sub := range aliasStream.Publisher.SubscriberRange { - if sub.StreamPath == req.Alias { - aliasStream.Publisher.RemoveSubscriber(sub) - s.Waiting.Wait(sub) - args = sub.Args + // 优先使用别名保存的查询参数 + args := aliasStream.Args + if len(args) == 0 { + // 如果没有保存的查询参数,则从订阅者中获取 + for sub := range aliasStream.Publisher.SubscriberRange { + if sub.StreamPath == req.Alias { + aliasStream.Publisher.RemoveSubscriber(sub) + s.Waiting.Wait(sub) + args = sub.Args + break + } } } - if args != nil { + if len(args) > 0 { s.OnSubscribe(req.Alias, args) } } @@ -218,7 +265,21 @@ func (s *Subscriber) processAliasOnStart() (hasInvited bool, done bool) { done = true return } else { - server.OnSubscribe(alias.StreamPath, s.Args) + // 合并参数:先使用别名保存的参数,然后用订阅者传入的参数覆盖同名参数 + args := make(url.Values) + // 先复制别名保存的参数 + if alias.Args != nil { + for k, v := range alias.Args { + args[k] = append([]string(nil), v...) + } + } + // 用订阅者传入的参数覆盖同名参数 + if s.Args != nil { + for k, v := range s.Args { + args[k] = append([]string(nil), v...) + } + } + server.OnSubscribe(alias.StreamPath, args) hasInvited = true } } else { @@ -227,12 +288,14 @@ func (s *Subscriber) processAliasOnStart() (hasInvited bool, done bool) { as := AliasStream{ StreamPath: streamPath, Alias: s.StreamPath, + Args: s.Args, } server.AliasStreams.Set(&as) if server.DB != nil { - server.DB.Where("alias = ?", s.StreamPath).Assign(as).FirstOrCreate(&StreamAliasDB{ + dbAlias := &StreamAliasDB{ AliasStream: as, - }) + } + server.DB.Where("alias = ?", s.StreamPath).Save(dbAlias) } if publisher, ok := server.Streams.Get(streamPath); ok { publisher.AddSubscriber(s)