refactor: marotask

This commit is contained in:
langhuihui
2024-08-10 08:25:27 +08:00
parent 14a5184477
commit da0066fcc8
21 changed files with 583 additions and 577 deletions

5
.gitignore vendored
View File

@@ -3,3 +3,8 @@
logs logs
fatal fatal
.idea .idea
victoria-logs-data
dump
record
bin
.DS_Store

30
api.go
View File

@@ -90,7 +90,7 @@ func (s *Server) api_Stream_AnnexB_(rw http.ResponseWriter, r *http.Request) {
} }
func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err error) { func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err error) {
tmp, _ := json.Marshal(pub.MetaData) tmp, _ := json.Marshal(pub.Description)
res = &pb.StreamInfoResponse{ res = &pb.StreamInfoResponse{
Meta: string(tmp), Meta: string(tmp),
Path: pub.StreamPath, Path: pub.StreamPath,
@@ -131,20 +131,21 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err
} }
func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.StreamInfoResponse, err error) { func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.StreamInfoResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok { if pub, ok := s.Streams.Get(req.StreamPath); ok {
res, err = s.getStreamInfo(pub) res, err = s.getStreamInfo(pub)
} else { } else {
err = pkg.ErrNotFound err = pkg.ErrNotFound
} }
return nil
}) })
return return
} }
func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) { func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
var subscribers []*pb.SubscriberSnapShot var subscribers []*pb.SubscriberSnapShot
for subscriber := range s.Subscribers.Range { for subscriber := range s.Subscribers.Range {
meta, _ := json.Marshal(subscriber.MetaData) meta, _ := json.Marshal(subscriber.Description)
snap := &pb.SubscriberSnapShot{ snap := &pb.SubscriberSnapShot{
Id: uint32(subscriber.ID), Id: uint32(subscriber.ID),
StartTime: timestamppb.New(subscriber.StartTime), StartTime: timestamppb.New(subscriber.StartTime),
@@ -172,11 +173,12 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest)
List: subscribers, List: subscribers,
Total: int32(s.Subscribers.Length), Total: int32(s.Subscribers.Length),
} }
return nil
}) })
return return
} }
func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasAudioTrack() { if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasAudioTrack() {
res = &pb.TrackSnapShotResponse{} res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.AudioTrack.Allocator.GetChildren() { for _, memlist := range pub.AudioTrack.Allocator.GetChildren() {
@@ -221,6 +223,7 @@ func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest)
} else { } else {
err = pkg.ErrNotFound err = pkg.ErrNotFound
} }
return nil
}) })
return return
} }
@@ -254,7 +257,7 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) {
} }
func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasVideoTrack() { if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasVideoTrack() {
res = &pb.TrackSnapShotResponse{} res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.VideoTrack.Allocator.GetChildren() { for _, memlist := range pub.VideoTrack.Allocator.GetChildren() {
@@ -299,6 +302,7 @@ func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest)
} else { } else {
err = pkg.ErrNotFound err = pkg.ErrNotFound
} }
return nil
}) })
return return
} }
@@ -320,34 +324,36 @@ func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *empt
} }
func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeRequest) (res *pb.SuccessResponse, err error) { func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeRequest) (res *pb.SuccessResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
if subscriber, ok := s.Subscribers.Get(req.Id); ok { if subscriber, ok := s.Subscribers.Get(req.Id); ok {
if pub, ok := s.Streams.Get(req.StreamPath); ok { if pub, ok := s.Streams.Get(req.StreamPath); ok {
subscriber.Publisher.RemoveSubscriber(subscriber) subscriber.Publisher.RemoveSubscriber(subscriber)
subscriber.StreamPath = req.StreamPath subscriber.StreamPath = req.StreamPath
pub.AddSubscriber(subscriber) pub.AddSubscriber(subscriber)
return return nil
} }
} }
err = pkg.ErrNotFound err = pkg.ErrNotFound
return nil
}) })
return &pb.SuccessResponse{}, err return &pb.SuccessResponse{}, err
} }
func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res *pb.SuccessResponse, err error) { func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res *pb.SuccessResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
if subscriber, ok := s.Subscribers.Get(req.Id); ok { if subscriber, ok := s.Subscribers.Get(req.Id); ok {
subscriber.Stop(errors.New("stop by api")) subscriber.Stop(errors.New("stop by api"))
} else { } else {
err = pkg.ErrNotFound err = pkg.ErrNotFound
} }
return nil
}) })
return &pb.SuccessResponse{}, err return &pb.SuccessResponse{}, err
} }
// /api/stream/list // /api/stream/list
func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *pb.StreamListResponse, err error) { func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *pb.StreamListResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
var streams []*pb.StreamInfoResponse var streams []*pb.StreamInfoResponse
for publisher := range s.Streams.Range { for publisher := range s.Streams.Range {
info, err := s.getStreamInfo(publisher) info, err := s.getStreamInfo(publisher)
@@ -357,18 +363,20 @@ func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *
streams = append(streams, info) streams = append(streams, info)
} }
res = &pb.StreamListResponse{List: streams, Total: int32(s.Streams.Length), PageNum: req.PageNum, PageSize: req.PageSize} res = &pb.StreamListResponse{List: streams, Total: int32(s.Streams.Length), PageNum: req.PageNum, PageSize: req.PageSize}
return nil
}) })
return return
} }
func (s *Server) WaitList(context.Context, *emptypb.Empty) (res *pb.StreamWaitListResponse, err error) { func (s *Server) WaitList(context.Context, *emptypb.Empty) (res *pb.StreamWaitListResponse, err error) {
s.streamTM.Call(func() { s.streamTask.Call(func(*pkg.Task) error {
res = &pb.StreamWaitListResponse{ res = &pb.StreamWaitListResponse{
List: make(map[string]int32), List: make(map[string]int32),
} }
for subs := range s.Waiting.Range { for subs := range s.Waiting.Range {
res.List[subs.StreamPath] = int32(subs.Subscribers.Length) res.List[subs.StreamPath] = int32(subs.Subscribers.Length)
} }
return nil
}) })
return return
} }

View File

