diff --git a/pkg/error.go b/pkg/error.go index 3d63102..23c2ccf 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -9,10 +9,12 @@ var ( ErrRecordExists = errors.New("record exists") ErrKick = errors.New("kick") ErrDiscard = errors.New("discard") + ErrPublishMaxCount = errors.New("publish max count exceeded") ErrPublishTimeout = errors.New("publish timeout") ErrPublishIdleTimeout = errors.New("publish idle timeout") ErrPublishDelayCloseTimeout = errors.New("publish delay close timeout") ErrPushRemoteURLExist = errors.New("push remote url exist") + ErrSubscribeMaxCount = errors.New("subscribe max count exceeded") ErrSubscribeTimeout = errors.New("subscribe timeout") ErrRestart = errors.New("restart") ErrInterrupt = errors.New("interrupt") diff --git a/publisher.go b/publisher.go index 6cda76f..c6ea8fd 100644 --- a/publisher.go +++ b/publisher.go @@ -145,6 +145,9 @@ func (p *Publisher) Start() (err error) { return ErrStreamExist } } + if p.MaxCount > 0 && s.Streams.Length >= p.MaxCount { + return ErrPublishMaxCount + } s.Streams.Set(p) p.Info("publish") p.processPullProxyOnStart() diff --git a/subscriber.go b/subscriber.go index 45a66f1..4905397 100644 --- a/subscriber.go +++ b/subscriber.go @@ -103,7 +103,11 @@ func (s *Subscriber) waitingPublish() bool { func (s *Subscriber) Start() (err error) { server := s.Plugin.Server - server.Subscribers.Add(s) + defer func() { + if err == nil { + server.Subscribers.Add(s) + } + }() s.Info("subscribe") hasInvited, done := s.processAliasOnStart() if done { @@ -111,6 +115,9 @@ func (s *Subscriber) Start() (err error) { } if publisher, ok := server.Streams.Get(s.StreamPath); ok { + if s.MaxCount > 0 && publisher.Subscribers.Length >= s.MaxCount { + return ErrSubscribeMaxCount + } publisher.AddSubscriber(s) } else { server.Waiting.Wait(s)