Files
monibuca/alias.go
2025-12-04 15:27:12 +08:00

314 lines
8.0 KiB
Go

package m7s
import (
"context"
"net/url"
"strings"
"time"
"google.golang.org/protobuf/types/known/emptypb"
"gorm.io/gorm"
"m7s.live/v5/pb"
)
type AliasStream struct {
*Publisher `gorm:"-:all"`
AutoRemove bool
StreamPath string
Alias string `gorm:"primarykey"`
Args url.Values `gorm:"-"`
}
func (a *AliasStream) GetKey() string {
return a.Alias
}
// StreamAliasDB 用于存储流别名的数据库模型
type StreamAliasDB struct {
AliasStream
ArgsString string `gorm:"column:args;type:text"`
CreatedAt time.Time `yaml:"-"`
UpdatedAt time.Time `yaml:"-"`
}
func (StreamAliasDB) TableName() string {
return "stream_alias"
}
// BeforeSave 保存前序列化查询参数
func (db *StreamAliasDB) BeforeSave(tx *gorm.DB) error {
if len(db.Args) > 0 {
db.ArgsString = db.Args.Encode()
} else {
db.ArgsString = ""
}
return nil
}
// BeforeCreate 创建前序列化查询参数
func (db *StreamAliasDB) BeforeCreate(tx *gorm.DB) error {
return db.BeforeSave(tx)
}
// BeforeUpdate 更新前序列化查询参数
func (db *StreamAliasDB) BeforeUpdate(tx *gorm.DB) error {
return db.BeforeSave(tx)
}
// AfterFind 查询后反序列化查询参数
func (db *StreamAliasDB) AfterFind(tx *gorm.DB) error {
if db.ArgsString != "" {
var err error
db.Args, err = url.ParseQuery(db.ArgsString)
if err != nil {
db.Args = nil
}
} else {
db.Args = nil
}
return nil
}
func (s *Server) initStreamAlias() {
if s.DB == nil {
return
}
var aliases []StreamAliasDB
s.DB.Find(&aliases)
for _, alias := range aliases {
s.AliasStreams.Add(&alias.AliasStream)
if publisher, ok := s.Streams.Get(alias.StreamPath); ok {
alias.Publisher = publisher
}
}
}
func (s *Server) GetStreamAlias(ctx context.Context, req *emptypb.Empty) (res *pb.StreamAliasListResponse, err error) {
res = &pb.StreamAliasListResponse{}
s.CallOnStreamTask(func() {
for alias := range s.AliasStreams.Range {
info := &pb.StreamAlias{
StreamPath: alias.StreamPath,
Alias: alias.Alias,
AutoRemove: alias.AutoRemove,
}
if s.Streams.Has(alias.Alias) {
info.Status = 2
} else if alias.Publisher != nil {
info.Status = 1
}
res.Data = append(res.Data, info)
}
})
return
}
func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasRequest) (res *pb.SuccessResponse, err error) {
res = &pb.SuccessResponse{}
s.CallOnStreamTask(func() {
if req.StreamPath != "" {
u, err := url.Parse(req.StreamPath)
if err != nil {
return
}
req.StreamPath = strings.TrimPrefix(u.Path, "/")
queryParams := u.Query()
publisher, canReplace := s.Streams.Get(req.StreamPath)
if !canReplace {
defer s.OnSubscribe(req.StreamPath, queryParams)
}
if aliasInfo, ok := s.AliasStreams.Get(req.Alias); ok { //modify alias
oldStreamPath := aliasInfo.StreamPath
aliasInfo.AutoRemove = req.AutoRemove
aliasInfo.Args = queryParams
if aliasInfo.StreamPath != req.StreamPath {
aliasInfo.StreamPath = req.StreamPath
if canReplace {
if aliasInfo.Publisher != nil {
aliasInfo.TransferSubscribers(publisher) // replace stream
aliasInfo.Publisher = publisher
} else {
aliasInfo.Publisher = publisher
s.Waiting.WakeUp(req.Alias, publisher)
}
}
}
// 更新数据库中的别名
if s.DB != nil {
dbAlias := &StreamAliasDB{
AliasStream: *aliasInfo,
}
s.DB.Where("alias = ?", req.Alias).Save(dbAlias)
}
s.Info("modify alias", "alias", req.Alias, "oldStreamPath", oldStreamPath, "streamPath", req.StreamPath, "replace", ok && canReplace)
} else { // create alias
aliasInfo := AliasStream{
AutoRemove: req.AutoRemove,
StreamPath: req.StreamPath,
Alias: req.Alias,
Args: queryParams,
}
var pubId uint32
s.AliasStreams.Add(&aliasInfo)
aliasStream, ok := s.Streams.Get(aliasInfo.Alias)
if canReplace {
aliasInfo.Publisher = publisher
if ok {
aliasStream.TransferSubscribers(publisher) // replace stream
} else {
s.Waiting.WakeUp(req.Alias, publisher)
}
} else if ok {
aliasInfo.Publisher = aliasStream
}
if aliasInfo.Publisher != nil {
pubId = aliasInfo.Publisher.ID
}
// 保存到数据库
if s.DB != nil {
dbAlias := &StreamAliasDB{
AliasStream: aliasInfo,
}
s.DB.Create(dbAlias)
}
s.Info("add alias", "alias", req.Alias, "streamPath", req.StreamPath, "replace", ok && canReplace, "pub", pubId)
}
} else {
s.Info("remove alias", "alias", req.Alias)
if aliasStream, ok := s.AliasStreams.Get(req.Alias); ok {
s.AliasStreams.Remove(aliasStream)
// 从数据库中删除
if s.DB != nil {
s.DB.Where("alias = ?", req.Alias).Delete(&StreamAliasDB{})
}
if aliasStream.Publisher != nil {
if publisher, hasTarget := s.Streams.Get(req.Alias); hasTarget { // restore stream
aliasStream.TransferSubscribers(publisher)
} else {
// 优先使用别名保存的查询参数
args := aliasStream.Args
if len(args) == 0 {
// 如果没有保存的查询参数,则从订阅者中获取
for sub := range aliasStream.Publisher.SubscriberRange {
if sub.StreamPath == req.Alias {
aliasStream.Publisher.RemoveSubscriber(sub)
s.Waiting.Wait(sub)
args = sub.Args
break
}
}
}
if len(args) > 0 {
s.OnSubscribe(req.Alias, args)
}
}
}
}
}
})
return
}
func (p *Publisher) processAliasOnStart() {
s := p.Plugin.Server
for alias := range s.AliasStreams.Range {
if alias.StreamPath != p.StreamPath {
continue
}
if alias.Publisher == nil {
alias.Publisher = p
s.Waiting.WakeUp(alias.Alias, p)
} else if alias.Publisher.StreamPath != alias.StreamPath {
alias.Publisher.TransferSubscribers(p)
alias.Publisher = p
}
}
}
func (p *Publisher) processAliasOnDispose() {
s := p.Plugin.Server
var relatedAlias []*AliasStream
for alias := range s.AliasStreams.Range {
if alias.StreamPath == p.StreamPath {
if alias.AutoRemove {
defer s.AliasStreams.Remove(alias)
if s.DB != nil {
defer s.DB.Where("alias = ?", alias.Alias).Delete(&StreamAliasDB{})
}
}
alias.Publisher = nil
relatedAlias = append(relatedAlias, alias)
}
}
if p.Subscribers.Length > 0 {
SUBSCRIBER:
for subscriber := range p.SubscriberRange {
for _, alias := range relatedAlias {
if subscriber.StreamPath == alias.Alias {
if originStream, ok := s.Streams.Get(alias.Alias); ok {
originStream.AddSubscriber(subscriber)
continue SUBSCRIBER
}
}
}
s.Waiting.Wait(subscriber)
}
p.Subscribers.Clear()
}
}
func (s *Subscriber) processAliasOnStart() (hasInvited bool, done bool) {
server := s.Plugin.Server
if alias, ok := server.AliasStreams.Get(s.StreamPath); ok {
if alias.Publisher != nil {
alias.Publisher.AddSubscriber(s)
done = true
return
} else {
// 合并参数:先使用别名保存的参数,然后用订阅者传入的参数覆盖同名参数
args := make(url.Values)
// 先复制别名保存的参数
if alias.Args != nil {
for k, v := range alias.Args {
args[k] = append([]string(nil), v...)
}
}
// 用订阅者传入的参数覆盖同名参数
if s.Args != nil {
for k, v := range s.Args {
args[k] = append([]string(nil), v...)
}
}
server.OnSubscribe(alias.StreamPath, args)
hasInvited = true
}
} else {
for reg, alias := range server.StreamAlias {
if streamPath := reg.Replace(s.StreamPath, alias); streamPath != "" {
as := AliasStream{
StreamPath: streamPath,
Alias: s.StreamPath,
Args: s.Args,
}
server.AliasStreams.Set(&as)
if server.DB != nil {
dbAlias := &StreamAliasDB{
AliasStream: as,
}
server.DB.Where("alias = ?", s.StreamPath).Save(dbAlias)
}
if publisher, ok := server.Streams.Get(streamPath); ok {
publisher.AddSubscriber(s)
done = true
return
} else {
server.OnSubscribe(streamPath, s.Args)
hasInvited = true
}
break
}
}
}
return
}