@@ -2,17 +2,23 @@ package pkg
import ( import (
"context" "context"
"io" "errors"
"log/slog" "log/slog"
"m7s.live/m7s/v5/pkg/util"
"reflect" "reflect"
"slices" "slices"
"sync/atomic" "sync/atomic"
"time" "time"
"m7s.live/m7s/v5/pkg/util"
) )
const TraceLevel = slog.Level(-8) const TraceLevel = slog.Level(-8)
var (
ErrAutoStop = errors.New("auto stop")
ErrCallbackTask = errors.New("callback task")
)
type getTask interface{ GetTask() *Task } type getTask interface{ GetTask() *Task }
type TaskExecutor interface { type TaskExecutor interface {
Start() error Start() error
@@ -25,11 +31,16 @@ type TempTaskExecutor struct {
} }
func (t TempTaskExecutor) Start() error { func (t TempTaskExecutor) Start() error {
if t.StartFunc == nil {
return nil
}
return t.StartFunc() return t.StartFunc()
} }
func (t TempTaskExecutor) Dispose() { func (t TempTaskExecutor) Dispose() {
t.DisposeFunc() if t.DisposeFunc != nil {
t.DisposeFunc()
}
} }
type Task struct { type Task struct {
@@ -38,9 +49,10 @@ type Task struct {
*slog.Logger *slog.Logger
context.Context context.Context
context.CancelCauseFunc context.CancelCauseFunc
exeStack []TaskExecutor exe TaskExecutor
Description map[string]any Description map[string]any
started *util.Promise startup, shutdown *util.Promise
parent *MarcoTask
} }
func (task *Task) GetTask() *Task { func (task *Task) GetTask() *Task {
@@ -51,29 +63,12 @@ func (task *Task) GetKey() uint32 {
return task.ID return task.ID
} }
func (task *Task) Begin() (err error) {
task.StartTime = time.Now()
for _, executor := range task.exeStack {
err = executor.Start()
if err != nil {
break
}
}
task.started.Fulfill(err)
return
}
func (task *Task) dispose(reason error) {
if task.Logger != nil {
task.Debug("stop", "reason", reason)
}
for _, executor := range slices.Backward(task.exeStack) {
executor.Dispose()
}
}
func (task *Task) WaitStarted() error { func (task *Task) WaitStarted() error {
return task.started.Await() return task.startup.Await()
}
func (task *Task) WaitStopped() error {
return task.shutdown.Await()
} }
func (task *Task) Trace(msg string, fields ...any) { func (task *Task) Trace(msg string, fields ...any) {
@@ -95,88 +90,154 @@ func (task *Task) Stop(err error) {
} }
} }
func (task *Task) With(child getTask, args ...any) { func (task *Task) Init(ctx context.Context, logger *slog.Logger, executor TaskExecutor) {
child.GetTask().Init(task.Context, task.Logger.With(args...))
}
func (task *Task) Init(ctx context.Context, logger *slog.Logger, executor ...TaskExecutor) {
task.Logger = logger task.Logger = logger
task.exeStack = executor task.exe = executor
task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx) task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx)
task.started = util.NewPromise(task.Context) task.startup = util.NewPromise(task.Context)
task.shutdown = util.NewPromise(task.Context)
} }
type CallBackTaskExecutor func() type CallBack func(*Task) error
func (call CallBackTaskExecutor) Start() error { // MarcoTask include sub tasks
call() type MarcoTask struct {
return io.EOF Task
KeepAlive bool
exe TaskExecutor
addSub chan *Task
subTasks []*Task
extraCases []reflect.SelectCase
extraCallbacks []reflect.Value
idG atomic.Uint32
} }
func (call CallBackTaskExecutor) Dispose() { func (mt *MarcoTask) Init(ctx context.Context, logger *slog.Logger, executor TaskExecutor, extra ...any) {
// nothing to do, never called mt.Task.Init(ctx, logger, mt)
} mt.exe = executor
for i := range len(extra) / 2 {
type TaskManager struct { mt.extraCases = append(mt.extraCases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])})
shutdown *util.Promise mt.extraCallbacks = append(mt.extraCallbacks, reflect.ValueOf(extra[i*2+1]))
stopReason error
start chan *Task
Tasks []*Task
idG atomic.Uint32
}
func NewTaskManager() *TaskManager {
return &TaskManager{
shutdown: util.NewPromise(context.TODO()),
start: make(chan *Task, 10),
} }
} }
func StartTaskManager(extra ...any) *TaskManager { func (mt *MarcoTask) InitKeepAlive(ctx context.Context, logger *slog.Logger, executor TaskExecutor, extra ...any) {
tm := NewTaskManager() mt.Init(ctx, logger, executor, extra...)
go tm.Run(extra...) mt.KeepAlive = true
return tm
} }
func (t *TaskManager) Add(getTask getTask) { func (mt *MarcoTask) Start() error {
if mt.exe != nil {
return mt.exe.Start()
}
return nil
}
func (mt *MarcoTask) AddTasks(getTasks ...getTask) {
for _, getTask := range getTasks {
mt.AddTask(getTask)
}
}
func (mt *MarcoTask) AddTask(getTask getTask) {
if mt.IsStopped() {
getTask.GetTask().startup.Reject(mt.StopReason())
return
}
if mt.addSub == nil {
mt.addSub = make(chan *Task, 10)
go mt.run()
}
task := getTask.GetTask() task := getTask.GetTask()
if v, ok := getTask.(TaskExecutor); len(task.exeStack) == 0 && ok { if task.ID == 0 {
task.exeStack = append(task.exeStack, v) task.ID = mt.GetID()
} }
t.start <- task if task.parent == nil {
task.parent = mt
if v, ok := getTask.(TaskExecutor); ok {
task.exe = v
}
}
mt.addSub <- task
} }
func (t *TaskManager) Call(callback CallBackTaskExecutor) error { func (mt *MarcoTask) Call(callback CallBack) {
task := mt.AddCall(callback, nil)
_ = task.WaitStarted()
}
func (mt *MarcoTask) AddCall(start CallBack, dispose func(*Task)) *Task {
var tmpTask Task var tmpTask Task
tmpTask.Init(context.TODO(), nil, callback) var tmpExe TempTaskExecutor
return t.Start(&tmpTask) if start != nil {
tmpExe.StartFunc = func() error {
err := start(&tmpTask)
if err == nil && dispose == nil {
err = ErrCallbackTask
}
return err
}
}
if dispose != nil {
tmpExe.DisposeFunc = func() {
dispose(&tmpTask)
}
}
tmpTask.Init(mt.Context, nil, tmpExe)
mt.AddTask(&tmpTask)
return &tmpTask
} }
func (t *TaskManager) Start(getTask getTask) error { func (mt *MarcoTask) WaitTaskAdded(getTask getTask) error {
t.Add(getTask) mt.AddTask(getTask)
return getTask.GetTask().WaitStarted() return getTask.GetTask().WaitStarted()
} }
func (t *TaskManager) GetID() uint32 { func (mt *MarcoTask) GetID() uint32 {
return t.idG.Add(1) return mt.idG.Add(1)
} }
// Run task Start and Dispose in this goroutine func (mt *MarcoTask) startSubTask(task *Task) (err error) {
func (t *TaskManager) Run(extra ...any) { if task.startup.IsPending() {
cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.start)}} task.StartTime = time.Now()
extraLen := len(extra) / 2 err = task.exe.Start()
var callbacks []reflect.Value if task.Logger != nil {
for i := range extraLen { task.Debug("start")
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])})
callbacks = append(callbacks, reflect.ValueOf(extra[i*2+1]))
}
defer func() {
for _, task := range t.Tasks {
task.dispose(t.stopReason)
} }
t.Tasks = nil task.startup.Fulfill(err)
cases = nil }
t.shutdown.Fulfill(t.stopReason) return
}
func (mt *MarcoTask) disposeSubTask(task *Task, reason error) {
if task.parent != mt {
return
}
if task.Logger != nil {
task.Debug("dispose", "reason", reason)
}
task.exe.Dispose()
if m, ok := task.exe.(*MarcoTask); ok {
m.WaitStopped()
} else {
task.shutdown.Fulfill(reason)
}
}
func (mt *MarcoTask) run() {
extraLen := len(mt.extraCases)
cases := append([]reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(mt.addSub)}}, mt.extraCases...)
defer func() {
stopReason := mt.StopReason()
for _, task := range mt.subTasks {
task.Stop(stopReason)
mt.disposeSubTask(task, stopReason)
}
mt.subTasks = nil
mt.addSub = nil
mt.extraCases = nil
mt.extraCallbacks = nil
mt.shutdown.Fulfill(stopReason)
}() }()
for { for {
if chosen, rev, ok := reflect.Select(cases); chosen == 0 { if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
@@ -184,11 +245,8 @@ func (t *TaskManager) Run(extra ...any) {
return return
} }
task := rev.Interface().(*Task) task := rev.Interface().(*Task)
if err := task.Begin(); err == nil { if err := mt.startSubTask(task); err == nil {
if task.Logger != nil { mt.subTasks = append(mt.subTasks, task)
task.Debug("start")
}
t.Tasks = append(t.Tasks, task)
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())}) cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
} else { } else {
if task.Logger != nil { if task.Logger != nil {
@@ -197,53 +255,29 @@ func (t *TaskManager) Run(extra ...any) {
task.Stop(err) task.Stop(err)
} }
} else if chosen <= extraLen { } else if chosen <= extraLen {
callbacks[chosen-1].Call([]reflect.Value{rev}) mt.extraCallbacks[chosen-1].Call([]reflect.Value{rev})
} else { } else {
taskIndex := chosen - extraLen - 1 taskIndex := chosen - extraLen - 1
task := t.Tasks[taskIndex] task := mt.subTasks[taskIndex]
task.dispose(task.StopReason()) mt.disposeSubTask(task, task.StopReason())
t.Tasks = slices.Delete(t.Tasks, taskIndex, taskIndex+1) mt.subTasks = slices.Delete(mt.subTasks, taskIndex, taskIndex+1)
cases = slices.Delete(cases, chosen, chosen+1) cases = slices.Delete(cases, chosen, chosen+1)
if !mt.KeepAlive && len(mt.subTasks) == 0 {
mt.Stop(ErrAutoStop)
}
} }
} }
} }
// Run task Start and Dispose in another goroutine
//func (t *TaskManager) Run() {
// cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.Start)}}
// defer func() {
// cases = slices.Delete(cases, 0, 1)
// for len(cases) > 0 {
// chosen, _, _ := reflect.Select(cases)
// t.Done <- t.Tasks[chosen]
// t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// close(t.Done)
// }()
// for {
// if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
// if !ok {
// return
// }
// task := rev.Interface().(*Task)
// t.Tasks = append(t.Tasks, task)
// cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
// } else {
// t.Done <- t.Tasks[chosen-1]
// t.Tasks = slices.Delete(t.Tasks, chosen-1, chosen)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// }
//}
// ShutDown wait all task dispose // ShutDown wait all task dispose
func (t *TaskManager) ShutDown(err error) { func (mt *MarcoTask) ShutDown(err error) {
t.Stop(err) mt.Stop(err)
_ = t.shutdown.Await() _ = mt.shutdown.Await()
} }
func (t *TaskManager) Stop(err error) { func (mt *MarcoTask) Dispose() {
t.stopReason = err if mt.exe != nil {
close(t.start) mt.exe.Dispose()
}
close(mt.addSub)
} }

View File

