feat: Publisher Elegant Exit

This commit is contained in:
langhuihui
2024-03-28 20:39:50 +08:00
parent ad13914b44
commit d286f1b6e1
14 changed files with 293 additions and 161 deletions

View File

@@ -1,5 +1,5 @@
global:
loglevel: debug
loglevel: info
rtmp:
publish:
# pubvideo: false

View File

@@ -3,7 +3,11 @@ package pkg
import "errors"
var (
ErrStreamExist = errors.New("stream exist")
ErrKick = errors.New("kick")
ErrDiscard = errors.New("discard")
ErrStreamExist = errors.New("stream exist")
ErrKick = errors.New("kick")
ErrDiscard = errors.New("discard")
ErrPublishTimeout = errors.New("publish timeout")
ErrPublishIdleTimeout = errors.New("publish idle timeout")
ErrPublishDelayCloseTimeout = errors.New("publish delay close timeout")
ErrSubscribeTimeout = errors.New("subscribe timeout")
)

View File

@@ -1,17 +0,0 @@
package pkg
// EventBus is a simple event bus
type EventBus chan any
// NewEventBus creates a new EventBus
func NewEventBus(size int) EventBus {
return make(chan any, size)
}
// // Publish publishes an event
// func (e *EventBus) Publish(event any) {
// }
// // Subscribe subscribes to an event
// func (e *EventBus) Subscribe(event any, handler func(event any)) {
// }

7
pkg/event.go Normal file
View File

@@ -0,0 +1,7 @@
package pkg
type Event[T any] struct {
Type string
Data T
}

26
pkg/unit.go Normal file
View File

@@ -0,0 +1,26 @@
package pkg
import (
"context"
"log/slog"
)
type Unit struct {
*slog.Logger `json:"-" yaml:"-"`
context.Context `json:"-" yaml:"-"`
context.CancelCauseFunc `json:"-" yaml:"-"`
}
func (unit *Unit) IsStopped() bool {
select {
case <-unit.Done():
return true
default:
}
return false
}
func (unit *Unit) Stop(err error) {
unit.Info("stop", "reason", err.Error())
unit.CancelCauseFunc(err)
}

View File

@@ -3,12 +3,12 @@ package m7s
import (
"context"
"crypto/tls"
"log/slog"
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"slices"
"strings"
"sync"
@@ -84,11 +84,6 @@ type IPlugin interface {
OnInit()
OnEvent(any)
}
type IPublishPlugin interface {
OnStopPublish(*Publisher, error)
}
type ITCPPlugin interface {
OnTCPConnect(*net.TCPConn)
}
@@ -132,17 +127,15 @@ func sendPromiseToServer[T any](server *Server, value T) (err error) {
}
type Plugin struct {
Disabled bool
Meta *PluginMeta
context.Context `json:"-" yaml:"-"`
context.CancelCauseFunc `json:"-" yaml:"-"`
eventChan chan any
config config.Common
Unit
Disabled bool
Meta *PluginMeta
eventChan chan any
config config.Common
config.Config
Publishers []*Publisher
*slog.Logger `json:"-" yaml:"-"`
handler IPlugin
server *Server
Publishers []*Publisher
handler IPlugin
server *Server
sync.RWMutex
}
@@ -230,18 +223,44 @@ func (p *Plugin) Start() {
go tcpConf.Listen(l, tcphandler.OnTCPConnect)
}
}
select {
case event := <-p.eventChan:
p.handler.OnEvent(event)
case <-p.Done():
return
for {
select {
case event := <-p.eventChan:
// switch event.(type) {
// case *Subscriber:
// }
p.handler.OnEvent(event)
case <-p.Done():
return
default:
for i := 0; i < len(p.Publishers); i++ {
publisher := p.Publishers[i]
select {
case <-publisher.Done():
if publisher.Closer != nil {
publisher.Closer.Close()
}
p.Publishers = slices.Delete(p.Publishers, i, i+1)
i--
p.server.eventChan <- UnpublishEvent{Publisher: publisher}
case <-publisher.TimeoutTimer.C:
if err := publisher.timeout(); err != nil {
publisher.Stop(err)
}
default:
for subscriber := range publisher.Subscribers {
select {
case <-subscriber.Done():
subscriber.Publisher.RemoveSubscriber(subscriber)
default:
}
}
}
}
}
}
}
func (p *Plugin) Stop(reason error) {
p.CancelCauseFunc(reason)
}
func (p *Plugin) OnEvent(event any) {
}
@@ -252,30 +271,14 @@ func (p *Plugin) OnTCPConnect(conn *net.TCPConn) {
func (p *Plugin) Publish(streamPath string, options ...any) (publisher *Publisher, err error) {
publisher = &Publisher{Publish: p.config.Publish}
ctx := p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
}
}
publisher.Init(ctx, p, streamPath)
publisher.Subscribers = make(map[*Subscriber]struct{})
publisher.TransTrack = make(map[reflect.Type]*AVTrack)
publisher.Init(p, streamPath, options...)
err = sendPromiseToServer(p.server, publisher)
return
}
func (p *Plugin) Subscribe(streamPath string, options ...any) (subscriber *Subscriber, err error) {
subscriber = &Subscriber{Subscribe: p.config.Subscribe}
ctx := p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
}
}
subscriber.Init(ctx, p, streamPath)
subscriber.Init(p, streamPath, options...)
err = sendPromiseToServer(p.server, subscriber)
return
}

