refactor: marotask

This commit is contained in:
langhuihui
2024-08-10 08:25:27 +08:00
parent 14a5184477
commit da0066fcc8
21 changed files with 583 additions and 577 deletions

5
.gitignore vendored
View File

@@ -3,3 +3,8 @@
logs
fatal
.idea
victoria-logs-data
dump
record
bin
.DS_Store

30
api.go
View File

@@ -90,7 +90,7 @@ func (s *Server) api_Stream_AnnexB_(rw http.ResponseWriter, r *http.Request) {
}
func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err error) {
tmp, _ := json.Marshal(pub.MetaData)
tmp, _ := json.Marshal(pub.Description)
res = &pb.StreamInfoResponse{
Meta: string(tmp),
Path: pub.StreamPath,
@@ -131,20 +131,21 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err
}
func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.StreamInfoResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok {
res, err = s.getStreamInfo(pub)
} else {
err = pkg.ErrNotFound
}
return nil
})
return
}
func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
var subscribers []*pb.SubscriberSnapShot
for subscriber := range s.Subscribers.Range {
meta, _ := json.Marshal(subscriber.MetaData)
meta, _ := json.Marshal(subscriber.Description)
snap := &pb.SubscriberSnapShot{
Id: uint32(subscriber.ID),
StartTime: timestamppb.New(subscriber.StartTime),
@@ -172,11 +173,12 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest)
List: subscribers,
Total: int32(s.Subscribers.Length),
}
return nil
})
return
}
func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasAudioTrack() {
res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.AudioTrack.Allocator.GetChildren() {
@@ -221,6 +223,7 @@ func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest)
} else {
err = pkg.ErrNotFound
}
return nil
})
return
}
@@ -254,7 +257,7 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) {
}
func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasVideoTrack() {
res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.VideoTrack.Allocator.GetChildren() {
@@ -299,6 +302,7 @@ func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest)
} else {
err = pkg.ErrNotFound
}
return nil
})
return
}
@@ -320,34 +324,36 @@ func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *empt
}
func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeRequest) (res *pb.SuccessResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
if subscriber, ok := s.Subscribers.Get(req.Id); ok {
if pub, ok := s.Streams.Get(req.StreamPath); ok {
subscriber.Publisher.RemoveSubscriber(subscriber)
subscriber.StreamPath = req.StreamPath
pub.AddSubscriber(subscriber)
return
return nil
}
}
err = pkg.ErrNotFound
return nil
})
return &pb.SuccessResponse{}, err
}
func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res *pb.SuccessResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
if subscriber, ok := s.Subscribers.Get(req.Id); ok {
subscriber.Stop(errors.New("stop by api"))
} else {
err = pkg.ErrNotFound
}
return nil
})
return &pb.SuccessResponse{}, err
}
// /api/stream/list
func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *pb.StreamListResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
var streams []*pb.StreamInfoResponse
for publisher := range s.Streams.Range {
info, err := s.getStreamInfo(publisher)
@@ -357,18 +363,20 @@ func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *
streams = append(streams, info)
}
res = &pb.StreamListResponse{List: streams, Total: int32(s.Streams.Length), PageNum: req.PageNum, PageSize: req.PageSize}
return nil
})
return
}
func (s *Server) WaitList(context.Context, *emptypb.Empty) (res *pb.StreamWaitListResponse, err error) {
s.streamTM.Call(func() {
s.streamTask.Call(func(*pkg.Task) error {
res = &pb.StreamWaitListResponse{
List: make(map[string]int32),
}
for subs := range s.Waiting.Range {
res.List[subs.StreamPath] = int32(subs.Subscribers.Length)
}
return nil
})
return
}

View File

@@ -2,17 +2,23 @@ package pkg
import (
"context"
"io"
"errors"
"log/slog"
"m7s.live/m7s/v5/pkg/util"
"reflect"
"slices"
"sync/atomic"
"time"
"m7s.live/m7s/v5/pkg/util"
)
const TraceLevel = slog.Level(-8)
var (
ErrAutoStop = errors.New("auto stop")
ErrCallbackTask = errors.New("callback task")
)
type getTask interface{ GetTask() *Task }
type TaskExecutor interface {
Start() error
@@ -25,12 +31,17 @@ type TempTaskExecutor struct {
}
func (t TempTaskExecutor) Start() error {
if t.StartFunc == nil {
return nil
}
return t.StartFunc()
}
func (t TempTaskExecutor) Dispose() {
if t.DisposeFunc != nil {
t.DisposeFunc()
}
}
type Task struct {
ID uint32
@@ -38,9 +49,10 @@ type Task struct {
*slog.Logger
context.Context
context.CancelCauseFunc
exeStack []TaskExecutor
exe TaskExecutor
Description map[string]any
started *util.Promise
startup, shutdown *util.Promise
parent *MarcoTask
}
func (task *Task) GetTask() *Task {
@@ -51,29 +63,12 @@ func (task *Task) GetKey() uint32 {
return task.ID
}
func (task *Task) Begin() (err error) {
task.StartTime = time.Now()
for _, executor := range task.exeStack {
err = executor.Start()
if err != nil {
break
}
}
task.started.Fulfill(err)
return
}
func (task *Task) dispose(reason error) {
if task.Logger != nil {
task.Debug("stop", "reason", reason)
}
for _, executor := range slices.Backward(task.exeStack) {
executor.Dispose()
}
}
func (task *Task) WaitStarted() error {
return task.started.Await()
return task.startup.Await()
}
func (task *Task) WaitStopped() error {
return task.shutdown.Await()
}
func (task *Task) Trace(msg string, fields ...any) {
@@ -95,88 +90,154 @@ func (task *Task) Stop(err error) {
}
}
func (task *Task) With(child getTask, args ...any) {
child.GetTask().Init(task.Context, task.Logger.With(args...))
}
func (task *Task) Init(ctx context.Context, logger *slog.Logger, executor ...TaskExecutor) {
func (task *Task) Init(ctx context.Context, logger *slog.Logger, executor TaskExecutor) {
task.Logger = logger
task.exeStack = executor
task.exe = executor
task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx)
task.started = util.NewPromise(task.Context)
task.startup = util.NewPromise(task.Context)
task.shutdown = util.NewPromise(task.Context)
}
type CallBackTaskExecutor func()
type CallBack func(*Task) error
func (call CallBackTaskExecutor) Start() error {
call()
return io.EOF
}
func (call CallBackTaskExecutor) Dispose() {
// nothing to do, never called
}
type TaskManager struct {
shutdown *util.Promise
stopReason error
start chan *Task
Tasks []*Task
// MarcoTask include sub tasks
type MarcoTask struct {
Task
KeepAlive bool
exe TaskExecutor
addSub chan *Task
subTasks []*Task
extraCases []reflect.SelectCase
extraCallbacks []reflect.Value
idG atomic.Uint32
}
func NewTaskManager() *TaskManager {
return &TaskManager{
shutdown: util.NewPromise(context.TODO()),
start: make(chan *Task, 10),
func (mt *MarcoTask) Init(ctx context.Context, logger *slog.Logger, executor TaskExecutor, extra ...any) {
mt.Task.Init(ctx, logger, mt)
mt.exe = executor
for i := range len(extra) / 2 {
mt.extraCases = append(mt.extraCases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])})
mt.extraCallbacks = append(mt.extraCallbacks, reflect.ValueOf(extra[i*2+1]))
}
}
func StartTaskManager(extra ...any) *TaskManager {
tm := NewTaskManager()
go tm.Run(extra...)
return tm
func (mt *MarcoTask) InitKeepAlive(ctx context.Context, logger *slog.Logger, executor TaskExecutor, extra ...any) {
mt.Init(ctx, logger, executor, extra...)
mt.KeepAlive = true
}
func (t *TaskManager) Add(getTask getTask) {
func (mt *MarcoTask) Start() error {
if mt.exe != nil {
return mt.exe.Start()
}
return nil
}
func (mt *MarcoTask) AddTasks(getTasks ...getTask) {
for _, getTask := range getTasks {
mt.AddTask(getTask)
}
}
func (mt *MarcoTask) AddTask(getTask getTask) {
if mt.IsStopped() {
getTask.GetTask().startup.Reject(mt.StopReason())
return
}
if mt.addSub == nil {
mt.addSub = make(chan *Task, 10)
go mt.run()
}
task := getTask.GetTask()
if v, ok := getTask.(TaskExecutor); len(task.exeStack) == 0 && ok {
task.exeStack = append(task.exeStack, v)
if task.ID == 0 {
task.ID = mt.GetID()
}
t.start <- task
if task.parent == nil {
task.parent = mt
if v, ok := getTask.(TaskExecutor); ok {
task.exe = v
}
}
mt.addSub <- task
}
func (t *TaskManager) Call(callback CallBackTaskExecutor) error {
func (mt *MarcoTask) Call(callback CallBack) {
task := mt.AddCall(callback, nil)
_ = task.WaitStarted()
}
func (mt *MarcoTask) AddCall(start CallBack, dispose func(*Task)) *Task {
var tmpTask Task
tmpTask.Init(context.TODO(), nil, callback)
return t.Start(&tmpTask)
var tmpExe TempTaskExecutor
if start != nil {
tmpExe.StartFunc = func() error {
err := start(&tmpTask)
if err == nil && dispose == nil {
err = ErrCallbackTask
}
return err
}
}
if dispose != nil {
tmpExe.DisposeFunc = func() {
dispose(&tmpTask)
}
}
tmpTask.Init(mt.Context, nil, tmpExe)
mt.AddTask(&tmpTask)
return &tmpTask
}
func (t *TaskManager) Start(getTask getTask) error {
t.Add(getTask)
func (mt *MarcoTask) WaitTaskAdded(getTask getTask) error {
mt.AddTask(getTask)
return getTask.GetTask().WaitStarted()
}
func (t *TaskManager) GetID() uint32 {
return t.idG.Add(1)
func (mt *MarcoTask) GetID() uint32 {
return mt.idG.Add(1)
}
// Run task Start and Dispose in this goroutine
func (t *TaskManager) Run(extra ...any) {
cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.start)}}
extraLen := len(extra) / 2
var callbacks []reflect.Value
for i := range extraLen {
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])})
callbacks = append(callbacks, reflect.ValueOf(extra[i*2+1]))
func (mt *MarcoTask) startSubTask(task *Task) (err error) {
if task.startup.IsPending() {
task.StartTime = time.Now()
err = task.exe.Start()
if task.Logger != nil {
task.Debug("start")
}
task.startup.Fulfill(err)
}
return
}
func (mt *MarcoTask) disposeSubTask(task *Task, reason error) {
if task.parent != mt {
return
}
if task.Logger != nil {
task.Debug("dispose", "reason", reason)
}
task.exe.Dispose()
if m, ok := task.exe.(*MarcoTask); ok {
m.WaitStopped()
} else {
task.shutdown.Fulfill(reason)
}
}
func (mt *MarcoTask) run() {
extraLen := len(mt.extraCases)
cases := append([]reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(mt.addSub)}}, mt.extraCases...)
defer func() {
for _, task := range t.Tasks {
task.dispose(t.stopReason)
stopReason := mt.StopReason()
for _, task := range mt.subTasks {
task.Stop(stopReason)
mt.disposeSubTask(task, stopReason)
}
t.Tasks = nil
cases = nil
t.shutdown.Fulfill(t.stopReason)
mt.subTasks = nil
mt.addSub = nil
mt.extraCases = nil
mt.extraCallbacks = nil
mt.shutdown.Fulfill(stopReason)
}()
for {
if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
@@ -184,11 +245,8 @@ func (t *TaskManager) Run(extra ...any) {
return
}
task := rev.Interface().(*Task)
if err := task.Begin(); err == nil {
if task.Logger != nil {
task.Debug("start")
}
t.Tasks = append(t.Tasks, task)
if err := mt.startSubTask(task); err == nil {
mt.subTasks = append(mt.subTasks, task)
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
} else {
if task.Logger != nil {
@@ -197,53 +255,29 @@ func (t *TaskManager) Run(extra ...any) {
task.Stop(err)
}
} else if chosen <= extraLen {
callbacks[chosen-1].Call([]reflect.Value{rev})
mt.extraCallbacks[chosen-1].Call([]reflect.Value{rev})
} else {
taskIndex := chosen - extraLen - 1
task := t.Tasks[taskIndex]
task.dispose(task.StopReason())
t.Tasks = slices.Delete(t.Tasks, taskIndex, taskIndex+1)
task := mt.subTasks[taskIndex]
mt.disposeSubTask(task, task.StopReason())
mt.subTasks = slices.Delete(mt.subTasks, taskIndex, taskIndex+1)
cases = slices.Delete(cases, chosen, chosen+1)
if !mt.KeepAlive && len(mt.subTasks) == 0 {
mt.Stop(ErrAutoStop)
}
}
}
}
// Run task Start and Dispose in another goroutine
//func (t *TaskManager) Run() {
// cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.Start)}}
// defer func() {
// cases = slices.Delete(cases, 0, 1)
// for len(cases) > 0 {
// chosen, _, _ := reflect.Select(cases)
// t.Done <- t.Tasks[chosen]
// t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// close(t.Done)
// }()
// for {
// if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
// if !ok {
// return
// }
// task := rev.Interface().(*Task)
// t.Tasks = append(t.Tasks, task)
// cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
// } else {
// t.Done <- t.Tasks[chosen-1]
// t.Tasks = slices.Delete(t.Tasks, chosen-1, chosen)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// }
//}
// ShutDown wait all task dispose
func (t *TaskManager) ShutDown(err error) {
t.Stop(err)
_ = t.shutdown.Await()
func (mt *MarcoTask) ShutDown(err error) {
mt.Stop(err)
_ = mt.shutdown.Await()
}
func (t *TaskManager) Stop(err error) {
t.stopReason = err
close(t.start)
func (mt *MarcoTask) Dispose() {
if mt.exe != nil {
mt.exe.Dispose()
}
close(mt.addSub)
}

View File

@@ -2,16 +2,7 @@ package m7s
import (
"context"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
myip "github.com/husanpao/ip"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
"log/slog"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/db"
"m7s.live/m7s/v5/pkg/util"
"net"
"net/http"
"os"
@@ -19,6 +10,16 @@ import (
"reflect"
"runtime"
"strings"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
myip "github.com/husanpao/ip"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/db"
"m7s.live/m7s/v5/pkg/util"
)
type DefaultYaml string
@@ -51,7 +52,7 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
p.handler = instance
p.Meta = plugin
p.Server = s
p.Task.Init(s.Context, s.Logger.With("plugin", plugin.Name), instance)
p.MarcoTask.Init(s.Context, s.Logger.With("plugin", plugin.Name), instance)
upperName := strings.ToUpper(plugin.Name)
if os.Getenv(upperName+"_ENABLE") == "false" {
p.Disabled = true
@@ -180,7 +181,7 @@ func InstallPlugin[C iPlugin](options ...any) error {
}
type Plugin struct {
Task
MarcoTask
Disabled bool
Meta *PluginMeta
config config.Common
@@ -352,7 +353,7 @@ func (p *Plugin) Publish(streamPath string, options ...any) (publisher *Publishe
}
}
}
err = p.Server.streamTM.Start(publisher)
err = p.Server.streamTask.WaitTaskAdded(publisher)
return
}
@@ -370,48 +371,20 @@ func (p *Plugin) Subscribe(streamPath string, options ...any) (subscriber *Subsc
}
}
}
err = p.Server.streamTM.Start(subscriber)
err = p.Server.streamTask.WaitTaskAdded(subscriber)
err = subscriber.Publisher.WaitTrack()
return
}
func (p *Plugin) pull(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
ctx = createPullContext(p, streamPath, url, options...)
err = p.Server.pullTM.Start(ctx)
return
}
func (p *Plugin) Pull(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
if ctx, err = p.pull(streamPath, url, options...); err == nil && p.Meta.Puller != nil {
go p.Meta.Puller(ctx)
}
return
}
func (p *Plugin) PullBlock(streamPath string, url string, options ...any) (ctx *PullContext, err error) {
if ctx, err = p.pull(streamPath, url, options...); err == nil && p.Meta.Puller != nil {
err = p.Meta.Puller(ctx)
}
return
}
func (p *Plugin) push(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
ctx = createPushContext(p, streamPath, url, options...)
err = p.Server.pushTM.Start(ctx)
ctx = createPullContext(p, streamPath, url, options...)
err = p.Server.pullTask.WaitTaskAdded(ctx);
return
}
func (p *Plugin) Push(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
if ctx, err = p.push(streamPath, url, options...); err == nil && p.Meta.Pusher != nil {
go p.Meta.Pusher(ctx)
}
return
}
func (p *Plugin) PushBlock(streamPath string, url string, options ...any) (ctx *PushContext, err error) {
if ctx, err = p.push(streamPath, url, options...); err == nil && p.Meta.Pusher != nil {
err = p.Meta.Pusher(ctx)
}
ctx = createPushContext(p, streamPath, url, options...)
err = p.Server.pushTask.WaitTaskAdded(ctx)
return
}
@@ -428,7 +401,7 @@ func (p *Plugin) record(streamPath string, filePath string, options ...any) (ctx
if err != nil {
return
}
err = p.Server.recordTM.Start(ctx)
err = p.Server.recordTask.WaitTaskAdded(ctx)
return
}
@@ -501,7 +474,7 @@ func (p *Plugin) AddLogHandler(handler slog.Handler) {
}
func (p *Plugin) SaveConfig() (err error) {
p.Server.pluginTM.Call(func() {
p.Server.Call(func(*Task) (err error) {
if p.Modify == nil {
os.Remove(p.settingPath())
return
@@ -512,6 +485,7 @@ func (p *Plugin) SaveConfig() (err error) {
}
defer file.Close()
err = yaml.NewEncoder(file).Encode(p.Modify)
return
})
if err == nil {
p.Info("config saved")

View File

@@ -22,7 +22,7 @@ const defaultConfig m7s.DefaultYaml = `publish:
func (p *FLVPlugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url, PullFLV)
p.Pull(streamPath, url, PullFLV)
}
return nil
}

View File

@@ -1,20 +1,21 @@
package plugin_mp4
import (
"github.com/Eyevinn/mp4ff/mp4"
"io"
"m7s.live/m7s/v5"
v5 "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/codec"
"m7s.live/m7s/v5/pkg/util"
pkg "m7s.live/m7s/v5/plugin/mp4/pkg"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
"maps"
"net"
"net/http"
"slices"
"strings"
"time"
"github.com/Eyevinn/mp4ff/mp4"
"m7s.live/m7s/v5"
v5 "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/codec"
"m7s.live/m7s/v5/pkg/util"
pkg "m7s.live/m7s/v5/plugin/mp4/pkg"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
)
type MediaContext struct {
@@ -75,7 +76,7 @@ const defaultConfig m7s.DefaultYaml = `publish:
func (p *MP4Plugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url)
p.Pull(streamPath, url)
}
return nil
}

View File

@@ -5,6 +5,8 @@ import (
"io"
"os"
"testing"
"m7s.live/m7s/v5/pkg/util"
)
func TestCreateMovDemuxer(t *testing.T) {
@@ -27,7 +29,7 @@ func TestCreateMovDemuxer(t *testing.T) {
mp4info := demuxer.GetMp4Info()
fmt.Printf("%+v\n", mp4info)
for {
pkg, err := demuxer.ReadPacket()
pkg, err := demuxer.ReadPacket(util.NewScalableMemoryAllocator(1 << 10))
if err != nil {
fmt.Println(err)
break

View File

@@ -1,224 +1,211 @@
package box
import (
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"testing"
"github.com/yapingcat/gomedia/go-codec"
"github.com/yapingcat/gomedia/go-mpeg2"
)
func TestCreateMp4Reader(t *testing.T) {
f, err := os.Open("jellyfish-3-mbps-hd.h264.mp4")
if err != nil {
fmt.Println(err)
return
}
defer f.Close()
for err == nil {
nn := int64(0)
size := make([]byte, 4)
_, err = io.ReadFull(f, size)
if err != nil {
break
}
nn += 4
boxtype := make([]byte, 4)
_, err = io.ReadFull(f, boxtype)
if err != nil {
break
}
nn += 4
var isize uint64 = uint64(binary.BigEndian.Uint32(size))
if isize == 1 {
size := make([]byte, 8)
_, err = io.ReadFull(f, size)
if err != nil {
break
}
isize = binary.BigEndian.Uint64(size)
nn += 8
}
fmt.Printf("Read Box(%s) size:%d\n", boxtype, isize)
f.Seek(int64(isize)-nn, 1)
}
}
func TestCreateMp4Muxer(t *testing.T) {
f, err := os.Open("jellyfish-3-mbps-hd.h265")
if err != nil {
fmt.Println(err)
return
}
defer f.Close()
mp4filename := "jellyfish-3-mbps-hd.h265.mp4"
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
fmt.Println(err)
return
}
defer mp4file.Close()
buf, _ := ioutil.ReadAll(f)
pts := uint64(0)
dts := uint64(0)
ii := [3]uint64{33, 33, 34}
idx := 0
type args struct {
wh io.WriteSeeker
}
tests := []struct {
name string
args args
want *Movmuxer
}{
{name: "muxer h264", args: args{wh: mp4file}, want: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
muxer, err := CreateMp4Muxer(tt.args.wh)
if err != nil {
fmt.Println(err)
return
}
tid := muxer.AddVideoTrack(MP4_CODEC_H265)
cache := make([]byte, 0)
codec.SplitFrameWithStartCode(buf, func(nalu []byte) bool {
ntype := codec.H265NaluType(nalu)
if !codec.IsH265VCLNaluType(ntype) {
cache = append(cache, nalu...)
return true
}
if len(cache) > 0 {
cache = append(cache, nalu...)
muxer.Write(tid, cache, pts, dts)
cache = cache[:0]
} else {
muxer.Write(tid, nalu, pts, dts)
}
pts += ii[idx]
dts += ii[idx]
idx++
idx = idx % 3
return true
})
fmt.Printf("last dts %d\n", dts)
muxer.WriteTrailer()
})
}
}
func TestMuxAAC(t *testing.T) {
f, err := os.Open("test.aac")
if err != nil {
fmt.Println(err)
return
}
defer f.Close()
mp4filename := "aac.mp4"
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
fmt.Println(err)
return
}
defer mp4file.Close()
aac, _ := ioutil.ReadAll(f)
var pts uint64 = 0
//var dts uint64 = 0
//var i int = 0
samples := uint64(0)
muxer, err := CreateMp4Muxer(mp4file)
if err != nil {
fmt.Println(err)
return
}
tid := muxer.AddAudioTrack(MP4_CODEC_AAC)
codec.SplitAACFrame(aac, func(aac []byte) {
samples += 1024
pts = samples * 1000 / 44100
// if i < 3 {
// pts += 23
// dts += 23
// i++
// } else {
// pts += 24
// dts += 24
// i = 0
// func TestCreateMp4Reader(t *testing.T) {
// f, err := os.Open("jellyfish-3-mbps-hd.h264.mp4")
// if err != nil {
// fmt.Println(err)
// return
// }
// defer f.Close()
// for err == nil {
// nn := int64(0)
// size := make([]byte, 4)
// _, err = io.ReadFull(f, size)
// if err != nil {
// break
// }
// nn += 4
// boxtype := make([]byte, 4)
// _, err = io.ReadFull(f, boxtype)
// if err != nil {
// break
// }
// nn += 4
// var isize uint64 = uint64(binary.BigEndian.Uint32(size))
// if isize == 1 {
// size := make([]byte, 8)
// _, err = io.ReadFull(f, size)
// if err != nil {
// break
// }
// isize = binary.BigEndian.Uint64(size)
// nn += 8
// }
// fmt.Printf("Read Box(%s) size:%d\n", boxtype, isize)
// f.Seek(int64(isize)-nn, 1)
// }
// }
muxer.Write(tid, aac, pts, pts)
//fmt.Println(pts)
})
muxer.WriteTrailer()
}
func TestMuxMp4(t *testing.T) {
tsfilename := `demo.ts` // input
tsfile, err := os.Open(tsfilename)
if err != nil {
fmt.Println(err)
return
}
defer tsfile.Close()
// func TestCreateMp4Muxer(t *testing.T) {
mp4filename := "test14.mp4" // output
mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
fmt.Println(err)
return
}
defer mp4file.Close()
// f, err := os.Open("jellyfish-3-mbps-hd.h265")
// if err != nil {
// fmt.Println(err)
// return
// }
// defer f.Close()
muxer, err := CreateMp4Muxer(mp4file)
if err != nil {
fmt.Println(err)
return
}
vtid := muxer.AddVideoTrack(MP4_CODEC_H264)
atid := muxer.AddAudioTrack(MP4_CODEC_AAC)
// mp4filename := "jellyfish-3-mbps-hd.h265.mp4"
// mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer mp4file.Close()
afile, err := os.OpenFile("r.aac", os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
fmt.Println(err)
return
}
defer afile.Close()
demuxer := mpeg2.NewTSDemuxer()
demuxer.OnFrame = func(cid mpeg2.TS_STREAM_TYPE, frame []byte, pts uint64, dts uint64) {
// buf, _ := ioutil.ReadAll(f)
// pts := uint64(0)
// dts := uint64(0)
// ii := [3]uint64{33, 33, 34}
// idx := 0
if cid == mpeg2.TS_STREAM_AAC {
err = muxer.Write(atid, frame, uint64(pts), uint64(dts))
if err != nil {
panic(err)
}
} else if cid == mpeg2.TS_STREAM_H264 {
fmt.Println("pts,dts,len", pts, dts, len(frame))
err = muxer.Write(vtid, frame, uint64(pts), uint64(dts))
if err != nil {
panic(err)
}
} else {
panic("unkwon cid " + strconv.Itoa(int(cid)))
}
}
// type args struct {
// wh io.WriteSeeker
// }
// tests := []struct {
// name string
// args args
// want *Movmuxer
// }{
// {name: "muxer h264", args: args{wh: mp4file}, want: nil},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// muxer, err := CreateMp4Muxer(tt.args.wh)
// if err != nil {
// fmt.Println(err)
// return
// }
// tid := muxer.AddVideoTrack(MP4_CODEC_H265)
// cache := make([]byte, 0)
// codec.SplitFrameWithStartCode(buf, func(nalu []byte) bool {
// ntype := codec.H265NaluType(nalu)
// if !codec.IsH265VCLNaluType(ntype) {
// cache = append(cache, nalu...)
// return true
// }
// if len(cache) > 0 {
// cache = append(cache, nalu...)
// muxer.Write(tid, cache, pts, dts)
// cache = cache[:0]
// } else {
// muxer.Write(tid, nalu, pts, dts)
// }
// pts += ii[idx]
// dts += ii[idx]
// idx++
// idx = idx % 3
// return true
// })
// fmt.Printf("last dts %d\n", dts)
// muxer.WriteTrailer()
// })
// }
// }
err = demuxer.Input(tsfile)
if err != nil {
panic(err)
}
// func TestMuxAAC(t *testing.T) {
// f, err := os.Open("test.aac")
// if err != nil {
// fmt.Println(err)
// return
// }
// defer f.Close()
err = muxer.WriteTrailer()
if err != nil {
panic(err)
}
}
// mp4filename := "aac.mp4"
// mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer mp4file.Close()
// aac, _ := ioutil.ReadAll(f)
// var pts uint64 = 0
// //var dts uint64 = 0
// //var i int = 0
// samples := uint64(0)
// muxer, err := CreateMp4Muxer(mp4file)
// if err != nil {
// fmt.Println(err)
// return
// }
// tid := muxer.AddAudioTrack(MP4_CODEC_AAC)
// codec.SplitAACFrame(aac, func(aac []byte) {
// samples += 1024
// pts = samples * 1000 / 44100
// // if i < 3 {
// // pts += 23
// // dts += 23
// // i++
// // } else {
// // pts += 24
// // dts += 24
// // i = 0
// // }
// muxer.Write(tid, aac, pts, pts)
// //fmt.Println(pts)
// })
// muxer.WriteTrailer()
// }
// func TestMuxMp4(t *testing.T) {
// tsfilename := `demo.ts` // input
// tsfile, err := os.Open(tsfilename)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer tsfile.Close()
// mp4filename := "test14.mp4" // output
// mp4file, err := os.OpenFile(mp4filename, os.O_CREATE|os.O_RDWR, 0666)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer mp4file.Close()
// muxer, err := CreateMp4Muxer(mp4file)
// if err != nil {
// fmt.Println(err)
// return
// }
// vtid := muxer.AddVideoTrack(MP4_CODEC_H264)
// atid := muxer.AddAudioTrack(MP4_CODEC_AAC)
// afile, err := os.OpenFile("r.aac", os.O_CREATE|os.O_RDWR, 0666)
// if err != nil {
// fmt.Println(err)
// return
// }
// defer afile.Close()
// demuxer := mpeg2.NewTSDemuxer()
// demuxer.OnFrame = func(cid mpeg2.TS_STREAM_TYPE, frame []byte, pts uint64, dts uint64) {
// if cid == mpeg2.TS_STREAM_AAC {
// err = muxer.Write(atid, frame, uint64(pts), uint64(dts))
// if err != nil {
// panic(err)
// }
// } else if cid == mpeg2.TS_STREAM_H264 {
// fmt.Println("pts,dts,len", pts, dts, len(frame))
// err = muxer.Write(vtid, frame, uint64(pts), uint64(dts))
// if err != nil {
// panic(err)
// }
// } else {
// panic("unkwon cid " + strconv.Itoa(int(cid)))
// }
// }
// err = demuxer.Input(tsfile)
// if err != nil {
// panic(err)
// }
// err = muxer.WriteTrailer()
// if err != nil {
// panic(err)
// }
// }

View File

@@ -3,11 +3,13 @@ package plugin_preview
import (
"embed"
"fmt"
"m7s.live/m7s/v5"
"mime"
"net/http"
"path/filepath"
"strings"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
)
//go:embed ui
@@ -22,11 +24,14 @@ var _ = m7s.InstallPlugin[PreviewPlugin]()
func (p *PreviewPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
s := "<h1><h1><h2>Live Streams 引擎中正在发布的流</h2>"
p.Server.Call(func() {
p.Server.CallOnStreamTask(func(*pkg.Task) error {
for publisher := range p.Server.Streams.Range {
s += fmt.Sprintf("<a href='%s'>%s</a> [ %s ]<br>", publisher.StreamPath, publisher.StreamPath, publisher.Plugin.Meta.Name)
}
s += "<h2>pull stream on subscribe 订阅时才会触发拉流的流</h2>"
return nil
})
p.Server.Call(func(*pkg.Task) error {
for plugin := range p.Server.Plugins.Range {
if pullPlugin, ok := plugin.GetHandler().(m7s.IPullerPlugin); ok {
s += fmt.Sprintf("<h3>%s</h3>", plugin.Meta.Name)
@@ -35,6 +40,7 @@ func (p *PreviewPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
}
return nil
})
w.Write([]byte(s))
return

View File

@@ -3,13 +3,14 @@ package plugin_rtmp
import (
"errors"
"io"
"maps"
"net"
"slices"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/plugin/rtmp/pb"
. "m7s.live/m7s/v5/plugin/rtmp/pkg"
"maps"
"net"
"slices"
)
type RTMPPlugin struct {
@@ -18,24 +19,18 @@ type RTMPPlugin struct {
ChunkSize int `default:"1024"`
KeepAlive bool
C2 bool
connTM *pkg.TaskManager
}
var _ = m7s.InstallPlugin[RTMPPlugin](m7s.DefaultYaml(`tcp:
listenaddr: :1935`), &pb.Rtmp_ServiceDesc, pb.RegisterRtmpHandler, Pull, Push)
func (p *RTMPPlugin) OnInit() error {
p.connTM = pkg.StartTaskManager()
for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url)
p.Pull(streamPath, url)
}
return nil
}
func (p *RTMPPlugin) Dispose() {
p.connTM.ShutDown(p.StopReason())
}
func (p *RTMPPlugin) GetPullableList() []string {
return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub))
}
@@ -44,14 +39,15 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
receivers := make(map[uint32]*Receiver)
var err error
nc := NewNetConnection(conn)
p.With(nc, "remote", conn.RemoteAddr().String())
p.connTM.Add(nc)
var connTask pkg.MarcoTask
connTask.Init(p.Context, p.With("remote", conn.RemoteAddr().String()), nc)
p.AddTask(&connTask)
defer func() {
nc.Stop(err)
connTask.Stop(err)
}()
/* Handshake */
if err = nc.Handshake(p.C2); err != nil {
nc.Error("handshake", "error", err)
connTask.Error("handshake", "error", err)
return
}
var msg *Chunk
@@ -63,13 +59,15 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
continue
}
switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE:
connTask.Info("msg read chunk size", "readChunkSize", nc.ReadChunkSize)
case RTMP_MSG_AMF0_COMMAND:
if msg.MsgData == nil {
err = errors.New("msg.MsgData is nil")
break
}
cmd := msg.MsgData.(Commander).GetCommand()
nc.Debug("recv cmd", "commandName", cmd.CommandName, "streamID", msg.MessageStreamID)
connTask.Debug("recv cmd", "commandName", cmd.CommandName, "streamID", msg.MessageStreamID)
switch cmd := msg.MsgData.(type) {
case *CallMessage: //connect
connectInfo = cmd.Object
@@ -82,16 +80,16 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
nc.ObjectEncoding = 0
}
nc.AppName = app.(string)
nc.Info("connect", "appName", nc.AppName, "objectEncoding", nc.ObjectEncoding)
connTask.Info("connect", "appName", nc.AppName, "objectEncoding", nc.ObjectEncoding)
err = nc.SendMessage(RTMP_MSG_ACK_SIZE, Uint32Message(512<<10))
if err != nil {
nc.Error("sendMessage ack size", "error", err)
connTask.Error("sendMessage ack size", "error", err)
return
}
nc.WriteChunkSize = p.ChunkSize
err = nc.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(p.ChunkSize))
if err != nil {
nc.Error("sendMessage chunk size", "error", err)
connTask.Error("sendMessage chunk size", "error", err)
return
}
err = nc.SendMessage(RTMP_MSG_BANDWIDTH, &SetPeerBandwidthMessage{
@@ -99,12 +97,12 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
LimitType: byte(2),
})
if err != nil {
nc.Error("sendMessage bandwidth", "error", err)
connTask.Error("sendMessage bandwidth", "error", err)
return
}
err = nc.SendStreamID(RTMP_USER_STREAM_BEGIN, 0)
if err != nil {
nc.Error("sendMessage stream begin", "error", err)
connTask.Error("sendMessage stream begin", "error", err)
return
}
m := new(ResponseConnectMessage)
@@ -123,11 +121,11 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
}
err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
if err != nil {
nc.Error("sendMessage connect", "error", err)
connTask.Error("sendMessage connect", "error", err)
}
case *CommandMessage: // "createStream"
gstreamid++
nc.Info("createStream:", "streamId", gstreamid)
connTask.Info("createStream:", "streamId", gstreamid)
nc.ResponseCreateStream(cmd.TransactionId, gstreamid)
case *CURDStreamMessage:
// if stream, ok := receivers[cmd.StreamId]; ok {
@@ -155,7 +153,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
StreamID: cmd.StreamId,
},
}
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, nc.Context, connectInfo)
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, connTask.Context, connectInfo)
if err != nil {
delete(receivers, cmd.StreamId)
err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
@@ -164,7 +162,9 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
err = receiver.BeginPublish(cmd.TransactionId)
}
if err != nil {
nc.Error("sendMessage publish", "error", err)
connTask.Error("sendMessage publish", "error", err)
} else {
connTask.AddTask(receiver.Publisher)
}
case *PlayMessage:
streamPath := nc.AppName + "/" + cmd.StreamName
@@ -174,7 +174,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
}
var suber *m7s.Subscriber
// sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID)
suber, err = p.Subscribe(streamPath, nc.Context, connectInfo)
suber, err = p.Subscribe(streamPath, connTask.Context, connectInfo)
if err != nil {
err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else {
@@ -183,7 +183,9 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
go m7s.PlayBlock(suber, audio.HandleAudio, video.HandleVideo)
}
if err != nil {
nc.Error("sendMessage play", "error", err)
connTask.Error("sendMessage play", "error", err)
} else {
connTask.AddTask(suber)
}
}
case RTMP_MSG_AUDIO:
@@ -191,20 +193,20 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
err = r.WriteAudio(msg.AVData.WrapAudio())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID)
connTask.Warn("ReceiveAudio", "MessageStreamID", msg.MessageStreamID)
}
case RTMP_MSG_VIDEO:
if r, ok := receivers[msg.MessageStreamID]; ok && r.PubVideo {
err = r.WriteVideo(msg.AVData.WrapVideo())
} else {
msg.AVData.Recycle()
nc.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID)
connTask.Warn("ReceiveVideo", "MessageStreamID", msg.MessageStreamID)
}
}
} else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
nc.Info("rtmp client closed")
connTask.Info("rtmp client closed")
} else {
nc.Warn("ReadMessage", "error", err)
connTask.Warn("ReadMessage", "error", err)
}
}
}

