refactor: retry system

This commit is contained in:
langhuihui
2024-08-15 09:13:13 +08:00
parent 4c42db525d
commit d79076416a
24 changed files with 983 additions and 647 deletions

23
api.go
View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"net"
"net/http"
"runtime"
@@ -141,6 +143,27 @@ func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res
})
return
}
func (s *Server) TaskTree(context.Context, *emptypb.Empty) (res *pb.TaskTreeResponse, err error) {
res = &pb.TaskTreeResponse{}
var fillData func(m *util.MarcoTask, res *pb.TaskTreeResponse)
fillData = func(m *util.MarcoTask, res *pb.TaskTreeResponse) {
for task, marcoTask := range m.Range {
child := &pb.TaskTreeResponse{Id: task.ID, Type: task.GetTaskType(), Owner: task.GetOwnerType(), StartTime: timestamppb.New(task.StartTime), Description: maps.Collect(func(yield func(key, value string) bool) {
for k, v := range task.Description {
yield(k, fmt.Sprintf("%v", v))
}
})}
if marcoTask != nil {
fillData(marcoTask, child)
}
res.Children = append(res.Children, child)
}
}
fillData(&s.MarcoTask, res)
return
}
func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) {
s.streamTask.Call(func() error {
var subscribers []*pb.SubscriberSnapShot

File diff suppressed because it is too large Load Diff

View File

@@ -172,6 +172,24 @@ func local_request_Global_Restart_0(ctx context.Context, marshaler runtime.Marsh
}
func request_Global_TaskTree_0(ctx context.Context, marshaler runtime.Marshaler, client GlobalClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq emptypb.Empty
var metadata runtime.ServerMetadata
msg, err := client.TaskTree(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_Global_TaskTree_0(ctx context.Context, marshaler runtime.Marshaler, server GlobalServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq emptypb.Empty
var metadata runtime.ServerMetadata
msg, err := server.TaskTree(ctx, &protoReq)
return msg, metadata, err
}
var (
filter_Global_StreamList_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
)
@@ -862,6 +880,31 @@ func RegisterGlobalHandlerServer(ctx context.Context, mux *runtime.ServeMux, ser
})
mux.Handle("GET", pattern_Global_TaskTree_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
var err error
var annotatedContext context.Context
annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/m7s.Global/TaskTree", runtime.WithHTTPPathPattern("/api/task/tree"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_Global_TaskTree_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_Global_TaskTree_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle("GET", pattern_Global_StreamList_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@@ -1266,6 +1309,28 @@ func RegisterGlobalHandlerClient(ctx context.Context, mux *runtime.ServeMux, cli
})
mux.Handle("GET", pattern_Global_TaskTree_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
var err error
var annotatedContext context.Context
annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/m7s.Global/TaskTree", runtime.WithHTTPPathPattern("/api/task/tree"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_Global_TaskTree_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_Global_TaskTree_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle("GET", pattern_Global_StreamList_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@@ -1520,6 +1585,8 @@ var (
pattern_Global_Restart_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"api", "restart", "id"}, ""))
pattern_Global_TaskTree_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "task", "tree"}, ""))
pattern_Global_StreamList_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "stream", "list"}, ""))
pattern_Global_WaitList_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "stream", "waitlist"}, ""))
@@ -1552,6 +1619,8 @@ var (
forward_Global_Restart_0 = runtime.ForwardResponseMessage
forward_Global_TaskTree_0 = runtime.ForwardResponseMessage
forward_Global_StreamList_0 = runtime.ForwardResponseMessage
forward_Global_WaitList_0 = runtime.ForwardResponseMessage

View File

@@ -27,6 +27,11 @@ service Global {
post: "/api/restart/{id}"
};
}
rpc TaskTree (google.protobuf.Empty) returns (TaskTreeResponse) {
option (google.api.http) = {
get: "/api/task/tree"
};
}
rpc StreamList (StreamListRequest) returns (StreamListResponse) {
option (google.api.http) = {
get: "/api/stream/list"
@@ -159,6 +164,15 @@ message SysInfoResponse {
repeated PluginInfo plugins = 8;
}
message TaskTreeResponse {
uint32 id = 1;
string type = 2;
string owner = 3;
google.protobuf.Timestamp startTime = 4;
map<string, string> description = 5;
repeated TaskTreeResponse children = 6;
}
message StreamListRequest {
int32 pageNum = 1;
int32 pageSize = 2;

View File

@@ -27,6 +27,7 @@ type GlobalClient interface {
Summary(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*SummaryResponse, error)
Shutdown(ctx context.Context, in *RequestWithId, opts ...grpc.CallOption) (*emptypb.Empty, error)
Restart(ctx context.Context, in *RequestWithId, opts ...grpc.CallOption) (*emptypb.Empty, error)
TaskTree(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*TaskTreeResponse, error)
StreamList(ctx context.Context, in *StreamListRequest, opts ...grpc.CallOption) (*StreamListResponse, error)
WaitList(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*StreamWaitListResponse, error)
StreamInfo(ctx context.Context, in *StreamSnapRequest, opts ...grpc.CallOption) (*StreamInfoResponse, error)
@@ -84,6 +85,15 @@ func (c *globalClient) Restart(ctx context.Context, in *RequestWithId, opts ...g
return out, nil
}
func (c *globalClient) TaskTree(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*TaskTreeResponse, error) {
out := new(TaskTreeResponse)
err := c.cc.Invoke(ctx, "/m7s.Global/TaskTree", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *globalClient) StreamList(ctx context.Context, in *StreamListRequest, opts ...grpc.CallOption) (*StreamListResponse, error) {
out := new(StreamListResponse)
err := c.cc.Invoke(ctx, "/m7s.Global/StreamList", in, out, opts...)
@@ -191,6 +201,7 @@ type GlobalServer interface {
Summary(context.Context, *emptypb.Empty) (*SummaryResponse, error)
Shutdown(context.Context, *RequestWithId) (*emptypb.Empty, error)
Restart(context.Context, *RequestWithId) (*emptypb.Empty, error)
TaskTree(context.Context, *emptypb.Empty) (*TaskTreeResponse, error)
StreamList(context.Context, *StreamListRequest) (*StreamListResponse, error)
WaitList(context.Context, *emptypb.Empty) (*StreamWaitListResponse, error)
StreamInfo(context.Context, *StreamSnapRequest) (*StreamInfoResponse, error)
@@ -221,6 +232,9 @@ func (UnimplementedGlobalServer) Shutdown(context.Context, *RequestWithId) (*emp
func (UnimplementedGlobalServer) Restart(context.Context, *RequestWithId) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Restart not implemented")
}
func (UnimplementedGlobalServer) TaskTree(context.Context, *emptypb.Empty) (*TaskTreeResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TaskTree not implemented")
}
func (UnimplementedGlobalServer) StreamList(context.Context, *StreamListRequest) (*StreamListResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StreamList not implemented")
}
@@ -339,6 +353,24 @@ func _Global_Restart_Handler(srv interface{}, ctx context.Context, dec func(inte
return interceptor(ctx, in, info, handler)
}
func _Global_TaskTree_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GlobalServer).TaskTree(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/m7s.Global/TaskTree",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GlobalServer).TaskTree(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Global_StreamList_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StreamListRequest)
if err := dec(in); err != nil {
@@ -560,6 +592,10 @@ var Global_ServiceDesc = grpc.ServiceDesc{
MethodName: "Restart",
Handler: _Global_Restart_Handler,
},
{
MethodName: "TaskTree",
Handler: _Global_TaskTree_Handler,
},
{
MethodName: "StreamList",
Handler: _Global_StreamList_Handler,

View File

@@ -8,8 +8,12 @@ type ChannelTask struct {
callback reflect.Value
}
func (t *ChannelTask) start() (reflect.Value, error) {
return t.channel, nil
func (t *ChannelTask) GetTaskType() string {
return "channel"
}
func (t *ChannelTask) getSignal() reflect.Value {
return t.channel
}
func (t *ChannelTask) tick(signal reflect.Value) {

View File

@@ -18,8 +18,7 @@ func GetNextTaskID() uint32 {
}
func init() {
RootTask.Name = "root"
RootTask.init(context.Background())
RootTask.initTask(context.Background(), &RootTask)
RootTask.Logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
}
@@ -27,9 +26,13 @@ type MarcoLongTask struct {
MarcoTask
}
func (task *MarcoLongTask) start() (signal reflect.Value, err error) {
task.keepAlive = true
return task.MarcoTask.start()
func (m *MarcoLongTask) initTask(ctx context.Context, task ITask) {
m.MarcoTask.initTask(ctx, task)
m.keepAlive = true
}
func (m *MarcoLongTask) GetTaskType() string {
return "long"
}
// MarcoTask include sub tasks
@@ -41,8 +44,16 @@ type MarcoTask struct {
keepAlive bool
}
func (mt *MarcoTask) init(ctx context.Context) {
mt.Task.init(ctx)
func (m *MarcoTask) GetTaskType() string {
return "marco"
}
func (mt *MarcoTask) getMaroTask() *MarcoTask {
return mt
}
func (mt *MarcoTask) initTask(ctx context.Context, task ITask) {
mt.Task.initTask(ctx, task)
mt.shutdown = nil
mt.addSub = make(chan ITask, 10)
}
@@ -50,13 +61,13 @@ func (mt *MarcoTask) init(ctx context.Context) {
func (mt *MarcoTask) dispose() {
reason := mt.StopReason()
if mt.Logger != nil {
mt.Debug("task dispose", "reason", reason, "taskId", mt.ID, "taskName", mt.Name)
mt.Debug("task dispose", "reason", reason, "taskId", mt.ID, "taskType", mt.GetTaskType(), "ownerType", mt.GetOwnerType())
}
mt.disposeHandler()
close(mt.addSub)
_ = mt.WaitStopped()
if mt.Logger != nil {
mt.Debug("task disposed", "reason", reason, "taskId", mt.ID, "taskName", mt.Name)
mt.Debug("task disposed", "reason", reason, "taskId", mt.ID, "taskType", mt.GetTaskType(), "ownerType", mt.GetOwnerType())
}
for _, listener := range mt.afterDisposeListeners {
listener()
@@ -91,19 +102,31 @@ func (mt *MarcoTask) lazyStart(t ITask) {
mt.addSub <- t
}
func (mt *MarcoTask) Range(callback func(task *Task, m *MarcoTask) bool) {
for _, task := range mt.children {
var m *MarcoTask
if v, ok := task.(interface{ getMaroTask() *MarcoTask }); ok {
m = v.getMaroTask()
}
callback(task.getTask(), m)
}
}
func (mt *MarcoTask) AddTask(task ITask) *Task {
t := task.getTask()
if t.parentCtx != nil && task.IsStopped() { //reuse task
t.parent = nil
return mt.AddTaskWithContext(t.parentCtx, task)
}
return mt.AddTaskWithContext(mt.Context, task)
}
func (mt *MarcoTask) AddTaskWithContext(ctx context.Context, t ITask) (task *Task) {
if ctx == nil && mt.Context == nil {
panic("context is nil")
}
if task = t.getTask(); task.parent == nil {
t.init(ctx)
if v, ok := t.(TaskStarter); ok {
task.startHandler = v.Start
}
if v, ok := t.(TaskDisposal); ok {
task.disposeHandler = v.Dispose
}
t.initTask(ctx, t)
}
mt.lazyStart(t)
return
@@ -116,7 +139,6 @@ func (mt *MarcoTask) Call(callback func() error) {
func CreateTaskByCallBack(start func() error, dispose func()) *Task {
var task Task
task.Name = "call"
task.startHandler = func() error {
err := start()
if err == nil && dispose == nil {
@@ -130,8 +152,7 @@ func CreateTaskByCallBack(start func() error, dispose func()) *Task {
func (mt *MarcoTask) AddChan(channel any, callback any) *ChannelTask {
var chanTask ChannelTask
chanTask.Name = "channel"
chanTask.init(mt.Context)
chanTask.initTask(mt.Context, &chanTask)
chanTask.channel = reflect.ValueOf(channel)
chanTask.callback = reflect.ValueOf(callback)
mt.lazyStart(&chanTask)
@@ -158,9 +179,9 @@ func (mt *MarcoTask) run() {
return
}
if task := rev.Interface().(ITask); task.getParent() == mt {
if signal, err := task.start(); err == nil {
if err := task.start(); err == nil {
mt.children = append(mt.children, task)
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: signal})
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: task.getSignal()})
} else {
task.Stop(err)
}

View File

@@ -1,48 +0,0 @@
package util
import (
"errors"
"fmt"
"reflect"
"time"
)
var ErrRetryRunOut = errors.New("retry run out")
type RetryTask struct {
Task
MaxRetry int
RetryCount int
RetryInterval time.Duration
}
func (task *RetryTask) start() (signal reflect.Value, err error) {
task.StartTime = time.Now()
for !task.parent.IsStopped() {
err = task.startHandler()
if task.Logger != nil {
task.Debug("task start", "taskId", task.ID)
}
task.startup.Fulfill(err)
if err == nil {
return
}
if task.MaxRetry < 0 || task.RetryCount < task.MaxRetry {
task.RetryCount++
if delta := time.Since(task.StartTime); delta < task.RetryInterval {
time.Sleep(task.RetryInterval - delta)
}
if task.Logger != nil {
task.Warn(fmt.Sprintf("retry %d/%d", task.RetryCount, task.MaxRetry))
}
task.init(task.parentCtx)
} else {
if task.Logger != nil {
task.Warn(fmt.Sprintf("max retry %d failed", task.MaxRetry))
}
task.Stop(ErrRetryRunOut)
return reflect.ValueOf(task.Done()), nil
}
}
return
}

View File

@@ -3,6 +3,7 @@ package util
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"time"
@@ -12,21 +13,29 @@ const TraceLevel = slog.Level(-8)
var (
ErrAutoStop = errors.New("auto stop")
ErrCallbackTask = errors.New("callback task")
ErrCallbackTask = errors.New("callback")
ErrRetryRunOut = errors.New("retry out")
ErrTaskComplete = errors.New("complete")
EmptyStart = func() error { return nil }
EmptyDispose = func() {}
)
type (
ITask interface {
init(ctx context.Context)
initTask(context.Context, ITask)
getParent() *MarcoTask
getTask() *Task
getSignal() reflect.Value
Stop(error)
StopReason() error
start() (reflect.Value, error)
start() error
dispose()
IsStopped() bool
GetTaskType() string
GetOwnerType() string
}
IMarcoTask interface {
Range(func(yield ITask) bool)
}
IChannelTask interface {
tick(reflect.Value)
@@ -37,14 +46,23 @@ type (
TaskDisposal interface {
Dispose()
}
TaskBlock interface {
Run() error
}
RetryConfig struct {
MaxRetry int
RetryCount int
RetryInterval time.Duration
}
Task struct {
ID uint32
Name string
StartTime time.Time
*slog.Logger
context.Context
context.CancelCauseFunc
startHandler func() error
retry RetryConfig
owner reflect.Type
startHandler, runHandler func() error
afterStartListeners, afterDisposeListeners []func()
disposeHandler func()
Description map[string]any
@@ -54,6 +72,19 @@ type (
}
)
func (task *Task) SetRetry(maxRetry int, retryInterval time.Duration) {
task.retry.MaxRetry = maxRetry
task.retry.RetryInterval = retryInterval
}
func (task *Task) GetOwnerType() string {
return task.owner.Name()
}
func (task *Task) GetTaskType() string {
return "base"
}
func (task *Task) getTask() *Task {
return task
}
@@ -71,6 +102,7 @@ func (task *Task) WaitStarted() error {
}
func (task *Task) WaitStopped() error {
_ = task.WaitStarted()
if task.shutdown == nil {
return task.StopReason()
}
@@ -92,7 +124,7 @@ func (task *Task) StopReason() error {
func (task *Task) Stop(err error) {
if task.CancelCauseFunc != nil {
if task.Logger != nil {
task.Debug("task stop", "reason", err.Error(), "elapsed", time.Since(task.StartTime), "taskId", task.ID, "taskName", task.Name)
task.Debug("task stop", "reason", err.Error(), "elapsed", time.Since(task.StartTime), "taskId", task.ID, "taskType", task.GetTaskType(), "ownerType", task.GetOwnerType())
}
task.CancelCauseFunc(err)
}
@@ -106,24 +138,44 @@ func (task *Task) OnDispose(listener func()) {
task.afterDisposeListeners = append(task.afterDisposeListeners, listener)
}
func (task *Task) start() (signal reflect.Value, err error) {
func (task *Task) getSignal() reflect.Value {
return reflect.ValueOf(task.Done())
}
func (task *Task) start() (err error) {
task.StartTime = time.Now()
err = task.startHandler()
if task.Logger != nil {
task.Debug("task start", "taskId", task.ID, "taskName", task.Name)
task.Debug("task start", "taskId", task.ID, "taskType", task.GetTaskType(), "ownerType", task.GetOwnerType())
}
for task.retry.MaxRetry < 0 || task.retry.RetryCount <= task.retry.MaxRetry {
err = task.startHandler()
if err == nil {
break
} else if task.IsStopped() {
return task.StopReason()
}
task.retry.RetryCount++
if task.Logger != nil {
task.Warn(fmt.Sprintf("retry %d/%d", task.retry.RetryCount, task.retry.MaxRetry))
}
if delta := time.Since(task.StartTime); delta < task.retry.RetryInterval {
time.Sleep(task.retry.RetryInterval - delta)
}
}
task.startup.Fulfill(err)
signal = reflect.ValueOf(task.Done())
for _, listener := range task.afterStartListeners {
listener()
}
if task.runHandler != nil {
go task.run()
}
return
}
func (task *Task) dispose() {
reason := task.StopReason()
if task.Logger != nil {
task.Debug("task dispose", "reason", reason, "taskId", task.ID, "taskName", task.Name)
task.Debug("task dispose", "reason", reason, "taskId", task.ID, "taskType", task.GetTaskType(), "ownerType", task.GetOwnerType())
}
task.disposeHandler()
task.shutdown.Fulfill(reason)
@@ -132,9 +184,49 @@ func (task *Task) dispose() {
}
}
func (task *Task) init(ctx context.Context) {
func (task *Task) initTask(ctx context.Context, iTask ITask) {
task.parentCtx = ctx
task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx)
task.startup = NewPromise(task.Context)
task.shutdown = NewPromise(context.Background())
task.owner = reflect.TypeOf(iTask)
if v, ok := iTask.(TaskStarter); ok {
task.startHandler = v.Start
}
if v, ok := iTask.(TaskDisposal); ok {
task.disposeHandler = v.Dispose
}
if v, ok := iTask.(TaskBlock); ok {
task.runHandler = v.Run
}
}
func (task *Task) ResetRetryCount() {
task.retry.RetryCount = 0
}
func (task *Task) run() {
var err error
retry := task.retry
for !task.IsStopped() {
if retry.MaxRetry < 0 || retry.RetryCount <= retry.MaxRetry {
err = task.runHandler()
if err == nil {
task.Stop(ErrTaskComplete)
} else {
retry.RetryCount++
if task.Logger != nil {
task.Warn(fmt.Sprintf("retry %d/%d", retry.RetryCount, retry.MaxRetry))
}
if delta := time.Since(task.StartTime); delta < retry.RetryInterval {
time.Sleep(retry.RetryInterval - delta)
}
}
} else {
if task.Logger != nil {
task.Warn(fmt.Sprintf("max retry %d failed", retry.MaxRetry))
}
task.Stop(errors.Join(err, ErrRetryRunOut))
}
}
}

View File

@@ -11,7 +11,7 @@ import (
func createMarcoTask() *MarcoTask {
var mt MarcoTask
mt.init(context.Background())
mt.initTask(context.Background())
mt.Logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
return &mt
}
@@ -54,7 +54,7 @@ func Test_RetryTask(t *testing.T) {
func Test_Call_ExecutesCallback(t *testing.T) {
mt := createMarcoTask()
called := false
mt.Call(func(*Task) error {
mt.Call(func() error {
called = true
return nil
})
@@ -63,21 +63,6 @@ func Test_Call_ExecutesCallback(t *testing.T) {
}
}
func Test_AddCall_AddsCallbackTask(t *testing.T) {
mt := createMarcoTask()
called := false
task := mt.AddCall(func(*Task) error {
return nil
}, func() {
called = true
})
task.Stop(ErrCallbackTask)
mt.WaitStopped()
if !called {
t.Errorf("expected callback to be called")
}
}
func Test_AddChan_AddsChannelTask(t *testing.T) {
mt := createMarcoTask()
channel := time.NewTimer(time.Millisecond * 100)

View File

@@ -53,7 +53,6 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
p.Meta = plugin
p.Server = s
p.Logger = s.Logger.With("plugin", plugin.Name)
p.Name = plugin.Name
upperName := strings.ToUpper(plugin.Name)
if os.Getenv(upperName+"_ENABLE") == "false" {
p.Disabled = true
@@ -107,7 +106,7 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
}
}
}
p.Description = map[string]any{"name": plugin.Name, "version": plugin.Version}
p.Description = map[string]any{"version": plugin.Version}
return
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/quic-go/quic-go"
"io"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
"net"
"net/http"
"strings"
@@ -58,17 +59,24 @@ type ConsolePlugin struct {
var _ = m7s.InstallPlugin[ConsolePlugin]()
func (cfg *ConsolePlugin) connect() (conn quic.Connection, err error) {
type ConnectServerTask struct {
util.Task
cfg *ConsolePlugin
quic.Connection
}
func (task *ConnectServerTask) Start() (err error) {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"monibuca"},
}
conn, err = quic.DialAddr(cfg.Context, cfg.Server, tlsConf, &quic.Config{
cfg := task.cfg
task.Connection, err = quic.DialAddr(cfg.Context, cfg.Server, tlsConf, &quic.Config{
KeepAlivePeriod: time.Second * 10,
EnableDatagrams: true,
})
if stream := quic.Stream(nil); err == nil {
if stream, err = conn.OpenStreamSync(cfg.Context); err == nil {
if stream, err = task.OpenStreamSync(cfg.Context); err == nil {
_, err = stream.Write(append([]byte{1}, (cfg.Secret + "\n")...))
if msg := []byte(nil); err == nil {
if msg, err = bufio.NewReader(stream).ReadSlice(0); err == nil {
@@ -76,7 +84,7 @@ func (cfg *ConsolePlugin) connect() (conn quic.Connection, err error) {
if err = json.Unmarshal(msg[:len(msg)-1], &rMessage); err == nil {
if rMessage["code"].(float64) != 0 {
// cfg.Error("response from console server ", cfg.Server, rMessage["msg"])
return nil, fmt.Errorf("response from console server %s %s", cfg.Server, rMessage["msg"])
return fmt.Errorf("response from console server %s %s", cfg.Server, rMessage["msg"])
} else {
// cfg.reportStream = stream
cfg.Info("response from console server ", cfg.Server, rMessage)
@@ -95,81 +103,89 @@ func (cfg *ConsolePlugin) connect() (conn quic.Connection, err error) {
return
}
func (cfg *ConsolePlugin) OnInit() error {
if cfg.Secret == "" {
return nil
}
conn, err := cfg.connect()
if err != nil {
return err
}
go func() {
for !cfg.IsStopped() {
for err == nil {
var s quic.Stream
if s, err = conn.AcceptStream(cfg.Context); err == nil {
go cfg.ReceiveRequest(s, conn)
}
}
time.Sleep(time.Second)
conn, err = cfg.connect()
if err != nil {
break
}
func (task *ConnectServerTask) Run() (err error) {
for err == nil {
var s quic.Stream
if s, err = task.AcceptStream(task.Task.Context); err == nil {
task.cfg.AddTask(&ReceiveRequestTask{
stream: s,
handler: task.cfg.GetGlobalCommonConf().GetHandler(),
conn: task.Connection,
})
}
}()
return err
}
return
}
func (cfg *ConsolePlugin) ReceiveRequest(s quic.Stream, conn quic.Connection) error {
defer s.Close()
wr := &myResponseWriter2{Stream: s}
reader := bufio.NewReader(s)
var req *http.Request
type ReceiveRequestTask struct {
util.Task
stream quic.Stream
handler http.Handler
conn quic.Connection
req *http.Request
}
func (task *ReceiveRequestTask) Start() (err error) {
reader := bufio.NewReader(task.stream)
url, _, err := reader.ReadLine()
if err == nil {
ctx, cancel := context.WithCancel(s.Context())
ctx, cancel := context.WithCancel(task.stream.Context())
defer cancel()
req, err = http.NewRequestWithContext(ctx, "GET", string(url), reader)
task.req, err = http.NewRequestWithContext(ctx, "GET", string(url), reader)
for err == nil {
var h []byte
if h, _, err = reader.ReadLine(); len(h) > 0 {
if b, a, f := strings.Cut(string(h), ": "); f {
req.Header.Set(b, a)
task.req.Header.Set(b, a)
}
} else {
break
}
}
if err == nil {
h := cfg.GetGlobalCommonConf().GetHandler()
if req.Header.Get("Accept") == "text/event-stream" {
go h.ServeHTTP(wr, req)
} else if req.Header.Get("Upgrade") == "websocket" {
var writer myResponseWriter3
writer.Stream = s
writer.Connection = conn
req.Host = req.Header.Get("Host")
if req.Host == "" {
req.Host = req.URL.Host
}
if req.Host == "" {
req.Host = "m7s.live"
}
h.ServeHTTP(&writer, req) //建立websocket连接,握手
} else {
method := req.Header.Get("M7s-Method")
if method == "POST" {
req.Method = "POST"
}
h.ServeHTTP(wr, req)
}
}
io.ReadAll(s)
}
if err != nil {
cfg.Error("read console server", "err", err)
}
return err
return
}
func (task *ReceiveRequestTask) Run() (err error) {
wr := &myResponseWriter2{Stream: task.stream}
req := task.req
if req.Header.Get("Accept") == "text/event-stream" {
go task.handler.ServeHTTP(wr, req)
} else if req.Header.Get("Upgrade") == "websocket" {
var writer myResponseWriter3
writer.Stream = task.stream
writer.Connection = task.conn
req.Host = req.Header.Get("Host")
if req.Host == "" {
req.Host = req.URL.Host
}
if req.Host == "" {
req.Host = "m7s.live"
}
task.handler.ServeHTTP(&writer, req) //建立websocket连接,握手
} else {
method := req.Header.Get("M7s-Method")
if method == "POST" {
req.Method = "POST"
}
task.handler.ServeHTTP(wr, req)
}
_, err = io.ReadAll(task.stream)
return
}
func (task *ReceiveRequestTask) Dispose() {
task.stream.Close()
}
func (cfg *ConsolePlugin) OnInit() error {
if cfg.Secret == "" || cfg.Server == "" {
return nil
}
connectTask := ConnectServerTask{
cfg: cfg,
}
connectTask.SetRetry(-1, time.Second)
cfg.AddTask(&connectTask)
return nil
}

View File

@@ -16,7 +16,6 @@ import (
var writeMetaTagQueueTask util.MarcoLongTask
func init() {
writeMetaTagQueueTask.Name = "writeMetaTagQueue"
util.RootTask.AddTask(&writeMetaTagQueueTask)
}

11
plugin/flv/pkg/vod.go Normal file
View File

@@ -0,0 +1,11 @@
package flv
import "m7s.live/m7s/v5/pkg/util"
type Vod struct {
util.Task
}
func (v *Vod) Start() error {
return nil
}

View File

@@ -3,7 +3,6 @@ package plugin_gb28181
import (
"github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip"
"log/slog"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
gb28181 "m7s.live/m7s/v5/plugin/gb28181/pkg"
@@ -44,7 +43,7 @@ type Device struct {
dialogClient *sipgo.DialogClient
contactHDR sip.ContactHeader
fromHDR sip.FromHeader
*slog.Logger
devices *util.Collection[string, *Device]
}
func (d *Device) GetKey() string {
@@ -94,12 +93,7 @@ func (d *Device) send(req *sip.Request) (*sip.Response, error) {
return d.client.Do(d, req)
}
func (d *Device) eventLoop(gb *GB28181Plugin) {
defer func() {
d.Status = DeviceOfflineStatus
if gb.devices.RemoveByKey(d.ID) {
}
}()
func (d *Device) Run() {
response, err := d.catalog()
if err != nil {
d.Error("catalog", "err", err)
@@ -126,7 +120,7 @@ func (d *Device) eventLoop(gb *GB28181Plugin) {
} else {
d.Debug("subCatalog", "response", response.String())
}
response, err = d.subscribePosition(int(gb.Position.Interval / time.Second))
response, err = d.subscribePosition(int(6))
if err != nil {
d.Error("subPosition", "err", err)
} else {
@@ -154,7 +148,7 @@ func (d *Device) eventLoop(gb *GB28181Plugin) {
//如果父ID并非本身所属设备一般情况下这是因为下级设备上传了目录信息该信息通常不需要处理。
// 暂时不考虑级联目录的实现
if d.ID != parentId {
if parent, ok := gb.devices.Get(parentId); ok {
if parent, ok := d.devices.Get(parentId); ok {
parent.addOrUpdateChannel(c)
continue
} else {

View File

@@ -267,18 +267,27 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) {
},
Params: sip.NewParams(),
},
devices: &gb.devices,
}
gb.With(d, "id", id)
d.Logger = gb.With("id", id)
d.fromHDR.Params.Add("tag", sip.GenerateTagN(16))
d.client, _ = sipgo.NewClient(gb.ua, sipgo.WithClientLogger(zerolog.New(os.Stdout)), sipgo.WithClientHostname(publicIP))
d.dialogClient = sipgo.NewDialogClient(d.client, d.contactHDR)
d.channels.L = new(sync.RWMutex)
d.Info("StoreDevice", "source", source, "desc", desc, "servIp", servIp, "publicIP", publicIP, "recipient", req.Recipient)
gb.devices.Add(d)
if gb.DB != nil {
//TODO
}
go d.eventLoop(gb)
task := gb.AddTask(d)
task.OnStart(func() {
gb.devices.Add(d)
})
task.OnDispose(func() {
d.Status = DeviceOfflineStatus
if gb.devices.RemoveByKey(d.ID) {
}
})
return
}

View File

@@ -147,7 +147,6 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
}
var publisher *m7s.Publisher
publisher, err = p.Publish(nc.Context, nc.AppName+"/"+cmd.PublishingName)
publisher.Description = nc.Description
if err != nil {
err = ns.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
} else {
@@ -170,7 +169,6 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
var suber *m7s.Subscriber
// sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID)
suber, err = p.Subscribe(nc.Context, streamPath)
suber.Description = nc.Description
if err != nil {
err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else {

View File

@@ -72,7 +72,6 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) {
chunkHeaderBuf: make(util.Buffer, 0, 20),
Receivers: make(map[uint32]*m7s.Publisher),
}
ret.Name = "NetConnection"
ret.mediaDataPool.SetAllocator(util.NewScalableMemoryAllocator(1 << util.MinPowerOf2))
return
}

View File

@@ -135,7 +135,6 @@ func (p *Publisher) GetKey() string {
func createPublisher(p *Plugin, streamPath string, conf config.Publish) (publisher *Publisher) {
publisher = &Publisher{Publish: conf}
publisher.ID = util.GetNextTaskID()
publisher.Name = "publisher"
publisher.Plugin = p
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
publisher.Logger = p.Logger.With("streamPath", streamPath, "pId", publisher.ID)

View File

@@ -4,6 +4,7 @@ import (
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/util"
"time"
)
type Connection struct {
@@ -23,7 +24,6 @@ func createPullContext(p *Plugin, streamPath string, url string) (pullCtx *PullC
Pull: p.config.Pull,
publishConfig: &publishConfig,
}
pullCtx.Name = "pull"
pullCtx.Plugin = p
pullCtx.ConnectProxy = p.config.Pull.Proxy
pullCtx.RemoteURL = url
@@ -44,25 +44,23 @@ func (p *PullContext) GetKey() string {
}
type PullSubTask struct {
util.RetryTask
util.Task
ctx *PullContext
Puller
}
func (p *PullSubTask) Start() (err error) {
p.MaxRetry = p.ctx.RePull
if p.ctx.Publisher, err = p.ctx.Plugin.PublishWithConfig(p.Context, p.ctx.StreamPath, *p.ctx.publishConfig); err != nil {
p.Error("pull publish failed", "error", err)
return
}
p.ctx.Publisher.OnDispose(func() {
p.Stop(p.ctx.Publisher.StopReason())
})
return p.Puller(p.ctx)
}
func (p *PullContext) Do(puller Puller) {
p.AddTask(&PullSubTask{ctx: p, Puller: puller})
task := &PullSubTask{ctx: p, Puller: puller}
task.SetRetry(p.RePull, time.Second*5)
p.AddTask(task)
}
func (p *PullContext) Start() (err error) {

View File

@@ -3,6 +3,7 @@ package m7s
import (
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/util"
"time"
"m7s.live/m7s/v5/pkg/config"
)
@@ -11,7 +12,6 @@ type Pusher = func(*PushContext) error
func createPushContext(p *Plugin, streamPath string, url string) (pushCtx *PushContext) {
pushCtx = &PushContext{Push: p.config.Push}
pushCtx.Name = "push"
pushCtx.Plugin = p
pushCtx.RemoteURL = url
pushCtx.StreamPath = streamPath
@@ -31,13 +31,12 @@ func (p *PushContext) GetKey() string {
}
type PushSubTask struct {
util.RetryTask
util.Task
ctx *PushContext
Pusher
}
func (p *PushSubTask) Start() (err error) {
p.MaxRetry = p.ctx.RePush
if p.ctx.Subscriber, err = p.ctx.Plugin.Subscribe(p.Context, p.ctx.StreamPath); err != nil {
p.Error("push subscribe failed", "error", err)
return
@@ -46,7 +45,9 @@ func (p *PushSubTask) Start() (err error) {
}
func (p *PushContext) Do(pusher Pusher) {
p.AddTask(&PushSubTask{ctx: p, Pusher: pusher})
task := &PushSubTask{ctx: p, Pusher: pusher}
task.SetRetry(p.RePush, time.Second*5)
p.AddTask(task)
}
func (p *PushContext) Start() (err error) {

View File

@@ -1,7 +1,6 @@
package m7s
import (
"fmt"
"m7s.live/m7s/v5/pkg/util"
"os"
"path/filepath"
@@ -20,7 +19,6 @@ func createRecoder(p *Plugin, streamPath string, filePath string) (recorder *Rec
FilePath: filePath,
StreamPath: streamPath,
}
recorder.Name = "record"
recorder.Logger = p.Logger.With("filePath", filePath, "streamPath", streamPath)
return
}
@@ -54,7 +52,9 @@ func (r *recordSubTask) Start() (err error) {
}
dir = filepath.Dir(p.FilePath)
}
r.Name = fmt.Sprintf("record:%s", p.FilePath)
r.Description = map[string]any{
"filePath": p.FilePath,
}
if err = os.MkdirAll(dir, 0755); err != nil {
return
}

View File

@@ -27,12 +27,11 @@ import (
)
var (
Version = "v5.0.0"
MergeConfigs = []string{"Publish", "Subscribe", "HTTP", "PublicIP", "LogLevel", "EnableAuth", "DB"}
ExecPath = os.Args[0]
ExecDir = filepath.Dir(ExecPath)
DefaultServer = NewServer()
serverMeta = PluginMeta{
Version = "v5.0.0"
MergeConfigs = []string{"Publish", "Subscribe", "HTTP", "PublicIP", "LogLevel", "EnableAuth", "DB"}
ExecPath = os.Args[0]
ExecDir = filepath.Dir(ExecPath)
serverMeta = PluginMeta{
Name: "Global",
Version: Version,
}
@@ -82,15 +81,17 @@ type Server struct {
conf any
}
func NewServer() (s *Server) {
s = &Server{}
func NewServer(conf any) (s *Server) {
s = &Server{
conf: conf,
}
s.ID = util.GetNextTaskID()
s.Meta = &serverMeta
return
}
func Run(ctx context.Context, conf any) error {
return DefaultServer.Run(ctx, conf)
return util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped()
}
type rawconfig = map[string]map[string]any
@@ -122,9 +123,7 @@ func (s *Server) GetKey() uint32 {
return s.ID
}
func (s *Server) Init(conf any) {
s.Name = "server"
s.conf = conf
func (s *Server) Start() (err error) {
s.Server = s
s.handler = s
s.config.HTTP.ListenAddrTLS = ":8443"
@@ -133,13 +132,6 @@ func (s *Server) Init(conf any) {
s.LogHandler.SetLevel(slog.LevelDebug)
s.LogHandler.Add(defaultLogHandler)
s.Logger = slog.New(&s.LogHandler).With("server", s.ID)
s.streamTask.Name = "stream"
s.pullTask.Name = "pull"
s.pushTask.Name = "push"
s.recordTask.Name = "record"
}
func (s *Server) Start() (err error) {
httpConf, tcpConf := &s.config.HTTP, &s.config.TCP
mux := runtime.NewServeMux(runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) {
httpConf.GetHttpMux().ServeHTTP(w, r)
@@ -267,24 +259,15 @@ func (s *Server) Dispose() {
_ = s.tcplis.Close()
_ = s.grpcClientConn.Close()
s.config.HTTP.StopListen()
}
func (s *Server) Run(ctx context.Context, conf any) (err error) {
for {
s.Init(conf)
util.RootTask.AddTaskWithContext(ctx, s)
if err = s.WaitStarted(); err != nil {
return
}
if err = s.WaitStopped(); err != ErrRestart {
s.Info("server stopped", "error", err)
return
}
if err := s.StopReason(); err == ErrRestart {
var server Server
server.ID = s.ID
server.Meta = s.Meta
server.DB = s.DB
*s = server
util.RootTask.AddTask(s)
} else {
s.Info("server stopped", "err", err)
}
}

View File

@@ -2,7 +2,6 @@ package m7s
import (
"errors"
"fmt"
"net/url"
"reflect"
"runtime"
@@ -28,6 +27,11 @@ func (ps *PubSubBase) Init(streamPath string, conf any) {
if u, err := url.Parse(streamPath); err == nil {
ps.StreamPath, ps.Args = u.Path, u.Query()
}
ps.Description = map[string]any{
"streamPath": ps.StreamPath,
"args": ps.Args,
"plugin": ps.Plugin.Meta.Name,
}
// args to config
if len(ps.Args) != 0 {
var c config.Config
@@ -60,7 +64,6 @@ type Subscriber struct {
func createSubscriber(p *Plugin, streamPath string, conf config.Subscribe) *Subscriber {
subscriber := &Subscriber{Subscribe: conf}
subscriber.ID = util.GetNextTaskID()
subscriber.Name = "subscriber"
subscriber.Plugin = p
subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout)
subscriber.Logger = p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)
@@ -166,7 +169,6 @@ type SubscribeHandler[A any, V any] struct {
func CreatePlayTask[A any, V any](s *Subscriber, onAudio func(A) error, onVideo func(V) error) util.ITask {
var handler SubscribeHandler[A, V]
handler.Name = fmt.Sprintf("play:%s", s.StreamPath)
handler.s = s
handler.OnAudio = onAudio
handler.OnVideo = onVideo