View File

@@ -19,8 +19,4 @@ func (p *DemoPlugin) OnInit() {
// })
}
func (p *DemoPlugin) OnStopPublish(puber *m7s.Publisher, err error) {
}
var _ = m7s.InstallPlugin[*DemoPlugin]()

View File

@@ -1,6 +1,7 @@
package rtmp
import (
"context"
"io"
"net"
@@ -18,10 +19,6 @@ func (p *RTMPPlugin) OnInit() {
}
func (p *RTMPPlugin) OnStopPublish(puber *m7s.Publisher, err error) {
}
var _ = m7s.InstallPlugin[*RTMPPlugin](m7s.DefaultYaml(`tcp:
listenaddr: :1935`))
@@ -32,18 +29,13 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
receivers := make(map[uint32]*RTMPReceiver)
var err error
logger.Info("conn")
defer func() {
p.Info("conn close")
for _, sender := range senders {
sender.Stop(err)
}
for _, receiver := range receivers {
receiver.Stop(err)
}
}()
nc := NewNetConnection(conn)
// ctx, cancel := context.WithCancel(p)
// defer cancel()
nc.Logger = logger
ctx, cancel := context.WithCancelCause(p)
defer func() {
logger.Info("conn close")
cancel(err)
}()
/* Handshake */
if err = nc.Handshake(); err != nil {
logger.Error("handshake", "error", err)
@@ -148,11 +140,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
StreamID: cmd.StreamId,
},
}
// receiver.SetParentCtx(ctx)
if !p.KeepAlive {
// receiver.SetIO(conn)
}
receiver.Publisher, err = p.Publish(nc.AppName + "/" + cmd.PublishingName)
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, ctx, conn)
if err != nil {
delete(receivers, cmd.StreamId)
err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
@@ -161,17 +149,17 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
receiver.Begin()
err = receiver.Response(cmd.TransactionId, NetStream_Publish_Start, Level_Status)
}
if err != nil {
logger.Error("sendMessage publish", "error", err)
return
}
case *PlayMessage:
streamPath := nc.AppName + "/" + cmd.StreamName
sender := &RTMPSender{}
sender.NetConnection = nc
sender.StreamID = cmd.StreamId
// sender.SetParentCtx(ctx)
if !p.KeepAlive {
// sender.SetIO(conn)
}
// sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID)
sender.Subscriber, err = p.Subscribe(streamPath)
sender.Subscriber, err = p.Subscribe(streamPath, ctx, conn)
if err != nil {
err = sender.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else {
@@ -182,15 +170,10 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
sender.Init()
go sender.Handle(sender.SendAudio, sender.SendVideo)
}
// if RTMPPlugin.Subscribe(streamPath, sender) != nil {
// sender.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
// } else {
// senders[sender.StreamID] = sender
// sender.Begin()
// sender.Response(cmd.TransactionId, NetStream_Play_Reset, Level_Status)
// sender.Response(cmd.TransactionId, NetStream_Play_Start, Level_Status)
// go sender.PlayRaw()
// }
if err != nil {
logger.Error("sendMessage play", "error", err)
return
}
}
case RTMP_MSG_AUDIO:
if r, ok := receivers[msg.MessageStreamID]; ok {

View File

@@ -53,7 +53,7 @@ func (av *AVSender) sendFrame(frame *RTMPData) (err error) {
chunk := net.Buffers{av.chunkHeader}
av.writeSeqNum += uint32(av.chunkHeader.Len() + r.WriteNTo(av.WriteChunkSize, &chunk))
for r.Length > 0 {
item := util.Buffer(av.byte16Pool.Get(16))
item := util.Buffer(av.byte16Pool.GetN(16))
defer av.byte16Pool.Put(item)
av.WriteTo(RTMP_CHUNK_HEAD_1, &item)
// 如果在音视频数据太大,一次发送不完,那么这里进行分割(data + Chunk Basic Header(1))

View File

@@ -50,7 +50,7 @@ type BytesPool struct {
ItemSize int
}
func (bp *BytesPool) Get(size int) []byte {
func (bp *BytesPool) GetN(size int) []byte {
if size != bp.ItemSize {
return make([]byte, size)
}
@@ -162,6 +162,7 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
if !ok {
chunk = &Chunk{}
conn.incommingChunks[ChunkStreamID] = chunk
chunk.AVData.IPool = &conn.byteChunkPool
}
if err = conn.readChunkType(&chunk.ChunkHeader, ChunkType); err != nil {
@@ -173,7 +174,7 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
if unRead := msgLen - chunk.AVData.Length; unRead < needRead {
needRead = unRead
}
mem := conn.byteChunkPool.Get(needRead)
mem := conn.byteChunkPool.GetN(needRead)
if n, err := conn.ReadFull(mem); err != nil {
conn.byteChunkPool.Put(mem)
return nil, err

View File

@@ -3,21 +3,79 @@ package m7s
import (
"reflect"
"sync"
"time"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
)
type PublisherState int
const (
PublisherStateInit PublisherState = iota
PublisherStateTrackAdded
PublisherStateSubscribed
PublisherStateWaitSubscriber
)
type UnpublishEvent struct {
*Publisher
}
type Publisher struct {
PubSubBase
sync.RWMutex
config.Publish
State PublisherState
VideoTrack *AVTrack
AudioTrack *AVTrack
DataTrack *DataTrack
TransTrack map[reflect.Type]*AVTrack
Subscribers map[*Subscriber]struct{}
GOP int
sync.RWMutex
}
func (p *Publisher) timeout() (err error) {
switch p.State {
case PublisherStateInit:
err = ErrPublishTimeout
case PublisherStateTrackAdded:
if p.Publish.IdleTimeout > 0 {
err = ErrPublishIdleTimeout
}
case PublisherStateSubscribed:
case PublisherStateWaitSubscriber:
if p.Publish.DelayCloseTimeout > 0 {
err = ErrPublishDelayCloseTimeout
}
}
return
}
func (p *Publisher) checkTimeout() (err error) {
if p.VideoTrack != nil && !p.VideoTrack.LastValue.WriteTime.IsZero() && time.Since(p.VideoTrack.LastValue.WriteTime) > p.PublishTimeout {
err = ErrPublishTimeout
}
if p.AudioTrack != nil && !p.AudioTrack.LastValue.WriteTime.IsZero() && time.Since(p.AudioTrack.LastValue.WriteTime) > p.PublishTimeout {
err = ErrPublishTimeout
}
return
}
func (p *Publisher) RemoveSubscriber(subscriber *Subscriber) (err error) {
p.Lock()
defer p.Unlock()
delete(p.Subscribers, subscriber)
if subscriber.Closer != nil {
err = subscriber.Closer.Close()
}
if p.State == PublisherStateSubscribed && len(p.Subscribers) == 0 {
p.State = PublisherStateWaitSubscriber
if p.DelayCloseTimeout > 0 {
p.TimeoutTimer.Reset(p.DelayCloseTimeout)
}
}
return
}
func (p *Publisher) AddSubscriber(subscriber *Subscriber) (err error) {
@@ -25,6 +83,11 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) (err error) {
defer p.Unlock()
p.Subscribers[subscriber] = struct{}{}
subscriber.Publisher = p
switch p.State {
case PublisherStateTrackAdded, PublisherStateWaitSubscriber:
p.State = PublisherStateSubscribed
p.TimeoutTimer.Reset(p.PublishTimeout)
}
return
}
@@ -35,7 +98,7 @@ func (p *Publisher) writeAV(t *AVTrack, data IAVFrame) {
}
func (p *Publisher) WriteVideo(data IAVFrame) (err error) {
if !p.PubVideo {
if !p.PubVideo || p.IsStopped() {
return
}
t := p.VideoTrack
@@ -46,6 +109,11 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) {
p.Lock()
p.VideoTrack = t
p.TransTrack[reflect.TypeOf(data)] = t
if len(p.Subscribers) > 0 {
p.State = PublisherStateSubscribed
} else {
p.State = PublisherStateTrackAdded
}
p.Unlock()
}
if t.ICodecCtx == nil {
@@ -75,7 +143,7 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) {
}
func (p *Publisher) WriteAudio(data IAVFrame) (err error) {
if !p.PubAudio {
if !p.PubAudio || p.IsStopped() {
return
}
t := p.AudioTrack

View File

@@ -5,7 +5,7 @@ import (
"log/slog"
"os"
"path/filepath"
"slices"
"reflect"
"strings"
"sync/atomic"
"time"
@@ -31,7 +31,9 @@ type Server struct {
StartTime time.Time
Plugins []*Plugin
Publishers map[string]*Publisher
Waiting map[string]*Subscriber
Waiting map[string][]*Subscriber
pidG int
sidG int
}
var DefaultServer = NewServer()
@@ -39,7 +41,7 @@ var DefaultServer = NewServer()
func NewServer() *Server {
return &Server{
Publishers: make(map[string]*Publisher),
Waiting: make(map[string]*Subscriber),
Waiting: make(map[string][]*Subscriber),
}
}
@@ -48,7 +50,7 @@ func Run(ctx context.Context, conf any) error {
}
func (s *Server) Run(ctx context.Context, conf any) (err error) {
s.Logger = slog.With("server", serverIndexG.Add(1))
s.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})).With("server", serverIndexG.Add(1))
s.Context, s.CancelCauseFunc = context.WithCancelCause(ctx)
s.config.HTTP.ListenAddrTLS = ":8443"
s.config.HTTP.ListenAddr = ":8080"
@@ -94,12 +96,33 @@ func (s *Server) Run(ctx context.Context, conf any) (err error) {
pulse.Stop()
return
case <-pulse.C:
for _, publisher := range s.Publishers {
publisher.checkTimeout()
}
for subscriber := range s.Waiting {
for _, sub := range s.Waiting[subscriber] {
select {
case <-sub.TimeoutTimer.C:
sub.Stop(ErrSubscribeTimeout)
default:
}
}
}
case event := <-s.eventChan:
switch v := event.(type) {
case *util.Promise[*Publisher]:
v.Fulfill(s.OnPublish(v.Value))
event = v.Value
case *util.Promise[*Subscriber]:
v.Fulfill(s.OnSubscribe(v.Value))
if !s.EnableSubEvent {
continue
}
event = v.Value
case UnpublishEvent:
s.onUnpublish(v.Publisher)
case UnsubscribeEvent:
}
for _, plugin := range s.Plugins {
if plugin.Disabled {
@@ -117,40 +140,59 @@ func (s *Server) initPlugins(cg map[string]map[string]any) {
}
}
func (s *Server) onUnpublish(publisher *Publisher) {
delete(s.Publishers, publisher.StreamPath)
for subscriber := range publisher.Subscribers {
s.Waiting[publisher.StreamPath] = append(s.Waiting[publisher.StreamPath], subscriber)
subscriber.TimeoutTimer.Reset(publisher.WaitCloseTimeout)
}
}
func (s *Server) OnPublish(publisher *Publisher) error {
if oldPublisher, ok := s.Publishers[publisher.StreamPath]; ok {
if publisher.KickExist {
oldPlugin := oldPublisher.Plugin
publisher.Warn("kick")
oldPlugin.handler.(IPublishPlugin).OnStopPublish(oldPublisher, ErrKick)
if index := slices.Index(oldPlugin.Publishers, oldPublisher); index != -1 {
oldPlugin.Publishers = slices.Delete(oldPlugin.Publishers, index, index+1)
}
oldPublisher.Stop(ErrKick)
publisher.VideoTrack = oldPublisher.VideoTrack
publisher.AudioTrack = oldPublisher.AudioTrack
publisher.DataTrack = oldPublisher.DataTrack
publisher.Subscribers = oldPublisher.Subscribers
publisher.TransTrack = oldPublisher.TransTrack
oldPublisher.Subscribers = nil
} else {
return ErrStreamExist
}
} else {
s.Publishers[publisher.StreamPath] = publisher
publisher.Plugin.Info("publish", "streamPath", publisher.StreamPath)
publisher.Plugin.Publishers = append(publisher.Plugin.Publishers, publisher)
publisher.Subscribers = make(map[*Subscriber]struct{})
publisher.TransTrack = make(map[reflect.Type]*AVTrack)
}
if subscriber, ok := s.Waiting[publisher.StreamPath]; ok {
s.Publishers[publisher.StreamPath] = publisher
s.pidG++
p := publisher.Plugin
publisher.ID = s.pidG
publisher.Logger = p.With("streamPath", publisher.StreamPath, "puber", publisher.ID)
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
p.Publishers = append(p.Publishers, publisher)
publisher.Info("publish")
if subscribers, ok := s.Waiting[publisher.StreamPath]; ok {
for _, subscriber := range subscribers {
publisher.AddSubscriber(subscriber)
}
delete(s.Waiting, publisher.StreamPath)
publisher.AddSubscriber(subscriber)
}
return nil
}
func (s *Server) OnSubscribe(subscriber *Subscriber) error {
s.sidG++
subscriber.ID = s.sidG
subscriber.Logger = subscriber.Plugin.With("streamPath", subscriber.StreamPath, "suber", subscriber.ID)
subscriber.TimeoutTimer = time.NewTimer(subscriber.Plugin.config.Subscribe.WaitTimeout)
subscriber.Info("subscribe")
if publisher, ok := s.Publishers[subscriber.StreamPath]; ok {
return publisher.AddSubscriber(subscriber)
} else {
s.Waiting[subscriber.StreamPath] = subscriber
s.Waiting[subscriber.StreamPath] = append(s.Waiting[subscriber.StreamPath], subscriber)
}
return nil
}

View File

@@ -2,7 +2,7 @@ package m7s
import (
"context"
"log/slog"
"io"
"net/url"
"reflect"
"strconv"
@@ -14,31 +14,38 @@ import (
)
type PubSubBase struct {
ID string
*slog.Logger `json:"-" yaml:"-"`
context.Context `json:"-" yaml:"-"`
context.CancelCauseFunc `json:"-" yaml:"-"`
Plugin *Plugin
StartTime time.Time
StreamPath string
Args url.Values
Unit
ID int
Plugin *Plugin
StartTime time.Time
StreamPath string
Args url.Values
TimeoutTimer *time.Timer
io.Closer
}
func (ps *PubSubBase) Stop(err error) {
ps.Error(err.Error())
ps.CancelCauseFunc(err)
}
func (ps *PubSubBase) Init(ctx context.Context, p *Plugin, streamPath string) {
func (ps *PubSubBase) Init(p *Plugin, streamPath string, options ...any) {
ps.Plugin = p
ps.Context, ps.CancelCauseFunc = context.WithCancelCause(p.Context)
ctx := p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
case io.Closer:
ps.Closer = v
}
}
ps.Context, ps.CancelCauseFunc = context.WithCancelCause(ctx)
if u, err := url.Parse(streamPath); err == nil {
ps.StreamPath, ps.Args = u.Path, u.Query()
}
ps.Logger = p.With("streamPath", ps.StreamPath)
ps.StartTime = time.Now()
}
type UnsubscribeEvent struct {
*Subscriber
}
type Subscriber struct {
PubSubBase
config.Subscribe
@@ -50,32 +57,41 @@ type ISubscriberHandler[T IAVFrame] func(data T)
func (s *Subscriber) Handle(audioHandler, videoHandler any) {
var ar, vr *AVRingReader
var ah, vh reflect.Value
if audioHandler != nil {
a1 := reflect.TypeOf(audioHandler).In(0)
at := s.Publisher.GetAudioTrack(a1)
if at != nil {
ar = NewAVRingReader(at)
ar.Logger = s.Logger.With("reader", a1.Name())
ah = reflect.ValueOf(audioHandler)
}
}
if videoHandler != nil {
v1 := reflect.TypeOf(videoHandler).In(0)
vt := s.Publisher.GetVideoTrack(v1)
if vt != nil {
vr = NewAVRingReader(vt)
vr.Logger = s.Logger.With("reader", v1.Name())
vh = reflect.ValueOf(videoHandler)
}
}
var a1, v1 reflect.Type
var initState = 0
var subMode = s.SubMode //订阅模式
if s.Args.Has(s.SubModeArgName) {
subMode, _ = strconv.Atoi(s.Args.Get(s.SubModeArgName))
}
var audioFrame, videoFrame, lastSentAF, lastSentVF *AVFrame
if audioHandler != nil {
a1 = reflect.TypeOf(audioHandler).In(0)
}
if videoHandler != nil {
v1 = reflect.TypeOf(videoHandler).In(0)
}
createAudioReader := func() {
if s.Publisher == nil || a1 == nil {
return
}
if at := s.Publisher.GetAudioTrack(a1); at != nil {
ar = NewAVRingReader(at)
ar.Logger = s.Logger.With("reader", a1.Name())
ah = reflect.ValueOf(audioHandler)
}
}
createVideoReader := func() {
if s.Publisher == nil || v1 == nil {
return
}
if vt := s.Publisher.GetVideoTrack(v1); vt != nil {
vr = NewAVRingReader(vt)
vr.Logger = s.Logger.With("reader", v1.Name())
vh = reflect.ValueOf(videoHandler)
}
}
createAudioReader()
createVideoReader()
defer func() {
if lastSentVF != nil {
lastSentVF.ReaderLeave()
@@ -83,7 +99,6 @@ func (s *Subscriber) Handle(audioHandler, videoHandler any) {
if lastSentAF != nil {
lastSentAF.ReaderLeave()
}
s.Info("subscriber stopped", "reason", context.Cause(s.Context))
}()
sendAudioFrame := func() {
lastSentAF = audioFrame
@@ -131,6 +146,8 @@ func (s *Subscriber) Handle(audioHandler, videoHandler any) {
sendVideoFrame()
}
}
} else {
createVideoReader()
}
// 正常模式下或者纯音频模式下,音频开始播放
if ar != nil {
@@ -176,6 +193,8 @@ func (s *Subscriber) Handle(audioHandler, videoHandler any) {
s.Debug("skip audio", "frame.AbsTime", audioFrame.Timestamp, "s.AudioReader.SkipTs", ar.SkipTs)
}
}
} else {
createAudioReader()
}
}
}