View File

@@ -43,7 +43,6 @@ func createClient(c *m7s.Connection) (*NetStream, error) {
}
ns := &NetStream{}
ns.NetConnection = NewNetConnection(conn)
c.With(ns)
defer func() {
if err != nil {
ns.Dispose()

View File

@@ -2,7 +2,6 @@ package rtmp
import (
"errors"
"m7s.live/m7s/v5/pkg"
"net"
"runtime"
"sync/atomic"
@@ -43,16 +42,12 @@ const (
)
type NetConnection struct {
pkg.Task
*util.BufReader
net.Conn
bandwidth uint32
readSeqNum uint32 // 当前读的字节
writeSeqNum uint32 // 当前写的字节
totalWrite uint32 // 总共写了多少字节
totalRead uint32 // 总共读了多少字节
WriteChunkSize int
readChunkSize int
readSeqNum, writeSeqNum uint32 // 当前读的字节
totalRead, totalWrite uint32 // 总共读写了多少字节
ReadChunkSize, WriteChunkSize int
incommingChunks map[uint32]*Chunk
ObjectEncoding float64
AppName string
@@ -67,7 +62,7 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) {
Conn: conn,
BufReader: util.NewBufReader(conn),
WriteChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
readChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
ReadChunkSize: RTMP_DEFAULT_CHUNK_SIZE,
incommingChunks: make(map[uint32]*Chunk),
bandwidth: RTMP_MAX_CHUNK_SIZE << 3,
tmpBuf: make(util.Buffer, 4),
@@ -150,10 +145,10 @@ func (conn *NetConnection) readChunk() (msg *Chunk, err error) {
return nil, nil
}
var bufSize = 0
if unRead := msgLen - chunk.bufLen; unRead < conn.readChunkSize {
if unRead := msgLen - chunk.bufLen; unRead < conn.ReadChunkSize {
bufSize = unRead
} else {
bufSize = conn.readChunkSize
bufSize = conn.ReadChunkSize
}
conn.readSeqNum += uint32(bufSize)
if chunk.bufLen == 0 {
@@ -275,8 +270,7 @@ func (conn *NetConnection) RecvMessage() (msg *Chunk, err error) {
if msg, err = conn.readChunk(); msg != nil && err == nil {
switch msg.MessageTypeID {
case RTMP_MSG_CHUNK_SIZE:
conn.readChunkSize = int(msg.MsgData.(Uint32Message))
conn.Info("msg read chunk size", "readChunkSize", conn.readChunkSize)
conn.ReadChunkSize = int(msg.MsgData.(Uint32Message))
case RTMP_MSG_ABORT:
delete(conn.incommingChunks, uint32(msg.MsgData.(Uint32Message)))
case RTMP_MSG_ACK, RTMP_MSG_EDGE:

View File

@@ -1,24 +1,22 @@
package rtp
import (
"testing"
"github.com/pion/webrtc/v3"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/util"
"testing"
)
func TestRTPH264Ctx_CreateFrame(t *testing.T) {
var ctx = &H264Ctx{
RTPCtx: RTPCtx{
RTPCodecParameters: webrtc.RTPCodecParameters{
var ctx = &H264Ctx{}
ctx.RTPCodecParameters = webrtc.RTPCodecParameters{
PayloadType: 96,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: 90000,
SDPFmtpLine: "packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029",
},
},
},
}
var randStr = util.RandomString(1500)
var avFrame = &pkg.AVFrame{}

View File

@@ -3,9 +3,6 @@ package plugin_rtsp
import (
"errors"
"fmt"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
. "m7s.live/m7s/v5/plugin/rtsp/pkg"
"maps"
"net"
"net/http"
@@ -13,6 +10,10 @@ import (
"slices"
"strconv"
"strings"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
. "m7s.live/m7s/v5/plugin/rtsp/pkg"
)
const defaultConfig = m7s.DefaultYaml(`tcp:
@@ -30,7 +31,7 @@ func (p *RTSPPlugin) GetPullableList() []string {
func (p *RTSPPlugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.PullBlock(streamPath, url)
p.Pull(streamPath, url)
}
return nil
}

View File

@@ -3,6 +3,7 @@ package plugin_stress
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/emptypb"
"m7s.live/m7s/v5"
gpb "m7s.live/m7s/v5/pb"
@@ -20,7 +21,13 @@ func (r *StressPlugin) pull(count int, format, url string, puller m7s.Puller) er
if err != nil {
return err
}
go r.startPull(ctx, puller)
ctx.AddCall(func(*pkg.Task) error {
r.pullers.AddUnique(ctx)
ctx.Do(puller)
return nil
}, func(*pkg.Task) {
r.pullers.Remove(ctx)
})
}
} else if count < i {
for j := i; j > count; j-- {
@@ -38,7 +45,13 @@ func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, pu
if err != nil {
return err
}
go r.startPush(ctx, pusher)
ctx.AddCall(func(*pkg.Task) error {
r.pushers.AddUnique(ctx)
ctx.Do(pusher)
return nil
}, func(*pkg.Task) {
r.pushers.Remove(ctx)
})
}
} else if count < i {
for j := i; j > count; j-- {
@@ -69,18 +82,6 @@ func (r *StressPlugin) PullHDL(ctx context.Context, req *pb.PullRequest) (res *g
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.PullFLV)
}
func (r *StressPlugin) startPush(pusher *m7s.PushContext, handler m7s.Pusher) {
r.pushers.AddUnique(pusher)
pusher.Run(handler)
r.pushers.Remove(pusher)
}
func (r *StressPlugin) startPull(puller *m7s.PullContext, handler m7s.Puller) {
r.pullers.AddUnique(puller)
puller.Run(handler)
r.pullers.Remove(puller)
}
func (r *StressPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *gpb.SuccessResponse, err error) {
for pusher := range r.pushers.Range {
pusher.Stop(pkg.ErrStopFromAPI)

View File

@@ -91,7 +91,7 @@ func (p *Publisher) GetKey() string {
func createPublisher(p *Plugin, streamPath string, options ...any) (publisher *Publisher) {
publisher = &Publisher{Publish: p.config.Publish}
publisher.ID = p.Server.streamTM.GetID()
publisher.ID = p.Server.streamTask.GetID()
publisher.Plugin = p
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
var opt = []any{publisher, p.Logger.With("streamPath", streamPath, "pId", publisher.ID)}
@@ -137,7 +137,7 @@ func (p *Publisher) Start() (err error) {
}
if remoteURL := plugin.GetCommonConf().CheckPush(p.StreamPath); remoteURL != "" {
if plugin.Meta.Pusher != nil {
go plugin.PushBlock(p.StreamPath, remoteURL, plugin.Meta.Pusher)
plugin.Push(p.StreamPath, remoteURL, plugin.Meta.Pusher)
}
}
if filePath := plugin.GetCommonConf().CheckRecord(p.StreamPath); filePath != "" {

View File

@@ -2,13 +2,14 @@ package m7s
import (
"context"
"time"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"time"
)
type Connection struct {
pkg.Task
pkg.MarcoTask
Plugin *Plugin
StreamPath string // 对应本地流
RemoteURL string // 远程服务器地址(用于推拉)
@@ -27,7 +28,6 @@ type Puller = func(*PullContext) error
func createPullContext(p *Plugin, streamPath string, url string, options ...any) (pullCtx *PullContext) {
pullCtx = &PullContext{Pull: p.config.Pull}
pullCtx.ID = p.Server.pullTM.GetID()
pullCtx.Plugin = p
pullCtx.ConnectProxy = p.config.Pull.Proxy
pullCtx.RemoteURL = url
@@ -44,8 +44,9 @@ func createPullContext(p *Plugin, streamPath string, url string, options ...any)
pullCtx.PublishOptions = append(pullCtx.PublishOptions, option)
}
}
p.Init(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath))
p.InitKeepAlive(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath), pullCtx)
pullCtx.PublishOptions = append(pullCtx.PublishOptions, pullCtx.Context)
p.Server.pullTask.AddTask(pullCtx)
return
}
@@ -60,30 +61,25 @@ func (p *PullContext) GetKey() string {
return p.StreamPath
}
func (p *PullContext) Run(puller Puller) {
var err error
for p.reconnect(p.RePull) {
if p.Publisher != nil {
if time.Since(p.Publisher.StartTime) < 5*time.Second {
func (p *PullContext) Do(puller Puller) {
p.AddCall(func(tmpTask *pkg.Task) (err error) {
publishOptions := append([]any{tmpTask.Context}, p.PublishOptions...)
if p.Publisher, err = p.Plugin.Publish(p.StreamPath, publishOptions...); err != nil {
p.Error("pull publish failed", "error", err)
return
}
err = puller(p)
if p.reconnect(p.RePull) {
if time.Since(tmpTask.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second)
}
p.Warn("retry", "count", p.ReConnectCount, "total", p.RePull)
p.Do(puller)
} else {
p.Stop(pkg.ErrRetryRunOut)
}
if p.Publisher, err = p.Plugin.Publish(p.StreamPath, p.PublishOptions...); err != nil {
p.Error("pull publish failed", "error", err)
break
}
err = puller(p)
p.Publisher.Stop(err)
if p.IsStopped() {
return
}
p.Error("pull interrupt", "error", err)
}
if err == nil {
err = pkg.ErrRetryRunOut
}
p.Stop(err)
}, nil)
}
func (p *PullContext) Start() (err error) {
@@ -92,6 +88,9 @@ func (p *PullContext) Start() (err error) {
return pkg.ErrStreamExist
}
s.Pulls.Add(p)
if p.Plugin.Meta.Puller != nil {
p.Do(p.Plugin.Meta.Puller)
}
return
}

View File

@@ -2,9 +2,10 @@ package m7s
import (
"context"
"m7s.live/m7s/v5/pkg"
"time"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
)
@@ -12,7 +13,7 @@ type Pusher = func(*PushContext) error
func createPushContext(p *Plugin, streamPath string, url string, options ...any) (pushCtx *PushContext) {
pushCtx = &PushContext{Push: p.config.Push}
pushCtx.ID = p.Server.pushTM.GetID()
pushCtx.ID = p.Server.pushTask.GetID()
pushCtx.Plugin = p
pushCtx.RemoteURL = url
pushCtx.StreamPath = streamPath
@@ -43,31 +44,27 @@ func (p *PushContext) GetKey() string {
return p.RemoteURL
}
func (p *PushContext) Run(pusher Pusher) {
p.StartTime = time.Now()
defer p.Info("stop push")
var err error
for p.Info("start push", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePush); p.Warn("restart push") {
func (p *PushContext) Do(pusher Pusher) {
p.AddCall(func(tmpTask *pkg.Task) (err error) {
if p.Subscriber != nil && time.Since(p.Subscriber.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second)
}
if p.Subscriber, err = p.Plugin.Subscribe(p.StreamPath, p.SubscribeOptions...); err != nil {
p.Error("push subscribe failed", "error", err)
break
return
}
err = pusher(p)
p.Subscriber.Stop(err)
if p.IsStopped() {
return
if p.Connection.reconnect(p.RePush) {
if time.Since(tmpTask.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second)
}
p.Warn("retry", "count", p.ReConnectCount, "total", p.RePush)
p.Do(pusher)
} else {
p.Error("push interrupt", "error", err)
p.Stop(pkg.ErrRetryRunOut)
}
}
if err == nil {
err = pkg.ErrRetryRunOut
}
p.Stop(err)
return
}, nil)
}
func (p *PushContext) Start() (err error) {
@@ -76,6 +73,9 @@ func (p *PushContext) Start() (err error) {
return pkg.ErrPushRemoteURLExist
}
s.Pushs.Add(p)
if p.Plugin.Meta.Pusher != nil {
p.Do(p.Plugin.Meta.Pusher)
}
return
}

View File

@@ -15,7 +15,7 @@ func createRecoder(p *Plugin, streamPath string, filePath string, options ...any
Append: p.config.Record.Append,
FilePath: filePath,
}
recorder.ID = p.Server.recordTM.GetID()
recorder.ID = p.Server.recordTask.GetID()
recorder.FilePath = filePath
recorder.SubscribeOptions = []any{p.config.Subscribe}
var ctx = p.Context

View File

@@ -36,7 +36,7 @@ var (
Version: Version,
}
Servers util.Collection[uint32, *Server]
serverTM = NewTaskManager()
globalTask MarcoTask
Routes = map[string]string{}
defaultLogHandler = console.NewHandler(os.Stdout, &console.HandlerOptions{TimeFormat: "15:04:05.000000"})
)
@@ -53,7 +53,6 @@ type Server struct {
pb.UnimplementedGlobalServer
Plugin
ServerConfig
//eventChan chan any
Plugins util.Collection[string, *Plugin]
Streams, Waiting util.Collection[string, *Publisher]
Pulls util.Collection[string, *PullContext]
@@ -67,13 +66,13 @@ type Server struct {
tcplis net.Listener
lastSummaryTime time.Time
lastSummary *pb.SummaryResponse
pluginTM, streamTM, pullTM, pushTM, recordTM *TaskManager
streamTask, pullTask, pushTask, recordTask MarcoTask
conf any
}
func NewServer() (s *Server) {
s = &Server{}
s.ID = serverTM.GetID()
s.ID = globalTask.GetID()
s.Meta = &serverMeta
return
}
@@ -87,12 +86,16 @@ type rawconfig = map[string]map[string]any
func init() {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go serverTM.Run(signalChan, func(os.Signal) {
globalTask.InitKeepAlive(context.Background(), nil, nil, signalChan, func(os.Signal) {
for _, meta := range plugins {
if meta.OnExit != nil {
meta.OnExit()
}
}
if serverMeta.OnExit != nil {
serverMeta.OnExit()
}
os.Exit(0)
})
for k, v := range myip.LocalAndInternalIPs() {
Routes[k] = v
@@ -116,7 +119,7 @@ func (s *Server) Init(ctx context.Context, conf any) {
s.config.TCP.ListenAddr = ":50051"
s.LogHandler.SetLevel(slog.LevelInfo)
s.LogHandler.Add(defaultLogHandler)
s.Task.Init(ctx, slog.New(&s.LogHandler).With("Server", s.ID))
s.MarcoTask.Init(ctx, slog.New(&s.LogHandler).With("Server", s.ID), s)
}
func (s *Server) Start() (err error) {
@@ -204,10 +207,9 @@ func (s *Server) Start() (err error) {
return
}
}
s.pluginTM = StartTaskManager()
for _, plugin := range plugins {
if p := plugin.Init(s, cg[strings.ToLower(plugin.Name)]); !p.Disabled {
s.pluginTM.Start(&p.Task)
s.WaitTaskAdded(p)
}
}
if s.tcplis != nil {
@@ -219,7 +221,7 @@ func (s *Server) Start() (err error) {
}
}(tcpConf.ListenAddr)
}
s.streamTM = StartTaskManager(time.NewTicker(s.PulseInterval).C, func(time.Time) {
s.streamTask.InitKeepAlive(s.Context, nil, nil, time.NewTicker(s.PulseInterval).C, func(time.Time) {
for publisher := range s.Streams.Range {
if err := publisher.checkTimeout(); err != nil {
publisher.Stop(err)
@@ -242,15 +244,16 @@ func (s *Server) Start() (err error) {
}
}
})
s.pullTM = StartTaskManager()
s.pushTM = StartTaskManager()
s.recordTM = StartTaskManager()
s.pullTask.InitKeepAlive(s.Context, nil, nil)
s.pushTask.InitKeepAlive(s.Context, nil, nil)
s.recordTask.InitKeepAlive(s.Context, nil, nil)
s.AddTasks(&s.streamTask, &s.pullTask, &s.pushTask, &s.recordTask)
Servers.Add(s)
return
}
func (s *Server) Call(callback func()) {
s.streamTM.Call(callback)
func (s *Server) CallOnStreamTask(callback func(*Task) error) {
s.streamTask.Call(callback)
}
func (s *Server) Dispose() {
@@ -258,22 +261,15 @@ func (s *Server) Dispose() {
_ = s.tcplis.Close()
_ = s.grpcClientConn.Close()
s.config.HTTP.StopListen()
err := context.Cause(s)
s.streamTM.ShutDown(err)
s.pullTM.ShutDown(err)
s.pushTM.ShutDown(err)
s.recordTM.ShutDown(err)
s.pluginTM.ShutDown(err)
}
func (s *Server) Run(ctx context.Context, conf any) (err error) {
for {
s.Init(ctx, conf)
if err = serverTM.Start(s); err != nil {
if err = globalTask.WaitTaskAdded(s); err != nil {
return
}
<-s.Done()
if err = context.Cause(s); err != ErrRestart {
if err = s.WaitStopped(); err != ErrRestart {
return
}
var server Server

View File

@@ -23,26 +23,25 @@ type PubSubBase struct {
StreamPath string
Args url.Values
TimeoutTimer *time.Timer
MetaData any
}
func (ps *PubSubBase) Init(streamPath string, conf any, options ...any) {
ctx := ps.Plugin.Context
var logger *slog.Logger
var executor []TaskExecutor
var executor TaskExecutor
for _, option := range options {
switch v := option.(type) {
case TaskExecutor:
executor = append(executor, v)
executor = v
case *slog.Logger:
logger = v
case context.Context:
ctx = v
default:
ps.MetaData = v
case map[string]any:
ps.Description = v
}
}
ps.Task.Init(ctx, logger, executor...)
ps.Task.Init(ctx, logger, executor)
if u, err := url.Parse(streamPath); err == nil {
ps.StreamPath, ps.Args = u.Path, u.Query()
}
@@ -77,7 +76,7 @@ type Subscriber struct {
func createSubscriber(p *Plugin, streamPath string, options ...any) *Subscriber {
subscriber := &Subscriber{Subscribe: p.config.Subscribe}
subscriber.ID = p.Server.streamTM.GetID()
subscriber.ID = p.Server.streamTask.GetID()
subscriber.Plugin = p
subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout)
var opt = []any{subscriber, p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)}
@@ -109,7 +108,7 @@ func (s *Subscriber) Start() (err error) {
for plugin := range server.Plugins.Range {
if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(s.StreamPath); remoteURL != "" {
if plugin.Meta.Puller != nil {
go plugin.PullBlock(s.StreamPath, remoteURL, plugin.Meta.Puller)
plugin.Pull(s.StreamPath, remoteURL, plugin.Meta.Puller)
}
}
}