@@ -2,16 +2,7 @@ package m7s
import ( import (
"context" "context"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
myip "github.com/husanpao/ip"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
"log/slog" "log/slog"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/db"
"m7s.live/m7s/v5/pkg/util"
"net" "net"
"net/http" "net/http"
"os" "os"
@@ -19,6 +10,16 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
myip "github.com/husanpao/ip"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/db"
"m7s.live/m7s/v5/pkg/util"
) )
type DefaultYaml string type DefaultYaml string
@@ -51,7 +52,7 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
p.handler = instance p.handler = instance
p.Meta = plugin p.Meta = plugin
p.Server = s p.Server = s
p.Task.Init(s.Context, s.Logger.With("plugin", plugin.Name), instance) p.MarcoTask.Init(s.Context, s.Logger.With("plugin", plugin.Name), instance)
upperName := strings.ToUpper(plugin.Name) upperName := strings.ToUpper(plugin.Name)
if os.Getenv(upperName+"_ENABLE") == "false" { if os.Getenv(upperName+"_ENABLE") == "false" {
p.Disabled = true p.Disabled = true
@@ -180,7 +181,7 @@ func InstallPlugin[C iPlugin](options ...any) error {
} }
type Plugin struct { type Plugin struct {
Task MarcoTask
Disabled bool Disabled bool
Meta *PluginMeta Meta *PluginMeta
config config.Common config config.Common
@@ -352,7 +353,7 @@ func (p *Plugin) Publish(streamPath string, options ...any) (publisher *Publishe
} }
} }
} }
err = p.Server.streamTM.Start(publisher) err = p.Server.streamTask.WaitTaskAdded(publisher)
return return
} }
@@ -370,48 +371,20 @@ func (p *Plugin) Subscribe(streamPath string, options ...any) (subscriber *Subsc
} }
} }
} }
err = p.Server.streamTM.Start(subscriber) err = p.Server.streamTask.WaitTaskAdded(subscriber)
err = subscriber.Publisher.WaitTrack() err = subscriber.Publisher.WaitTrack()
return return
} }
func (p *Plugin) pull(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
ctx = createPullContext(p, streamPath, url, options...)
err = p.Server.pullTM.Start(ctx)
return
}
func (p *Plugin) Pull(streamPath string, url string, options ...any) (ctx *PullContext, err error) { func (p *Plugin) Pull(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
if ctx, err = p.pull(streamPath, url, options...); err == nil && p.Meta.Puller != nil { ctx = createPullContext(p, streamPath, url, options...)
go p.Meta.Puller(ctx) err = p.Server.pullTask.WaitTaskAdded(ctx);
}
return
}
func (p *Plugin) PullBlock(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
if ctx, err = p.pull(streamPath, url, options...); err == nil && p.Meta.Puller != nil {
err = p.Meta.Puller(ctx)
}
return
}
func (p *Plugin) push(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
ctx = createPushContext(p, streamPath, url, options...)
err = p.Server.pushTM.Start(ctx)
return return
} }
func (p *Plugin) Push(streamPath string, url string, options ...any) (ctx *PushContext, err error) { func (p *Plugin) Push(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
if ctx, err = p.push(streamPath, url, options...); err == nil && p.Meta.Pusher != nil { ctx = createPushContext(p, streamPath, url, options...)
go p.Meta.Pusher(ctx) err = p.Server.pushTask.WaitTaskAdded(ctx)
}
return
}
func (p *Plugin) PushBlock(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
if ctx, err = p.push(streamPath, url, options...); err == nil && p.Meta.Pusher != nil {
err = p.Meta.Pusher(ctx)
}
return return
} }
@@ -428,7 +401,7 @@ func (p *Plugin) record(streamPath string, filePath string, options ...any) (ctx
if err != nil { if err != nil {
return return
} }
err = p.Server.recordTM.Start(ctx) err = p.Server.recordTask.WaitTaskAdded(ctx)
return return
} }
@@ -501,7 +474,7 @@ func (p *Plugin) AddLogHandler(handler slog.Handler) {
} }
func (p *Plugin) SaveConfig() (err error) { func (p *Plugin) SaveConfig() (err error) {
p.Server.pluginTM.Call(func() { p.Server.Call(func(*Task) (err error) {
if p.Modify == nil { if p.Modify == nil {
os.Remove(p.settingPath()) os.Remove(p.settingPath())
return return
@@ -512,6 +485,7 @@ func (p *Plugin) SaveConfig() (err error) {
} }
defer file.Close() defer file.Close()
err = yaml.NewEncoder(file).Encode(p.Modify) err = yaml.NewEncoder(file).Encode(p.Modify)
return
}) })
if err == nil { if err == nil {
p.Info("config saved") p.Info("config saved")

View File

@@ -22,7 +22,7 @@ const defaultConfig m7s.DefaultYaml = `publish:
func (p *FLVPlugin) OnInit() error { func (p *FLVPlugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart { for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url, PullFLV) p.Pull(streamPath, url, PullFLV)
} }
return nil return nil
} }

View File

@@ -1,20 +1,21 @@
package plugin_mp4 package plugin_mp4
import ( import (
"github.com/Eyevinn/mp4ff/mp4"
"io" "io"
"m7s.live/m7s/v5"
v5 "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/codec"
"m7s.live/m7s/v5/pkg/util"
pkg "m7s.live/m7s/v5/plugin/mp4/pkg"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
"maps" "maps"
"net" "net"
"net/http" "net/http"
"slices" "slices"
"strings" "strings"
"time" "time"
"github.com/Eyevinn/mp4ff/mp4"
"m7s.live/m7s/v5"
v5 "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/codec"
"m7s.live/m7s/v5/pkg/util"
pkg "m7s.live/m7s/v5/plugin/mp4/pkg"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
) )
type MediaContext struct { type MediaContext struct {
@@ -75,7 +76,7 @@ const defaultConfig m7s.DefaultYaml = `publish:
func (p *MP4Plugin) OnInit() error { func (p *MP4Plugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart { for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url) p.Pull(streamPath, url)
} }
return nil return nil
} }

View File

@@ -5,6 +5,8 @@ import (
"io" "io"
"os" "os"
"testing" "testing"
"m7s.live/m7s/v5/pkg/util"
) )
func TestCreateMovDemuxer(t *testing.T) { func TestCreateMovDemuxer(t *testing.T) {
@@ -27,7 +29,7 @@ func TestCreateMovDemuxer(t *testing.T) {
mp4info := demuxer.GetMp4Info() mp4info := demuxer.GetMp4Info()
fmt.Printf("%+v\n", mp4info) fmt.Printf("%+v\n", mp4info)
for { for {
pkg, err := demuxer.ReadPacket() pkg, err := demuxer.ReadPacket(util.NewScalableMemoryAllocator(1 << 10))
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
break break

View File

@@ -1,224 +1,211 @@
package box package box
import ( // func TestCreateMp4Reader(t *testing.T) {
"encoding/binary" // f, err := os.Open("jellyfish-3-mbps-hd.h264.mp4")
"fmt" // if err != nil {
"io" // fmt.Println(err)
"io/ioutil" // return
"os" // }
"strconv" // defer f.Close()
"testing" // for err == nil {
// nn := int64(0)
// size := make([]byte, 4)
// _, err = io.ReadFull(f, size)
// if err != nil {
// break
// }
// nn += 4
// boxtype := make([]byte, 4)
// _, err = io.ReadFull(f, boxtype)
// if err != nil {
// break
// }
// nn += 4
// var isize uint64 = uint64(binary.BigEndian.Uint32(size))
// if isize == 1 {
// size := make([]byte, 8)
// _, err = io.ReadFull(f, size)
// if err != nil {
// break
// }
// isize = binary.BigEndian.Uint64(size)
// nn += 8
// }
// fmt.Printf("Read Box(%s) size:%d\n", boxtype, isize)
// f.Seek(int64(isize)-nn, 1)
// }
// }
"github.com/yapingcat/gomedia/go-codec" // func TestCreateMp4Muxer(t *testing.T) {
"github.com/yapingcat/gomedia/go-mpeg2"
)
func TestCreateMp4Reader(t *testing.T) { // f, err := os.Open("jellyfish-3-mbps-hd.h265")
f, err := os.Open("jellyfish-3-mbps-hd.h264.mp4") // if err != nil {
if err != nil { // fmt.Println(err)
fmt.Println(err) // return
return // }
} // defer f.Close()
defer f.Close()
for err == nil {
nn := int64(0)
size := make([]byte, 4)
_, err = io.ReadFull(f, size)
if err != nil {
break
}
nn += 4
boxtype := make([]byte, 4)
_, err = io.ReadFull(f, boxtype)
if err != nil {
break
}
nn += 4
var isize uint64 = uint64(binary.BigEndian.Uint32(size))
if isize == 1 {
size := make([]byte, 8)
_, err = io.ReadFull(f, size)
if err != nil {
break
}
isize = binary.BigEndian.Uint64(size)
nn += 8
}
fmt.Printf("Read Box(%s) size:%d\n", boxtype, isize)
f.Seek(int64(isize)-nn, 1)
}
}
func TestCreateMp4Muxer(t *testing.T) { // mp4filename := "jellyfish-3-mbps-hd.h265.mp4"
// mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer mp4file.Close()
f, err := os.Open("jellyfish-3-mbps-hd.h265") // buf, _ := ioutil.ReadAll(f)
if err != nil { // pts := uint64(0)
fmt.Println(err) // dts := uint64(0)
return // ii := [3]uint64{33, 33, 34}
} // idx := 0
defer f.Close()
mp4filename := "jellyfish-3-mbps-hd.h265.mp4" // type args struct {
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666) // wh io.WriteSeeker
if err != nil { // }
fmt.Println(err) // tests := []struct {
return // name string
} // args args
defer mp4file.Close() // want *Movmuxer
// }{
// {name: "muxer h264", args: args{wh: mp4file}, want: nil},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// muxer, err := CreateMp4Muxer(tt.args.wh)
// if err != nil {
// fmt.Println(err)
// return
// }
// tid := muxer.AddVideoTrack(MP4_CODEC_H265)
// cache := make([]byte, 0)
// codec.SplitFrameWithStartCode(buf, func(nalu []byte) bool {
// ntype := codec.H265NaluType(nalu)
// if !codec.IsH265VCLNaluType(ntype) {
// cache = append(cache, nalu...)
// return true
// }
// if len(cache) > 0 {
// cache = append(cache, nalu...)
// muxer.Write(tid, cache, pts, dts)
// cache = cache[:0]
// } else {
// muxer.Write(tid, nalu, pts, dts)
// }
// pts += ii[idx]
// dts += ii[idx]
// idx++
// idx = idx % 3
// return true
// })
// fmt.Printf("last dts %d\n", dts)
// muxer.WriteTrailer()
// })
// }
// }
buf, _ := ioutil.ReadAll(f) // func TestMuxAAC(t *testing.T) {
pts := uint64(0) // f, err := os.Open("test.aac")
dts := uint64(0) // if err != nil {
ii := [3]uint64{33, 33, 34} // fmt.Println(err)
idx := 0 // return
// }
// defer f.Close()
type args struct { // mp4filename := "aac.mp4"
wh io.WriteSeeker // mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
} // if err != nil {
tests := []struct { // fmt.Println(err)
name string // return
args args // }
want *Movmuxer // defer mp4file.Close()
}{
{name: "muxer h264", args: args{wh: mp4file}, want: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
muxer, err := CreateMp4Muxer(tt.args.wh)
if err != nil {
fmt.Println(err)
return
}
tid := muxer.AddVideoTrack(MP4_CODEC_H265)
cache := make([]byte, 0)
codec.SplitFrameWithStartCode(buf, func(nalu []byte) bool {
ntype := codec.H265NaluType(nalu)
if !codec.IsH265VCLNaluType(ntype) {
cache = append(cache, nalu...)
return true
}
if len(cache) > 0 {
cache = append(cache, nalu...)
muxer.Write(tid, cache, pts, dts)
cache = cache[:0]
} else {
muxer.Write(tid, nalu, pts, dts)
}
pts += ii[idx]
dts += ii[idx]
idx++
idx = idx % 3
return true
})
fmt.Printf("last dts %d\n", dts)
muxer.WriteTrailer()
})
}
}
func TestMuxAAC(t *testing.T) { // aac, _ := ioutil.ReadAll(f)
f, err := os.Open("test.aac") // var pts uint64 = 0
if err != nil { // //var dts uint64 = 0
fmt.Println(err) // //var i int = 0
return // samples := uint64(0)
} // muxer, err := CreateMp4Muxer(mp4file)
defer f.Close() // if err != nil {
// fmt.Println(err)
// return
// }
mp4filename := "aac.mp4" // tid := muxer.AddAudioTrack(MP4_CODEC_AAC)
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666) // codec.SplitAACFrame(aac, func(aac []byte) {
if err != nil { // samples += 1024
fmt.Println(err) // pts = samples * 1000 / 44100
return // // if i < 3 {
} // // pts += 23
defer mp4file.Close() // // dts += 23
// // i++
// // } else {
// // pts += 24
// // dts += 24
// // i = 0
// // }
// muxer.Write(tid, aac, pts, pts)
// //fmt.Println(pts)
// })
// muxer.WriteTrailer()
// }
aac, _ := ioutil.ReadAll(f) // func TestMuxMp4(t *testing.T) {
var pts uint64 = 0 // tsfilename := `demo.ts` // input
//var dts uint64 = 0 // tsfile, err := os.Open(tsfilename)
//var i int = 0 // if err != nil {
samples := uint64(0) // fmt.Println(err)
muxer, err := CreateMp4Muxer(mp4file) // return
if err != nil { // }
fmt.Println(err) // defer tsfile.Close()
return
}
tid := muxer.AddAudioTrack(MP4_CODEC_AAC) // mp4filename := "test14.mp4" // output
codec.SplitAACFrame(aac, func(aac []byte) { // mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
samples += 1024 // if err != nil {
pts = samples * 1000 / 44100 // fmt.Println(err)
// if i < 3 { // return
// pts += 23 // }
// dts += 23 // defer mp4file.Close()
// i++
// } else {
// pts += 24
// dts += 24
// i = 0
// }
muxer.Write(tid, aac, pts, pts)
//fmt.Println(pts)
})
muxer.WriteTrailer()
}
func TestMuxMp4(t *testing.T) { // muxer, err := CreateMp4Muxer(mp4file)
tsfilename := `demo.ts` // input // if err != nil {
tsfile, err := os.Open(tsfilename) // fmt.Println(err)
if err != nil { // return
fmt.Println(err) // }
return // vtid := muxer.AddVideoTrack(MP4_CODEC_H264)
} // atid := muxer.AddAudioTrack(MP4_CODEC_AAC)
defer tsfile.Close()
mp4filename := "test14.mp4" // output // afile, err := os.OpenFile("r.aac", os.O_CREATE|os.O_RDWR, 0666)
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666) // if err != nil {
if err != nil { // fmt.Println(err)
fmt.Println(err) // return
return // }
} // defer afile.Close()
defer mp4file.Close() // demuxer := mpeg2.NewTSDemuxer()
// demuxer.OnFrame = func(cid mpeg2.TS_STREAM_TYPE, frame []byte, pts uint64, dts uint64) {
muxer, err := CreateMp4Muxer(mp4file) // if cid == mpeg2.TS_STREAM_AAC {
if err != nil { // err = muxer.Write(atid, frame, uint64(pts), uint64(dts))
fmt.Println(err) // if err != nil {
return // panic(err)
} // }
vtid := muxer.AddVideoTrack(MP4_CODEC_H264) // } else if cid == mpeg2.TS_STREAM_H264 {
atid := muxer.AddAudioTrack(MP4_CODEC_AAC) // fmt.Println("pts,dts,len", pts, dts, len(frame))
// err = muxer.Write(vtid, frame, uint64(pts), uint64(dts))
// if err != nil {
// panic(err)
// }
// } else {
// panic("unkwon cid " + strconv.Itoa(int(cid)))
// }
// }
afile, err := os.OpenFile("r.aac", os.O_CREATE|os.O_RDWR, 0666) // err = demuxer.Input(tsfile)
if err != nil { // if err != nil {
fmt.Println(err) // panic(err)
return // }
}
defer afile.Close()
demuxer := mpeg2.NewTSDemuxer()
demuxer.OnFrame = func(cid mpeg2.TS_STREAM_TYPE, frame []byte, pts uint64, dts uint64) {
if cid == mpeg2.TS_STREAM_AAC { // err = muxer.WriteTrailer()
err = muxer.Write(atid, frame, uint64(pts), uint64(dts)) // if err != nil {
if err != nil { // panic(err)
panic(err) // }
} // }
} else if cid == mpeg2.TS_STREAM_H264 {
fmt.Println("pts,dts,len", pts, dts, len(frame))
err = muxer.Write(vtid, frame, uint64(pts), uint64(dts))
if err != nil {
panic(err)
}
} else {
panic("unkwon cid " + strconv.Itoa(int(cid)))
}
}
err = demuxer.Input(tsfile)
if err != nil {
panic(err)
}
err = muxer.WriteTrailer()
if err != nil {
panic(err)
}
}

View File

@@ -3,11 +3,13 @@ package plugin_preview
import ( import (
"embed" "embed"
"fmt" "fmt"
"m7s.live/m7s/v5"
"mime" "mime"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings" "strings"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
) )
//go:embed ui //go:embed ui
@@ -22,11 +24,14 @@ var _ = m7s.InstallPlugin[PreviewPlugin]()
func (p *PreviewPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *PreviewPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
s := "<h1><h1><h2>Live Streams 引擎中正在发布的流</h2>" s := "<h1><h1><h2>Live Streams 引擎中正在发布的流</h2>"
p.Server.Call(func() { p.Server.CallOnStreamTask(func(*pkg.Task) error {
for publisher := range p.Server.Streams.Range { for publisher := range p.Server.Streams.Range {
s += fmt.Sprintf("<a href='%s'>%s</a> [ %s ]<br>", publisher.StreamPath, publisher.StreamPath, publisher.Plugin.Meta.Name) s += fmt.Sprintf("<a href='%s'>%s</a> [ %s ]<br>", publisher.StreamPath, publisher.StreamPath, publisher.Plugin.Meta.Name)
} }
s += "<h2>pull stream on subscribe 订阅时才会触发拉流的流</h2>" s += "<h2>pull stream on subscribe 订阅时才会触发拉流的流</h2>"
return nil
})
p.Server.Call(func(*pkg.Task) error {
for plugin := range p.Server.Plugins.Range { for plugin := range p.Server.Plugins.Range {
if pullPlugin, ok := plugin.GetHandler().(m7s.IPullerPlugin); ok { if pullPlugin, ok := plugin.GetHandler().(m7s.IPullerPlugin); ok {
s += fmt.Sprintf("<h3>%s</h3>", plugin.Meta.Name) s += fmt.Sprintf("<h3>%s</h3>", plugin.Meta.Name)
@@ -35,6 +40,7 @@ func (p *PreviewPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
} }
return nil
}) })
w.Write([]byte(s)) w.Write([]byte(s))
return return

View File

@@ -3,13 +3,14 @@ package plugin_rtmp
import ( import (
"errors" "errors"
"io" "io"
"maps"
"net"
"slices"
"m7s.live/m7s/v5" "m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/plugin/rtmp/pb" "m7s.live/m7s/v5/plugin/rtmp/pb"
. "m7s.live/m7s/v5/plugin/rtmp/pkg" . "m7s.live/m7s/v5/plugin/rtmp/pkg"
"maps"
"net"
"slices"
) )
type RTMPPlugin struct { type RTMPPlugin struct {
@@ -18,24 +19,18 @@ type RTMPPlugin struct {
ChunkSize int `default:"1024"` ChunkSize int `default:"1024"`
KeepAlive bool KeepAlive bool
C2 bool C2 bool
connTM *pkg.TaskManager
} }
var _ = m7s.InstallPlugin[RTMPPlugin](m7s.DefaultYaml(`tcp: var _ = m7s.InstallPlugin[RTMPPlugin](m7s.DefaultYaml(`tcp:
listenaddr: :1935`), &pb.Rtmp_ServiceDesc, pb.RegisterRtmpHandler, Pull, Push) listenaddr: :1935`), &pb.Rtmp_ServiceDesc, pb.RegisterRtmpHandler, Pull, Push)
func (p *RTMPPlugin) OnInit() error { func (p *RTMPPlugin) OnInit() error {
p.connTM = pkg.StartTaskManager()
for streamPath, url := range p.GetCommonConf().PullOnStart { for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url) p.Pull(streamPath, url)
} }
return nil return nil
} }
func (p *RTMPPlugin) Dispose() {
p.connTM.ShutDown(p.StopReason())
}
func (p *RTMPPlugin) GetPullableList() []string { func (p *RTMPPlugin) GetPullableList() []string {
return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub)) return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub))
} }
@@ -44,14 +39,15 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
receivers := make(map[uint32]*Receiver) receivers := make(map[uint32]*Receiver)
var err error var err error
nc := NewNetConnection(conn) nc := NewNetConnection(conn)
p.With(nc, "remote", conn.RemoteAddr().String()) var connTask pkg.MarcoTask
p.connTM.Add(nc) connTask.Init(p.Context, p.With("remote", conn.RemoteAddr().String()), nc)
p.AddTask(&connTask)
defer func() { defer func() {
nc.Stop(err) connTask.Stop(err)
}() }()
/* Handshake */ /* Handshake */
if err = nc.Handshake(p.C2); err != nil { if err = nc.Handshake(p.C2); err != nil {
nc.Error("handshake", "error", err) connTask.Error("handshake", "error", err)
return return
} }
var msg *Chunk var msg *Chunk
@@ -63,13 +59,15 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
continue continue
} }
switch msg.MessageTypeID { switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE:
connTask.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize)
case RTMP_MSG_AMF0_COMMAND: case RTMP_MSG_AMF0_COMMAND:
if msg.MsgData == nil { if msg.MsgData == nil {
err = errors.New("msg.MsgData is nil") err = errors.New("msg.MsgData is nil")
break break
} }
cmd := msg.MsgData.(Commander).GetCommand() cmd := msg.MsgData.(Commander).GetCommand()
nc.Debug("recv cmd", "commandName", cmd.CommandName, "streamID", msg.MessageStreamID) connTask.Debug("recv cmd", "commandName", cmd.CommandName, "streamID", msg.MessageStreamID)
switch cmd := msg.MsgData.(type) { switch cmd := msg.MsgData.(type) {
case *CallMessage: //connect case *CallMessage: //connect
connectInfo = cmd.Object connectInfo = cmd.Object
@@ -82,16 +80,16 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
nc.ObjectEncoding = 0 nc.ObjectEncoding = 0
} }
nc.AppName = app.(string) nc.AppName = app.(string)
nc.Info("connect", "appName", nc.AppName, "objectEncoding", nc.ObjectEncoding) connTask.Info("connect", "appName", nc.AppName, "objectEncoding", nc.ObjectEncoding)
err = nc.SendMessage(RTMP_MSG_ACK_SIZE, Uint32Message(512<<10)) err = nc.SendMessage(RTMP_MSG_ACK_SIZE, Uint32Message(512<<10))
if err != nil { if err != nil {
nc.Error("sendMessage ack size", "error", err) connTask.Error("sendMessage ack size", "error", err)
return return
} }
nc.WriteChunkSize = p.ChunkSize nc.WriteChunkSize = p.ChunkSize
err = nc.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(p.ChunkSize)) err = nc.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(p.ChunkSize))
if err != nil { if err != nil {
nc.Error("sendMessage chunk size", "error", err) connTask.Error("sendMessage chunk size", "error", err)
return return
} }
err = nc.SendMessage(RTMP_MSG_BANDWIDTH, &SetPeerBandwidthMessage{ err = nc.SendMessage(RTMP_MSG_BANDWIDTH, &SetPeerBandwidthMessage{
@@ -99,12 +97,12 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
LimitType: byte(2), LimitType: byte(2),
}) })
if err != nil { if err != nil {
nc.Error("sendMessage bandwidth", "error", err) connTask.Error("sendMessage bandwidth", "error", err)
return return
} }
err = nc.SendStreamID(RTMP_USER_STREAM_BEGIN, 0) err = nc.SendStreamID(RTMP_USER_STREAM_BEGIN, 0)
if err != nil { if err != nil {
nc.Error("sendMessage stream begin", "error", err) connTask.Error("sendMessage stream begin", "error", err)
return return
} }
m := new(ResponseConnectMessage) m := new(ResponseConnectMessage)
@@ -123,11 +121,11 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
} }
err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m) err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
if err != nil { if err != nil {
nc.Error("sendMessage connect", "error", err) connTask.Error("sendMessage connect", "error", err)
} }
case *CommandMessage: // "createStream" case *CommandMessage: // "createStream"
gstreamid++ gstreamid++
nc.Info("createStream:", "streamId", gstreamid) connTask.Info("createStream:", "streamId", gstreamid)
nc.ResponseCreateStream(cmd.TransactionId, gstreamid) nc.ResponseCreateStream(cmd.TransactionId, gstreamid)
case *CURDStreamMessage: case *CURDStreamMessage:
// if stream, ok := receivers[cmd.StreamId]; ok { // if stream, ok := receivers[cmd.StreamId]; ok {
@@ -155,7 +153,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
StreamID: cmd.StreamId, StreamID: cmd.StreamId,
}, },
} }
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, nc.Context, connectInfo) receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, connTask.Context, connectInfo)
if err != nil { if err != nil {
delete(receivers, cmd.StreamId) delete(receivers, cmd.StreamId)
err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error) err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
@@ -164,7 +162,9 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
err = receiver.BeginPublish(cmd.TransactionId) err = receiver.BeginPublish(cmd.TransactionId)
} }
if err != nil { if err != nil {
nc.Error("sendMessage publish", "error", err) connTask.Error("sendMessage publish", "error", err)
} else {
connTask.AddTask(receiver.Publisher)
} }
case *PlayMessage: case *PlayMessage:
streamPath := nc.AppName + "/" + cmd.StreamName streamPath := nc.AppName + "/" + cmd.StreamName
@@ -174,7 +174,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
} }
var suber *m7s.Subscriber var suber *m7s.Subscriber
// sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID) // sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID)
suber, err = p.Subscribe(streamPath, nc.Context, connectInfo) suber, err = p.Subscribe(streamPath, connTask.Context, connectInfo)
if err != nil { if err != nil {
err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error) err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else { } else {
@@ -183,7 +183,9 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
go m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo) go m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo)
} }
if err != nil { if err != nil {
nc.Error("sendMessage play", "error", err) connTask.Error("sendMessage play", "error", err)
} else {
connTask.AddTask(suber)
} }
} }
case RTMP_MSG_AUDIO: case RTMP_MSG_AUDIO:
@@ -191,20 +193,20 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
err = r.WriteAudio(msg.AVData.WrapAudio()) err = r.WriteAudio(msg.AVData.WrapAudio())
} else { } else {
msg.AVData.Recycle() msg.AVData.Recycle()
nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID) connTask.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID)
} }
case RTMP_MSG_VIDEO: case RTMP_MSG_VIDEO:
if r, ok := receivers[msg.MessageStreamID]; ok && r.PubVideo { if r, ok := receivers[msg.MessageStreamID]; ok && r.PubVideo {
err = r.WriteVideo(msg.AVData.WrapVideo()) err = r.WriteVideo(msg.AVData.WrapVideo())
} else { } else {
msg.AVData.Recycle() msg.AVData.Recycle()
nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID) connTask.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID)
} }
} }
} else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { } else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
nc.Info("rtmp client closed") connTask.Info("rtmp client closed")
} else { } else {
nc.Warn("ReadMessage", "error", err) connTask.Warn("ReadMessage", "error", err)
} }
} }
} }

