feat: dynamic select

This commit is contained in:
langhuihui
2024-03-29 17:15:17 +08:00
parent d286f1b6e1
commit 8c4865b434
9 changed files with 234 additions and 131 deletions

View File

@@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
_ "embed" _ "embed"
"net" "net"
"runtime"
"time" "time"
) )
@@ -21,9 +22,56 @@ type TCP struct {
ListenNum int `desc:"同时并行监听数量0为CPU核心数量"` //同时并行监听数量0为CPU核心数量 ListenNum int `desc:"同时并行监听数量0为CPU核心数量"` //同时并行监听数量0为CPU核心数量
NoDelay bool `desc:"是否禁用Nagle算法"` //是否禁用Nagle算法 NoDelay bool `desc:"是否禁用Nagle算法"` //是否禁用Nagle算法
KeepAlive bool `desc:"是否启用KeepAlive"` //是否启用KeepAlive KeepAlive bool `desc:"是否启用KeepAlive"` //是否启用KeepAlive
listener net.Listener
listenerTls net.Listener
} }
func (tcp *TCP) Listen(l net.Listener, handler func(*net.TCPConn)) { func (tcp *TCP) StopListen() {
if tcp.listener != nil {
tcp.listener.Close()
}
if tcp.listenerTls != nil {
tcp.listenerTls.Close()
}
}
func (tcp *TCP) Listen(handler func(*net.TCPConn)) (err error) {
tcp.listener, err = net.Listen("tcp", tcp.ListenAddr)
if err == nil {
count := tcp.ListenNum
if count == 0 {
count = runtime.NumCPU()
}
for range count {
tcp.listen(tcp.listener, handler)
}
}
return
}
func (tcp *TCP) ListenTLS(handler func(*net.TCPConn)) (err error) {
keyPair, _ := tls.X509KeyPair(LocalCert, LocalKey)
if tcp.CertFile != "" || tcp.KeyFile != "" {
keyPair, err = tls.LoadX509KeyPair(tcp.CertFile, tcp.KeyFile)
}
if err == nil {
tcp.listenerTls, err = tls.Listen("tcp", tcp.ListenAddrTLS, &tls.Config{
Certificates: []tls.Certificate{keyPair},
})
if err == nil {
count := tcp.ListenNum
if count == 0 {
count = runtime.NumCPU()
}
for range count {
tcp.listen(tcp.listenerTls, handler)
}
}
}
return
}
func (tcp *TCP) listen(l net.Listener, handler func(*net.TCPConn)) {
var tempDelay time.Duration var tempDelay time.Duration
for { for {
conn, err := l.Accept() conn, err := l.Accept()

View File

@@ -3,9 +3,11 @@ package pkg
import ( import (
"context" "context"
"log/slog" "log/slog"
"time"
) )
type Unit struct { type Unit struct {
StartTime time.Time
*slog.Logger `json:"-" yaml:"-"` *slog.Logger `json:"-" yaml:"-"`
context.Context `json:"-" yaml:"-"` context.Context `json:"-" yaml:"-"`
context.CancelCauseFunc `json:"-" yaml:"-"` context.CancelCauseFunc `json:"-" yaml:"-"`

View File

@@ -31,12 +31,12 @@ type IPool[T any] interface {
Clear() Clear()
} }
type RecyclebleMemory struct { type RecyclableMemory struct {
IPool[[]byte] IPool[[]byte]
Data net.Buffers Data net.Buffers
} }
func (r *RecyclebleMemory) Recycle() { func (r *RecyclableMemory) Recycle() {
if r.IPool != nil { if r.IPool != nil {
for _, b := range r.Data { for _, b := range r.Data {
r.Put(b) r.Put(b)

103
plugin.go
View File

@@ -2,17 +2,14 @@ package m7s
import ( import (
"context" "context"
"crypto/tls"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"sync" "sync"
"github.com/logrusorgru/aurora/v4"
"github.com/mcuadros/go-defaults" "github.com/mcuadros/go-defaults"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
. "m7s.live/m7s/v5/pkg" . "m7s.live/m7s/v5/pkg"
@@ -36,7 +33,6 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) {
p.handler = instance p.handler = instance
p.Meta = plugin p.Meta = plugin
p.server = s p.server = s
p.eventChan = make(chan any, 10)
p.Logger = s.Logger.With("plugin", plugin.Name) p.Logger = s.Logger.With("plugin", plugin.Name)
p.Context, p.CancelCauseFunc = context.WithCancelCause(s.Context) p.Context, p.CancelCauseFunc = context.WithCancelCause(s.Context)
s.Plugins = append(s.Plugins, p) s.Plugins = append(s.Plugins, p)
@@ -128,10 +124,9 @@ func sendPromiseToServer[T any](server *Server, value T) (err error) {
type Plugin struct { type Plugin struct {
Unit Unit
Disabled bool Disabled bool
Meta *PluginMeta Meta *PluginMeta
eventChan chan any config config.Common
config config.Common
config.Config config.Config
Publishers []*Publisher Publishers []*Publisher
handler IPlugin handler IPlugin
@@ -161,19 +156,23 @@ func (p *Plugin) assign() {
// p.registerHandler() // p.registerHandler()
} }
func (p *Plugin) Stop(err error) {
p.Unit.Stop(err)
p.config.HTTP.StopListen()
p.config.TCP.StopListen()
}
func (p *Plugin) Start() { func (p *Plugin) Start() {
var err error
httpConf := p.config.HTTP httpConf := p.config.HTTP
defer httpConf.StopListen()
if httpConf.ListenAddrTLS != "" && (httpConf.ListenAddrTLS != p.server.config.HTTP.ListenAddrTLS) { if httpConf.ListenAddrTLS != "" && (httpConf.ListenAddrTLS != p.server.config.HTTP.ListenAddrTLS) {
go func() { go func() {
p.Info("https listen at ", "addr", aurora.Blink(httpConf.ListenAddrTLS)) p.Info("https listen at ", "addr", httpConf.ListenAddrTLS)
p.Stop(httpConf.ListenTLS()) p.Stop(httpConf.ListenTLS())
}() }()
} }
if httpConf.ListenAddr != "" && (httpConf.ListenAddr != p.server.config.HTTP.ListenAddr) { if httpConf.ListenAddr != "" && (httpConf.ListenAddr != p.server.config.HTTP.ListenAddr) {
go func() { go func() {
p.Info("http listen at ", "addr", aurora.Blink(httpConf.ListenAddr)) p.Info("http listen at ", "addr", httpConf.ListenAddr)
p.Stop(httpConf.Listen()) p.Stop(httpConf.Listen())
}() }()
} }
@@ -182,83 +181,39 @@ func (p *Plugin) Start() {
if !ok { if !ok {
tcphandler = p tcphandler = p
} }
count := p.config.TCP.ListenNum
if count == 0 {
count = runtime.NumCPU()
}
if p.config.TCP.ListenAddr != "" { if p.config.TCP.ListenAddr != "" {
l, err := net.Listen("tcp", tcpConf.ListenAddr) p.Info("listen tcp", "addr", tcpConf.ListenAddr)
err := tcpConf.Listen(tcphandler.OnTCPConnect)
if err != nil { if err != nil {
p.Error("listen tcp", "addr", tcpConf.ListenAddr, "error", err) p.Error("listen tcp", "addr", tcpConf.ListenAddr, "error", err)
p.Stop(err) p.Stop(err)
return return
} }
defer l.Close()
p.Info("listen tcp", "addr", tcpConf.ListenAddr)
for i := 0; i < count; i++ {
go tcpConf.Listen(l, tcphandler.OnTCPConnect)
}
} }
if tcpConf.ListenAddrTLS != "" { if tcpConf.ListenAddrTLS != "" {
keyPair, _ := tls.X509KeyPair(config.LocalCert, config.LocalKey) p.Info("listen tcp tls", "addr", tcpConf.ListenAddrTLS)
if tcpConf.CertFile != "" || tcpConf.KeyFile != "" { err := tcpConf.ListenTLS(tcphandler.OnTCPConnect)
keyPair, err = tls.LoadX509KeyPair(tcpConf.CertFile, tcpConf.KeyFile)
}
if err != nil {
p.Error("LoadX509KeyPair", "error", err)
p.Stop(err)
return
}
l, err := tls.Listen("tcp", tcpConf.ListenAddrTLS, &tls.Config{
Certificates: []tls.Certificate{keyPair},
})
if err != nil { if err != nil {
p.Error("listen tcp tls", "addr", tcpConf.ListenAddrTLS, "error", err) p.Error("listen tcp tls", "addr", tcpConf.ListenAddrTLS, "error", err)
p.Stop(err) p.Stop(err)
return return
} }
defer l.Close() }
p.Info("listen tcp tls", "addr", tcpConf.ListenAddrTLS) }
for i := 0; i < count; i++ {
go tcpConf.Listen(l, tcphandler.OnTCPConnect) func (p *Plugin) OnInit() {
}
} }
for {
select { func (p *Plugin) onEvent(event any) {
case event := <-p.eventChan: switch v := event.(type) {
// switch event.(type) { case *Publisher:
// case *Subscriber: if h, ok := p.handler.(interface{ OnPublish(*Publisher) }); ok {
// } h.OnPublish(v)
p.handler.OnEvent(event)
case <-p.Done():
return
default:
for i := 0; i < len(p.Publishers); i++ {
publisher := p.Publishers[i]
select {
case <-publisher.Done():
if publisher.Closer != nil {
publisher.Closer.Close()
}
p.Publishers = slices.Delete(p.Publishers, i, i+1)
i--
p.server.eventChan <- UnpublishEvent{Publisher: publisher}
case <-publisher.TimeoutTimer.C:
if err := publisher.timeout(); err != nil {
publisher.Stop(err)
}
default:
for subscriber := range publisher.Subscribers {
select {
case <-subscriber.Done():
subscriber.Publisher.RemoveSubscriber(subscriber)
default:
}
}
}
}
} }
} }
p.handler.OnEvent(event)
} }
func (p *Plugin) OnEvent(event any) { func (p *Plugin) OnEvent(event any) {

View File

@@ -1,22 +1,68 @@
package demo package demo
import ( import (
"time"
"m7s.live/m7s/v5" "m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/util"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
) )
type AnnexB struct {
PTS time.Duration
DTS time.Duration
util.RecyclableMemory
}
// DecodeConfig implements pkg.IAVFrame.
func (a *AnnexB) DecodeConfig(*pkg.AVTrack) error {
panic("unimplemented")
}
// FromRaw implements pkg.IAVFrame.
func (a *AnnexB) FromRaw(t *pkg.AVTrack, raw any) error {
var nalus = raw.(pkg.Nalus)
a.PTS = nalus.PTS
a.DTS = nalus.DTS
return nil
}
// GetTimestamp implements pkg.IAVFrame.
func (a *AnnexB) GetTimestamp() time.Duration {
return a.DTS / 90
}
// IsIDR implements pkg.IAVFrame.
func (a *AnnexB) IsIDR() bool {
return false
}
// ToRaw implements pkg.IAVFrame.
func (a *AnnexB) ToRaw(*pkg.AVTrack) (any, error) {
return a.Data, nil
}
type DemoPlugin struct { type DemoPlugin struct {
m7s.Plugin m7s.Plugin
} }
func (p *DemoPlugin) OnInit() { func (p *DemoPlugin) OnInit() {
// puber, err := p.Publish("live/demo") publisher, err := p.Publish("live/demo")
// if err != nil { if err == nil {
// return var annexB AnnexB
// } publisher.WriteVideo(&annexB)
// puber.WriteVideo(&rtmp.RTMPVideo{ }
// Timestamp: 0, }
// Buffers: net.Buffers{[]byte{0x17, 0x00, 0x67, 0x42, 0x00, 0x0a, 0x8f, 0x14, 0x01, 0x00, 0x00, 0x03, 0x00, 0x80, 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, 0x3c, 0x80}},
// }) func (p *DemoPlugin) OnPublish(publisher *m7s.Publisher) {
subscriber, err := p.Subscribe(publisher.StreamPath)
if err == nil {
go subscriber.Handle(nil, func(v *rtmp.RTMPVideo) {
})
}
} }
var _ = m7s.InstallPlugin[*DemoPlugin]() var _ = m7s.InstallPlugin[*DemoPlugin]()

View File

@@ -21,7 +21,7 @@ var FourCC_AV1 = [4]byte{'a', 'v', '0', '1'}
type RTMPData struct { type RTMPData struct {
Timestamp uint32 Timestamp uint32
util.Buffers util.Buffers
util.RecyclebleMemory util.RecyclableMemory
} }
func (avcc *RTMPData) GetTimestamp() time.Duration { func (avcc *RTMPData) GetTimestamp() time.Duration {

View File

@@ -18,10 +18,6 @@ const (
PublisherStateWaitSubscriber PublisherStateWaitSubscriber
) )
type UnpublishEvent struct {
*Publisher
}
type Publisher struct { type Publisher struct {
PubSubBase PubSubBase
sync.RWMutex sync.RWMutex
@@ -53,11 +49,16 @@ func (p *Publisher) timeout() (err error) {
} }
func (p *Publisher) checkTimeout() (err error) { func (p *Publisher) checkTimeout() (err error) {
if p.VideoTrack != nil && !p.VideoTrack.LastValue.WriteTime.IsZero() && time.Since(p.VideoTrack.LastValue.WriteTime) > p.PublishTimeout { select {
err = ErrPublishTimeout case <-p.TimeoutTimer.C:
} err = p.timeout()
if p.AudioTrack != nil && !p.AudioTrack.LastValue.WriteTime.IsZero() && time.Since(p.AudioTrack.LastValue.WriteTime) > p.PublishTimeout { default:
err = ErrPublishTimeout if p.VideoTrack != nil && !p.VideoTrack.LastValue.WriteTime.IsZero() && time.Since(p.VideoTrack.LastValue.WriteTime) > p.PublishTimeout {
err = ErrPublishTimeout
}
if p.AudioTrack != nil && !p.AudioTrack.LastValue.WriteTime.IsZero() && time.Since(p.AudioTrack.LastValue.WriteTime) > p.PublishTimeout {
err = ErrPublishTimeout
}
} }
return return
} }
@@ -66,9 +67,6 @@ func (p *Publisher) RemoveSubscriber(subscriber *Subscriber) (err error) {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
delete(p.Subscribers, subscriber) delete(p.Subscribers, subscriber)
if subscriber.Closer != nil {
err = subscriber.Closer.Close()
}
if p.State == PublisherStateSubscribed && len(p.Subscribers) == 0 { if p.State == PublisherStateSubscribed && len(p.Subscribers) == 0 {
p.State = PublisherStateWaitSubscriber p.State = PublisherStateWaitSubscriber
if p.DelayCloseTimeout > 0 { if p.DelayCloseTimeout > 0 {
@@ -82,7 +80,6 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) (err error) {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
p.Subscribers[subscriber] = struct{}{} p.Subscribers[subscriber] = struct{}{}
subscriber.Publisher = p
switch p.State { switch p.State {
case PublisherStateTrackAdded, PublisherStateWaitSubscriber: case PublisherStateTrackAdded, PublisherStateWaitSubscriber:
p.State = PublisherStateSubscribed p.State = PublisherStateSubscribed
@@ -154,6 +151,11 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) {
p.Lock() p.Lock()
p.AudioTrack = t p.AudioTrack = t
p.TransTrack[reflect.TypeOf(data)] = t p.TransTrack[reflect.TypeOf(data)] = t
if len(p.Subscribers) > 0 {
p.State = PublisherStateSubscribed
} else {
p.State = PublisherStateTrackAdded
}
p.Unlock() p.Unlock()
} }
if t.ICodecCtx == nil { if t.ICodecCtx == nil {

107
server.go
View File

@@ -6,6 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"slices"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -28,20 +29,23 @@ var (
type Server struct { type Server struct {
Plugin Plugin
config.Engine config.Engine
StartTime time.Time eventChan chan any
Plugins []*Plugin Plugins []*Plugin
Publishers map[string]*Publisher Streams map[string]*Publisher
Waiting map[string][]*Subscriber Waiting map[string][]*Subscriber
pidG int Publishers []*Publisher
sidG int Subscribers []*Subscriber
pidG int
sidG int
} }
var DefaultServer = NewServer() var DefaultServer = NewServer()
func NewServer() *Server { func NewServer() *Server {
return &Server{ return &Server{
Publishers: make(map[string]*Publisher), Streams: make(map[string]*Publisher),
Waiting: make(map[string][]*Subscriber), Waiting: make(map[string][]*Subscriber),
eventChan: make(chan any, 10),
} }
} }
@@ -54,7 +58,6 @@ func (s *Server) Run(ctx context.Context, conf any) (err error) {
s.Context, s.CancelCauseFunc = context.WithCancelCause(ctx) s.Context, s.CancelCauseFunc = context.WithCancelCause(ctx)
s.config.HTTP.ListenAddrTLS = ":8443" s.config.HTTP.ListenAddrTLS = ":8443"
s.config.HTTP.ListenAddr = ":8080" s.config.HTTP.ListenAddr = ":8080"
s.eventChan = make(chan any, 10)
s.Info("start") s.Info("start")
var cg map[string]map[string]any var cg map[string]map[string]any
@@ -87,17 +90,37 @@ func (s *Server) Run(ctx context.Context, conf any) (err error) {
var lv slog.LevelVar var lv slog.LevelVar
lv.UnmarshalText([]byte(s.LogLevel)) lv.UnmarshalText([]byte(s.LogLevel))
slog.SetLogLoggerLevel(lv.Level()) slog.SetLogLoggerLevel(lv.Level())
s.initPlugins(cg) for _, plugin := range plugins {
plugin.Init(s, cg[strings.ToLower(plugin.Name)])
}
s.eventLoop()
s.Warn("Server is done", "reason", context.Cause(s))
for _, publisher := range s.Publishers {
publisher.Stop(nil)
}
for _, subscriber := range s.Subscribers {
subscriber.Stop(nil)
}
for _, p := range s.Plugins {
p.Stop(nil)
}
return
}
func (s *Server) eventLoop() {
pulse := time.NewTicker(s.PulseInterval) pulse := time.NewTicker(s.PulseInterval)
defer pulse.Stop()
cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(s.Done())}, {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(pulse.C)}, {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(s.eventChan)}}
var pubCount, subCount int
for { for {
select { switch chosen, rev, _ := reflect.Select(cases); chosen {
case <-s.Done(): case 0:
s.Warn("Server is done", "reason", context.Cause(s))
pulse.Stop()
return return
case <-pulse.C: case 1:
for _, publisher := range s.Publishers { for _, publisher := range s.Streams {
publisher.checkTimeout() if err := publisher.checkTimeout(); err != nil {
publisher.Stop(err)
}
} }
for subscriber := range s.Waiting { for subscriber := range s.Waiting {
for _, sub := range s.Waiting[subscriber] { for _, sub := range s.Waiting[subscriber] {
@@ -108,40 +131,59 @@ func (s *Server) Run(ctx context.Context, conf any) (err error) {
} }
} }
} }
case event := <-s.eventChan: case 2:
event := rev.Interface()
switch v := event.(type) { switch v := event.(type) {
case *util.Promise[*Publisher]: case *util.Promise[*Publisher]:
v.Fulfill(s.OnPublish(v.Value)) v.Fulfill(s.OnPublish(v.Value))
event = v.Value event = v.Value
if nl := len(s.Publishers); nl > pubCount {
pubCount = nl
if subCount == 0 {
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(v.Value.Done())})
} else {
cases = slices.Insert(cases, 3+pubCount, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(v.Value.Done())})
}
}
case *util.Promise[*Subscriber]: case *util.Promise[*Subscriber]:
v.Fulfill(s.OnSubscribe(v.Value)) v.Fulfill(s.OnSubscribe(v.Value))
if nl := len(s.Subscribers); nl > subCount {
subCount = nl
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(v.Value.Done())})
}
if !s.EnableSubEvent { if !s.EnableSubEvent {
continue continue
} }
event = v.Value event = v.Value
case UnpublishEvent:
s.onUnpublish(v.Publisher)
case UnsubscribeEvent:
} }
for _, plugin := range s.Plugins { for _, plugin := range s.Plugins {
if plugin.Disabled { if plugin.Disabled {
continue continue
} }
plugin.handler.OnEvent(event) plugin.onEvent(event)
} }
default:
if subStart := 3 + pubCount; chosen < subStart {
s.onUnpublish(s.Publishers[chosen-3])
pubCount--
s.Publishers = slices.Delete(s.Publishers, chosen-3, chosen-2)
} else {
i := chosen - subStart
s.onUnsubscribe(s.Subscribers[i])
subCount--
s.Subscribers = slices.Delete(s.Subscribers, i, i+1)
}
cases = slices.Delete(cases, chosen, chosen+1)
} }
} }
} }
func (s *Server) initPlugins(cg map[string]map[string]any) { func (s *Server) onUnsubscribe(subscriber *Subscriber) {
for _, plugin := range plugins { subscriber.Publisher.RemoveSubscriber(subscriber)
plugin.Init(s, cg[strings.ToLower(plugin.Name)])
}
} }
func (s *Server) onUnpublish(publisher *Publisher) { func (s *Server) onUnpublish(publisher *Publisher) {
delete(s.Publishers, publisher.StreamPath) delete(s.Streams, publisher.StreamPath)
for subscriber := range publisher.Subscribers { for subscriber := range publisher.Subscribers {
s.Waiting[publisher.StreamPath] = append(s.Waiting[publisher.StreamPath], subscriber) s.Waiting[publisher.StreamPath] = append(s.Waiting[publisher.StreamPath], subscriber)
subscriber.TimeoutTimer.Reset(publisher.WaitCloseTimeout) subscriber.TimeoutTimer.Reset(publisher.WaitCloseTimeout)
@@ -149,7 +191,7 @@ func (s *Server) onUnpublish(publisher *Publisher) {
} }
func (s *Server) OnPublish(publisher *Publisher) error { func (s *Server) OnPublish(publisher *Publisher) error {
if oldPublisher, ok := s.Publishers[publisher.StreamPath]; ok { if oldPublisher, ok := s.Streams[publisher.StreamPath]; ok {
if publisher.KickExist { if publisher.KickExist {
publisher.Warn("kick") publisher.Warn("kick")
oldPublisher.Stop(ErrKick) oldPublisher.Stop(ErrKick)
@@ -166,13 +208,13 @@ func (s *Server) OnPublish(publisher *Publisher) error {
publisher.Subscribers = make(map[*Subscriber]struct{}) publisher.Subscribers = make(map[*Subscriber]struct{})
publisher.TransTrack = make(map[reflect.Type]*AVTrack) publisher.TransTrack = make(map[reflect.Type]*AVTrack)
} }
s.Publishers[publisher.StreamPath] = publisher s.Streams[publisher.StreamPath] = publisher
s.Publishers = append(s.Publishers, publisher)
s.pidG++ s.pidG++
p := publisher.Plugin p := publisher.Plugin
publisher.ID = s.pidG publisher.ID = s.pidG
publisher.Logger = p.With("streamPath", publisher.StreamPath, "puber", publisher.ID) publisher.Logger = p.With("streamPath", publisher.StreamPath, "puber", publisher.ID)
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout) publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
p.Publishers = append(p.Publishers, publisher)
publisher.Info("publish") publisher.Info("publish")
if subscribers, ok := s.Waiting[publisher.StreamPath]; ok { if subscribers, ok := s.Waiting[publisher.StreamPath]; ok {
for _, subscriber := range subscribers { for _, subscriber := range subscribers {
@@ -188,8 +230,9 @@ func (s *Server) OnSubscribe(subscriber *Subscriber) error {
subscriber.ID = s.sidG subscriber.ID = s.sidG
subscriber.Logger = subscriber.Plugin.With("streamPath", subscriber.StreamPath, "suber", subscriber.ID) subscriber.Logger = subscriber.Plugin.With("streamPath", subscriber.StreamPath, "suber", subscriber.ID)
subscriber.TimeoutTimer = time.NewTimer(subscriber.Plugin.config.Subscribe.WaitTimeout) subscriber.TimeoutTimer = time.NewTimer(subscriber.Plugin.config.Subscribe.WaitTimeout)
s.Subscribers = append(s.Subscribers, subscriber)
subscriber.Info("subscribe") subscriber.Info("subscribe")
if publisher, ok := s.Publishers[subscriber.StreamPath]; ok { if publisher, ok := s.Streams[subscriber.StreamPath]; ok {
return publisher.AddSubscriber(subscriber) return publisher.AddSubscriber(subscriber)
} else { } else {
s.Waiting[subscriber.StreamPath] = append(s.Waiting[subscriber.StreamPath], subscriber) s.Waiting[subscriber.StreamPath] = append(s.Waiting[subscriber.StreamPath], subscriber)

View File

@@ -24,6 +24,13 @@ type PubSubBase struct {
io.Closer io.Closer
} }
func (ps *PubSubBase) Stop(reason error) {
ps.Unit.Stop(reason)
if ps.Closer != nil {
ps.Closer.Close()
}
}
func (ps *PubSubBase) Init(p *Plugin, streamPath string, options ...any) { func (ps *PubSubBase) Init(p *Plugin, streamPath string, options ...any) {
ps.Plugin = p ps.Plugin = p
ctx := p.Context ctx := p.Context