diff --git a/api.go b/api.go index 5f3d328..dce8aaa 100644 --- a/api.go +++ b/api.go @@ -6,9 +6,8 @@ import ( "errors" "net" "net/http" - "regexp" + "net/url" "runtime" - "slices" "strings" "time" @@ -608,48 +607,51 @@ func (s *Server) RemoveDevice(ctx context.Context, req *pb.RequestWithId) (res * } func (s *Server) SetStreamAlias(ctx context.Context, req *pb.SetStreamAliasRequest) (res *pb.SuccessResponse, err error) { + res = &pb.SuccessResponse{} s.Streams.Call(func() error { - reg := config.Regexp{ - Regexp: regexp.MustCompile(req.Alias), - } if req.StreamPath != "" { - s.StreamAlias = append(s.StreamAlias, StreamAlias{ - Alias: reg, - Path: req.StreamPath, - AutoRemove: req.AutoRemove, - }) - for publisher := range s.Streams.Range { - if streamPath := reg.Replace(publisher.StreamPath, req.StreamPath); streamPath != "" { - if publisher2, ok := s.Streams.Get(streamPath); ok { - for subscriber := range publisher.Subscribers.Range { - publisher.RemoveSubscriber(subscriber) - subscriber.setAlias(reg, streamPath) - publisher2.AddSubscriber(subscriber) + u, err := url.Parse(req.StreamPath) + if err != nil { + return err + } + req.StreamPath = strings.TrimPrefix(u.Path, "/") + publisher, canReplace := s.Streams.Get(req.StreamPath) + if !canReplace { + defer s.OnSubscribe(req.StreamPath, u.Query()) + } + if aliasStream, ok := s.AliasStreams.Get(req.Alias); ok { //modify alias + aliasStream.AutoRemove = req.AutoRemove + if aliasStream.StreamPath != req.StreamPath { + aliasStream.StreamPath = req.StreamPath + if canReplace { + if aliasStream.Publisher != nil { + aliasStream.TransferSubscribers(publisher) // replace stream + } else { + s.Waiting.WakeUp(req.Alias, publisher) } } } - } - for waitStream := range s.Waiting.Range { - if streamPath := reg.Replace(waitStream.StreamPath, req.StreamPath); streamPath != "" { - if publisher2, ok := s.Streams.Get(streamPath); ok { - for subscriber := range waitStream.Range { - waitStream.Remove(subscriber) - subscriber.setAlias(reg, streamPath) - publisher2.AddSubscriber(subscriber) - } + } else { // create alias + s.AliasStreams.Add(&AliasStream{ + AutoRemove: req.AutoRemove, + StreamPath: req.StreamPath, + Alias: req.Alias, + }) + if canReplace { + if aliasStream, ok := s.Streams.Get(req.Alias); ok { + aliasStream.TransferSubscribers(publisher) // replace stream + } else { + s.Waiting.WakeUp(req.Alias, publisher) } } } } else { - for i, alias := range s.StreamAlias { - if alias.Alias.String() == req.Alias { - for subscriber := range s.Subscribers.Range { - if subscriber.AliasKey == alias.Alias { - subscriber.removeAlias() - } + if aliasStream, ok := s.AliasStreams.Get(req.Alias); ok { + s.AliasStreams.Remove(aliasStream) + if aliasStream.Publisher != nil { + if publisher, hasTarget := s.Streams.Get(req.Alias); hasTarget { // restore stream + aliasStream.TransferSubscribers(publisher) } - s.StreamAlias = slices.Delete(s.StreamAlias, i, i+1) - break } } } diff --git a/plugin.go b/plugin.go index 9d68d58..d49fb95 100644 --- a/plugin.go +++ b/plugin.go @@ -352,6 +352,8 @@ func (p *Plugin) OnInit() error { func (p *Plugin) OnStop() { } + +// TODO: use alias stream func (p *Plugin) OnPublish(pub *Publisher) { onPublish := p.config.OnPub if p.Meta.Pusher != nil { @@ -391,7 +393,7 @@ func (p *Plugin) OnPublish(pub *Publisher) { } } } -func (p *Plugin) OnSubscribe(sub *Subscriber) { +func (p *Plugin) OnSubscribe(streamPath string, args url.Values) { // var avoidTrans bool //AVOID: // for trans := range server.Transforms.Range { @@ -404,13 +406,13 @@ func (p *Plugin) OnSubscribe(sub *Subscriber) { // } for reg, conf := range p.config.OnSub.Pull { if p.Meta.Puller != nil { - conf.Args = sub.Args - conf.URL = reg.Replace(sub.StreamPath, conf.URL) - p.handler.Pull(sub.StreamPath, conf) + conf.Args = args + conf.URL = reg.Replace(streamPath, conf.URL) + p.handler.Pull(streamPath, conf) } } for device := range p.Server.Devices.Range { - if device.Status == DeviceStatusOnline && device.GetStreamPath() == sub.StreamPath { + if device.Status == DeviceStatusOnline && device.GetStreamPath() == streamPath { device.Handler.Pull() } } diff --git a/publisher.go b/publisher.go index e16953f..a178791 100644 --- a/publisher.go +++ b/publisher.go @@ -135,6 +135,17 @@ type Publisher struct { dumpFile *os.File } +type AliasStream struct { + *Publisher + AutoRemove bool + StreamPath string + Alias string +} + +func (a *AliasStream) GetKey() string { + return a.Alias +} + func (p *Publisher) SubscriberRange(yield func(sub *Subscriber) bool) { p.Subscribers.Range(yield) } @@ -180,12 +191,19 @@ func (p *Publisher) Start() (err error) { os.MkdirAll(filepath.Dir(f), 0666) p.dumpFile, _ = os.OpenFile(filepath.Join("./dump", p.StreamPath), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) } - if waiting, ok := s.Waiting.Get(p.StreamPath); ok { - for subscriber := range waiting.Range { - p.AddSubscriber(subscriber) + + s.Waiting.WakeUp(p.StreamPath, p) + + for alias := range s.AliasStreams.Range { + if alias.StreamPath == p.StreamPath && alias.Publisher == nil { + alias.Publisher = p + s.Waiting.WakeUp(alias.Alias, p) + } else if alias.Publisher.StreamPath != alias.StreamPath { + alias.Publisher.TransferSubscribers(p) + alias.Publisher = p } - s.Waiting.Remove(waiting) } + for plugin := range s.Plugins.Range { plugin.OnPublish(p) } @@ -540,43 +558,52 @@ func (p *Publisher) Dispose() { if !p.StopReasonIs(ErrKick) { s.Streams.Remove(p) } - if p.Subscribers.Length > 0 { - w := s.createWait(p.StreamPath) - if p.HasAudioTrack() { - w.baseTsAudio = p.AudioTrack.LastTs - } - if p.HasVideoTrack() { - w.baseTsVideo = p.VideoTrack.LastTs - } - w.Info("takeOver", "pId", p.ID) - for subscriber := range p.SubscriberRange { - if subscriber.AliasStreamPath != "" { - subscriber.removeAlias() - } else { - subscriber.Publisher = nil - w.Add(subscriber) + for alias := range s.AliasStreams.Range { + if alias.Alias == p.StreamPath { + if alias.AutoRemove { + s.AliasStreams.Remove(alias) + } + for subscriber := range p.SubscriberRange { + if subscriber.StreamPath == alias.StreamPath { + if originStream, ok := s.Streams.Get(alias.StreamPath); ok { + p.Subscribers.Remove(subscriber) + originStream.AddSubscriber(subscriber) + } + } } } - if w.Length == 0 { - s.Waiting.Remove(w) + } + + if p.Subscribers.Length > 0 { + for subscriber := range p.SubscriberRange { + s.Waiting.Wait(subscriber) } - p.AudioTrack.Dispose() - p.VideoTrack.Dispose() p.Subscribers.Clear() } + p.AudioTrack.Dispose() + p.VideoTrack.Dispose() p.Info("unpublish", "remain", s.Streams.Length, "reason", p.StopReason()) if p.dumpFile != nil { p.dumpFile.Close() } p.State = PublisherStateDisposed - var remainAlias []StreamAlias - for _, alias := range s.StreamAlias { - if alias.Path == p.StreamPath && alias.AutoRemove { - continue - } - remainAlias = append(remainAlias, alias) + +} + +func (p *Publisher) TransferSubscribers(newPublisher *Publisher) { + for subscriber := range p.SubscriberRange { + newPublisher.AddSubscriber(subscriber) + } + p.Subscribers.Clear() + p.BufferTime = p.Plugin.GetCommonConf().Publish.BufferTime + p.AudioTrack.SetMinBuffer(p.BufferTime) + p.VideoTrack.SetMinBuffer(p.BufferTime) + if p.State == PublisherStateSubscribed { + p.State = PublisherStateWaitSubscriber + if p.DelayCloseTimeout > 0 { + p.TimeoutTimer.Reset(p.DelayCloseTimeout) + } } - s.StreamAlias = remainAlias } func (p *Publisher) takeOver(old *Publisher) { diff --git a/server.go b/server.go index ec32403..0b7e801 100644 --- a/server.go +++ b/server.go @@ -73,10 +73,8 @@ type ( } } WaitStream struct { - *slog.Logger StreamPath string SubscriberCollection - baseTsAudio, baseTsVideo time.Duration } Server struct { pb.UnimplementedApiServer @@ -84,7 +82,8 @@ type ( ServerConfig Plugins util.Collection[string, *Plugin] Streams task.Manager[string, *Publisher] - Waiting util.Collection[string, *WaitStream] + AliasStreams util.Collection[string, *AliasStream] + Waiting WaitManager Pulls task.Manager[string, *PullJob] Pushs task.Manager[string, *PushJob] Records task.Manager[string, *RecordJob] @@ -185,6 +184,7 @@ func (s *Server) Start() (err error) { s.LogHandler.SetLevel(slog.LevelDebug) s.LogHandler.Add(defaultLogHandler) s.Logger = slog.New(&s.LogHandler).With("server", s.ID) + s.Waiting.Logger = s.Logger mux := runtime.NewServeMux(runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithForwardResponseOption(func(ctx context.Context, w http.ResponseWriter, m proto.Message) error { header := w.Header() header.Set("Access-Control-Allow-Credentials", "true") @@ -378,16 +378,7 @@ func (c *CheckSubWaitTimeout) Tick(any) { c.Info("tick", "cpu", cpu, "streams", c.s.Streams.Length, "subscribers", c.s.Subscribers.Length, "waits", c.s.Waiting.Length) } } - - for waits := range c.s.Waiting.Range { - for sub := range waits.Range { - select { - case <-sub.TimeoutTimer.C: - sub.Stop(ErrSubscribeTimeout) - default: - } - } - } + c.s.Waiting.checkTimeout() } func (gRPC *GRPCServer) Dispose() { @@ -412,14 +403,10 @@ func (s *Server) Dispose() { } } -func (s *Server) createWait(streamPath string) *WaitStream { - newPublisher := &WaitStream{ - StreamPath: streamPath, - Logger: s.Logger.With("streamPath", streamPath), +func (s *Server) OnSubscribe(streamPath string, args url.Values) { + for plugin := range s.Plugins.Range { + plugin.OnSubscribe(streamPath, args) } - s.Info("createWait") - s.Waiting.Set(newPublisher) - return newPublisher } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/subscriber.go b/subscriber.go index 1beb111..5cf677d 100644 --- a/subscriber.go +++ b/subscriber.go @@ -62,8 +62,6 @@ type SubscriberCollection = util.Collection[uint32, *Subscriber] type Subscriber struct { PubSubBase config.Subscribe - AliasStreamPath string - AliasKey config.Regexp Publisher *Publisher waitPublishDone *util.Promise AudioReader, VideoReader *AVRingReader @@ -83,67 +81,41 @@ func createSubscriber(p *Plugin, streamPath string, conf config.Subscribe) *Subs return subscriber } -func (s *Subscriber) setAlias(key config.Regexp, streamPath string) { - s.AliasKey = key - s.AliasStreamPath = s.StreamPath - s.StreamPath = streamPath - s.SetDescription("streamPath", streamPath) - s.SetDescription("alias", s.AliasStreamPath) -} - -func (s *Subscriber) removeAlias() { - server := s.Plugin.Server - if s.Publisher != nil { - s.Publisher.RemoveSubscriber(s) - } else { - if waitStream, ok := server.Waiting.Get(s.StreamPath); ok { - waitStream.Remove(s) - } - } - s.StreamPath = s.AliasStreamPath - s.AliasStreamPath = "" - s.AliasKey = config.Regexp{} - s.RemoveDescription("alias") - s.SetDescription("streamPath", s.StreamPath) - if publisher, ok := server.Streams.Get(s.StreamPath); ok { - publisher.AddSubscriber(s) - return - } else { - if waitStream, ok := server.Waiting.Get(s.StreamPath); ok { - waitStream.Add(s) - } else { - server.createWait(s.StreamPath).Add(s) - } - for plugin := range server.Plugins.Range { - plugin.OnSubscribe(s) - } - } -} - func (s *Subscriber) Start() (err error) { server := s.Plugin.Server server.Subscribers.Add(s) s.Info("subscribe") - for _, alias := range server.StreamAlias { - if streamPath := alias.Alias.Replace(s.StreamPath, alias.Path); streamPath != "" { - s.setAlias(alias.Alias, streamPath) - break + if alias, ok := server.AliasStreams.Get(s.StreamPath); ok { + if alias.Publisher != nil { + alias.Publisher.AddSubscriber(s) + return + } else { + server.OnSubscribe(alias.StreamPath, s.Args) + } + } else { + for _, alias := range server.StreamAlias { + if streamPath := alias.Alias.Replace(s.StreamPath, alias.Path); streamPath != "" { + server.AliasStreams.Set(&AliasStream{ + StreamPath: streamPath, + Alias: s.StreamPath, + }) + if publisher, ok := server.Streams.Get(streamPath); ok { + publisher.AddSubscriber(s) + return + } else { + server.OnSubscribe(streamPath, s.Args) + } + break + } } } - if publisher, ok := server.Streams.Get(s.StreamPath); ok { publisher.AddSubscriber(s) return } else { - if waitStream, ok := server.Waiting.Get(s.StreamPath); ok { - waitStream.Add(s) - } else { - server.createWait(s.StreamPath).Add(s) - } - for plugin := range server.Plugins.Range { - plugin.OnSubscribe(s) - } + server.Waiting.Wait(s) + server.OnSubscribe(s.StreamPath, s.Args) } return } @@ -153,8 +125,8 @@ func (s *Subscriber) Dispose() { s.Info("unsubscribe", "reason", s.StopReason()) if s.Publisher != nil { s.Publisher.RemoveSubscriber(s) - } else if waitStream, ok := s.Plugin.Server.Waiting.Get(s.StreamPath); ok { - waitStream.Remove(s) + } else { + s.Plugin.Server.Waiting.Leave(s) } } diff --git a/wait-stream.go b/wait-stream.go new file mode 100644 index 0000000..309dde5 --- /dev/null +++ b/wait-stream.go @@ -0,0 +1,55 @@ +package m7s + +import ( + "log/slog" + + . "m7s.live/m7s/v5/pkg" + "m7s.live/m7s/v5/pkg/util" +) + +type WaitManager struct { + *slog.Logger + util.Collection[string, *WaitStream] +} + +func (w *WaitManager) Wait(subscriber *Subscriber) *WaitStream { + subscriber.Publisher = nil + if waiting, ok := w.Get(subscriber.StreamPath); ok { + waiting.Add(subscriber) + return waiting + } else { + waiting := &WaitStream{ + StreamPath: subscriber.StreamPath, + } + w.Set(waiting) + waiting.Add(subscriber) + return waiting + } +} + +func (w *WaitManager) WakeUp(streamPath string, publisher *Publisher) { + if waiting, ok := w.Get(streamPath); ok { + for subscriber := range waiting.Range { + publisher.AddSubscriber(subscriber) + } + w.Remove(waiting) + } +} + +func (w *WaitManager) checkTimeout() { + for waits := range w.Range { + for sub := range waits.Range { + select { + case <-sub.TimeoutTimer.C: + sub.Stop(ErrSubscribeTimeout) + default: + } + } + } +} + +func (w *WaitManager) Leave(s *Subscriber) { + if waitStream, ok := w.Get(s.StreamPath); ok { + waitStream.Remove(s) + } +}