View File

@@ -43,7 +43,6 @@ func createClient(c *m7s.Connection) (*NetStream, error) {
} }
ns := &NetStream{} ns := &NetStream{}
ns.NetConnection = NewNetConnection(conn) ns.NetConnection = NewNetConnection(conn)
c.With(ns)
defer func() { defer func() {
if err != nil { if err != nil {
ns.Dispose() ns.Dispose()

View File

@@ -2,7 +2,6 @@ package rtmp
import ( import (
"errors" "errors"
"m7s.live/m7s/v5/pkg"
"net" "net"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
@@ -43,23 +42,19 @@ const (
) )
type NetConnection struct { type NetConnection struct {
pkg.Task
*util.BufReader *util.BufReader
net.Conn net.Conn
bandwidth uint32 bandwidth uint32
readSeqNum uint32 // 当前读的字节 readSeqNum, writeSeqNum uint32 // 当前读的字节
writeSeqNum uint32 // 当前写的字节 totalRead, totalWrite uint32 // 总共读写了多少字节
totalWrite uint32 // 总共写了多少字节 ReadChunkSize, WriteChunkSize int
totalRead uint32 // 总共读了多少字节 incommingChunks map[uint32]*Chunk
WriteChunkSize int ObjectEncoding float64
readChunkSize int AppName string
incommingChunks map[uint32]*Chunk tmpBuf util.Buffer //用来接收/发送小数据,复用内存
ObjectEncoding float64 chunkHeaderBuf util.Buffer
AppName string mediaDataPool util.RecyclableMemory
tmpBuf util.Buffer //用来接收/发送小数据,复用内存 writing atomic.Bool // false 可写true 不可写
chunkHeaderBuf util.Buffer
mediaDataPool util.RecyclableMemory
writing atomic.Bool // false 可写true 不可写
} }
func NewNetConnection(conn net.Conn) (ret *NetConnection) { func NewNetConnection(conn net.Conn) (ret *NetConnection) {
@@ -67,7 +62,7 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) {
Conn: conn, Conn: conn,
BufReader: util.NewBufReader(conn), BufReader: util.NewBufReader(conn),
WriteChunkSize: RTMP_DEFAULT_CHUNK_SIZE, WriteChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
readChunkSize: RTMP_DEFAULT_CHUNK_SIZE, ReadChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
incommingChunks: make(map[uint32]*Chunk), incommingChunks: make(map[uint32]*Chunk),
bandwidth: RTMP_MAX_CHUNK_SIZE << 3, bandwidth: RTMP_MAX_CHUNK_SIZE << 3,
tmpBuf: make(util.Buffer, 4), tmpBuf: make(util.Buffer, 4),
@@ -150,10 +145,10 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
return nil, nil return nil, nil
} }
var bufSize = 0 var bufSize = 0
if unRead := msgLen - chunk.bufLen; unRead < conn.readChunkSize { if unRead := msgLen - chunk.bufLen; unRead < conn.ReadChunkSize {
bufSize = unRead bufSize = unRead
} else { } else {
bufSize = conn.readChunkSize bufSize = conn.ReadChunkSize
} }
conn.readSeqNum += uint32(bufSize) conn.readSeqNum += uint32(bufSize)
if chunk.bufLen == 0 { if chunk.bufLen == 0 {
@@ -275,8 +270,7 @@ func (conn *NetConnection) RecvMessage() (msg *Chunk, err error) {
if msg, err = conn.readChunk(); msg != nil && err == nil { if msg, err = conn.readChunk(); msg != nil && err == nil {
switch msg.MessageTypeID { switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE: case RTMP_MSG_CHUNK_SIZE:
conn.readChunkSize = int(msg.MsgData.(Uint32Message)) conn.ReadChunkSize = int(msg.MsgData.(Uint32Message))
conn.Info("msg read chunk size", "readChunkSize", conn.readChunkSize)
case RTMP_MSG_ABORT: case RTMP_MSG_ABORT:
delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message))) delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message)))
case RTMP_MSG_ACK, RTMP_MSG_EDGE: case RTMP_MSG_ACK, RTMP_MSG_EDGE:

View File

@@ -1,23 +1,21 @@
package rtp package rtp
import ( import (
"testing"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/util" "m7s.live/m7s/v5/pkg/util"
"testing"
) )
func TestRTPH264Ctx_CreateFrame(t *testing.T) { func TestRTPH264Ctx_CreateFrame(t *testing.T) {
var ctx = &H264Ctx{ var ctx = &H264Ctx{}
RTPCtx: RTPCtx{ ctx.RTPCodecParameters = webrtc.RTPCodecParameters{
RTPCodecParameters: webrtc.RTPCodecParameters{ PayloadType: 96,
PayloadType: 96, RTPCodecCapability: webrtc.RTPCodecCapability{
RTPCodecCapability: webrtc.RTPCodecCapability{ MimeType: webrtc.MimeTypeH264,
MimeType: webrtc.MimeTypeH264, ClockRate: 90000,
ClockRate: 90000, SDPFmtpLine: "packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029",
SDPFmtpLine: "packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029",
},
},
}, },
} }
var randStr = util.RandomString(1500) var randStr = util.RandomString(1500)

View File

@@ -3,9 +3,6 @@ package plugin_rtsp
import ( import (
"errors" "errors"
"fmt" "fmt"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
. "m7s.live/m7s/v5/plugin/rtsp/pkg"
"maps" "maps"
"net" "net"
"net/http" "net/http"
@@ -13,6 +10,10 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
. "m7s.live/m7s/v5/plugin/rtsp/pkg"
) )
const defaultConfig = m7s.DefaultYaml(`tcp: const defaultConfig = m7s.DefaultYaml(`tcp:
@@ -30,7 +31,7 @@ func (p *RTSPPlugin) GetPullableList() []string {
func (p *RTSPPlugin) OnInit() error { func (p *RTSPPlugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart { for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url) p.Pull(streamPath, url)
} }
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package plugin_stress
import ( import (
"context" "context"
"fmt" "fmt"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"m7s.live/m7s/v5" "m7s.live/m7s/v5"
gpb "m7s.live/m7s/v5/pb" gpb "m7s.live/m7s/v5/pb"
@@ -20,7 +21,13 @@ func (r *StressPlugin) pull(count int, format, url string, puller m7s.Puller) er
if err != nil { if err != nil {
return err return err
} }
go r.startPull(ctx, puller) ctx.AddCall(func(*pkg.Task) error {
r.pullers.AddUnique(ctx)
ctx.Do(puller)
return nil
}, func(*pkg.Task) {
r.pullers.Remove(ctx)
})
} }
} else if count < i { } else if count < i {
for j := i; j > count; j-- { for j := i; j > count; j-- {
@@ -38,7 +45,13 @@ func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, pu
if err != nil { if err != nil {
return err return err
} }
go r.startPush(ctx, pusher) ctx.AddCall(func(*pkg.Task) error {
r.pushers.AddUnique(ctx)
ctx.Do(pusher)
return nil
}, func(*pkg.Task) {
r.pushers.Remove(ctx)
})
} }
} else if count < i { } else if count < i {
for j := i; j > count; j-- { for j := i; j > count; j-- {
@@ -69,18 +82,6 @@ func (r *StressPlugin) PullHDL(ctx context.Context, req *pb.PullRequest) (res *g
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.PullFLV) return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.PullFLV)
} }
func (r *StressPlugin) startPush(pusher *m7s.PushContext, handler m7s.Pusher) {
r.pushers.AddUnique(pusher)
pusher.Run(handler)
r.pushers.Remove(pusher)
}
func (r *StressPlugin) startPull(puller *m7s.PullContext, handler m7s.Puller) {
r.pullers.AddUnique(puller)
puller.Run(handler)
r.pullers.Remove(puller)
}
func (r *StressPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *gpb.SuccessResponse, err error) { func (r *StressPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *gpb.SuccessResponse, err error) {
for pusher := range r.pushers.Range { for pusher := range r.pushers.Range {
pusher.Stop(pkg.ErrStopFromAPI) pusher.Stop(pkg.ErrStopFromAPI)

View File

@@ -91,7 +91,7 @@ func (p *Publisher) GetKey() string {
func createPublisher(p *Plugin, streamPath string, options ...any) (publisher *Publisher) { func createPublisher(p *Plugin, streamPath string, options ...any) (publisher *Publisher) {
publisher = &Publisher{Publish: p.config.Publish} publisher = &Publisher{Publish: p.config.Publish}
publisher.ID = p.Server.streamTM.GetID() publisher.ID = p.Server.streamTask.GetID()
publisher.Plugin = p publisher.Plugin = p
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout) publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
var opt = []any{publisher, p.Logger.With("streamPath", streamPath, "pId", publisher.ID)} var opt = []any{publisher, p.Logger.With("streamPath", streamPath, "pId", publisher.ID)}
@@ -137,7 +137,7 @@ func (p *Publisher) Start() (err error) {
} }
if remoteURL := plugin.GetCommonConf().CheckPush(p.StreamPath); remoteURL != "" { if remoteURL := plugin.GetCommonConf().CheckPush(p.StreamPath); remoteURL != "" {
if plugin.Meta.Pusher != nil { if plugin.Meta.Pusher != nil {
go plugin.PushBlock(p.StreamPath, remoteURL, plugin.Meta.Pusher) plugin.Push(p.StreamPath, remoteURL, plugin.Meta.Pusher)
} }
} }
if filePath := plugin.GetCommonConf().CheckRecord(p.StreamPath); filePath != "" { if filePath := plugin.GetCommonConf().CheckRecord(p.StreamPath); filePath != "" {

View File

@@ -2,13 +2,14 @@ package m7s
import ( import (
"context" "context"
"time"
"m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config" "m7s.live/m7s/v5/pkg/config"
"time"
) )
type Connection struct { type Connection struct {
pkg.Task pkg.MarcoTask
Plugin *Plugin Plugin *Plugin
StreamPath string // 对应本地流 StreamPath string // 对应本地流
RemoteURL string // 远程服务器地址(用于推拉) RemoteURL string // 远程服务器地址(用于推拉)
@@ -27,7 +28,6 @@ type Puller = func(*PullContext) error
func createPullContext(p *Plugin, streamPath string, url string, options ...any) (pullCtx *PullContext) { func createPullContext(p *Plugin, streamPath string, url string, options ...any) (pullCtx *PullContext) {
pullCtx = &PullContext{Pull: p.config.Pull} pullCtx = &PullContext{Pull: p.config.Pull}
pullCtx.ID = p.Server.pullTM.GetID()
pullCtx.Plugin = p pullCtx.Plugin = p
pullCtx.ConnectProxy = p.config.Pull.Proxy pullCtx.ConnectProxy = p.config.Pull.Proxy
pullCtx.RemoteURL = url pullCtx.RemoteURL = url
@@ -44,8 +44,9 @@ func createPullContext(p *Plugin, streamPath string, url string, options ...any)
pullCtx.PublishOptions = append(pullCtx.PublishOptions, option) pullCtx.PublishOptions = append(pullCtx.PublishOptions, option)
} }
} }
p.Init(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath)) p.InitKeepAlive(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath), pullCtx)
pullCtx.PublishOptions = append(pullCtx.PublishOptions, pullCtx.Context) pullCtx.PublishOptions = append(pullCtx.PublishOptions, pullCtx.Context)
p.Server.pullTask.AddTask(pullCtx)
return return
} }
@@ -60,30 +61,25 @@ func (p *PullContext) GetKey() string {
return p.StreamPath return p.StreamPath
} }
func (p *PullContext) Run(puller Puller) { func (p *PullContext) Do(puller Puller) {
var err error p.AddCall(func(tmpTask *pkg.Task) (err error) {
for p.reconnect(p.RePull) { publishOptions := append([]any{tmpTask.Context}, p.PublishOptions...)
if p.Publisher != nil { if p.Publisher, err = p.Plugin.Publish(p.StreamPath, publishOptions...); err != nil {
if time.Since(p.Publisher.StartTime) < 5*time.Second { p.Error("pull publish failed", "error", err)
return
}
err = puller(p)
if p.reconnect(p.RePull) {
if time.Since(tmpTask.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
} }
p.Warn("retry", "count", p.ReConnectCount, "total", p.RePull) p.Warn("retry", "count", p.ReConnectCount, "total", p.RePull)
p.Do(puller)
} else {
p.Stop(pkg.ErrRetryRunOut)
} }
if p.Publisher, err = p.Plugin.Publish(p.StreamPath, p.PublishOptions...); err != nil { return
p.Error("pull publish failed", "error", err) }, nil)
break
}
err = puller(p)
p.Publisher.Stop(err)
if p.IsStopped() {
return
}
p.Error("pull interrupt", "error", err)
}
if err == nil {
err = pkg.ErrRetryRunOut
}
p.Stop(err)
} }
func (p *PullContext) Start() (err error) { func (p *PullContext) Start() (err error) {
@@ -92,6 +88,9 @@ func (p *PullContext) Start() (err error) {
return pkg.ErrStreamExist return pkg.ErrStreamExist
} }
s.Pulls.Add(p) s.Pulls.Add(p)
if p.Plugin.Meta.Puller != nil {
p.Do(p.Plugin.Meta.Puller)
}
return return
} }

View File

@@ -2,9 +2,10 @@ package m7s
import ( import (
"context" "context"
"m7s.live/m7s/v5/pkg"
"time" "time"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config" "m7s.live/m7s/v5/pkg/config"
) )
@@ -12,7 +13,7 @@ type Pusher = func(*PushContext) error
func createPushContext(p *Plugin, streamPath string, url string, options ...any) (pushCtx *PushContext) { func createPushContext(p *Plugin, streamPath string, url string, options ...any) (pushCtx *PushContext) {
pushCtx = &PushContext{Push: p.config.Push} pushCtx = &PushContext{Push: p.config.Push}
pushCtx.ID = p.Server.pushTM.GetID() pushCtx.ID = p.Server.pushTask.GetID()
pushCtx.Plugin = p pushCtx.Plugin = p
pushCtx.RemoteURL = url pushCtx.RemoteURL = url
pushCtx.StreamPath = streamPath pushCtx.StreamPath = streamPath
@@ -43,31 +44,27 @@ func (p *PushContext) GetKey() string {
return p.RemoteURL return p.RemoteURL
} }
func (p *PushContext) Run(pusher Pusher) { func (p *PushContext) Do(pusher Pusher) {
p.StartTime = time.Now() p.AddCall(func(tmpTask *pkg.Task) (err error) {
defer p.Info("stop push")
var err error
for p.Info("start push", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePush); p.Warn("restart push") {
if p.Subscriber != nil && time.Since(p.Subscriber.StartTime) < 5*time.Second { if p.Subscriber != nil && time.Since(p.Subscriber.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
} }
if p.Subscriber, err = p.Plugin.Subscribe(p.StreamPath, p.SubscribeOptions...); err != nil { if p.Subscriber, err = p.Plugin.Subscribe(p.StreamPath, p.SubscribeOptions...); err != nil {
p.Error("push subscribe failed", "error", err) p.Error("push subscribe failed", "error", err)
break return
} }
err = pusher(p) err = pusher(p)
p.Subscriber.Stop(err) if p.Connection.reconnect(p.RePush) {
if p.IsStopped() { if time.Since(tmpTask.StartTime) < 5*time.Second {
return time.Sleep(5 * time.Second)
}
p.Warn("retry", "count", p.ReConnectCount, "total", p.RePush)
p.Do(pusher)
} else { } else {
p.Error("push interrupt", "error", err) p.Stop(pkg.ErrRetryRunOut)
} }
} return
if err == nil { }, nil)
err = pkg.ErrRetryRunOut
}
p.Stop(err)
return
} }
func (p *PushContext) Start() (err error) { func (p *PushContext) Start() (err error) {
@@ -76,6 +73,9 @@ func (p *PushContext) Start() (err error) {
return pkg.ErrPushRemoteURLExist return pkg.ErrPushRemoteURLExist
} }
s.Pushs.Add(p) s.Pushs.Add(p)
if p.Plugin.Meta.Pusher != nil {
p.Do(p.Plugin.Meta.Pusher)
}
return return
} }

View File

@@ -15,7 +15,7 @@ func createRecoder(p *Plugin, streamPath string, filePath string, options ...any
Append: p.config.Record.Append, Append: p.config.Record.Append,
FilePath: filePath, FilePath: filePath,
} }
recorder.ID = p.Server.recordTM.GetID() recorder.ID = p.Server.recordTask.GetID()
recorder.FilePath = filePath recorder.FilePath = filePath
recorder.SubscribeOptions = []any{p.config.Subscribe} recorder.SubscribeOptions = []any{p.config.Subscribe}
var ctx = p.Context var ctx = p.Context

View File

@@ -36,7 +36,7 @@ var (
Version: Version, Version: Version,
} }
Servers util.Collection[uint32, *Server] Servers util.Collection[uint32, *Server]
serverTM = NewTaskManager() globalTask MarcoTask
Routes = map[string]string{} Routes = map[string]string{}
defaultLogHandler = console.NewHandler(os.Stdout, &console.HandlerOptions{TimeFormat: "15:04:05.000000"}) defaultLogHandler = console.NewHandler(os.Stdout, &console.HandlerOptions{TimeFormat: "15:04:05.000000"})
) )
@@ -53,27 +53,26 @@ type Server struct {
pb.UnimplementedGlobalServer pb.UnimplementedGlobalServer
Plugin Plugin
ServerConfig ServerConfig
//eventChan chan any Plugins util.Collection[string, *Plugin]
Plugins util.Collection[string, *Plugin] Streams, Waiting util.Collection[string, *Publisher]
Streams, Waiting util.Collection[string, *Publisher] Pulls util.Collection[string, *PullContext]
Pulls util.Collection[string, *PullContext] Pushs util.Collection[string, *PushContext]
Pushs util.Collection[string, *PushContext] Records util.Collection[string, *RecordContext]
Records util.Collection[string, *RecordContext] Subscribers SubscriberCollection
Subscribers SubscriberCollection LogHandler MultiLogHandler
LogHandler MultiLogHandler apiList []string
apiList []string grpcServer *grpc.Server
grpcServer *grpc.Server grpcClientConn *grpc.ClientConn
grpcClientConn *grpc.ClientConn tcplis net.Listener
tcplis net.Listener lastSummaryTime time.Time
lastSummaryTime time.Time lastSummary *pb.SummaryResponse
lastSummary *pb.SummaryResponse streamTask, pullTask, pushTask, recordTask MarcoTask
pluginTM, streamTM, pullTM, pushTM, recordTM *TaskManager conf any
conf any
} }
func NewServer() (s *Server) { func NewServer() (s *Server) {
s = &Server{} s = &Server{}
s.ID = serverTM.GetID() s.ID = globalTask.GetID()
s.Meta = &serverMeta s.Meta = &serverMeta
return return
} }
@@ -87,12 +86,16 @@ type rawconfig = map[string]map[string]any
func init() { func init() {
signalChan := make(chan os.Signal, 1) signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go serverTM.Run(signalChan, func(os.Signal) { globalTask.InitKeepAlive(context.Background(), nil, nil, signalChan, func(os.Signal) {
for _, meta := range plugins { for _, meta := range plugins {
if meta.OnExit != nil { if meta.OnExit != nil {
meta.OnExit() meta.OnExit()
} }
} }
if serverMeta.OnExit != nil {
serverMeta.OnExit()
}
os.Exit(0)
}) })
for k, v := range myip.LocalAndInternalIPs() { for k, v := range myip.LocalAndInternalIPs() {
Routes[k] = v Routes[k] = v
@@ -116,7 +119,7 @@ func (s *Server) Init(ctx context.Context, conf any) {
s.config.TCP.ListenAddr = ":50051" s.config.TCP.ListenAddr = ":50051"
s.LogHandler.SetLevel(slog.LevelInfo) s.LogHandler.SetLevel(slog.LevelInfo)
s.LogHandler.Add(defaultLogHandler) s.LogHandler.Add(defaultLogHandler)
s.Task.Init(ctx, slog.New(&s.LogHandler).With("Server", s.ID)) s.MarcoTask.Init(ctx, slog.New(&s.LogHandler).With("Server", s.ID), s)
} }
func (s *Server) Start() (err error) { func (s *Server) Start() (err error) {
@@ -204,10 +207,9 @@ func (s *Server) Start() (err error) {
return return
} }
} }
s.pluginTM = StartTaskManager()
for _, plugin := range plugins { for _, plugin := range plugins {
if p := plugin.Init(s, cg[strings.ToLower(plugin.Name)]); !p.Disabled { if p := plugin.Init(s, cg[strings.ToLower(plugin.Name)]); !p.Disabled {
s.pluginTM.Start(&p.Task) s.WaitTaskAdded(p)
} }
} }
if s.tcplis != nil { if s.tcplis != nil {
@@ -219,7 +221,7 @@ func (s *Server) Start() (err error) {
} }
}(tcpConf.ListenAddr) }(tcpConf.ListenAddr)
} }
s.streamTM = StartTaskManager(time.NewTicker(s.PulseInterval).C, func(time.Time) { s.streamTask.InitKeepAlive(s.Context, nil, nil, time.NewTicker(s.PulseInterval).C, func(time.Time) {
for publisher := range s.Streams.Range { for publisher := range s.Streams.Range {
if err := publisher.checkTimeout(); err != nil { if err := publisher.checkTimeout(); err != nil {
publisher.Stop(err) publisher.Stop(err)
@@ -242,15 +244,16 @@ func (s *Server) Start() (err error) {
} }
} }
}) })
s.pullTM = StartTaskManager() s.pullTask.InitKeepAlive(s.Context, nil, nil)
s.pushTM = StartTaskManager() s.pushTask.InitKeepAlive(s.Context, nil, nil)
s.recordTM = StartTaskManager() s.recordTask.InitKeepAlive(s.Context, nil, nil)
s.AddTasks(&s.streamTask, &s.pullTask, &s.pushTask, &s.recordTask)
Servers.Add(s) Servers.Add(s)
return return
} }
func (s *Server) Call(callback func()) { func (s *Server) CallOnStreamTask(callback func(*Task) error) {
s.streamTM.Call(callback) s.streamTask.Call(callback)
} }
func (s *Server) Dispose() { func (s *Server) Dispose() {
@@ -258,22 +261,15 @@ func (s *Server) Dispose() {
_ = s.tcplis.Close() _ = s.tcplis.Close()
_ = s.grpcClientConn.Close() _ = s.grpcClientConn.Close()
s.config.HTTP.StopListen() s.config.HTTP.StopListen()
err := context.Cause(s)
s.streamTM.ShutDown(err)
s.pullTM.ShutDown(err)
s.pushTM.ShutDown(err)
s.recordTM.ShutDown(err)
s.pluginTM.ShutDown(err)
} }
func (s *Server) Run(ctx context.Context, conf any) (err error) { func (s *Server) Run(ctx context.Context, conf any) (err error) {
for { for {
s.Init(ctx, conf) s.Init(ctx, conf)
if err = serverTM.Start(s); err != nil { if err = globalTask.WaitTaskAdded(s); err != nil {
return return
} }
<-s.Done() if err = s.WaitStopped(); err != ErrRestart {
if err = context.Cause(s); err != ErrRestart {
return return
} }
var server Server var server Server

View File

@@ -23,26 +23,25 @@ type PubSubBase struct {
StreamPath string StreamPath string
Args url.Values Args url.Values
TimeoutTimer *time.Timer TimeoutTimer *time.Timer
MetaData any
} }
func (ps *PubSubBase) Init(streamPath string, conf any, options ...any) { func (ps *PubSubBase) Init(streamPath string, conf any, options ...any) {
ctx := ps.Plugin.Context ctx := ps.Plugin.Context
var logger *slog.Logger var logger *slog.Logger
var executor []TaskExecutor var executor TaskExecutor
for _, option := range options { for _, option := range options {
switch v := option.(type) { switch v := option.(type) {
case TaskExecutor: case TaskExecutor:
executor = append(executor, v) executor = v
case *slog.Logger: case *slog.Logger:
logger = v logger = v
case context.Context: case context.Context:
ctx = v ctx = v
default: case map[string]any:
ps.MetaData = v ps.Description = v
} }
} }
ps.Task.Init(ctx, logger, executor...) ps.Task.Init(ctx, logger, executor)
if u, err := url.Parse(streamPath); err == nil { if u, err := url.Parse(streamPath); err == nil {
ps.StreamPath, ps.Args = u.Path, u.Query() ps.StreamPath, ps.Args = u.Path, u.Query()
} }
@@ -77,7 +76,7 @@ type Subscriber struct {
func createSubscriber(p *Plugin, streamPath string, options ...any) *Subscriber { func createSubscriber(p *Plugin, streamPath string, options ...any) *Subscriber {
subscriber := &Subscriber{Subscribe: p.config.Subscribe} subscriber := &Subscriber{Subscribe: p.config.Subscribe}
subscriber.ID = p.Server.streamTM.GetID() subscriber.ID = p.Server.streamTask.GetID()
subscriber.Plugin = p subscriber.Plugin = p
subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout) subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout)
var opt = []any{subscriber, p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)} var opt = []any{subscriber, p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)}
@@ -109,7 +108,7 @@ func (s *Subscriber) Start() (err error) {
for plugin := range server.Plugins.Range { for plugin := range server.Plugins.Range {
if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(s.StreamPath); remoteURL != "" { if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(s.StreamPath); remoteURL != "" {
if plugin.Meta.Puller != nil { if plugin.Meta.Puller != nil {
go plugin.PullBlock(s.StreamPath, remoteURL, plugin.Meta.Puller) plugin.Pull(s.StreamPath, remoteURL, plugin.Meta.Puller)
} }
} }
} }