diff --git a/api.go b/api.go index 1895349..312a667 100644 --- a/api.go +++ b/api.go @@ -131,7 +131,7 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err } func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.StreamInfoResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { if pub, ok := s.Streams.Get(req.StreamPath); ok { res, err = s.getStreamInfo(pub) } else { @@ -141,7 +141,7 @@ func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res return } func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { var subscribers []*pb.SubscriberSnapShot for subscriber := range s.Subscribers.Range { meta, _ := json.Marshal(subscriber.MetaData) @@ -176,7 +176,7 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) return } func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasAudioTrack() { res = &pb.TrackSnapShotResponse{} for _, memlist := range pub.AudioTrack.Allocator.GetChildren() { @@ -254,7 +254,7 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) { } func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasVideoTrack() { res = &pb.TrackSnapShotResponse{} for _, memlist := range pub.VideoTrack.Allocator.GetChildren() { @@ -304,15 +304,15 @@ func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) } func (s *Server) Restart(ctx context.Context, req *pb.RequestWithId) (res *emptypb.Empty, err error) { - if Servers[req.Id] != nil { - Servers[req.Id].Stop(pkg.ErrRestart) + if s, ok := Servers.Get(req.Id); ok { + s.Stop(pkg.ErrRestart) } return empty, err } func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *emptypb.Empty, err error) { - if Servers[req.Id] != nil { - Servers[req.Id].Stop(pkg.ErrStopFromAPI) + if s, ok := Servers.Get(req.Id); ok { + s.Stop(pkg.ErrStopFromAPI) } else { return nil, pkg.ErrNotFound } @@ -320,8 +320,8 @@ func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *empt } func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeRequest) (res *pb.SuccessResponse, err error) { - s.Call(func() { - if subscriber, ok := s.Subscribers.Get(int(req.Id)); ok { + s.streamTM.Call(func() { + if subscriber, ok := s.Subscribers.Get(req.Id); ok { if pub, ok := s.Streams.Get(req.StreamPath); ok { subscriber.Publisher.RemoveSubscriber(subscriber) subscriber.StreamPath = req.StreamPath @@ -335,8 +335,8 @@ func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeReq } func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res *pb.SuccessResponse, err error) { - s.Call(func() { - if subscriber, ok := s.Subscribers.Get(int(req.Id)); ok { + s.streamTM.Call(func() { + if subscriber, ok := s.Subscribers.Get(req.Id); ok { subscriber.Stop(errors.New("stop by api")) } else { err = pkg.ErrNotFound @@ -347,7 +347,7 @@ func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res // /api/stream/list func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *pb.StreamListResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { var streams []*pb.StreamInfoResponse for publisher := range s.Streams.Range { info, err := s.getStreamInfo(publisher) @@ -362,7 +362,7 @@ func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res * } func (s *Server) WaitList(context.Context, *emptypb.Empty) (res *pb.StreamWaitListResponse, err error) { - s.Call(func() { + s.streamTM.Call(func() { res = &pb.StreamWaitListResponse{ List: make(map[string]int32), } @@ -381,53 +381,51 @@ func (s *Server) Api_Summary_SSE(rw http.ResponseWriter, r *http.Request) { } func (s *Server) Summary(context.Context, *emptypb.Empty) (res *pb.SummaryResponse, err error) { - s.Call(func() { - dur := time.Since(s.lastSummaryTime) - if dur < time.Second { - res = s.lastSummary - return + dur := time.Since(s.lastSummaryTime) + if dur < time.Second { + res = s.lastSummary + return + } + v, _ := mem.VirtualMemory() + d, _ := disk.Usage("/") + nv, _ := IOCounters(true) + res = &pb.SummaryResponse{ + Memory: &pb.Usage{ + Total: v.Total >> 20, + Free: v.Available >> 20, + Used: v.Used >> 20, + Usage: float32(v.UsedPercent), + }, + HardDisk: &pb.Usage{ + Total: d.Total >> 30, + Free: d.Free >> 30, + Used: d.Used >> 30, + Usage: float32(d.UsedPercent), + }, + } + if cc, _ := cpu.Percent(time.Second, false); len(cc) > 0 { + res.CpuUsage = float32(cc[0]) + } + netWorks := []*pb.NetWorkInfo{} + for i, n := range nv { + info := &pb.NetWorkInfo{ + Name: n.Name, + Receive: n.BytesRecv, + Sent: n.BytesSent, } - v, _ := mem.VirtualMemory() - d, _ := disk.Usage("/") - nv, _ := IOCounters(true) - res = &pb.SummaryResponse{ - Memory: &pb.Usage{ - Total: v.Total >> 20, - Free: v.Available >> 20, - Used: v.Used >> 20, - Usage: float32(v.UsedPercent), - }, - HardDisk: &pb.Usage{ - Total: d.Total >> 30, - Free: d.Free >> 30, - Used: d.Used >> 30, - Usage: float32(d.UsedPercent), - }, + if s.lastSummary != nil && len(s.lastSummary.NetWork) > i { + info.ReceiveSpeed = (n.BytesRecv - s.lastSummary.NetWork[i].Receive) / uint64(dur.Seconds()) + info.SentSpeed = (n.BytesSent - s.lastSummary.NetWork[i].Sent) / uint64(dur.Seconds()) } - if cc, _ := cpu.Percent(time.Second, false); len(cc) > 0 { - res.CpuUsage = float32(cc[0]) - } - netWorks := []*pb.NetWorkInfo{} - for i, n := range nv { - info := &pb.NetWorkInfo{ - Name: n.Name, - Receive: n.BytesRecv, - Sent: n.BytesSent, - } - if s.lastSummary != nil && len(s.lastSummary.NetWork) > i { - info.ReceiveSpeed = (n.BytesRecv - s.lastSummary.NetWork[i].Receive) / uint64(dur.Seconds()) - info.SentSpeed = (n.BytesSent - s.lastSummary.NetWork[i].Sent) / uint64(dur.Seconds()) - } - netWorks = append(netWorks, info) - } - res.StreamCount = int32(s.Streams.Length) - res.PullCount = int32(s.Pulls.Length) - res.PushCount = int32(s.Pushs.Length) - res.SubscribeCount = int32(s.Subscribers.Length) - res.NetWork = netWorks - s.lastSummary = res - s.lastSummaryTime = time.Now() - }) + netWorks = append(netWorks, info) + } + res.StreamCount = int32(s.Streams.Length) + res.PullCount = int32(s.Pulls.Length) + res.PushCount = int32(s.Pushs.Length) + res.SubscribeCount = int32(s.Subscribers.Length) + res.NetWork = netWorks + s.lastSummary = res + s.lastSummaryTime = time.Now() return } diff --git a/example/default/config.yaml b/example/default/config.yaml index 400ef00..203ba5f 100644 --- a/example/default/config.yaml +++ b/example/default/config.yaml @@ -1,5 +1,7 @@ global: loglevel: trace + http: + listenaddr: :8082 # enableauth: true # tcp: # listenaddr: :50051 diff --git a/pkg/config/types.go b/pkg/config/types.go index bdf05fb..e517640 100755 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -159,7 +159,8 @@ func (p *Push) CheckPush(streamPath string) string { type Record struct { EnableRegexp bool `desc:"是否启用正则表达式"` // 是否启用正则表达式 RecordList map[string]string - Fragment time.Duration `desc:"分片时长"` // 分片时长 + Fragment time.Duration `desc:"分片时长"` // 分片时长 + Append bool `desc:"是否追加录制"` // 是否追加录制 } func (p *Record) GetRecordConfig() *Record { diff --git a/pkg/error.go b/pkg/error.go index cfe6efb..debf236 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -11,10 +11,13 @@ var ( ErrPublishTimeout = errors.New("publish timeout") ErrPublishIdleTimeout = errors.New("publish idle timeout") ErrPublishDelayCloseTimeout = errors.New("publish delay close timeout") + ErrPushRemoteURLExist = errors.New("push remote url exist") ErrSubscribeTimeout = errors.New("subscribe timeout") ErrRestart = errors.New("restart") ErrInterrupt = errors.New("interrupt") ErrUnsupportCodec = errors.New("unsupport codec") ErrMuted = errors.New("muted") - ErrorLost = errors.New("lost") + ErrLost = errors.New("lost") + ErrRetryRunOut = errors.New("retry run out") + ErrRecordSamePath = errors.New("record same path") ) diff --git a/pkg/task.go b/pkg/task.go new file mode 100644 index 0000000..b508e02 --- /dev/null +++ b/pkg/task.go @@ -0,0 +1,203 @@ +package pkg + +import ( + "context" + "io" + "log/slog" + "m7s.live/m7s/v5/pkg/util" + "reflect" + "slices" + "sync/atomic" + "time" +) + +const TraceLevel = slog.Level(-8) + +type TaskExecutor interface { + Start() error + Dispose() +} + +type Task struct { + ID uint32 + StartTime time.Time + *slog.Logger + context.Context + context.CancelCauseFunc + Executor TaskExecutor + started *util.Promise +} + +func (task *Task) GetTask() *Task { + return task +} + +func (task *Task) GetKey() uint32 { + return task.ID +} + +func (task *Task) Begin() (err error) { + task.StartTime = time.Now() + err = task.Executor.Start() + task.started.Fulfill(err) + return +} + +func (task *Task) WaitStarted() error { + return task.started.Await() +} + +func (task *Task) Trace(msg string, fields ...any) { + task.Log(task.Context, TraceLevel, msg, fields...) +} + +func (task *Task) IsStopped() bool { + return task.Err() != nil +} + +func (task *Task) StopReason() error { + return context.Cause(task.Context) +} + +func (task *Task) Stop(err error) { + if task.CancelCauseFunc != nil && !task.IsStopped() { + task.Info("stop", "reason", err.Error()) + task.CancelCauseFunc(err) + } +} + +func (task *Task) Init(ctx context.Context, logger *slog.Logger) { + task.Logger = logger + task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx) + task.started = util.NewPromise(task.Context) +} + +type CallBackTaskExecutor func() + +func (call CallBackTaskExecutor) Start() error { + call() + return io.EOF +} + +func (call CallBackTaskExecutor) Dispose() { + // nothing to do, never called +} + +type TaskManager struct { + shutdown *util.Promise + stopReason error + start chan *Task + Tasks []*Task + idG atomic.Uint32 +} + +func NewTaskManager() *TaskManager { + return &TaskManager{ + shutdown: util.NewPromise(context.TODO()), + start: make(chan *Task, 10), + } +} + +func (t *TaskManager) Add(task *Task) { + t.start <- task +} + +func (t *TaskManager) Call(callback CallBackTaskExecutor) { + var tmpTask Task + tmpTask.Init(context.TODO(), nil) + tmpTask.Executor = callback + _ = t.Start(&tmpTask) +} + +func (t *TaskManager) Start(task *Task) error { + t.start <- task + return task.WaitStarted() +} + +func (t *TaskManager) GetID() uint32 { + return t.idG.Add(1) +} + +// Run task Start and Dispose in this goroutine +func (t *TaskManager) Run(extra ...any) { + cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.start)}} + extraLen := len(extra) / 2 + var callbacks []reflect.Value + for i := range extraLen { + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])}) + callbacks = append(callbacks, reflect.ValueOf(extra[i*2+1])) + } + defer func() { + cases = slices.Delete(cases, 0, 1+extraLen) + for len(cases) > 0 { + chosen, _, _ := reflect.Select(cases) + task := t.Tasks[chosen] + task.Executor.Dispose() + t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1) + cases = slices.Delete(cases, chosen, chosen+1) + } + t.shutdown.Fulfill(t.stopReason) + }() + for { + if chosen, rev, ok := reflect.Select(cases); chosen == 0 { + if !ok { + return + } + task := rev.Interface().(*Task) + if err := task.Begin(); err == nil { + t.Tasks = append(t.Tasks, task) + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())}) + } else { + task.Stop(err) + } + } else if chosen <= extraLen { + callbacks[chosen-1].Call([]reflect.Value{rev}) + } else { + taskIndex := chosen - 1 - extraLen + task := t.Tasks[taskIndex] + task.Executor.Dispose() + t.Tasks = slices.Delete(t.Tasks, taskIndex, taskIndex+1) + cases = slices.Delete(cases, chosen, chosen+1) + } + } +} + +// Run task Start and Dispose in another goroutine +//func (t *TaskManager) Run() { +// cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.Start)}} +// defer func() { +// cases = slices.Delete(cases, 0, 1) +// for len(cases) > 0 { +// chosen, _, _ := reflect.Select(cases) +// t.Done <- t.Tasks[chosen] +// t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1) +// cases = slices.Delete(cases, chosen, chosen+1) +// } +// close(t.Done) +// }() +// for { +// if chosen, rev, ok := reflect.Select(cases); chosen == 0 { +// if !ok { +// return +// } +// task := rev.Interface().(*Task) +// t.Tasks = append(t.Tasks, task) +// cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())}) +// } else { +// t.Done <- t.Tasks[chosen-1] +// t.Tasks = slices.Delete(t.Tasks, chosen-1, chosen) +// cases = slices.Delete(cases, chosen, chosen+1) +// } +// } +//} + +// ShutDown wait all task dispose +func (t *TaskManager) ShutDown(err error) { + t.Stop(err) + _ = t.shutdown.Await() +} + +func (t *TaskManager) Stop(err error) { + t.stopReason = err + close(t.start) +} diff --git a/pkg/track.go b/pkg/track.go index 673af12..3ac7a32 100644 --- a/pkg/track.go +++ b/pkg/track.go @@ -14,7 +14,7 @@ import ( type ( Track struct { *slog.Logger - ready *util.Promise[struct{}] + ready *util.Promise FrameType reflect.Type bytesIn int frameCount int @@ -55,7 +55,7 @@ func NewAVTrack(args ...any) (t *AVTrack) { t.RingWriter = NewRingWriter(v.RingSize) t.BufferRange[0] = v.BufferTime t.RingWriter.SLogger = t.Logger - case *util.Promise[struct{}]: + case *util.Promise: t.ready = v } } @@ -112,8 +112,7 @@ func (t *Track) IsReady() bool { } func (t *Track) WaitReady() error { - _, err := t.ready.Await() - return err + return t.ready.Await() } func (t *Track) Trace(msg string, fields ...any) { diff --git a/pkg/unit.go b/pkg/unit.go deleted file mode 100644 index 8afd124..0000000 --- a/pkg/unit.go +++ /dev/null @@ -1,34 +0,0 @@ -package pkg - -import ( - "context" - "log/slog" - "time" -) - -const TraceLevel = slog.Level(-8) - -type Unit[T any] struct { - ID T - StartTime time.Time - *slog.Logger - context.Context - context.CancelCauseFunc -} - -func (unit *Unit[T]) Trace(msg string, fields ...any) { - unit.Log(unit.Context, TraceLevel, msg, fields...) -} - -func (unit *Unit[T]) IsStopped() bool { - return unit.StopReason() != nil -} - -func (unit *Unit[T]) StopReason() error { - return context.Cause(unit.Context) -} - -func (unit *Unit[T]) Stop(err error) { - unit.Info("stop", "reason", err.Error()) - unit.CancelCauseFunc(err) -} diff --git a/pkg/util/promise.go b/pkg/util/promise.go index 4e2308f..4902792 100644 --- a/pkg/util/promise.go +++ b/pkg/util/promise.go @@ -6,21 +6,21 @@ import ( "time" ) -type Promise[T any] struct { +type Promise struct { context.Context context.CancelCauseFunc - Value T timer *time.Timer } -func NewPromise[T any](v T) *Promise[T] { - p := &Promise[T]{Value: v} - p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background()) +func NewPromise(ctx context.Context) *Promise { + p := &Promise{} + p.Context, p.CancelCauseFunc = context.WithCancelCause(ctx) return p } -func NewPromiseWithTimeout[T any](v T, timeout time.Duration) *Promise[T] { - p := &Promise[T]{Value: v} - p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background()) + +func NewPromiseWithTimeout(ctx context.Context, timeout time.Duration) *Promise { + p := &Promise{} + p.Context, p.CancelCauseFunc = context.WithCancelCause(ctx) p.timer = time.AfterFunc(timeout, func() { p.CancelCauseFunc(ErrTimeout) }) @@ -30,27 +30,30 @@ func NewPromiseWithTimeout[T any](v T, timeout time.Duration) *Promise[T] { var ErrResolve = errors.New("promise resolved") var ErrTimeout = errors.New("promise timeout") -func (p *Promise[T]) Resolve(v T) { - p.Value = v - p.CancelCauseFunc(ErrResolve) +func (p *Promise) Resolve() { + p.Fulfill(nil) } -func (p *Promise[T]) Await() (T, error) { +func (p *Promise) Reject(err error) { + p.Fulfill(err) +} + +func (p *Promise) Await() (err error) { <-p.Done() - err := context.Cause(p.Context) + err = context.Cause(p.Context) if errors.Is(err, ErrResolve) { err = nil } - return p.Value, err + return } -func (p *Promise[T]) Fulfill(err error) { +func (p *Promise) Fulfill(err error) { if p.timer != nil { p.timer.Stop() } p.CancelCauseFunc(Conditoinal(err == nil, ErrResolve, err)) } -func (p *Promise[T]) IsPending() bool { +func (p *Promise) IsPending() bool { return context.Cause(p.Context) == nil } diff --git a/plugin.go b/plugin.go index 8e6d2e8..5f473cb 100644 --- a/plugin.go +++ b/plugin.go @@ -2,9 +2,14 @@ package m7s import ( "context" + gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" myip "github.com/husanpao/ip" + "google.golang.org/grpc" + "gopkg.in/yaml.v3" "gorm.io/gorm" "log/slog" + . "m7s.live/m7s/v5/pkg" + "m7s.live/m7s/v5/pkg/config" "m7s.live/m7s/v5/pkg/db" "net" "net/http" @@ -13,13 +18,6 @@ import ( "reflect" "runtime" "strings" - - gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "google.golang.org/grpc" - "gopkg.in/yaml.v3" - . "m7s.live/m7s/v5/pkg" - "m7s.live/m7s/v5/pkg/config" - "m7s.live/m7s/v5/pkg/util" ) type DefaultYaml string @@ -33,17 +31,17 @@ type PluginMeta struct { RegisterGRPCHandler func(context.Context, *gatewayRuntime.ServeMux, *grpc.ClientConn) error } -func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) { +func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) { instance, ok := reflect.New(plugin.Type).Interface().(IPlugin) if !ok { panic("plugin must implement IPlugin") } - p := reflect.ValueOf(instance).Elem().FieldByName("Plugin").Addr().Interface().(*Plugin) + p = reflect.ValueOf(instance).Elem().FieldByName("Plugin").Addr().Interface().(*Plugin) p.handler = instance p.Meta = plugin + p.Executor = instance p.Server = s - p.Logger = s.Logger.With("plugin", plugin.Name) - p.Context, p.CancelCauseFunc = context.WithCancelCause(s.Context) + p.Task.Init(s.Context, s.Logger.With("plugin", plugin.Name)) upperName := strings.ToUpper(plugin.Name) if os.Getenv(upperName+"_ENABLE") == "false" { p.Disabled = true @@ -93,29 +91,12 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) { if err != nil { s.Error("failed to connect database", "error", err, "dsn", s.config.DSN, "type", s.config.DBType) p.Disabled = true + p.Stop(err) return } } } - err = instance.OnInit() - if err != nil { - p.Error("init", "error", err) - p.Stop(err) - return - } - if plugin.ServiceDesc != nil && s.grpcServer != nil { - s.grpcServer.RegisterService(plugin.ServiceDesc, instance) - if plugin.RegisterGRPCHandler != nil { - if err = plugin.RegisterGRPCHandler(p.Context, s.config.HTTP.GetGRPCMux(), s.grpcClientConn); err != nil { - p.Error("init", "error", err) - p.Stop(err) - } else { - p.Info("grpc handler registered") - } - } - } - s.Plugins.Add(p) - p.Start() + return } type iPlugin interface { @@ -123,8 +104,8 @@ type iPlugin interface { } type IPlugin interface { + TaskExecutor OnInit() error - OnEvent(any) OnExit() } @@ -133,16 +114,16 @@ type IRegisterHandler interface { } type IPullerPlugin interface { - NewPullHandler() PullHandler + DoPull(*PullContext) error GetPullableList() []string } type IPusherPlugin interface { - NewPushHandler() PushHandler + DoPush(*PushContext) error } type IRecorderPlugin interface { - NewRecordHandler() RecordHandler + DoRecord(*RecordContext) error } type ITCPPlugin interface { @@ -186,7 +167,7 @@ func InstallPlugin[C iPlugin](options ...any) error { } type Plugin struct { - Unit[int] + Task Disabled bool Meta *PluginMeta config config.Common @@ -252,13 +233,36 @@ func (p *Plugin) assign() { p.registerHandler(handlerMap) } -func (p *Plugin) Stop(err error) { - p.Unit.Stop(err) +func (p *Plugin) Start() (err error) { + s := p.Server + err = p.handler.OnInit() + if err != nil { + p.Error("init", "error", err) + return + } + if p.Meta.ServiceDesc != nil && s.grpcServer != nil { + s.grpcServer.RegisterService(p.Meta.ServiceDesc, p.handler) + if p.Meta.RegisterGRPCHandler != nil { + if err = p.Meta.RegisterGRPCHandler(p.Context, s.config.HTTP.GetGRPCMux(), s.grpcClientConn); err != nil { + p.Error("init", "error", err) + return + } else { + p.Info("grpc handler registered") + } + } + } + s.Plugins.Add(p) + p.listen() + return +} + +func (p *Plugin) Dispose() { + p.Server.Plugins.Remove(p) p.config.HTTP.StopListen() p.config.TCP.StopListen() } -func (p *Plugin) Start() { +func (p *Plugin) listen() { httpConf := &p.config.HTTP if httpConf.ListenAddrTLS != "" && (httpConf.ListenAddrTLS != p.Server.config.HTTP.ListenAddrTLS) { p.Info("https listen at ", "addr", httpConf.ListenAddrTLS) @@ -272,50 +276,44 @@ func (p *Plugin) Start() { p.Stop(httpConf.Listen()) }() } - tcpConf := &p.config.TCP - tcphandler, ok := p.handler.(ITCPPlugin) - if !ok { - tcphandler = p + if tcphandler, ok := p.handler.(ITCPPlugin); ok { + tcpConf := &p.config.TCP + if tcpConf.ListenAddr != "" && tcpConf.AutoListen { + p.Info("listen tcp", "addr", tcpConf.ListenAddr) + go func() { + err := tcpConf.Listen(tcphandler.OnTCPConnect) + if err != nil { + p.Error("listen tcp", "addr", tcpConf.ListenAddr, "error", err) + p.Stop(err) + } + }() + } + if tcpConf.ListenAddrTLS != "" && tcpConf.AutoListen { + p.Info("listen tcp tls", "addr", tcpConf.ListenAddrTLS) + go func() { + err := tcpConf.ListenTLS(tcphandler.OnTCPConnect) + if err != nil { + p.Error("listen tcp tls", "addr", tcpConf.ListenAddrTLS, "error", err) + p.Stop(err) + } + }() + } } - if tcpConf.ListenAddr != "" && tcpConf.AutoListen { - p.Info("listen tcp", "addr", tcpConf.ListenAddr) - go func() { - err := tcpConf.Listen(tcphandler.OnTCPConnect) - if err != nil { - p.Error("listen tcp", "addr", tcpConf.ListenAddr, "error", err) - p.Stop(err) - } - }() - } - if tcpConf.ListenAddrTLS != "" && tcpConf.AutoListen { - p.Info("listen tcp tls", "addr", tcpConf.ListenAddrTLS) - go func() { - err := tcpConf.ListenTLS(tcphandler.OnTCPConnect) - if err != nil { - p.Error("listen tcp tls", "addr", tcpConf.ListenAddrTLS, "error", err) - p.Stop(err) - } - }() - } - udpConf := &p.config.UDP - - udpHandler, ok := p.handler.(IUDPPlugin) - if !ok { - udpHandler = p - } - - if udpConf.ListenAddr != "" && udpConf.AutoListen { - p.Info("listen udp", "addr", udpConf.ListenAddr) - go func() { - err := udpConf.Listen(udpHandler.OnUDPConnect) - if err != nil { - p.Error("listen udp", "addr", udpConf.ListenAddr, "error", err) - p.Stop(err) - } - }() + if udpHandler, ok := p.handler.(IUDPPlugin); ok { + udpConf := &p.config.UDP + if udpConf.ListenAddr != "" && udpConf.AutoListen { + p.Info("listen udp", "addr", udpConf.ListenAddr) + go func() { + err := udpConf.Listen(udpHandler.OnUDPConnect) + if err != nil { + p.Error("listen udp", "addr", udpConf.ListenAddr, "error", err) + p.Stop(err) + } + }() + } } } @@ -327,140 +325,76 @@ func (p *Plugin) OnExit() { } -func (p *Plugin) onEvent(event any) { - switch v := event.(type) { - case *Publisher: - if h, ok := p.handler.(interface{ OnPublish(*Publisher) }); ok { - h.OnPublish(v) - } - case *Puller: - if h, ok := p.handler.(interface{ OnPull(*Puller) }); ok { - h.OnPull(v) - } - } - p.handler.OnEvent(event) -} - -func (p *Plugin) OnEvent(event any) { - -} - -func (p *Plugin) OnTCPConnect(conn *net.TCPConn) { - p.handler.OnEvent(conn) -} - -func (p *Plugin) OnUDPConnect(conn *net.UDPConn) { - p.handler.OnEvent(conn) -} - func (p *Plugin) Publish(streamPath string, options ...any) (publisher *Publisher, err error) { - publisher = &Publisher{Publish: p.config.Publish} + publisher = createPublisher(p, streamPath, options...) if p.config.EnableAuth { if onAuthPub, ok := p.Server.OnAuthPubs[p.Meta.Name]; ok { - authPromise := util.NewPromise(publisher) - onAuthPub(authPromise) - if _, err = authPromise.Await(); err != nil { + if err = onAuthPub(publisher).Await(); err != nil { p.Warn("auth failed", "error", err) return } } } - for _, option := range options { - switch v := option.(type) { - case func(*config.Publish): - v(&publisher.Publish) - } - } - publisher.Init(p, streamPath, &publisher.Publish, options...) - _, err = p.Server.Call(publisher) - return -} - -func (p *Plugin) Pull(streamPath string, url string, options ...any) (puller *Puller, err error) { - puller = &Puller{Pull: p.config.Pull} - puller.Client.Proxy = p.config.Pull.Proxy - puller.Client.RemoteURL = url - puller.Client.PubSubBase = &puller.PubSubBase - puller.Publish = p.config.Publish - puller.PublishTimeout = 0 - puller.StreamPath = streamPath - var pullHandler PullHandler - for _, option := range options { - switch v := option.(type) { - case PullHandler: - pullHandler = v - } - } - puller.Init(p, streamPath, &puller.Publish, options...) - if _, err = p.Server.Call(puller); err != nil { - return - } - if v, ok := p.handler.(IPullerPlugin); pullHandler == nil && ok { - pullHandler = v.NewPullHandler() - } - if pullHandler != nil { - err = puller.Start(pullHandler) - } - return -} - -func (p *Plugin) Record(streamPath string, filePath string, options ...any) (recorder *Recorder, err error) { - recorder = &Recorder{ - Record: p.config.Record, - } - if err = os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - return - } - recorder.StreamPath = streamPath - recorder.Subscribe = p.config.Subscribe - if recorder.File, err = os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil { - return - } - defer func() { - err = recorder.File.Close() - if info, err := recorder.File.Stat(); err == nil && info.Size() == 0 { - os.Remove(recorder.File.Name()) - } - }() - recorder.Init(p, streamPath, &recorder.Subscribe, options...) - if _, err = p.Server.Call(recorder); err != nil { - return - } - recorder.Publisher.WaitTrack() - var recordHandler RecordHandler - if v, ok := p.handler.(IRecorderPlugin); recordHandler == nil && ok { - recordHandler = v.NewRecordHandler() - } - if recordHandler != nil { - err = recorder.Start(recordHandler) - } + err = p.Server.streamTM.Start(&publisher.Task) return } func (p *Plugin) Subscribe(streamPath string, options ...any) (subscriber *Subscriber, err error) { - subscriber = &Subscriber{Subscribe: p.config.Subscribe} + subscriber = createSubscriber(p, streamPath, options...) if p.config.EnableAuth { if onAuthSub, ok := p.Server.OnAuthSubs[p.Meta.Name]; ok { - authPromise := util.NewPromise(subscriber) - onAuthSub(authPromise) - if _, err = authPromise.Await(); err != nil { + if err = onAuthSub(subscriber).Await(); err != nil { p.Warn("auth failed", "error", err) return } } } - for _, option := range options { - switch v := option.(type) { - case func(*config.Subscribe): - v(&subscriber.Subscribe) - } + err = p.Server.streamTM.Start(&subscriber.Task) + err = subscriber.Publisher.WaitTrack() + return +} + +func (p *Plugin) Pull(streamPath string, url string, options ...any) (puller *PullContext, err error) { + puller = createPullContext(p, streamPath, url, options...) + if err = p.Server.pullTM.Start(&puller.Task); err != nil { + return } - subscriber.Init(p, streamPath, &subscriber.Subscribe, options...) - if subscriber.Subscribe.BufferTime > 0 { - subscriber.Subscribe.SubMode = SUBMODE_BUFFER + if pullPlugin, ok := p.handler.(IPullerPlugin); ok { + puller.Run(pullPlugin.DoPull) + } + return +} + +func (p *Plugin) Push(streamPath string, url string, options ...any) (pusher *PushContext, err error) { + pusher = createPushContext(p, streamPath, url, options...) + if err = p.Server.pushTM.Start(&pusher.Task); err != nil { + return + } + if pushPlugin, ok := p.handler.(IPusherPlugin); ok { + pusher.Run(pushPlugin.DoPush) + } + return +} + +func (p *Plugin) Record(streamPath string, filePath string, options ...any) (recorder *RecordContext, err error) { + recorder = createRecoder(p, streamPath, filePath, options...) + dir := filePath + if filepath.Ext(filePath) != "" { + dir = filepath.Dir(filePath) + } + if err = os.MkdirAll(dir, 0755); err != nil { + return + } + recorder.Subscriber, err = p.Subscribe(streamPath, p.config.Subscribe) + if err != nil { + return + } + if err = p.Server.recordTM.Start(&recorder.Task); err != nil { + return + } + if recordPlugin, ok := p.handler.(IRecorderPlugin); ok { + recorder.Run(recordPlugin.DoRecord) } - _, err = p.Server.Call(subscriber) - subscriber.Publisher.WaitTrack() return } @@ -487,34 +421,6 @@ func (p *Plugin) registerHandler(handlers map[string]http.HandlerFunc) { } } -func (p *Plugin) Push(streamPath string, url string, options ...any) (pusher *Pusher, err error) { - pusher = &Pusher{Push: p.config.Push} - pusher.Client.PubSubBase = &pusher.PubSubBase - pusher.Client.Proxy = p.config.Push.Proxy - pusher.Client.RemoteURL = url - pusher.Subscribe = p.config.Subscribe - pusher.StreamPath = streamPath - var pushHandler PushHandler - for _, option := range options { - switch v := option.(type) { - case PushHandler: - pushHandler = v - } - } - pusher.Init(p, streamPath, &pusher.Subscribe, options...) - if _, err = p.Server.Call(pusher); err != nil { - return - } - pusher.Publisher.WaitTrack() - if v, ok := p.handler.(IPusherPlugin); pushHandler == nil && ok { - pushHandler = v.NewPushHandler() - } - if pushHandler != nil { - err = pusher.Start(pushHandler) - } - return -} - func (p *Plugin) logHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { p.Debug("visit", "path", r.URL.String(), "remote", r.RemoteAddr) @@ -546,28 +452,21 @@ func (p *Plugin) AddLogHandler(handler slog.Handler) { p.Server.LogHandler.Add(handler) } -func (p *Plugin) PostToServer(event any) { - if p.Server.eventChan == nil { - panic("eventChan is nil") - } - p.Server.PostMessage(event) -} - func (p *Plugin) SaveConfig() (err error) { - _, err = p.Server.Call(func() error { + p.Server.pluginTM.Call(func() { if p.Modify == nil { os.Remove(p.settingPath()) - return nil + return } - file, err := os.OpenFile(p.settingPath(), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) - if err == nil { - defer file.Close() - err = yaml.NewEncoder(file).Encode(p.Modify) + var file *os.File + if file, err = os.OpenFile(p.settingPath(), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil { + return } - if err == nil { - p.Info("config saved") - } - return err + defer file.Close() + err = yaml.NewEncoder(file).Encode(p.Modify) }) + if err == nil { + p.Info("config saved") + } return } diff --git a/plugin/flv/index.go b/plugin/flv/index.go index 7a311af..269b297 100644 --- a/plugin/flv/index.go +++ b/plugin/flv/index.go @@ -27,10 +27,14 @@ func (p *FLVPlugin) OnInit() error { return nil } -var _ = m7s.InstallPlugin[FLVPlugin](defaultConfig, NewPullHandler) +var _ = m7s.InstallPlugin[FLVPlugin](defaultConfig) -func (p *FLVPlugin) NewRecordHandler() m7s.RecordHandler { - return &Recorder{} +func (p *FLVPlugin) DoPull(pull *m7s.PullContext) error { + return PullFLV(pull) +} + +func (p *FLVPlugin) DoRecord(ctx *m7s.RecordContext) error { + return RecordFlv(ctx) } func (p *FLVPlugin) WriteFlvHeader(sub *m7s.Subscriber) (flv net.Buffers) { @@ -95,7 +99,6 @@ func (p *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { if hijacker, ok := w.(http.Hijacker); ok && wto > 0 { conn, _, _ := hijacker.Hijack() conn.SetWriteDeadline(time.Now().Add(wto)) - sub.Closer = conn gotFlvTag = func(flv net.Buffers) (err error) { conn.SetWriteDeadline(time.Now().Add(wto)) _, err = flv.WriteTo(conn) diff --git a/plugin/flv/pkg/flv.go b/plugin/flv/pkg/flv.go index f3c026a..5a8974f 100644 --- a/plugin/flv/pkg/flv.go +++ b/plugin/flv/pkg/flv.go @@ -1,8 +1,10 @@ package flv import ( + "bufio" "io" "m7s.live/m7s/v5/pkg/util" + rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg" "net" ) @@ -13,6 +15,8 @@ const ( FLV_TAG_TYPE_SCRIPT = 0x12 ) +var FLVHead = []byte{'F', 'L', 'V', 0x01, 0x05, 0, 0, 0, 9, 0, 0, 0, 0} + func AVCC2FLV(t byte, ts uint32, avcc ...[]byte) (flv net.Buffers) { b := util.Buffer(make([]byte, 0, 15)) b.WriteByte(t) @@ -30,8 +34,33 @@ func WriteFLVTagHead(t uint8, ts, dataSize uint32, b []byte) { b[4], b[5], b[6], b[7] = byte(ts>>16), byte(ts>>8), byte(ts), byte(ts>>24) } -func WriteFLVTag(w io.Writer, t byte, timestamp uint32, payload []byte) (err error) { - buffers := AVCC2FLV(t, timestamp, payload) +func WriteFLVTag(w io.Writer, t byte, timestamp uint32, payload ...[]byte) (err error) { + buffers := AVCC2FLV(t, timestamp, payload...) _, err = buffers.WriteTo(w) return } + +func ReadMetaData(reader io.Reader) (metaData rtmp.EcmaArray, err error) { + r := bufio.NewReader(reader) + _, err = r.Discard(13) + tagHead := make(util.Buffer, 11) + _, err = io.ReadFull(r, tagHead) + if err != nil { + return + } + tmp := tagHead + t := tmp.ReadByte() + dataLen := tmp.ReadUint24() + _, err = r.Discard(4) + if t == FLV_TAG_TYPE_SCRIPT { + data := make([]byte, dataLen+4) + _, err = io.ReadFull(reader, data) + amf := &rtmp.AMF{ + Buffer: util.Buffer(data[1+2+len("onMetaData") : len(data)-4]), + } + var obj any + obj, err = amf.Unmarshal() + metaData = obj.(rtmp.EcmaArray) + } + return +} diff --git a/plugin/flv/pkg/pull.go b/plugin/flv/pkg/pull.go index 0524bae..b9f26d6 100644 --- a/plugin/flv/pkg/pull.go +++ b/plugin/flv/pkg/pull.go @@ -13,29 +13,14 @@ import ( rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg" ) -type FLVPuller struct { - *util.BufReader - *util.ScalableMemoryAllocator - hasAudio bool - hasVideo bool - absTS uint32 //绝对时间戳 -} - -func NewFLVPuller() *FLVPuller { - return &FLVPuller{ - ScalableMemoryAllocator: util.NewScalableMemoryAllocator(1 << 10), - } -} - -func NewPullHandler() m7s.PullHandler { - return NewFLVPuller() -} - -func (puller *FLVPuller) Connect(p *m7s.Client) (err error) { +func PullFLV(p *m7s.PullContext) (err error) { + var reader *util.BufReader + var hasAudio, hasVideo bool + var absTS uint32 if strings.HasPrefix(p.RemoteURL, "http") { var res *http.Response client := http.DefaultClient - if proxyConf := p.Proxy; proxyConf != "" { + if proxyConf := p.ConnectProxy; proxyConf != "" { proxy, err := url.Parse(proxyConf) if err != nil { return err @@ -47,19 +32,19 @@ func (puller *FLVPuller) Connect(p *m7s.Client) (err error) { if res.StatusCode != http.StatusOK { return io.EOF } - p.Closer = res.Body - puller.BufReader = util.NewBufReader(res.Body) + defer res.Body.Close() + reader = util.NewBufReader(res.Body) } } else { var res *os.File if res, err = os.Open(p.RemoteURL); err == nil { - p.Closer = res - puller.BufReader = util.NewBufReader(res) + defer res.Close() + reader = util.NewBufReader(res) } } if err == nil { var head util.Memory - head, err = puller.BufReader.ReadBytes(13) + head, err = reader.ReadBytes(13) if err == nil { var flvHead [3]byte var version, flag byte @@ -68,37 +53,35 @@ func (puller *FLVPuller) Connect(p *m7s.Client) (err error) { if flvHead != [...]byte{'F', 'L', 'V'} { err = errors.New("not flv file") } else { - puller.hasAudio = flag&0x04 != 0 - puller.hasVideo = flag&0x01 != 0 + hasAudio = flag&0x04 != 0 + hasVideo = flag&0x01 != 0 } } } - return -} -func (puller *FLVPuller) Pull(p *m7s.Puller) (err error) { var startTs uint32 - pubConf := p.GetPublishConfig() - if !puller.hasAudio { + pubConf := p.Publisher.GetPublishConfig() + if !hasAudio { pubConf.PubAudio = false } - if !puller.hasVideo { + if !hasVideo { pubConf.PubVideo = false } - for offsetTs := puller.absTS; err == nil; _, err = puller.ReadBE(4) { - t, err := puller.ReadByte() + allocator := util.NewScalableMemoryAllocator(1 << 10) + for offsetTs := absTS; err == nil; _, err = reader.ReadBE(4) { + t, err := reader.ReadByte() if err != nil { return err } - dataSize, err := puller.ReadBE32(3) + dataSize, err := reader.ReadBE32(3) if err != nil { return err } - timestamp, err := puller.ReadBE32(3) + timestamp, err := reader.ReadBE32(3) if err != nil { return err } - h, err := puller.ReadByte() + h, err := reader.ReadByte() if err != nil { return err } @@ -106,28 +89,28 @@ func (puller *FLVPuller) Pull(p *m7s.Puller) (err error) { if startTs == 0 { startTs = timestamp } - if _, err = puller.ReadBE(3); err != nil { // stream id always 0 + if _, err = reader.ReadBE(3); err != nil { // stream id always 0 return err } var frame rtmp.RTMPData switch ds := int(dataSize); t { case FLV_TAG_TYPE_AUDIO, FLV_TAG_TYPE_VIDEO: - frame.SetAllocator(puller.ScalableMemoryAllocator) - err = puller.ReadNto(ds, frame.NextN(ds)) + frame.SetAllocator(allocator) + err = reader.ReadNto(ds, frame.NextN(ds)) default: - err = puller.Skip(ds) + err = reader.Skip(ds) } if err != nil { return err } - puller.absTS = offsetTs + (timestamp - startTs) - frame.Timestamp = puller.absTS + absTS = offsetTs + (timestamp - startTs) + frame.Timestamp = absTS //fmt.Println(t, offsetTs, timestamp, startTs, puller.absTS) switch t { case FLV_TAG_TYPE_AUDIO: - p.WriteAudio(frame.WrapAudio()) + err = p.Publisher.WriteAudio(frame.WrapAudio()) case FLV_TAG_TYPE_VIDEO: - p.WriteVideo(frame.WrapVideo()) + err = p.Publisher.WriteVideo(frame.WrapVideo()) case FLV_TAG_TYPE_SCRIPT: p.Info("script") } diff --git a/plugin/flv/pkg/record.go b/plugin/flv/pkg/record.go index 29c3b6d..9de9c55 100644 --- a/plugin/flv/pkg/record.go +++ b/plugin/flv/pkg/record.go @@ -7,111 +7,179 @@ import ( "m7s.live/m7s/v5/pkg/util" rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg" "os" + "slices" + "time" ) -type Recorder struct { - *m7s.Subscriber - filepositions []uint64 - times []float64 - Offset int64 - duration int64 -} - -func (r *Recorder) Record(recorder *m7s.Recorder) (err error) { - return -} - -func (r *Recorder) Close() { - -} - -func (r *Recorder) writeMetaData(file util.ReadWriteSeekCloser, duration int64) { - defer file.Close() - at, vt := r.AudioReader, r.VideoReader - hasAudio, hasVideo := at != nil, vt != nil - var amf rtmp.AMF - metaData := rtmp.EcmaArray{ - "MetaDataCreator": "m7s/" + m7s.Version, - "hasVideo": hasVideo, - "hasAudio": hasAudio, - "hasMatadata": true, - "canSeekToEnd": true, - "duration": float64(duration) / 1000, - "hasKeyFrames": len(r.filepositions) > 0, - "filesize": 0, - } - var flags byte - if hasAudio { - ctx := at.Track.ICodecCtx.GetBase().(pkg.IAudioCodecCtx) - flags |= (1 << 2) - metaData["audiocodecid"] = int(rtmp.ParseAudioCodec(ctx.FourCC())) - metaData["audiosamplerate"] = ctx.GetSampleRate() - metaData["audiosamplesize"] = ctx.GetSampleSize() - metaData["stereo"] = ctx.GetChannels() == 2 - } - if hasVideo { - ctx := vt.Track.ICodecCtx.GetBase().(pkg.IVideoCodecCtx) - flags |= 1 - metaData["videocodecid"] = int(rtmp.ParseVideoCodec(ctx.FourCC())) - metaData["width"] = ctx.Width() - metaData["height"] = ctx.Height() - metaData["framerate"] = vt.Track.FPS - metaData["videodatarate"] = vt.Track.BPS - metaData["keyframes"] = map[string]any{ - "filepositions": r.filepositions, - "times": r.times, - } - defer func() { - r.filepositions = []uint64{0} - r.times = []float64{0} - }() - } - amf.Marshals("onMetaData", metaData) - offset := amf.Len() + 13 + 15 - if keyframesCount := len(r.filepositions); keyframesCount > 0 { - metaData["filesize"] = uint64(offset) + r.filepositions[keyframesCount-1] - for i := range r.filepositions { - r.filepositions[i] += uint64(offset) - } - metaData["keyframes"] = map[string]any{ - "filepositions": r.filepositions, - "times": r.times, - } - } - - if tempFile, err := os.CreateTemp("", "*.flv"); err != nil { - r.Error("create temp file failed", "err", err) +func RecordFlv(ctx *m7s.RecordContext) (err error) { + var file *os.File + var filepositions []uint64 + var times []float64 + var offset int64 + var duration int64 + if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR|util.Conditoinal(ctx.Append, os.O_APPEND, os.O_TRUNC), 0666); err != nil { return - } else { + } + suber := ctx.Subscriber + ar, vr := suber.AudioReader, suber.VideoReader + hasAudio, hasVideo := ar != nil, vr != nil + writeMetaTag := func() { defer func() { - tempFile.Close() - os.Remove(tempFile.Name()) - r.Info("writeMetaData success") + err = file.Close() + if info, err := file.Stat(); err == nil && info.Size() == 0 { + os.Remove(file.Name()) + } }() - _, err := tempFile.Write([]byte{'F', 'L', 'V', 0x01, flags, 0, 0, 0, 9, 0, 0, 0, 0}) - if err != nil { - r.Error(err.Error()) - return + var amf rtmp.AMF + metaData := rtmp.EcmaArray{ + "MetaDataCreator": "m7s/" + m7s.Version, + "hasVideo": hasVideo, + "hasAudio": hasAudio, + "hasMatadata": true, + "canSeekToEnd": true, + "duration": float64(duration) / 1000, + "hasKeyFrames": len(filepositions) > 0, + "filesize": 0, } - amf.Reset() - marshals := amf.Marshals("onMetaData", metaData) - WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, marshals) - _, err = file.Seek(13, io.SeekStart) - if err != nil { - r.Error("writeMetaData Seek failed", "err", err) - return + var flags byte + if hasAudio { + ctx := ar.Track.ICodecCtx.GetBase().(pkg.IAudioCodecCtx) + flags |= (1 << 2) + metaData["audiocodecid"] = int(rtmp.ParseAudioCodec(ctx.FourCC())) + metaData["audiosamplerate"] = ctx.GetSampleRate() + metaData["audiosamplesize"] = ctx.GetSampleSize() + metaData["stereo"] = ctx.GetChannels() == 2 } - _, err = io.Copy(tempFile, file) - if err != nil { - r.Error("writeMetaData Copy failed", "err", err) - return + if hasVideo { + ctx := vr.Track.ICodecCtx.GetBase().(pkg.IVideoCodecCtx) + flags |= 1 + metaData["videocodecid"] = int(rtmp.ParseVideoCodec(ctx.FourCC())) + metaData["width"] = ctx.Width() + metaData["height"] = ctx.Height() + metaData["framerate"] = vr.Track.FPS + metaData["videodatarate"] = vr.Track.BPS + metaData["keyframes"] = map[string]any{ + "filepositions": filepositions, + "times": times, + } + defer func() { + filepositions = []uint64{0} + times = []float64{0} + }() } - _, err = tempFile.Seek(0, io.SeekStart) - _, err = file.Seek(0, io.SeekStart) - _, err = io.Copy(file, tempFile) - if err != nil { - r.Error("writeMetaData Copy failed", "err", err) + amf.Marshals("onMetaData", metaData) + offset := amf.Len() + 13 + 15 + if keyframesCount := len(filepositions); keyframesCount > 0 { + metaData["filesize"] = uint64(offset) + filepositions[keyframesCount-1] + for i := range filepositions { + filepositions[i] += uint64(offset) + } + metaData["keyframes"] = map[string]any{ + "filepositions": filepositions, + "times": times, + } + } + + if tempFile, err := os.CreateTemp("", "*.flv"); err != nil { + ctx.Error("create temp file failed", "err", err) return + } else { + defer func() { + tempFile.Close() + os.Remove(tempFile.Name()) + ctx.Info("writeMetaData success") + }() + _, err := tempFile.Write([]byte{'F', 'L', 'V', 0x01, flags, 0, 0, 0, 9, 0, 0, 0, 0}) + if err != nil { + ctx.Error(err.Error()) + return + } + amf.Reset() + marshals := amf.Marshals("onMetaData", metaData) + WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, marshals) + _, err = file.Seek(13, io.SeekStart) + if err != nil { + ctx.Error("writeMetaData Seek failed", "err", err) + return + } + _, err = io.Copy(tempFile, file) + if err != nil { + ctx.Error("writeMetaData Copy failed", "err", err) + return + } + _, err = tempFile.Seek(0, io.SeekStart) + _, err = file.Seek(0, io.SeekStart) + _, err = io.Copy(file, tempFile) + if err != nil { + ctx.Error("writeMetaData Copy failed", "err", err) + return + } } } + if ctx.Append { + var metaData rtmp.EcmaArray + metaData, err = ReadMetaData(file) + keyframes := metaData["keyframes"].(map[string]any) + filepositions = slices.Collect(func(yield func(uint64) bool) { + for _, v := range keyframes["filepositions"].([]float64) { + yield(uint64(v)) + } + }) + times = keyframes["times"].([]float64) + if _, err = file.Seek(-4, io.SeekEnd); err != nil { + ctx.Error("seek file failed", "err", err) + file.Write(FLVHead) + } else { + tmp := make(util.Buffer, 4) + tmp2 := tmp + file.Read(tmp) + tagSize := tmp.ReadUint32() + tmp = tmp2 + file.Seek(int64(tagSize), io.SeekEnd) + file.Read(tmp2) + ts := tmp2.ReadUint24() | (uint32(tmp[3]) << 24) + ctx.Info("append flv", "last tagSize", tagSize, "last ts", ts) + if hasVideo { + vr.StartTs = time.Duration(ts) * time.Millisecond + } + if hasAudio { + ar.StartTs = time.Duration(ts) * time.Millisecond + } + file.Seek(0, io.SeekEnd) + } + } else { + file.Write(FLVHead) + } + if ctx.Fragment == 0 { + defer writeMetaTag() + } + checkFragment := func(absTime uint32) { + if ctx.Fragment == 0 { + return + } + if duration = int64(absTime); time.Duration(duration)*time.Millisecond >= ctx.Fragment { + writeMetaTag() + offset = 0 + if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR, 0666); err != nil { + return + } + file.Write(FLVHead) + if vr != nil { + vr.ResetAbsTime() + err = WriteFLVTag(file, FLV_TAG_TYPE_VIDEO, 0, vr.Track.SequenceFrame.(*rtmp.RTMPVideo).Buffers...) + } + } + } + return m7s.PlayBlock(ctx.Subscriber, func(audio *rtmp.RTMPAudio) (err error) { + if !hasVideo { + checkFragment(ar.AbsTime) + } + return WriteFLVTag(file, FLV_TAG_TYPE_AUDIO, vr.AbsTime, audio.Buffers...) + }, func(video *rtmp.RTMPVideo) (err error) { + if vr.Value.IDR { + filepositions = append(filepositions, uint64(offset)) + times = append(times, float64(vr.AbsTime)/1000) + } + return WriteFLVTag(file, FLV_TAG_TYPE_VIDEO, vr.AbsTime, video.Buffers...) + }) } diff --git a/plugin/gb28181/channel.go b/plugin/gb28181/channel.go index ec80701..d990c4e 100644 --- a/plugin/gb28181/channel.go +++ b/plugin/gb28181/channel.go @@ -10,7 +10,8 @@ import ( type RecordRequest struct { SN, SumNum int - *util.Promise[[]gb28181.Record] + Response []gb28181.Record + *util.Promise } func (r *RecordRequest) GetKey() int { diff --git a/plugin/gb28181/device.go b/plugin/gb28181/device.go index 515a001..7114c67 100644 --- a/plugin/gb28181/device.go +++ b/plugin/gb28181/device.go @@ -3,6 +3,7 @@ package plugin_gb28181 import ( "github.com/emiago/sipgo" "github.com/emiago/sipgo/sip" + "log/slog" "m7s.live/m7s/v5" "m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg/util" @@ -23,7 +24,8 @@ const ( ) type Device struct { - pkg.Unit[string] + pkg.Task + ID string Name string Manufacturer string Model string @@ -43,6 +45,7 @@ type Device struct { dialogClient *sipgo.DialogClient contactHDR sip.ContactHeader fromHDR sip.FromHeader + *slog.Logger } func (d *Device) GetKey() string { @@ -63,7 +66,8 @@ func (d *Device) onMessage(req *sip.Request, tx sip.ServerTransaction, msg *gb28 case "RecordInfo": if channel, ok := d.channels.Get(msg.DeviceID); ok { if req, ok := channel.RecordReqs.Get(msg.SN); ok { - req.Resolve(msg.RecordList) + req.Response = msg.RecordList + req.Resolve() } } case "DeviceInfo": diff --git a/plugin/gb28181/dialog.go b/plugin/gb28181/dialog.go index c6d2d4a..a123731 100644 --- a/plugin/gb28181/dialog.go +++ b/plugin/gb28181/dialog.go @@ -24,7 +24,8 @@ func (d *Dialog) GetCallID() string { return d.session.InviteRequest.CallID().Value() } -func (d *Dialog) Connect(p *m7s.Client) (err error) { +func (d *Dialog) Pull(p *m7s.PullContext) (err error) { + sss := strings.Split(p.RemoteURL, "/") deviceId, channelId := sss[0], sss[1] if len(sss) == 2 { @@ -41,11 +42,8 @@ func (d *Dialog) Connect(p *m7s.Client) (err error) { var recordRange util.Range[int] err = recordRange.Resolve(sss[2]) } - return -} -func (d *Dialog) Pull(p *m7s.Puller) (err error) { - d.Receiver = gb28181.NewReceiver(&p.Publisher) + d.Receiver = gb28181.NewReceiver(p.Publisher) ssrc := d.CreateSSRC(d.gb.Serial) d.gb.dialogs.Set(d) defer d.gb.dialogs.Remove(d) diff --git a/plugin/gb28181/index.go b/plugin/gb28181/index.go index 5fbb7a7..40b494c 100644 --- a/plugin/gb28181/index.go +++ b/plugin/gb28181/index.go @@ -1,7 +1,6 @@ package plugin_gb28181 import ( - "context" "errors" "fmt" "github.com/emiago/sipgo" @@ -11,7 +10,6 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" "m7s.live/m7s/v5" - "m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg/config" "m7s.live/m7s/v5/pkg/util" "m7s.live/m7s/v5/plugin/gb28181/pb" @@ -243,11 +241,7 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) { port, _ := strconv.Atoi(portStr) serverPort, _ := strconv.Atoi(sPortStr) d = &Device{ - Unit: pkg.Unit[string]{ - ID: id, - StartTime: time.Now(), - Logger: gb.Logger.With("id", id), - }, + ID: id, UpdateTime: time.Now(), Status: DeviceRegisterStatus, Recipient: sip.Uri{ @@ -274,7 +268,7 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) { Params: sip.NewParams(), }, } - d.Context, d.CancelCauseFunc = context.WithCancelCause(gb.Context) + d.Init(gb.Context, gb.Logger.With("id", id)) d.fromHDR.Params.Add("tag", sip.GenerateTagN(16)) d.client, _ = sipgo.NewClient(gb.ua, sipgo.WithClientLogger(zerolog.New(os.Stdout)), sipgo.WithClientHostname(publicIP)) d.dialogClient = sipgo.NewDialogClient(d.client, d.contactHDR) @@ -288,10 +282,11 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) { return } -func (gb *GB28181Plugin) NewPullHandler() m7s.PullHandler { - return &Dialog{ +func (gb *GB28181Plugin) DoPull(ctx *m7s.PullContext) error { + dialog := Dialog{ gb: gb, } + return dialog.Pull(ctx) } func (gb *GB28181Plugin) GetPullableList() []string { diff --git a/plugin/logrotate/api.go b/plugin/logrotate/api.go index 9d4bd19..c079d33 100644 --- a/plugin/logrotate/api.go +++ b/plugin/logrotate/api.go @@ -54,6 +54,6 @@ func (h *LogRotatePlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (l *LogRotatePlugin) API_tail(w http.ResponseWriter, r *http.Request) { writer := util.NewSSE(w, r.Context()) h := console.NewHandler(writer, &console.HandlerOptions{NoColor: true}) - l.PostToServer(h) + l.Server.AddLogHandler(h) <-r.Context().Done() } diff --git a/plugin/mp4/index.go b/plugin/mp4/index.go index 0545dfc..dd01650 100644 --- a/plugin/mp4/index.go +++ b/plugin/mp4/index.go @@ -83,16 +83,16 @@ func (p *MP4Plugin) OnInit() error { var _ = m7s.InstallPlugin[MP4Plugin](defaultConfig) -func (p *MP4Plugin) NewPullHandler() m7s.PullHandler { - return pkg.NewMP4Puller() +func (p *MP4Plugin) DoPull(ctx *m7s.PullContext) error { + return pkg.PullMP4(ctx) } func (p *MP4Plugin) GetPullableList() []string { return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub)) } -func (p *MP4Plugin) NewRecordHandler() m7s.RecordHandler { - return &pkg.Recorder{} +func (p *MP4Plugin) DoRecord(ctx *m7s.RecordContext) error { + return pkg.RecordMP4(ctx) } func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -179,10 +179,8 @@ func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } if hijacker, ok := w.(http.Hijacker); ok && ctx.wto > 0 { - sub.Conn, _, _ = hijacker.Hijack() - sub.Closer = sub.Conn - ctx.Writer = sub.Conn - ctx.conn = sub.Conn + ctx.conn, _, _ = hijacker.Hijack() + ctx.Writer = ctx.conn } else { ctx.Writer = w w.(http.Flusher).Flush() diff --git a/plugin/mp4/pkg/pull.go b/plugin/mp4/pkg/pull.go index 315154f..9de2ee4 100644 --- a/plugin/mp4/pkg/pull.go +++ b/plugin/mp4/pkg/pull.go @@ -14,22 +14,12 @@ import ( "strings" ) -type MP4Puller struct { - *util.ScalableMemoryAllocator - *box.MovDemuxer -} - -func NewMP4Puller() *MP4Puller { - return &MP4Puller{ - ScalableMemoryAllocator: util.NewScalableMemoryAllocator(1 << 10), - } -} - -func (puller *MP4Puller) Connect(p *m7s.Client) (err error) { - if strings.HasPrefix(p.RemoteURL, "http") { +func PullMP4(ctx *m7s.PullContext) (err error) { + var demuxer *box.MovDemuxer + if strings.HasPrefix(ctx.RemoteURL, "http") { var res *http.Response client := http.DefaultClient - if proxyConf := p.Proxy; proxyConf != "" { + if proxyConf := ctx.ConnectProxy; proxyConf != "" { proxy, err := url.Parse(proxyConf) if err != nil { return err @@ -37,68 +27,66 @@ func (puller *MP4Puller) Connect(p *m7s.Client) (err error) { transport := &http.Transport{Proxy: http.ProxyURL(proxy)} client = &http.Client{Transport: transport} } - if res, err = client.Get(p.RemoteURL); err == nil { + if res, err = client.Get(ctx.RemoteURL); err == nil { if res.StatusCode != http.StatusOK { return io.EOF } - p.Closer = res.Body - + defer res.Body.Close() content, err := io.ReadAll(res.Body) if err != nil { return err } - puller.MovDemuxer = box.CreateMp4Demuxer(strings.NewReader(string(content))) + demuxer = box.CreateMp4Demuxer(strings.NewReader(string(content))) } } else { var res *os.File - if res, err = os.Open(p.RemoteURL); err == nil { - p.Closer = res + if res, err = os.Open(ctx.RemoteURL); err == nil { + defer res.Close() } - puller.MovDemuxer = box.CreateMp4Demuxer(res) + demuxer = box.CreateMp4Demuxer(res) } - return -} -func (puller *MP4Puller) Pull(p *m7s.Puller) (err error) { var tracks []box.TrackInfo - if tracks, err = puller.ReadHead(); err != nil { + if tracks, err = demuxer.ReadHead(); err != nil { return } + publisher := ctx.Publisher for _, track := range tracks { switch track.Cid { case box.MP4_CODEC_H264: var sequece rtmp.RTMPVideo sequece.Append([]byte{0x17, 0x00, 0x00, 0x00, 0x00}, track.ExtraData) - p.WriteVideo(&sequece) + err = publisher.WriteVideo(&sequece) case box.MP4_CODEC_H265: var sequece rtmp.RTMPVideo sequece.Append([]byte{0b1001_0000 | rtmp.PacketTypeSequenceStart}, codec.FourCC_H265[:], track.ExtraData) - p.WriteVideo(&sequece) + err = publisher.WriteVideo(&sequece) case box.MP4_CODEC_AAC: var sequence rtmp.RTMPAudio sequence.Append([]byte{0xaf, 0x00}, track.ExtraData) - p.WriteAudio(&sequence) + err = publisher.WriteAudio(&sequence) } } + allocator := util.NewScalableMemoryAllocator(1 << 10) for { - pkg, err := puller.ReadPacket(puller.ScalableMemoryAllocator) + pkg, err := demuxer.ReadPacket(allocator) if err != nil { - p.Error("Error reading MP4 packet", "err", err) + ctx.Error("Error reading MP4 packet", "err", err) return err } switch track := tracks[pkg.TrackId-1]; track.Cid { case box.MP4_CODEC_H264: var videoFrame rtmp.RTMPVideo - videoFrame.SetAllocator(puller.ScalableMemoryAllocator) + videoFrame.SetAllocator(allocator) videoFrame.CTS = uint32(pkg.Pts - pkg.Dts) videoFrame.Timestamp = uint32(pkg.Dts) keyFrame := codec.H264NALUType(pkg.Data[5]&0x1F) == codec.NALU_IDR_Picture videoFrame.AppendOne([]byte{util.Conditoinal[byte](keyFrame, 0x17, 0x27), 0x01, byte(videoFrame.CTS >> 24), byte(videoFrame.CTS >> 8), byte(videoFrame.CTS)}) videoFrame.AddRecycleBytes(pkg.Data) - p.WriteVideo(&videoFrame) + err = publisher.WriteVideo(&videoFrame) case box.MP4_CODEC_H265: var videoFrame rtmp.RTMPVideo - videoFrame.SetAllocator(puller.ScalableMemoryAllocator) + videoFrame.SetAllocator(allocator) videoFrame.CTS = uint32(pkg.Pts - pkg.Dts) videoFrame.Timestamp = uint32(pkg.Dts) var head []byte @@ -122,14 +110,14 @@ func (puller *MP4Puller) Pull(p *m7s.Puller) (err error) { } copy(head[1:], codec.FourCC_H265[:]) videoFrame.AddRecycleBytes(pkg.Data) - p.WriteVideo(&videoFrame) + err = publisher.WriteVideo(&videoFrame) case box.MP4_CODEC_AAC: var audioFrame rtmp.RTMPAudio - audioFrame.SetAllocator(puller.ScalableMemoryAllocator) + audioFrame.SetAllocator(allocator) audioFrame.Timestamp = uint32(pkg.Dts) audioFrame.AppendOne([]byte{0xaf, 0x01}) audioFrame.AddRecycleBytes(pkg.Data) - p.WriteAudio(&audioFrame) + err = publisher.WriteAudio(&audioFrame) } } } diff --git a/plugin/mp4/pkg/record.go b/plugin/mp4/pkg/record.go index 800c60a..e98d256 100644 --- a/plugin/mp4/pkg/record.go +++ b/plugin/mp4/pkg/record.go @@ -5,62 +5,56 @@ import ( "m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg/codec" "m7s.live/m7s/v5/plugin/mp4/pkg/box" + "os" "time" ) -type Recorder struct { - *m7s.Subscriber - *box.Movmuxer - videoId uint32 - audioId uint32 -} - -func (r *Recorder) Record(recorder *m7s.Recorder) (err error) { - r.Movmuxer, err = box.CreateMp4Muxer(recorder.File) - if recorder.Publisher.HasAudioTrack() { - audioTrack := recorder.Publisher.AudioTrack +func RecordMP4(ctx *m7s.RecordContext) (err error) { + var file *os.File + var muxer *box.Movmuxer + var audioId, videoId uint32 + // TODO: fragment + if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil { + return + } + defer func() { + err = muxer.WriteTrailer() + if err != nil { + ctx.Error("write trailer", "err", err) + } else { + ctx.Info("write trailer") + } + err = file.Close() + }() + muxer, err = box.CreateMp4Muxer(file) + ar, vr := ctx.Subscriber.AudioReader, ctx.Subscriber.VideoReader + if ar != nil { + audioTrack := ar.Track switch ctx := audioTrack.ICodecCtx.GetBase().(type) { case *codec.AACCtx: - r.audioId = r.AddAudioTrack(box.MP4_CODEC_AAC, box.WithExtraData(ctx.ConfigBytes)) + audioId = muxer.AddAudioTrack(box.MP4_CODEC_AAC, box.WithExtraData(ctx.ConfigBytes)) case *codec.PCMACtx: - r.audioId = r.AddAudioTrack(box.MP4_CODEC_G711A, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize))) + audioId = muxer.AddAudioTrack(box.MP4_CODEC_G711A, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize))) case *codec.PCMUCtx: - r.audioId = r.AddAudioTrack(box.MP4_CODEC_G711U, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize))) + audioId = muxer.AddAudioTrack(box.MP4_CODEC_G711U, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize))) } } - if recorder.Publisher.HasVideoTrack() { - videoTrack := recorder.Publisher.VideoTrack + if vr != nil { + videoTrack := vr.Track switch ctx := videoTrack.ICodecCtx.GetBase().(type) { case *codec.H264Ctx: - r.videoId = r.AddVideoTrack(box.MP4_CODEC_H264, box.WithExtraData(ctx.Record)) + videoId = muxer.AddVideoTrack(box.MP4_CODEC_H264, box.WithExtraData(ctx.Record)) case *codec.H265Ctx: - r.videoId = r.AddVideoTrack(box.MP4_CODEC_H265, box.WithExtraData(ctx.Record)) + videoId = muxer.AddVideoTrack(box.MP4_CODEC_H265, box.WithExtraData(ctx.Record)) } } - r.Subscriber = &recorder.Subscriber - return m7s.PlayBlock(&recorder.Subscriber, func(audio *pkg.RawAudio) error { - return r.WriteAudio(r.audioId, audio.ToBytes(), uint64(audio.Timestamp/time.Millisecond)) + return m7s.PlayBlock(ctx.Subscriber, func(audio *pkg.RawAudio) error { + return muxer.WriteAudio(audioId, audio.ToBytes(), uint64(audio.Timestamp/time.Millisecond)) }, func(video *pkg.H26xFrame) error { var nalus [][]byte for _, nalu := range video.Nalus { nalus = append(nalus, nalu.ToBytes()) } - return r.WriteVideo(r.videoId, nalus, uint64(video.Timestamp/time.Millisecond), uint64(video.CTS/time.Millisecond)) + return muxer.WriteVideo(videoId, nalus, uint64(video.Timestamp/time.Millisecond), uint64(video.CTS/time.Millisecond)) }) } - -func (r *Recorder) Close() { - //defer func() { - // if err := recover(); err != nil { - // r.Error("close", "err", err) - // } else { - // r.Info("close") - // } - //}() - err := r.WriteTrailer() - if err != nil { - r.Error("write trailer", "err", err) - } else { - r.Info("write trailer") - } -} diff --git a/plugin/rtmp/api.go b/plugin/rtmp/api.go index 6fc07e7..b61868f 100644 --- a/plugin/rtmp/api.go +++ b/plugin/rtmp/api.go @@ -4,10 +4,13 @@ import ( "context" gpb "m7s.live/m7s/v5/pb" "m7s.live/m7s/v5/plugin/rtmp/pb" - rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg" ) func (r *RTMPPlugin) PushOut(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) { - go r.Push(req.StreamPath, req.RemoteURL, &rtmp.Client{}) - return &gpb.SuccessResponse{}, nil + if pushContext, err := r.Push(req.StreamPath, req.RemoteURL); err != nil { + return nil, err + } else { + go pushContext.Run(r.DoPush) + } + return &gpb.SuccessResponse{}, err } diff --git a/plugin/rtmp/index.go b/plugin/rtmp/index.go index 07b60b9..9bcae7c 100644 --- a/plugin/rtmp/index.go +++ b/plugin/rtmp/index.go @@ -14,6 +14,7 @@ import ( type RTMPPlugin struct { pb.UnimplementedRtmpServer + Client m7s.Plugin ChunkSize int `default:"1024"` KeepAlive bool @@ -30,18 +31,10 @@ func (p *RTMPPlugin) OnInit() error { return nil } -func (p *RTMPPlugin) NewPullHandler() m7s.PullHandler { - return &Client{} -} - func (p *RTMPPlugin) GetPullableList() []string { return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub)) } -func (p *RTMPPlugin) NewPushHandler() m7s.PushHandler { - return &Client{} -} - func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { logger := p.Logger.With("remote", conn.RemoteAddr().String()) receivers := make(map[uint32]*Receiver) @@ -55,7 +48,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { } if len(receivers) > 0 { for _, receiver := range receivers { - receiver.Dispose(err) + receiver.Stop(err) } } }() @@ -165,7 +158,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { StreamID: cmd.StreamId, }, } - receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, conn, connectInfo) + receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, receiver, connectInfo) if err != nil { delete(receivers, cmd.StreamId) err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error) @@ -185,7 +178,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) { } var suber *m7s.Subscriber // sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID) - suber, err = p.Subscribe(streamPath, conn, connectInfo) + suber, err = p.Subscribe(streamPath, &ns, connectInfo) if err != nil { err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error) } else { diff --git a/plugin/rtmp/pkg/client.go b/plugin/rtmp/pkg/client.go index 2660e9c..56bfa95 100644 --- a/plugin/rtmp/pkg/client.go +++ b/plugin/rtmp/pkg/client.go @@ -11,29 +11,18 @@ import ( "m7s.live/m7s/v5" ) -type Client struct { - NetStream - ServerInfo map[string]any -} +type Client struct{} -func NewPushHandler() m7s.PushHandler { - return &Client{} -} - -func NewPullHandler() m7s.PullHandler { - return &Client{} -} - -func (client *Client) Connect(p *m7s.Client) (err error) { +func createClient(c *m7s.Connection) (*NetStream, error) { chunkSize := 4096 - addr := p.RemoteURL + addr := c.RemoteURL u, err := url.Parse(addr) if err != nil { - return err + return nil, err } ps := strings.Split(u.Path, "/") if len(ps) < 3 { - return errors.New("illegal rtmp url") + return nil, errors.New("illegal rtmp url") } isRtmps := u.Scheme == "rtmps" if strings.Count(u.Host, ":") == 0 { @@ -52,101 +41,108 @@ func (client *Client) Connect(p *m7s.Client) (err error) { conn, err = net.Dial("tcp", u.Host) } if err != nil { - return err + return nil, err } + ns := &NetStream{} + ns.NetConnection = NewNetConnection(conn, c.Logger) defer func() { if err != nil { - conn.Close() + ns.disconnect() } }() - client.NetConnection = NewNetConnection(conn, p.Logger) - if err = client.ClientHandshake(); err != nil { - return err + if err = ns.ClientHandshake(); err != nil { + return ns, err } - client.AppName = strings.Join(ps[1:len(ps)-1], "/") - err = client.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(chunkSize)) + ns.AppName = strings.Join(ps[1:len(ps)-1], "/") + err = ns.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(chunkSize)) if err != nil { - return + return ns, err } - client.WriteChunkSize = chunkSize + ns.WriteChunkSize = chunkSize path := u.Path if len(u.Query()) != 0 { path += "?" + u.RawQuery } - err = client.SendMessage(RTMP_MSG_AMF0_COMMAND, &CallMessage{ + err = ns.SendMessage(RTMP_MSG_AMF0_COMMAND, &CallMessage{ CommandMessage{"connect", 1}, map[string]any{ - "app": client.AppName, + "app": ns.AppName, "flashVer": "monibuca/" + m7s.Version, "swfUrl": addr, - "tcUrl": strings.TrimSuffix(addr, path) + "/" + client.AppName, + "tcUrl": strings.TrimSuffix(addr, path) + "/" + ns.AppName, }, nil, }) + var msg *Chunk for err != nil { - msg, err := client.RecvMessage() + msg, err = ns.RecvMessage() if err != nil { - return err + return ns, err } switch msg.MessageTypeID { case RTMP_MSG_AMF0_COMMAND: cmd := msg.MsgData.(Commander).GetCommand() switch cmd.CommandName { case "_result": - client.ServerInfo = msg.MsgData.(*ResponseMessage).Properties + c.MetaData = msg.MsgData.(*ResponseMessage).Properties response := msg.MsgData.(*ResponseMessage) if response.Infomation["code"] == NetConnection_Connect_Success { } else { - return err + return ns, err } default: fmt.Println(cmd.CommandName) } } } - client.Info("connect", "remoteURL", p.RemoteURL) - return + c.Info("connect", "remoteURL", c.RemoteURL) + return ns, nil } -func (puller *Client) Pull(p *m7s.Puller) (err error) { - p.MetaData = puller.ServerInfo +func (Client) DoPull(p *m7s.PullContext) (err error) { + var connection *NetStream + if connection, err = createClient(&p.Connection); err != nil { + return + } defer func() { - puller.Close() + connection.disconnect() if p := recover(); p != nil { err = p.(error) } - p.Dispose(err) }() - err = puller.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2}) + err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2}) + var msg *Chunk for err == nil { - msg, err := puller.RecvMessage() - if err != nil { + if err = p.Publisher.Err(); err != nil { + return + } + if msg, err = connection.RecvMessage(); err != nil { return err } switch msg.MessageTypeID { case RTMP_MSG_AUDIO: - p.WriteAudio(msg.AVData.WrapAudio()) + err = p.Publisher.WriteAudio(msg.AVData.WrapAudio()) case RTMP_MSG_VIDEO: - p.WriteVideo(msg.AVData.WrapVideo()) + err = p.Publisher.WriteVideo(msg.AVData.WrapVideo()) case RTMP_MSG_AMF0_COMMAND: cmd := msg.MsgData.(Commander).GetCommand() switch cmd.CommandName { case "_result": if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok { - puller.StreamID = response.StreamId + connection.StreamID = response.StreamId m := &PlayMessage{} m.StreamId = response.StreamId m.TransactionId = 4 m.CommandMessage.CommandName = "play" - URL, _ := url.Parse(p.Client.RemoteURL) + URL, _ := url.Parse(p.Connection.RemoteURL) ps := strings.Split(URL.Path, "/") - p.Args = URL.Query() + args := URL.Query() m.StreamName = ps[len(ps)-1] - if len(p.Args) > 0 { - m.StreamName += "?" + p.Args.Encode() + if len(args) > 0 { + m.StreamName += "?" + args.Encode() } - puller.SendMessage(RTMP_MSG_AMF0_COMMAND, m) + connection.SendMessage(RTMP_MSG_AMF0_COMMAND, m) // if response, ok := msg.MsgData.(*ResponsePlayMessage); ok { // if response.Object["code"] == "NetStream.Play.Start" { @@ -163,11 +159,16 @@ func (puller *Client) Pull(p *m7s.Puller) (err error) { return } -func (pusher *Client) Push(p *m7s.Pusher) (err error) { - p.MetaData = pusher.ServerInfo - pusher.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2}) - for { - msg, err := pusher.RecvMessage() +func (Client) DoPush(p *m7s.PushContext) (err error) { + var connection *NetStream + if connection, err = createClient(&p.Connection); err != nil { + return + } + defer connection.disconnect() + err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2}) + var msg *Chunk + for err == nil { + msg, err = connection.RecvMessage() if err != nil { return err } @@ -177,15 +178,15 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) { switch cmd.CommandName { case Response_Result, Response_OnStatus: if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok { - pusher.StreamID = response.StreamId - URL, _ := url.Parse(p.Client.RemoteURL) + connection.StreamID = response.StreamId + URL, _ := url.Parse(p.Connection.RemoteURL) _, streamPath, _ := strings.Cut(URL.Path, "/") _, streamPath, _ = strings.Cut(streamPath, "/") - p.Args = URL.Query() - if len(p.Args) > 0 { - streamPath += "?" + p.Args.Encode() + args := URL.Query() + if len(args) > 0 { + streamPath += "?" + args.Encode() } - pusher.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{ + err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{ CURDStreamMessage{ CommandMessage{ "publish", @@ -198,8 +199,15 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) { }) } else if response, ok := msg.MsgData.(*ResponsePublishMessage); ok { if response.Infomation["code"] == NetStream_Publish_Start { - audio, video := pusher.CreateSender(true) - go m7s.PlayBlock(&p.Subscriber, audio.HandleAudio, video.HandleVideo) + p.Connection.ReConnectCount = 0 + audio, video := connection.CreateSender(true) + go func() { + for err == nil { + msg, err = connection.RecvMessage() + } + p.Subscriber.Stop(err) + }() + return m7s.PlayBlock(p.Subscriber, audio.HandleAudio, video.HandleVideo) } else { return errors.New(response.Infomation["code"].(string)) } @@ -207,4 +215,5 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) { } } } + return } diff --git a/plugin/rtmp/pkg/net-stream.go b/plugin/rtmp/pkg/net-stream.go index 2d157fb..53f1ff7 100644 --- a/plugin/rtmp/pkg/net-stream.go +++ b/plugin/rtmp/pkg/net-stream.go @@ -63,9 +63,9 @@ func (ns *NetStream) BeginPlay(tid uint64) (err error) { return } -func (ns *NetStream) Close() error { - if ns.NetConnection != nil { +func (ns *NetStream) disconnect() { + if ns != nil && ns.NetConnection != nil { ns.NetConnection.Destroy() } - return nil + return } diff --git a/plugin/rtsp/index.go b/plugin/rtsp/index.go index a3aa5d6..ef22ce5 100644 --- a/plugin/rtsp/index.go +++ b/plugin/rtsp/index.go @@ -22,20 +22,13 @@ var _ = m7s.InstallPlugin[RTSPPlugin](defaultConfig) type RTSPPlugin struct { m7s.Plugin -} - -func (p *RTSPPlugin) NewPullHandler() m7s.PullHandler { - return &Client{} + Client } func (p *RTSPPlugin) GetPullableList() []string { return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub)) } -func (p *RTSPPlugin) NewPushHandler() m7s.PushHandler { - return &Client{} -} - func (p *RTSPPlugin) OnInit() error { for streamPath, url := range p.GetCommonConf().PullOnStart { go p.Pull(streamPath, url) @@ -56,7 +49,7 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { logger.Error(err.Error(), "stack", string(debug.Stack())) } if receiver != nil { - receiver.Dispose(err) + receiver.Stop(err) } }() var req *util.Request @@ -114,7 +107,7 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { receiver = &Receiver{} receiver.NetConnection = nc - if receiver.Publisher, err = p.Publish(strings.TrimPrefix(nc.URL.Path, "/")); err != nil { + if receiver.Publisher, err = p.Publish(strings.TrimPrefix(nc.URL.Path, "/"), receiver); err != nil { receiver = nil err = nc.WriteResponse(&util.Response{ StatusCode: 500, Status: err.Error(), @@ -131,9 +124,9 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { case MethodDescribe: sendMode = true - - var subscriber *m7s.Subscriber - subscriber, err = p.Subscribe(strings.TrimPrefix(nc.URL.Path, "/"), conn) + sender = &Sender{} + sender.NetConnection = nc + sender.Subscriber, err = p.Subscribe(strings.TrimPrefix(nc.URL.Path, "/"), sender) if err != nil { res := &util.Response{ StatusCode: http.StatusBadRequest, @@ -149,10 +142,6 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) { }, Request: req, } - sender = &Sender{ - Subscriber: subscriber, - } - sender.NetConnection = nc // convert tracks to real output medias var medias []*Media if medias, err = sender.GetMedia(); err != nil { diff --git a/plugin/rtsp/pkg/client.go b/plugin/rtsp/pkg/client.go index f0acfc8..f09274c 100644 --- a/plugin/rtsp/pkg/client.go +++ b/plugin/rtsp/pkg/client.go @@ -2,35 +2,21 @@ package rtsp import ( "crypto/tls" - "errors" - "fmt" "m7s.live/m7s/v5" "m7s.live/m7s/v5/pkg/util" "net" - "net/http" "net/url" - "strconv" "strings" ) -type Client struct { - Stream -} +type Client struct{} -func NewPushHandler() m7s.PushHandler { - return &Client{} -} - -func NewPullHandler() m7s.PullHandler { - return &Client{} -} - -func (c *Client) Connect(p *m7s.Client) (err error) { +func createClient(p *m7s.Connection) (s *Stream, err error) { addr := p.RemoteURL var rtspURL *url.URL rtspURL, err = url.Parse(addr) if err != nil { - return err + return } //ps := strings.Split(u.Path, "/") //if len(ps) < 3 { @@ -53,290 +39,73 @@ func (c *Client) Connect(p *m7s.Client) (err error) { conn, err = net.Dial("tcp", rtspURL.Host) } if err != nil { - return err + return } - defer func() { - if err != nil { - conn.Close() - } - }() - c.NetConnection = NewNetConnection(conn, p.Logger) - c.URL = rtspURL - c.auth = util.NewAuth(c.URL.User) - c.Backchannel = true - return c.Options() + s = &Stream{NetConnection: NewNetConnection(conn, p.Logger)} + s.URL = rtspURL + s.auth = util.NewAuth(s.URL.User) + s.Backchannel = true + err = s.Options() + if err != nil { + s.disconnect() + return + } + return } -func (c *Client) Pull(p *m7s.Puller) (err error) { +func (Client) DoPull(p *m7s.PullContext) (err error) { + var s *Stream + if s, err = createClient(&p.Connection); err != nil { + return + } defer func() { - c.Close() + s.disconnect() if p := recover(); p != nil { err = p.(error) } - p.Dispose(err) }() var media []*Media - if media, err = c.Describe(); err != nil { + if media, err = s.Describe(); err != nil { return } - receiver := &Receiver{Publisher: &p.Publisher, Stream: c.Stream} + receiver := &Receiver{Publisher: p.Publisher, Stream: s} if err = receiver.SetMedia(media); err != nil { return } - if err = c.Play(); err != nil { + if err = s.Play(); err != nil { return } + p.Connection.ReConnectCount = 0 return receiver.Receive() } -func (c *Client) Push(p *m7s.Pusher) (err error) { - defer c.Close() - sender := &Sender{Subscriber: &p.Subscriber, Stream: c.Stream} +func (Client) DoPush(ctx *m7s.PushContext) (err error) { + var s *Stream + if s, err = createClient(&ctx.Connection); err != nil { + return + } + defer s.disconnect() + sender := &Sender{Subscriber: ctx.Subscriber, Stream: s} var medias []*Media medias, err = sender.GetMedia() - err = c.Announce(medias) + err = s.Announce(medias) if err != nil { return } for i, media := range medias { switch media.Kind { case "audio", "video": - _, err = c.SetupMedia(media, i) + _, err = s.SetupMedia(media, i) if err != nil { return } default: - c.Warn("media kind not support", "kind", media.Kind) + ctx.Warn("media kind not support", "kind", media.Kind) } } - if err = c.Record(); err != nil { + if err = s.Record(); err != nil { return } - + ctx.Connection.ReConnectCount = 0 return sender.Send() } - -func (c *Client) Do(req *util.Request) (*util.Response, error) { - if err := c.WriteRequest(req); err != nil { - return nil, err - } - - res, err := c.ReadResponse() - if err != nil { - return nil, err - } - - if res.StatusCode == http.StatusUnauthorized { - switch c.auth.Method { - case util.AuthNone: - if c.auth.ReadNone(res) { - return c.Do(req) - } - return nil, errors.New("user/pass not provided") - case util.AuthUnknown: - if c.auth.Read(res) { - return c.Do(req) - } - default: - return nil, errors.New("wrong user/pass") - } - } - - if res.StatusCode != http.StatusOK { - return res, fmt.Errorf("wrong response on %s", req.Method) - } - - return res, nil -} - -func (c *Client) Options() error { - req := &util.Request{Method: MethodOptions, URL: c.URL} - - res, err := c.Do(req) - if err != nil { - return err - } - - if val := res.Header.Get("Content-Base"); val != "" { - c.URL, err = urlParse(val) - if err != nil { - return err - } - } - - return nil -} - -func (c *Client) Describe() (medias []*Media, err error) { - // 5.3 Back channel connection - // https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf - req := &util.Request{ - Method: MethodDescribe, - URL: c.URL, - Header: map[string][]string{ - "Accept": {"application/sdp"}, - }, - } - - if c.Backchannel { - req.Header.Set("Require", "www.onvif.org/ver20/backchannel") - } - - if c.UserAgent != "" { - // this camera will answer with 401 on DESCRIBE without User-Agent - // https://github.com/AlexxIT/go2rtc/issues/235 - req.Header.Set("User-Agent", c.UserAgent) - } - var res *util.Response - res, err = c.Do(req) - if err != nil { - return - } - - if val := res.Header.Get("Content-Base"); val != "" { - c.URL, err = urlParse(val) - if err != nil { - return - } - } - - c.sdp = string(res.Body) // for info - - medias, err = UnmarshalSDP(res.Body) - if err != nil { - return - } - if c.Media != "" { - clone := make([]*Media, 0, len(medias)) - for _, media := range medias { - if strings.Contains(c.Media, media.Kind) { - clone = append(clone, media) - } - } - medias = clone - } - - return -} - -func (c *Client) Announce(medias []*Media) (err error) { - req := &util.Request{ - Method: MethodAnnounce, - URL: c.URL, - Header: map[string][]string{ - "Content-Type": {"application/sdp"}, - }, - } - - req.Body, err = MarshalSDP(c.SessionName, medias) - if err != nil { - return err - } - - _, err = c.Do(req) - - return -} - -func (c *Client) SetupMedia(media *Media, index int) (byte, error) { - var transport string - transport = fmt.Sprintf( - // i - RTP (data channel) - // i+1 - RTCP (control channel) - "RTP/AVP/TCP;unicast;interleaved=%d-%d", index*2, index*2+1, - ) - if transport == "" { - return 0, fmt.Errorf("wrong media: %v", media) - } - - rawURL := media.ID // control - if !strings.Contains(rawURL, "://") { - rawURL = c.URL.String() - if !strings.HasSuffix(rawURL, "/") { - rawURL += "/" - } - rawURL += media.ID - } - trackURL, err := urlParse(rawURL) - if err != nil { - return 0, err - } - - req := &util.Request{ - Method: MethodSetup, - URL: trackURL, - Header: map[string][]string{ - "Transport": {transport}, - }, - } - - res, err := c.Do(req) - if err != nil { - // some Dahua/Amcrest cameras fail here because two simultaneous - // backchannel connections - //if c.Backchannel { - // c.Backchannel = false - // if err = c.Connect(); err != nil { - // return 0, err - // } - // return c.SetupMedia(media) - //} - - return 0, err - } - - if c.Session == "" { - // Session: 7116520596809429228 - // Session: 216525287999;timeout=60 - if s := res.Header.Get("Session"); s != "" { - if i := strings.IndexByte(s, ';'); i > 0 { - c.Session = s[:i] - if i = strings.Index(s, "timeout="); i > 0 { - c.keepalive, _ = strconv.Atoi(s[i+8:]) - } - } else { - c.Session = s - } - } - } - - // we send our `interleaved`, but camera can answer with another - - // Transport: RTP/AVP/TCP;unicast;interleaved=10-11;ssrc=10117CB7 - // Transport: RTP/AVP/TCP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0 - // Transport: RTP/AVP/TCP;ssrc=22345682;interleaved=0-1 - transport = res.Header.Get("Transport") - if !strings.HasPrefix(transport, "RTP/AVP/TCP;") { - // Escam Q6 has a bug: - // Transport: RTP/AVP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0-1 - if !strings.Contains(transport, ";interleaved=") { - return 0, fmt.Errorf("wrong transport: %s", transport) - } - } - - channel := Between(transport, "interleaved=", "-") - i, err := strconv.Atoi(channel) - if err != nil { - return 0, err - } - - return byte(i), nil -} - -func (c *Client) Play() (err error) { - return c.WriteRequest(&util.Request{Method: MethodPlay, URL: c.URL}) -} - -func (c *Client) Record() (err error) { - return c.WriteRequest(&util.Request{Method: MethodRecord, URL: c.URL}) -} - -func (c *Client) Teardown() (err error) { - // allow TEARDOWN from any state (ex. ANNOUNCE > SETUP) - return c.WriteRequest(&util.Request{Method: MethodTeardown, URL: c.URL}) -} - -func (c *Client) Destroy() { - _ = c.Teardown() - c.NetConnection.Destroy() -} diff --git a/plugin/rtsp/pkg/net-stream.go b/plugin/rtsp/pkg/net-stream.go new file mode 100644 index 0000000..9b20de7 --- /dev/null +++ b/plugin/rtsp/pkg/net-stream.go @@ -0,0 +1,243 @@ +package rtsp + +import ( + "errors" + "fmt" + "m7s.live/m7s/v5/pkg/util" + "net/http" + "strconv" + "strings" +) + +type Stream struct { + *NetConnection + AudioChannelID int + VideoChannelID int +} + +func (c *Stream) Do(req *util.Request) (*util.Response, error) { + if err := c.WriteRequest(req); err != nil { + return nil, err + } + + res, err := c.ReadResponse() + if err != nil { + return nil, err + } + + if res.StatusCode == http.StatusUnauthorized { + switch c.auth.Method { + case util.AuthNone: + if c.auth.ReadNone(res) { + return c.Do(req) + } + return nil, errors.New("user/pass not provided") + case util.AuthUnknown: + if c.auth.Read(res) { + return c.Do(req) + } + default: + return nil, errors.New("wrong user/pass") + } + } + + if res.StatusCode != http.StatusOK { + return res, fmt.Errorf("wrong response on %s", req.Method) + } + + return res, nil +} + +func (c *Stream) Options() error { + req := &util.Request{Method: MethodOptions, URL: c.URL} + + res, err := c.Do(req) + if err != nil { + return err + } + + if val := res.Header.Get("Content-Base"); val != "" { + c.URL, err = urlParse(val) + if err != nil { + return err + } + } + + return nil +} + +func (c *Stream) Describe() (medias []*Media, err error) { + // 5.3 Back channel connection + // https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf + req := &util.Request{ + Method: MethodDescribe, + URL: c.URL, + Header: map[string][]string{ + "Accept": {"application/sdp"}, + }, + } + + if c.Backchannel { + req.Header.Set("Require", "www.onvif.org/ver20/backchannel") + } + + if c.UserAgent != "" { + // this camera will answer with 401 on DESCRIBE without User-Agent + // https://github.com/AlexxIT/go2rtc/issues/235 + req.Header.Set("User-Agent", c.UserAgent) + } + var res *util.Response + res, err = c.Do(req) + if err != nil { + return + } + + if val := res.Header.Get("Content-Base"); val != "" { + c.URL, err = urlParse(val) + if err != nil { + return + } + } + + c.sdp = string(res.Body) // for info + + medias, err = UnmarshalSDP(res.Body) + if err != nil { + return + } + if c.Media != "" { + clone := make([]*Media, 0, len(medias)) + for _, media := range medias { + if strings.Contains(c.Media, media.Kind) { + clone = append(clone, media) + } + } + medias = clone + } + + return +} + +func (c *Stream) Announce(medias []*Media) (err error) { + req := &util.Request{ + Method: MethodAnnounce, + URL: c.URL, + Header: map[string][]string{ + "Content-Type": {"application/sdp"}, + }, + } + + req.Body, err = MarshalSDP(c.SessionName, medias) + if err != nil { + return err + } + + _, err = c.Do(req) + + return +} + +func (c *Stream) SetupMedia(media *Media, index int) (byte, error) { + var transport string + transport = fmt.Sprintf( + // i - RTP (data channel) + // i+1 - RTCP (control channel) + "RTP/AVP/TCP;unicast;interleaved=%d-%d", index*2, index*2+1, + ) + if transport == "" { + return 0, fmt.Errorf("wrong media: %v", media) + } + + rawURL := media.ID // control + if !strings.Contains(rawURL, "://") { + rawURL = c.URL.String() + if !strings.HasSuffix(rawURL, "/") { + rawURL += "/" + } + rawURL += media.ID + } + trackURL, err := urlParse(rawURL) + if err != nil { + return 0, err + } + + req := &util.Request{ + Method: MethodSetup, + URL: trackURL, + Header: map[string][]string{ + "Transport": {transport}, + }, + } + + res, err := c.Do(req) + if err != nil { + // some Dahua/Amcrest cameras fail here because two simultaneous + // backchannel connections + //if c.Backchannel { + // c.Backchannel = false + // if err = c.Connect(); err != nil { + // return 0, err + // } + // return c.SetupMedia(media) + //} + + return 0, err + } + + if c.Session == "" { + // Session: 7116520596809429228 + // Session: 216525287999;timeout=60 + if s := res.Header.Get("Session"); s != "" { + if i := strings.IndexByte(s, ';'); i > 0 { + c.Session = s[:i] + if i = strings.Index(s, "timeout="); i > 0 { + c.keepalive, _ = strconv.Atoi(s[i+8:]) + } + } else { + c.Session = s + } + } + } + + // we send our `interleaved`, but camera can answer with another + + // Transport: RTP/AVP/TCP;unicast;interleaved=10-11;ssrc=10117CB7 + // Transport: RTP/AVP/TCP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0 + // Transport: RTP/AVP/TCP;ssrc=22345682;interleaved=0-1 + transport = res.Header.Get("Transport") + if !strings.HasPrefix(transport, "RTP/AVP/TCP;") { + // Escam Q6 has a bug: + // Transport: RTP/AVP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0-1 + if !strings.Contains(transport, ";interleaved=") { + return 0, fmt.Errorf("wrong transport: %s", transport) + } + } + + channel := Between(transport, "interleaved=", "-") + i, err := strconv.Atoi(channel) + if err != nil { + return 0, err + } + + return byte(i), nil +} + +func (c *Stream) Play() (err error) { + return c.WriteRequest(&util.Request{Method: MethodPlay, URL: c.URL}) +} + +func (c *Stream) Record() (err error) { + return c.WriteRequest(&util.Request{Method: MethodRecord, URL: c.URL}) +} + +func (c *Stream) Teardown() (err error) { + // allow TEARDOWN from any state (ex. ANNOUNCE > SETUP) + return c.WriteRequest(&util.Request{Method: MethodTeardown, URL: c.URL}) +} + +func (ns *Stream) disconnect() { + if ns != nil && ns.NetConnection != nil { + _ = ns.Teardown() + ns.NetConnection.Destroy() + } +} diff --git a/plugin/rtsp/pkg/transceiver.go b/plugin/rtsp/pkg/transceiver.go index 3b76b4c..53c9863 100644 --- a/plugin/rtsp/pkg/transceiver.go +++ b/plugin/rtsp/pkg/transceiver.go @@ -10,30 +10,18 @@ import ( "reflect" ) -type Stream struct { - *NetConnection - AudioChannelID int - VideoChannelID int -} type Sender struct { *m7s.Subscriber - Stream + *Stream } type Receiver struct { *m7s.Publisher - Stream + *Stream AudioCodecParameters *webrtc.RTPCodecParameters VideoCodecParameters *webrtc.RTPCodecParameters } -func (ns *Stream) Close() error { - if ns.NetConnection != nil { - ns.NetConnection.Destroy() - } - return nil -} - func (s *Sender) GetMedia() (medias []*Media, err error) { if s.SubAudio && s.Publisher.PubAudio && s.Publisher.HasAudioTrack() { audioTrack := s.Publisher.GetAudioTrack(reflect.TypeOf((*mrtp.RTPAudio)(nil))) @@ -163,6 +151,7 @@ func (r *Receiver) Receive() (err error) { var channelID byte var buf []byte for err == nil { + channelID, buf, err = r.NetConnection.Receive(false) if err != nil { return @@ -184,7 +173,9 @@ func (r *Receiver) Receive() (err error) { audioFrame.AddRecycleBytes(buf) audioFrame.Packets = append(audioFrame.Packets, packet) } else { - err = r.WriteAudio(audioFrame) + if err = r.WriteAudio(audioFrame); err != nil { + return + } audioFrame = &mrtp.RTPAudio{} audioFrame.AddRecycleBytes(buf) audioFrame.Packets = []*rtp.Packet{packet} @@ -204,7 +195,9 @@ func (r *Receiver) Receive() (err error) { videoFrame.Packets = append(videoFrame.Packets, packet) } else { // t := time.Now() - err = r.WriteVideo(videoFrame) + if err = r.WriteVideo(videoFrame); err != nil { + return + } // fmt.Println("write video", time.Since(t)) videoFrame = &mrtp.Video{} videoFrame.AddRecycleBytes(buf) diff --git a/plugin/stress/api.go b/plugin/stress/api.go index 5e202b5..b683897 100644 --- a/plugin/stress/api.go +++ b/plugin/stress/api.go @@ -13,14 +13,14 @@ import ( "m7s.live/m7s/v5/plugin/stress/pb" ) -func (r *StressPlugin) pull(count int, format, url string, newFunc func() m7s.PullHandler) error { +func (r *StressPlugin) pull(count int, format, url string, puller m7s.Puller) error { if i := r.pullers.Length; count > i { for j := i; j < count; j++ { - puller, err := r.Pull(fmt.Sprintf("stress/%d", j), fmt.Sprintf(format, url)) + ctx, err := r.Pull(fmt.Sprintf("stress/%d", j), fmt.Sprintf(format, url)) if err != nil { return err } - go r.startPull(puller, newFunc()) + go r.startPull(ctx, puller) } } else if count < i { for j := i; j > count; j-- { @@ -31,14 +31,14 @@ func (r *StressPlugin) pull(count int, format, url string, newFunc func() m7s.Pu return nil } -func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, newFunc func() m7s.PushHandler) (err error) { +func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, pusher m7s.Pusher) (err error) { if i := r.pushers.Length; count > i { for j := i; j < count; j++ { - pusher, err := r.Push(streamPath, fmt.Sprintf(format, remoteHost, j)) + ctx, err := r.Push(streamPath, fmt.Sprintf(format, remoteHost, j)) if err != nil { return err } - go r.startPush(pusher, newFunc()) + go r.startPush(ctx, pusher) } } else if count < i { for j := i; j > count; j-- { @@ -50,34 +50,34 @@ func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, ne } func (r *StressPlugin) PushRTMP(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) { - return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtmp://%s/stress/%d", req.RemoteHost, rtmp.NewPushHandler) + return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtmp://%s/stress/%d", req.RemoteHost, rtmp.Client{}.DoPush) } func (r *StressPlugin) PushRTSP(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) { - return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtsp://%s/stress/%d", req.RemoteHost, rtsp.NewPushHandler) + return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtsp://%s/stress/%d", req.RemoteHost, rtsp.Client{}.DoPush) } func (r *StressPlugin) PullRTMP(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) { - return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtmp://%s", req.RemoteURL, rtmp.NewPullHandler) + return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtmp://%s", req.RemoteURL, rtmp.Client{}.DoPull) } func (r *StressPlugin) PullRTSP(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) { - return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtsp://%s", req.RemoteURL, rtsp.NewPullHandler) + return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtsp://%s", req.RemoteURL, rtsp.Client{}.DoPull) } func (r *StressPlugin) PullHDL(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) { - return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.NewPullHandler) + return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.PullFLV) } -func (r *StressPlugin) startPush(pusher *m7s.Pusher, handler m7s.PushHandler) { +func (r *StressPlugin) startPush(pusher *m7s.PushContext, handler m7s.Pusher) { r.pushers.AddUnique(pusher) - pusher.Start(handler) + pusher.Run(handler) r.pushers.Remove(pusher) } -func (r *StressPlugin) startPull(puller *m7s.Puller, handler m7s.PullHandler) { +func (r *StressPlugin) startPull(puller *m7s.PullContext, handler m7s.Puller) { r.pullers.AddUnique(puller) - puller.Start(handler) + puller.Run(handler) r.pullers.Remove(puller) } diff --git a/plugin/stress/index.go b/plugin/stress/index.go index 6b5b11f..5e0a4d2 100644 --- a/plugin/stress/index.go +++ b/plugin/stress/index.go @@ -10,8 +10,8 @@ import ( type StressPlugin struct { pb.UnimplementedApiServer m7s.Plugin - pushers util.Collection[string, *m7s.Pusher] - pullers util.Collection[string, *m7s.Puller] + pushers util.Collection[string, *m7s.PushContext] + pullers util.Collection[string, *m7s.PullContext] } var _ = m7s.InstallPlugin[StressPlugin](&pb.Api_ServiceDesc, pb.RegisterApiHandler) diff --git a/publisher.go b/publisher.go index bf5cfee..dd70231 100644 --- a/publisher.go +++ b/publisher.go @@ -1,6 +1,7 @@ package m7s import ( + "context" "math" "os" "path/filepath" @@ -59,21 +60,22 @@ type AVTracks struct { } func (t *AVTracks) CreateSubTrack(dataType reflect.Type) (track *AVTrack) { - track = NewAVTrack(dataType, t.AVTrack, util.NewPromise(struct{}{})) + track = NewAVTrack(dataType, t.AVTrack, util.NewPromise(context.TODO())) track.WrapIndex = t.Length t.Add(track) return } +// createPublisher -> Start -> WriteAudio/WriteVideo -> Dispose type Publisher struct { PubSubBase - sync.RWMutex `json:"-" yaml:"-"` + sync.RWMutex config.Publish State PublisherState AudioTrack, VideoTrack AVTracks - audioReady, videoReady *util.Promise[struct{}] + audioReady, videoReady *util.Promise DataTrack *DataTrack - Subscribers util.Collection[int, *Subscriber] `json:"-" yaml:"-"` + Subscribers SubscriberCollection GOP int baseTs, lastTs time.Duration dumpFile *os.File @@ -87,6 +89,70 @@ func (p *Publisher) GetKey() string { return p.StreamPath } +func createPublisher(p *Plugin, streamPath string, options ...any) (publisher *Publisher) { + publisher = &Publisher{Publish: p.config.Publish} + publisher.ID = p.Server.streamTM.GetID() + publisher.Executor = publisher + publisher.Plugin = p + publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout) + var opt = []any{p.Logger.With("streamPath", streamPath, "pId", publisher.ID)} + for _, option := range options { + switch v := option.(type) { + case func(*config.Publish): + v(&publisher.Publish) + default: + opt = append(opt, option) + } + } + publisher.Init(streamPath, &publisher.Publish, opt...) + return +} + +func (p *Publisher) Start() (err error) { + s := p.Plugin.Server + if oldPublisher, ok := s.Streams.Get(p.StreamPath); ok { + if p.KickExist { + p.Warn("kick") + oldPublisher.Stop(ErrKick) + p.TakeOver(oldPublisher) + } else { + return ErrStreamExist + } + } + s.Streams.Set(p) + p.Info("publish") + p.audioReady = util.NewPromiseWithTimeout(p, time.Second*5) + p.videoReady = util.NewPromiseWithTimeout(p, time.Second*5) + if p.Dump { + f := filepath.Join("./dump", p.StreamPath) + os.MkdirAll(filepath.Dir(f), 0666) + p.dumpFile, _ = os.OpenFile(filepath.Join("./dump", p.StreamPath), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) + } + if waiting, ok := s.Waiting.Get(p.StreamPath); ok { + p.TakeOver(waiting) + s.Waiting.Remove(waiting) + } + for plugin := range s.Plugins.Range { + if plugin.Disabled { + continue + } + if remoteURL := plugin.GetCommonConf().CheckPush(p.StreamPath); remoteURL != "" { + if _, ok := plugin.handler.(IPusherPlugin); ok { + go plugin.Push(p.StreamPath, remoteURL) + } + } + if filePath := plugin.GetCommonConf().CheckRecord(p.StreamPath); filePath != "" { + if _, ok := plugin.handler.(IRecorderPlugin); ok { + go plugin.Record(p.StreamPath, filePath) + } + } + //if h, ok := plugin.handler.(IOnPublishPlugin); ok { + // h.OnPublish(publisher) + //} + } + return +} + func (p *Publisher) timeout() (err error) { switch p.State { case PublisherStateInit: @@ -179,17 +245,6 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) { } } -func (p *Publisher) Start() { - p.Info("publish") - p.audioReady = util.NewPromiseWithTimeout(struct{}{}, time.Second*5) - p.videoReady = util.NewPromiseWithTimeout(struct{}{}, time.Second*5) - if p.Dump { - f := filepath.Join("./dump", p.StreamPath) - os.MkdirAll(filepath.Dir(f), 0666) - p.dumpFile, _ = os.OpenFile(filepath.Join("./dump", p.StreamPath), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) - } -} - func (p *Publisher) writeAV(t *AVTrack, data IAVFrame) { frame := &t.Value frame.Wraps = append(frame.Wraps, data) @@ -222,6 +277,9 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) { data.Recycle() } }() + if err = p.Err(); err != nil { + return + } if p.dumpFile != nil { data.Dump(1, p.dumpFile) } @@ -320,6 +378,9 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) { data.Recycle() } }() + if err = p.Err(); err != nil { + return + } if p.dumpFile != nil { data.Dump(0, p.dumpFile) } @@ -394,6 +455,9 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) { } func (p *Publisher) WriteData(data IDataFrame) (err error) { + if err = p.Err(); err != nil { + return + } if p.DataTrack == nil { p.DataTrack = &DataTrack{} p.DataTrack.Logger = p.Logger.With("track", "data") @@ -441,26 +505,35 @@ func (p *Publisher) HasVideoTrack() bool { return p.VideoTrack.Length > 0 } -func (p *Publisher) Dispose(err error) { +func (p *Publisher) Dispose() { + s := p.Plugin.Server + s.Streams.Remove(p) + if p.Subscribers.Length > 0 { + s.Waiting.Add(p) + } + p.Info("unpublish", "remain", s.Streams.Length, "reason", p.StopReason()) + for subscriber := range p.SubscriberRange { + waitCloseTimeout := p.WaitCloseTimeout + if waitCloseTimeout == 0 { + waitCloseTimeout = subscriber.WaitTimeout + } + subscriber.TimeoutTimer.Reset(waitCloseTimeout) + } p.Lock() defer p.Unlock() if p.dumpFile != nil { p.dumpFile.Close() } if p.State == PublisherStateDisposed { - return + panic("disposed") } - if p.IsStopped() { - if p.HasAudioTrack() { - p.AudioTrack.Dispose() - } - if p.HasVideoTrack() { - p.VideoTrack.Dispose() - } - p.State = PublisherStateDisposed - return + if p.HasAudioTrack() { + p.AudioTrack.Dispose() } - p.Stop(err) + if p.HasVideoTrack() { + p.VideoTrack.Dispose() + } + p.State = PublisherStateDisposed } func (p *Publisher) TakeOver(old *Publisher) { @@ -469,18 +542,16 @@ func (p *Publisher) TakeOver(old *Publisher) { for subscriber := range old.SubscriberRange { p.AddSubscriber(subscriber) } - if old.Plugin != nil { - old.Dispose(nil) - } - old.Subscribers = util.Collection[int, *Subscriber]{} + old.Stop(ErrKick) + old.Subscribers = SubscriberCollection{} } func (p *Publisher) WaitTrack() (err error) { if p.PubVideo { - _, err = p.videoReady.Await() + err = p.videoReady.Await() } if p.PubAudio { - _, err = p.audioReady.Await() + err = p.audioReady.Await() } return } diff --git a/puller.go b/puller.go index 9f3ae4f..5bfcee1 100644 --- a/puller.go +++ b/puller.go @@ -1,66 +1,101 @@ package m7s import ( - "io" - "time" - + "context" + "m7s.live/m7s/v5/pkg" "m7s.live/m7s/v5/pkg/config" + "time" ) -type Client struct { - *PubSubBase +type Connection struct { + pkg.Task + Plugin *Plugin + StreamPath string // 对应本地流 RemoteURL string // 远程服务器地址(用于推拉) ReConnectCount int //重连次数 - Proxy string // 代理地址 + ConnectProxy string // 连接代理 + MetaData any } -func (client *Client) reconnect(count int) (ok bool) { +func (client *Connection) reconnect(count int) (ok bool) { ok = count == -1 || client.ReConnectCount <= count client.ReConnectCount++ return } -type PullHandler interface { - Connect(*Client) error - // Disconnect() - Pull(*Puller) error +type Puller = func(*PullContext) error + +func createPullContext(p *Plugin, streamPath string, url string, options ...any) (pullCtx *PullContext) { + pullCtx = &PullContext{Pull: p.config.Pull} + pullCtx.ID = p.Server.pullTM.GetID() + pullCtx.Plugin = p + pullCtx.Executor = pullCtx + pullCtx.ConnectProxy = p.config.Pull.Proxy + pullCtx.RemoteURL = url + publishConfig := p.config.Publish + publishConfig.PublishTimeout = 0 + pullCtx.StreamPath = streamPath + pullCtx.PublishOptions = []any{publishConfig} + var ctx = p.Context + for _, option := range options { + switch v := option.(type) { + case context.Context: + ctx = v + default: + pullCtx.PublishOptions = append(pullCtx.PublishOptions, option) + } + + } + p.Init(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath)) + pullCtx.PublishOptions = append(pullCtx.PublishOptions, pullCtx.Context) + return } -type Puller struct { - Client Client - Publisher +type PullContext struct { + Connection + Publisher *Publisher + PublishOptions []any config.Pull } -func (p *Puller) Start(handler PullHandler) (err error) { - badPuller := true - var startTime time.Time - for p.Info("start pull", "url", p.Client.RemoteURL); p.Client.reconnect(p.RePull); p.Warn("restart pull") { - if time.Since(startTime) < 5*time.Second { +func (p *PullContext) GetKey() string { + return p.StreamPath +} + +func (p *PullContext) Run(puller Puller) { + var err error + defer p.Info("stop pull") + for p.Info("start pull", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePull); p.Warn("restart pull") { + if p.Publisher != nil && time.Since(p.Publisher.StartTime) < 5*time.Second { time.Sleep(5 * time.Second) } - startTime = time.Now() - if err = handler.Connect(&p.Client); err != nil { - if err == io.EOF { - p.Info("pull complete") - return - } - p.Error("pull connect", "error", err) - if badPuller { - return - } - } else { - badPuller = false - p.Client.ReConnectCount = 0 - if err = handler.Pull(p); err != nil && !p.IsStopped() { - p.Error("pull interrupt", "error", err) - } + if p.Publisher, err = p.Plugin.Publish(p.StreamPath, p.PublishOptions...); err != nil { + p.Error("pull publish failed", "error", err) + break } + err = puller(p) + p.Publisher.Stop(err) if p.IsStopped() { - p.Info("stop pull") return + } else { + p.Error("pull interrupt", "error", err) } - // handler.Disconnect() } - return nil + if err == nil { + err = pkg.ErrRetryRunOut + } + p.Stop(err) +} + +func (p *PullContext) Start() (err error) { + s := p.Plugin.Server + if _, ok := s.Pulls.Get(p.GetKey()); ok { + return pkg.ErrStreamExist + } + s.Pulls.Add(p) + return +} + +func (p *PullContext) Dispose() { + p.Plugin.Server.Pulls.Remove(p) } diff --git a/pusher.go b/pusher.go index 0fbf8f3..52cfd56 100644 --- a/pusher.go +++ b/pusher.go @@ -1,57 +1,85 @@ package m7s import ( - "io" + "context" + "m7s.live/m7s/v5/pkg" "time" "m7s.live/m7s/v5/pkg/config" ) -type PushHandler interface { - Connect(*Client) error - // Disconnect() - Push(*Pusher) error +type Pusher = func(*PushContext) error + +func createPushContext(p *Plugin, streamPath string, url string, options ...any) (pushCtx *PushContext) { + pushCtx = &PushContext{Push: p.config.Push} + pushCtx.ID = p.Server.pushTM.GetID() + pushCtx.Plugin = p + pushCtx.Executor = pushCtx + pushCtx.RemoteURL = url + pushCtx.StreamPath = streamPath + pushCtx.ConnectProxy = p.config.Push.Proxy + pushCtx.SubscribeOptions = []any{p.config.Subscribe} + var ctx = p.Context + for _, option := range options { + switch v := option.(type) { + case context.Context: + ctx = v + default: + pushCtx.SubscribeOptions = append(pushCtx.SubscribeOptions, option) + } + } + pushCtx.Init(ctx, p.Logger.With("pushURL", url, "streamPath", streamPath)) + pushCtx.SubscribeOptions = append(pushCtx.SubscribeOptions, pushCtx.Context) + return } -type Pusher struct { - Client Client - Subscriber +type PushContext struct { + Connection + Subscriber *Subscriber + SubscribeOptions []any config.Push } -func (p *Pusher) GetKey() string { - return p.Client.RemoteURL +func (p *PushContext) GetKey() string { + return p.RemoteURL } -func (p *Pusher) Start(handler PushHandler) (err error) { - badPuller := true - var startTime time.Time - for p.Info("start push", "url", p.Client.RemoteURL); p.Client.reconnect(p.RePush); p.Warn("restart push") { - if time.Since(startTime) < 5*time.Second { +func (p *PushContext) Run(pusher Pusher) { + p.StartTime = time.Now() + defer p.Info("stop push") + var err error + for p.Info("start push", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePush); p.Warn("restart push") { + if p.Subscriber != nil && time.Since(p.Subscriber.StartTime) < 5*time.Second { time.Sleep(5 * time.Second) } - startTime = time.Now() - if err = handler.Connect(&p.Client); err != nil { - if err == io.EOF { - p.Info("push complete") - return - } - p.Error("push connect", "error", err) - if badPuller { - return - } - } else { - badPuller = false - p.Client.ReConnectCount = 0 - if err = handler.Push(p); err != nil && !p.IsStopped() { - p.Error("push interrupt", "error", err) - } + if p.Subscriber, err = p.Plugin.Subscribe(p.StreamPath, p.SubscribeOptions...); err != nil { + p.Error("push subscribe failed", "error", err) + break } + err = pusher(p) + p.Subscriber.Stop(err) if p.IsStopped() { - p.Info("stop push") return + } else { + p.Error("push interrupt", "error", err) } - // handler.Disconnect() } - return nil + if err == nil { + err = pkg.ErrRetryRunOut + } + p.Stop(err) + return +} + +func (p *PushContext) Start() (err error) { + s := p.Plugin.Server + if _, ok := s.Pushs.Get(p.GetKey()); ok { + return pkg.ErrPushRemoteURLExist + } + s.Pushs.Add(p) + return +} + +func (p *PushContext) Dispose() { + p.Plugin.Server.Pushs.Remove(p) } diff --git a/recoder.go b/recoder.go index 3d1384b..5adf478 100644 --- a/recoder.go +++ b/recoder.go @@ -1,26 +1,66 @@ package m7s import ( - "m7s.live/m7s/v5/pkg/config" - "os" + "context" + "m7s.live/m7s/v5/pkg" + "time" ) -type RecordHandler interface { - Close() - Record(*Recorder) error +type Recorder = func(*RecordContext) error + +func createRecoder(p *Plugin, streamPath string, filePath string, options ...any) (recorder *RecordContext) { + recorder = &RecordContext{ + Plugin: p, + Fragment: p.config.Record.Fragment, + Append: p.config.Record.Append, + FilePath: filePath, + } + recorder.ID = p.Server.recordTM.GetID() + recorder.Executor = recorder + recorder.FilePath = filePath + recorder.SubscribeOptions = []any{p.config.Subscribe} + var ctx = p.Context + for _, option := range options { + switch v := option.(type) { + case context.Context: + ctx = v + default: + recorder.SubscribeOptions = append(recorder.SubscribeOptions, option) + } + } + recorder.Init(ctx, p.Logger.With("filePath", filePath, "streamPath", streamPath)) + recorder.SubscribeOptions = append(recorder.SubscribeOptions, recorder.Context) + return } -type Recorder struct { - File *os.File - Subscriber - config.Record +type RecordContext struct { + pkg.Task + Plugin *Plugin + Subscriber *Subscriber + SubscribeOptions []any + Fragment time.Duration + Append bool + FilePath string } -func (p *Recorder) GetKey() string { - return p.File.Name() +func (p *RecordContext) GetKey() string { + return p.FilePath } -func (p *Recorder) Start(handler RecordHandler) (err error) { - defer handler.Close() - return handler.Record(p) +func (p *RecordContext) Run(recorder Recorder) { + err := recorder(p) + p.Stop(err) +} + +func (p *RecordContext) Start() (err error) { + s := p.Plugin.Server + if _, ok := s.Records.Get(p.GetKey()); ok { + return pkg.ErrRecordSamePath + } + s.Records.Add(p) + return +} + +func (p *RecordContext) Dispose() { + p.Plugin.Server.Records.Remove(p) } diff --git a/server.go b/server.go index 0dfe434..b2d7117 100644 --- a/server.go +++ b/server.go @@ -20,8 +20,6 @@ import ( "os" "os/signal" "path/filepath" - "reflect" - "slices" "strings" "sync/atomic" "syscall" @@ -33,13 +31,13 @@ var ( MergeConfigs = []string{"Publish", "Subscribe", "HTTP", "PublicIP", "LogLevel", "EnableAuth", "DB"} ExecPath = os.Args[0] ExecDir = filepath.Dir(ExecPath) - serverIndexG atomic.Uint32 DefaultServer = NewServer() serverMeta = PluginMeta{ Name: "Global", Version: Version, } - Servers = make([]*Server, 10) + Servers util.Collection[uint32, *Server] + serverIdG atomic.Uint32 Routes = map[string]string{} defaultLogHandler = console.NewHandler(os.Stdout, &console.HandlerOptions{TimeFormat: "15:04:05.000000"}) ) @@ -56,31 +54,34 @@ type Server struct { pb.UnimplementedGlobalServer Plugin ServerConfig - eventChan chan any - Plugins util.Collection[string, *Plugin] - Streams, Waiting util.Collection[string, *Publisher] - Pulls util.Collection[string, *Puller] - Pushs util.Collection[string, *Pusher] - Records util.Collection[string, *Recorder] - Subscribers util.Collection[int, *Subscriber] - LogHandler MultiLogHandler - pidG, sidG int - apiList []string - grpcServer *grpc.Server - grpcClientConn *grpc.ClientConn - lastSummaryTime time.Time - lastSummary *pb.SummaryResponse - OnAuthPubs map[string]func(p *util.Promise[*Publisher]) - OnAuthSubs map[string]func(p *util.Promise[*Subscriber]) + //eventChan chan any + Plugins util.Collection[string, *Plugin] + Streams, Waiting util.Collection[string, *Publisher] + Pulls util.Collection[string, *PullContext] + Pushs util.Collection[string, *PushContext] + Records util.Collection[string, *RecordContext] + Subscribers SubscriberCollection + LogHandler MultiLogHandler + apiList []string + grpcServer *grpc.Server + grpcClientConn *grpc.ClientConn + lastSummaryTime time.Time + lastSummary *pb.SummaryResponse + OnAuthPubs map[string]func(*Publisher) *util.Promise + OnAuthSubs map[string]func(*Subscriber) *util.Promise + pluginTM, streamTM, pullTM, pushTM, recordTM *TaskManager + runOption struct { + ctx context.Context + conf any + } } func NewServer() (s *Server) { s = &Server{} - s.ID = int(serverIndexG.Add(1)) + s.ID = serverIdG.Add(1) s.Meta = &serverMeta - s.OnAuthPubs = make(map[string]func(p *util.Promise[*Publisher])) - s.OnAuthSubs = make(map[string]func(p *util.Promise[*Subscriber])) - Servers[s.ID] = s + s.OnAuthPubs = make(map[string]func(*Publisher) *util.Promise) + s.OnAuthSubs = make(map[string]func(*Subscriber) *util.Promise) return } @@ -100,21 +101,11 @@ func init() { } } -func (s *Server) Run(ctx context.Context, conf any) (err error) { - s.StartTime = time.Now() - for err = s.run(ctx, conf); err == ErrRestart; err = s.run(ctx, conf) { - var server Server - server.ID = s.ID - server.Meta = s.Meta - server.OnAuthPubs = s.OnAuthPubs - server.OnAuthSubs = s.OnAuthSubs - server.DB = s.DB - *s = server - } - return +func (s *Server) GetKey() uint32 { + return s.ID } -func (s *Server) run(ctx context.Context, conf any) (err error) { +func (s *Server) Start() (err error) { s.Server = s s.handler = s s.config.HTTP.ListenAddrTLS = ":8443" @@ -122,18 +113,16 @@ func (s *Server) run(ctx context.Context, conf any) (err error) { s.config.TCP.ListenAddr = ":50051" s.LogHandler.SetLevel(slog.LevelInfo) s.LogHandler.Add(defaultLogHandler) - s.Logger = slog.New(&s.LogHandler).With("Server", s.ID) - + s.Task.Init(s.runOption.ctx, slog.New(&s.LogHandler).With("Server", s.ID)) + s.Info("start") httpConf, tcpConf := &s.config.HTTP, &s.config.TCP mux := runtime.NewServeMux(runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) { httpConf.GetHttpMux().ServeHTTP(w, r) })) httpConf.SetMux(mux) - s.Context, s.CancelCauseFunc = context.WithCancelCause(ctx) - s.Info("start") var cg rawconfig var configYaml []byte - switch v := conf.(type) { + switch v := s.runOption.conf.(type) { case string: if _, err = os.Stat(v); err != nil { v = filepath.Join(ExecDir, v) @@ -156,7 +145,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) { if cg != nil { s.Config.ParseUserFile(cg["global"]) } - s.eventChan = make(chan any, s.EventBusSize) + //s.eventChan = make(chan any, s.EventBusSize) s.LogHandler.SetLevel(ParseLevel(s.config.LogLevel)) s.registerHandler(map[string]http.HandlerFunc{ "/api/config/json/{name}": s.api_Config_JSON_, @@ -175,7 +164,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) { if httpConf.ListenAddrTLS != "" { s.Info("https listen at ", "addr", httpConf.ListenAddrTLS) go func(addr string) { - if err := httpConf.ListenTLS(); err != http.ErrServerClosed { + if err = httpConf.ListenTLS(); err != http.ErrServerClosed { s.Stop(err) } s.Info("https stop listen at ", "addr", addr) @@ -184,7 +173,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) { if httpConf.ListenAddr != "" { s.Info("http listen at ", "addr", httpConf.ListenAddr) go func(addr string) { - if err := httpConf.Listen(); err != http.ErrServerClosed { + if err = httpConf.Listen(); err != http.ErrServerClosed { s.Stop(err) } s.Info("http stop listen at ", "addr", addr) @@ -196,316 +185,125 @@ func (s *Server) run(ctx context.Context, conf any) (err error) { s.grpcServer = grpc.NewServer(opts...) pb.RegisterGlobalServer(s.grpcServer, s) - s.grpcClientConn, err = grpc.DialContext(ctx, tcpConf.ListenAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + s.grpcClientConn, err = grpc.DialContext(s.Context, tcpConf.ListenAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { s.Error("failed to dial", "error", err) - return err + return } defer s.grpcClientConn.Close() - if err = pb.RegisterGlobalHandler(ctx, mux, s.grpcClientConn); err != nil { + if err = pb.RegisterGlobalHandler(s.Context, mux, s.grpcClientConn); err != nil { s.Error("register handler faild", "error", err) - return err + return } tcplis, err = net.Listen("tcp", tcpConf.ListenAddr) if err != nil { s.Error("failed to listen", "error", err) - return err + return } defer tcplis.Close() } + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + s.pluginTM = NewTaskManager() + go s.pluginTM.Run(signalChan, func() { + for plugin := range s.Plugins.Range { + plugin.handler.OnExit() + } + }) for _, plugin := range plugins { - plugin.Init(s, cg[strings.ToLower(plugin.Name)]) + if p := plugin.Init(s, cg[strings.ToLower(plugin.Name)]); !p.Disabled { + s.pluginTM.Add(&p.Task) + } } if tcplis != nil { go func(addr string) { - if err := s.grpcServer.Serve(tcplis); err != nil { + if err = s.grpcServer.Serve(tcplis); err != nil { s.Stop(err) } s.Info("grpc stop listen at ", "addr", addr) }(tcpConf.ListenAddr) } - s.eventLoop() - err = context.Cause(s) - s.Warn("Server is done", "reason", err) - for publisher := range s.Streams.Range { - publisher.Stop(err) - } - for subscriber := range s.Subscribers.Range { - subscriber.Stop(err) - } - for p := range s.Plugins.Range { - p.Stop(err) - } - httpConf.StopListen() + s.streamTM = NewTaskManager() + s.pullTM = NewTaskManager() + s.pushTM = NewTaskManager() + s.recordTM = NewTaskManager() + go s.streamTM.Run(time.NewTicker(s.PulseInterval).C, func(time.Time) { + for publisher := range s.Streams.Range { + if err := publisher.checkTimeout(); err != nil { + publisher.Stop(err) + } + } + for publisher := range s.Waiting.Range { + // TODO: ? + //if publisher.Plugin != nil { + // if err := publisher.checkTimeout(); err != nil { + // publisher.Stop(err) + // s.createWait(publisher.StreamPath) + // } + //} + for sub := range publisher.SubscriberRange { + select { + case <-sub.TimeoutTimer.C: + sub.Stop(ErrSubscribeTimeout) + default: + } + } + } + }) + go s.pullTM.Run() + go s.pushTM.Run() + go s.recordTM.Run() + Servers.Add(s) return } -type DoneChan = <-chan struct{} - -func (s *Server) doneEventLoop(input chan DoneChan, output chan int) { - cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(input)}} - for { - switch chosen, rev, ok := reflect.Select(cases); chosen { - case 0: - if !ok { - return - } - cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: rev}) - default: - output <- chosen - 1 - cases = slices.Delete(cases, chosen, chosen+1) - } - } +func (s *Server) Call(callback func()) { + s.streamTM.Call(callback) } -// eventLoop powerful grateful graceful beautiful -func (s *Server) eventLoop() { - pulse := time.NewTicker(s.PulseInterval) - defer pulse.Stop() - pubChan := make(chan DoneChan, 10) - pubDoneChan := make(chan int, 10) - subChan := make(chan DoneChan, 10) - subDoneChan := make(chan int, 10) - defer close(pubChan) - defer close(subChan) - go s.doneEventLoop(pubChan, pubDoneChan) - go s.doneEventLoop(subChan, subDoneChan) - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(signalChan) +func (s *Server) Dispose() { + Servers.Remove(s) + s.config.HTTP.StopListen() + err := context.Cause(s) + s.streamTM.ShutDown(err) + s.pullTM.ShutDown(err) + s.pushTM.ShutDown(err) + s.recordTM.ShutDown(err) + s.pluginTM.ShutDown(err) + s.Warn("Server is done", "reason", err) +} + +func (s *Server) Run(ctx context.Context, conf any) (err error) { for { - select { - case <-signalChan: - for plugin := range s.Plugins.Range { - if plugin.Disabled { - continue - } - plugin.handler.OnExit() - } - case <-s.Done(): + s.runOption.ctx = ctx + s.runOption.conf = conf + if err = s.Start(); err != nil { return - case <-pulse.C: - for publisher := range s.Streams.Range { - if err := publisher.checkTimeout(); err != nil { - publisher.Stop(err) - } - } - for publisher := range s.Waiting.Range { - if publisher.Plugin != nil { - if err := publisher.checkTimeout(); err != nil { - publisher.Dispose(err) - s.createWait(publisher.StreamPath) - } - } - for sub := range publisher.SubscriberRange { - select { - case <-sub.TimeoutTimer.C: - sub.Stop(ErrSubscribeTimeout) - default: - } - } - } - case pubDone := <-pubDoneChan: - s.onUnpublish(s.Streams.Items[pubDone]) - case subDone := <-subDoneChan: - s.onUnsubscribe(s.Subscribers.Items[subDone]) - case event := <-s.eventChan: - switch v := event.(type) { - case *util.Promise[any]: - switch vv := v.Value.(type) { - case func(): - vv() - v.Fulfill(nil) - continue - case func() error: - v.Fulfill(vv()) - continue - case *Publisher: - err := s.OnPublish(vv) - if v.Fulfill(err); err != nil { - continue - } - event = vv - pubChan <- vv.Done() - case *Subscriber: - err := s.OnSubscribe(vv) - if v.Fulfill(err); err != nil { - continue - } - subChan <- vv.Done() - if !s.EnableSubEvent { - continue - } - event = v.Value - case *Puller: - if _, ok := s.Pulls.Get(vv.GetKey()); ok { - v.Fulfill(ErrStreamExist) - continue - } else { - err := s.OnPublish(&vv.Publisher) - v.Fulfill(err) - if err != nil { - continue - } - s.Pulls.Add(vv) - pubChan <- vv.Done() - event = v.Value - } - case *Pusher: - if _, ok := s.Pushs.Get(vv.GetKey()); ok { - v.Fulfill(ErrStreamExist) - continue - } else { - err := s.OnSubscribe(&vv.Subscriber) - v.Fulfill(err) - if err != nil { - continue - } - subChan <- vv.Done() - s.Pushs.Add(vv) - event = v.Value - } - case *Recorder: - if _, ok := s.Records.Get(vv.GetKey()); ok { - v.Fulfill(ErrStreamExist) - continue - } else { - err := s.OnSubscribe(&vv.Subscriber) - v.Fulfill(err) - if err != nil { - continue - } - subChan <- vv.Done() - s.Records.Add(vv) - event = v.Value - } - } - case slog.Handler: - s.LogHandler.Add(v) - } - for plugin := range s.Plugins.Range { - if plugin.Disabled { - continue - } - plugin.onEvent(event) - } } - } -} - -func (s *Server) onUnsubscribe(subscriber *Subscriber) { - s.Subscribers.Remove(subscriber) - s.Info("unsubscribe", "streamPath", subscriber.StreamPath, "reason", subscriber.StopReason()) - if subscriber.Closer != nil { - subscriber.Close() - } - for pusher := range s.Pushs.Range { - if &pusher.Subscriber == subscriber { - s.Pushs.Remove(pusher) - break + <-s.Done() + s.Dispose() + if err = context.Cause(s); err != ErrRestart { + return } + var server Server + server.ID = s.ID + server.Meta = s.Meta + server.OnAuthPubs = s.OnAuthPubs + server.OnAuthSubs = s.OnAuthSubs + server.DB = s.DB + *s = server } - if subscriber.Publisher != nil { - subscriber.Publisher.RemoveSubscriber(subscriber) - } -} - -func (s *Server) onUnpublish(publisher *Publisher) { - s.Streams.Remove(publisher) - if publisher.Subscribers.Length > 0 { - s.Waiting.Add(publisher) - } - s.Info("unpublish", "streamPath", publisher.StreamPath, "count", s.Streams.Length, "reason", publisher.StopReason()) - for subscriber := range publisher.SubscriberRange { - waitCloseTimeout := publisher.WaitCloseTimeout - if waitCloseTimeout == 0 { - waitCloseTimeout = subscriber.WaitTimeout - } - subscriber.TimeoutTimer.Reset(waitCloseTimeout) - } - if publisher.Closer != nil { - _ = publisher.Close() - } - s.Pulls.RemoveByKey(publisher.StreamPath) -} - -func (s *Server) OnPublish(publisher *Publisher) error { - if oldPublisher, ok := s.Streams.Get(publisher.StreamPath); ok { - if publisher.KickExist { - publisher.Warn("kick") - oldPublisher.Stop(ErrKick) - publisher.TakeOver(oldPublisher) - } else { - return ErrStreamExist - } - } - s.Streams.Set(publisher) - s.pidG++ - p := publisher.Plugin - publisher.ID = s.pidG - publisher.Logger = p.With("streamPath", publisher.StreamPath, "pubID", publisher.ID) - publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout) - publisher.Start() - if waiting, ok := s.Waiting.Get(publisher.StreamPath); ok { - publisher.TakeOver(waiting) - s.Waiting.Remove(waiting) - } - for plugin := range s.Plugins.Range { - if plugin.Disabled { - continue - } - if remoteURL := plugin.GetCommonConf().CheckPush(publisher.StreamPath); remoteURL != "" { - if _, ok := plugin.handler.(IPusherPlugin); ok { - go plugin.Push(publisher.StreamPath, remoteURL) - } - } - if filePath := plugin.GetCommonConf().CheckRecord(publisher.StreamPath); filePath != "" { - if _, ok := plugin.handler.(IRecorderPlugin); ok { - go plugin.Record(publisher.StreamPath, filePath) - } - } - } - return nil } func (s *Server) createWait(streamPath string) *Publisher { newPublisher := &Publisher{} - s.pidG++ - newPublisher.ID = s.pidG - newPublisher.Logger = s.Logger.With("pubID", newPublisher.ID, "streamPath", streamPath) + newPublisher.Logger = s.Logger.With("streamPath", streamPath) s.Info("createWait") newPublisher.StreamPath = streamPath s.Waiting.Set(newPublisher) return newPublisher } -func (s *Server) OnSubscribe(subscriber *Subscriber) error { - s.sidG++ - subscriber.ID = s.sidG - subscriber.Logger = subscriber.Plugin.With("streamPath", subscriber.StreamPath, "subID", subscriber.ID) - subscriber.TimeoutTimer = time.NewTimer(subscriber.Plugin.config.Subscribe.WaitTimeout) - s.Subscribers.Add(subscriber) - subscriber.Info("subscribe") - if publisher, ok := s.Streams.Get(subscriber.StreamPath); ok { - publisher.AddSubscriber(subscriber) - } else if publisher, ok = s.Waiting.Get(subscriber.StreamPath); ok { - publisher.AddSubscriber(subscriber) - } else { - s.createWait(subscriber.StreamPath).AddSubscriber(subscriber) - for plugin := range s.Plugins.Range { - if plugin.Disabled { - continue - } - if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(subscriber.StreamPath); remoteURL != "" { - if _, ok := plugin.handler.(IPullerPlugin); ok { - go plugin.Pull(subscriber.StreamPath, remoteURL) - } - } - } - } - return nil -} - func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/favicon.ico" { http.ServeFile(w, r, "favicon.ico") @@ -520,17 +318,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) Call(arg any) (result any, err error) { - promise := util.NewPromise(arg) - s.eventChan <- promise - <-promise.Done() - result = promise.Value - if err = context.Cause(promise.Context); err == util.ErrResolve { - err = nil - } - return -} - -func (s *Server) PostMessage(msg any) { - s.eventChan <- msg -} +//func (s *Server) Call(arg any) (result any, err error) { +// promise := util.NewPromise(arg) +// s.eventChan <- promise +// <-promise.Done() +// result = promise.Value +// if err = context.Cause(promise.Context); err == util.ErrResolve { +// err = nil +// } +// return +//} +// +//func (s *Server) PostMessage(msg any) { +// s.eventChan <- msg +//} diff --git a/subscriber.go b/subscriber.go index 9d16ed7..00bbcfc 100644 --- a/subscriber.go +++ b/subscriber.go @@ -3,10 +3,8 @@ package m7s import ( "context" "errors" - "io" - "net" + "log/slog" "net/url" - "os" "reflect" "runtime" "strings" @@ -19,46 +17,29 @@ import ( var AVFrameType = reflect.TypeOf((*AVFrame)(nil)) -type Owner struct { - Conn net.Conn - File *os.File - MetaData any - io.Closer -} - type PubSubBase struct { - Unit[int] - Owner + Task Plugin *Plugin StreamPath string Args url.Values TimeoutTimer *time.Timer + MetaData any } -func (p *PubSubBase) GetKey() int { - return p.ID -} - -func (ps *PubSubBase) Init(p *Plugin, streamPath string, conf any, options ...any) { - ps.Plugin = p - ctx := p.Context +func (ps *PubSubBase) Init(streamPath string, conf any, options ...any) { + ctx := ps.Plugin.Context + var logger *slog.Logger for _, option := range options { switch v := option.(type) { + case *slog.Logger: + logger = v case context.Context: ctx = v - case net.Conn: - ps.Conn = v - ps.Closer = v - case *os.File: - ps.File = v - ps.Closer = v - case io.Closer: - ps.Closer = v default: ps.MetaData = v } } - ps.Context, ps.CancelCauseFunc = context.WithCancelCause(ctx) + ps.Task.Init(ctx, logger) if u, err := url.Parse(streamPath); err == nil { ps.StreamPath, ps.Args = u.Path, u.Query() } @@ -80,8 +61,11 @@ func (ps *PubSubBase) Init(p *Plugin, streamPath string, conf any, options ...an c.ParseModifyFile(cc) } ps.StartTime = time.Now() + } +type SubscriberCollection = util.Collection[uint32, *Subscriber] + type Subscriber struct { PubSubBase config.Subscribe @@ -90,6 +74,57 @@ type Subscriber struct { VideoReader *AVRingReader } +func createSubscriber(p *Plugin, streamPath string, options ...any) *Subscriber { + subscriber := &Subscriber{Subscribe: p.config.Subscribe} + subscriber.ID = p.Server.streamTM.GetID() + subscriber.Plugin = p + subscriber.Executor = subscriber + subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout) + var opt = []any{p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)} + for _, option := range options { + switch v := option.(type) { + case func(*config.Subscribe): + v(&subscriber.Subscribe) + default: + opt = append(opt, option) + } + } + subscriber.Init(streamPath, &subscriber.Subscribe, opt...) + if subscriber.Subscribe.BufferTime > 0 { + subscriber.Subscribe.SubMode = SUBMODE_BUFFER + } + return subscriber +} + +func (s *Subscriber) Start() (err error) { + server := s.Plugin.Server + server.Subscribers.Add(s) + s.Info("subscribe") + if publisher, ok := server.Streams.Get(s.StreamPath); ok { + publisher.AddSubscriber(s) + } else if publisher, ok = server.Waiting.Get(s.StreamPath); ok { + publisher.AddSubscriber(s) + } else { + server.createWait(s.StreamPath).AddSubscriber(s) + for plugin := range server.Plugins.Range { + if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(s.StreamPath); remoteURL != "" { + if _, ok := plugin.handler.(IPullerPlugin); ok { + go plugin.Pull(s.StreamPath, remoteURL) + } + } + } + } + return +} + +func (s *Subscriber) Dispose() { + s.Plugin.Server.Subscribers.Remove(s) + s.Info("unsubscribe", "reason", s.StopReason()) + if s.Publisher != nil { + s.Publisher.RemoveSubscriber(s) + } +} + func (s *Subscriber) createAudioReader(dataType reflect.Type, startAudioTs time.Duration) (awi int) { if s.Publisher == nil || dataType == nil { return @@ -173,6 +208,7 @@ func PlayBlock0[A any, V any](s *Subscriber, handler SubscribeHandler[A, V]) (er awi := s.createAudioReader(a1, startAudioTs) vwi := s.createVideoReader(v1, startVideoTs) defer func() { + s.Stop(err) if s.AudioReader != nil { s.AudioReader.StopRead() }