refactor: task manager

This commit is contained in:
langhuihui
2024-08-07 18:13:45 +08:00
parent 63b2864e5f
commit 2c17d1da60
37 changed files with 1613 additions and 1471 deletions

30
api.go
View File

@@ -131,7 +131,7 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err
}
func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.StreamInfoResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
if pub, ok := s.Streams.Get(req.StreamPath); ok {
res, err = s.getStreamInfo(pub)
} else {
@@ -141,7 +141,7 @@ func (s *Server) StreamInfo(ctx context.Context, req *pb.StreamSnapRequest) (res
return
}
func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest) (res *pb.SubscribersResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
var subscribers []*pb.SubscriberSnapShot
for subscriber := range s.Subscribers.Range {
meta, _ := json.Marshal(subscriber.MetaData)
@@ -176,7 +176,7 @@ func (s *Server) GetSubscribers(ctx context.Context, req *pb.SubscribersRequest)
return
}
func (s *Server) AudioTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasAudioTrack() {
res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.AudioTrack.Allocator.GetChildren() {
@@ -254,7 +254,7 @@ func (s *Server) api_VideoTrack_SSE(rw http.ResponseWriter, r *http.Request) {
}
func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest) (res *pb.TrackSnapShotResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
if pub, ok := s.Streams.Get(req.StreamPath); ok && pub.HasVideoTrack() {
res = &pb.TrackSnapShotResponse{}
for _, memlist := range pub.VideoTrack.Allocator.GetChildren() {
@@ -304,15 +304,15 @@ func (s *Server) VideoTrackSnap(ctx context.Context, req *pb.StreamSnapRequest)
}
func (s *Server) Restart(ctx context.Context, req *pb.RequestWithId) (res *emptypb.Empty, err error) {
if Servers[req.Id] != nil {
Servers[req.Id].Stop(pkg.ErrRestart)
if s, ok := Servers.Get(req.Id); ok {
s.Stop(pkg.ErrRestart)
}
return empty, err
}
func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *emptypb.Empty, err error) {
if Servers[req.Id] != nil {
Servers[req.Id].Stop(pkg.ErrStopFromAPI)
if s, ok := Servers.Get(req.Id); ok {
s.Stop(pkg.ErrStopFromAPI)
} else {
return nil, pkg.ErrNotFound
}
@@ -320,8 +320,8 @@ func (s *Server) Shutdown(ctx context.Context, req *pb.RequestWithId) (res *empt
}
func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeRequest) (res *pb.SuccessResponse, err error) {
s.Call(func() {
if subscriber, ok := s.Subscribers.Get(int(req.Id)); ok {
s.streamTM.Call(func() {
if subscriber, ok := s.Subscribers.Get(req.Id); ok {
if pub, ok := s.Streams.Get(req.StreamPath); ok {
subscriber.Publisher.RemoveSubscriber(subscriber)
subscriber.StreamPath = req.StreamPath
@@ -335,8 +335,8 @@ func (s *Server) ChangeSubscribe(ctx context.Context, req *pb.ChangeSubscribeReq
}
func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res *pb.SuccessResponse, err error) {
s.Call(func() {
if subscriber, ok := s.Subscribers.Get(int(req.Id)); ok {
s.streamTM.Call(func() {
if subscriber, ok := s.Subscribers.Get(req.Id); ok {
subscriber.Stop(errors.New("stop by api"))
} else {
err = pkg.ErrNotFound
@@ -347,7 +347,7 @@ func (s *Server) StopSubscribe(ctx context.Context, req *pb.RequestWithId) (res
// /api/stream/list
func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *pb.StreamListResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
var streams []*pb.StreamInfoResponse
for publisher := range s.Streams.Range {
info, err := s.getStreamInfo(publisher)
@@ -362,7 +362,7 @@ func (s *Server) StreamList(_ context.Context, req *pb.StreamListRequest) (res *
}
func (s *Server) WaitList(context.Context, *emptypb.Empty) (res *pb.StreamWaitListResponse, err error) {
s.Call(func() {
s.streamTM.Call(func() {
res = &pb.StreamWaitListResponse{
List: make(map[string]int32),
}
@@ -381,7 +381,6 @@ func (s *Server) Api_Summary_SSE(rw http.ResponseWriter, r *http.Request) {
}
func (s *Server) Summary(context.Context, *emptypb.Empty) (res *pb.SummaryResponse, err error) {
s.Call(func() {
dur := time.Since(s.lastSummaryTime)
if dur < time.Second {
res = s.lastSummary
@@ -427,7 +426,6 @@ func (s *Server) Summary(context.Context, *emptypb.Empty) (res *pb.SummaryRespon
res.NetWork = netWorks
s.lastSummary = res
s.lastSummaryTime = time.Now()
})
return
}

View File

@@ -1,5 +1,7 @@
global:
loglevel: trace
http:
listenaddr: :8082
# enableauth: true
# tcp:
# listenaddr: :50051

View File

@@ -160,6 +160,7 @@ type Record struct {
EnableRegexp bool `desc:"是否启用正则表达式"` // 是否启用正则表达式
RecordList map[string]string
Fragment time.Duration `desc:"分片时长"` // 分片时长
Append bool `desc:"是否追加录制"` // 是否追加录制
}
func (p *Record) GetRecordConfig() *Record {

View File

@@ -11,10 +11,13 @@ var (
ErrPublishTimeout = errors.New("publish timeout")
ErrPublishIdleTimeout = errors.New("publish idle timeout")
ErrPublishDelayCloseTimeout = errors.New("publish delay close timeout")
ErrPushRemoteURLExist = errors.New("push remote url exist")
ErrSubscribeTimeout = errors.New("subscribe timeout")
ErrRestart = errors.New("restart")
ErrInterrupt = errors.New("interrupt")
ErrUnsupportCodec = errors.New("unsupport codec")
ErrMuted = errors.New("muted")
ErrorLost = errors.New("lost")
ErrLost = errors.New("lost")
ErrRetryRunOut = errors.New("retry run out")
ErrRecordSamePath = errors.New("record same path")
)

203
pkg/task.go Normal file
View File

@@ -0,0 +1,203 @@
package pkg
import (
"context"
"io"
"log/slog"
"m7s.live/m7s/v5/pkg/util"
"reflect"
"slices"
"sync/atomic"
"time"
)
const TraceLevel = slog.Level(-8)
type TaskExecutor interface {
Start() error
Dispose()
}
type Task struct {
ID uint32
StartTime time.Time
*slog.Logger
context.Context
context.CancelCauseFunc
Executor TaskExecutor
started *util.Promise
}
func (task *Task) GetTask() *Task {
return task
}
func (task *Task) GetKey() uint32 {
return task.ID
}
func (task *Task) Begin() (err error) {
task.StartTime = time.Now()
err = task.Executor.Start()
task.started.Fulfill(err)
return
}
func (task *Task) WaitStarted() error {
return task.started.Await()
}
func (task *Task) Trace(msg string, fields ...any) {
task.Log(task.Context, TraceLevel, msg, fields...)
}
func (task *Task) IsStopped() bool {
return task.Err() != nil
}
func (task *Task) StopReason() error {
return context.Cause(task.Context)
}
func (task *Task) Stop(err error) {
if task.CancelCauseFunc != nil && !task.IsStopped() {
task.Info("stop", "reason", err.Error())
task.CancelCauseFunc(err)
}
}
func (task *Task) Init(ctx context.Context, logger *slog.Logger) {
task.Logger = logger
task.Context, task.CancelCauseFunc = context.WithCancelCause(ctx)
task.started = util.NewPromise(task.Context)
}
type CallBackTaskExecutor func()
func (call CallBackTaskExecutor) Start() error {
call()
return io.EOF
}
func (call CallBackTaskExecutor) Dispose() {
// nothing to do, never called
}
type TaskManager struct {
shutdown *util.Promise
stopReason error
start chan *Task
Tasks []*Task
idG atomic.Uint32
}
func NewTaskManager() *TaskManager {
return &TaskManager{
shutdown: util.NewPromise(context.TODO()),
start: make(chan *Task, 10),
}
}
func (t *TaskManager) Add(task *Task) {
t.start <- task
}
func (t *TaskManager) Call(callback CallBackTaskExecutor) {
var tmpTask Task
tmpTask.Init(context.TODO(), nil)
tmpTask.Executor = callback
_ = t.Start(&tmpTask)
}
func (t *TaskManager) Start(task *Task) error {
t.start <- task
return task.WaitStarted()
}
func (t *TaskManager) GetID() uint32 {
return t.idG.Add(1)
}
// Run task Start and Dispose in this goroutine
func (t *TaskManager) Run(extra ...any) {
cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.start)}}
extraLen := len(extra) / 2
var callbacks []reflect.Value
for i := range extraLen {
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(extra[i*2])})
callbacks = append(callbacks, reflect.ValueOf(extra[i*2+1]))
}
defer func() {
cases = slices.Delete(cases, 0, 1+extraLen)
for len(cases) > 0 {
chosen, _, _ := reflect.Select(cases)
task := t.Tasks[chosen]
task.Executor.Dispose()
t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1)
cases = slices.Delete(cases, chosen, chosen+1)
}
t.shutdown.Fulfill(t.stopReason)
}()
for {
if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
if !ok {
return
}
task := rev.Interface().(*Task)
if err := task.Begin(); err == nil {
t.Tasks = append(t.Tasks, task)
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
} else {
task.Stop(err)
}
} else if chosen <= extraLen {
callbacks[chosen-1].Call([]reflect.Value{rev})
} else {
taskIndex := chosen - 1 - extraLen
task := t.Tasks[taskIndex]
task.Executor.Dispose()
t.Tasks = slices.Delete(t.Tasks, taskIndex, taskIndex+1)
cases = slices.Delete(cases, chosen, chosen+1)
}
}
}
// Run task Start and Dispose in another goroutine
//func (t *TaskManager) Run() {
// cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.Start)}}
// defer func() {
// cases = slices.Delete(cases, 0, 1)
// for len(cases) > 0 {
// chosen, _, _ := reflect.Select(cases)
// t.Done <- t.Tasks[chosen]
// t.Tasks = slices.Delete(t.Tasks, chosen, chosen+1)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// close(t.Done)
// }()
// for {
// if chosen, rev, ok := reflect.Select(cases); chosen == 0 {
// if !ok {
// return
// }
// task := rev.Interface().(*Task)
// t.Tasks = append(t.Tasks, task)
// cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(task.Done())})
// } else {
// t.Done <- t.Tasks[chosen-1]
// t.Tasks = slices.Delete(t.Tasks, chosen-1, chosen)
// cases = slices.Delete(cases, chosen, chosen+1)
// }
// }
//}
// ShutDown wait all task dispose
func (t *TaskManager) ShutDown(err error) {
t.Stop(err)
_ = t.shutdown.Await()
}
func (t *TaskManager) Stop(err error) {
t.stopReason = err
close(t.start)
}

View File

@@ -14,7 +14,7 @@ import (
type (
Track struct {
*slog.Logger
ready *util.Promise[struct{}]
ready *util.Promise
FrameType reflect.Type
bytesIn int
frameCount int
@@ -55,7 +55,7 @@ func NewAVTrack(args ...any) (t *AVTrack) {
t.RingWriter = NewRingWriter(v.RingSize)
t.BufferRange[0] = v.BufferTime
t.RingWriter.SLogger = t.Logger
case *util.Promise[struct{}]:
case *util.Promise:
t.ready = v
}
}
@@ -112,8 +112,7 @@ func (t *Track) IsReady() bool {
}
func (t *Track) WaitReady() error {
_, err := t.ready.Await()
return err
return t.ready.Await()
}
func (t *Track) Trace(msg string, fields ...any) {

View File

@@ -1,34 +0,0 @@
package pkg
import (
"context"
"log/slog"
"time"
)
const TraceLevel = slog.Level(-8)
type Unit[T any] struct {
ID T
StartTime time.Time
*slog.Logger
context.Context
context.CancelCauseFunc
}
func (unit *Unit[T]) Trace(msg string, fields ...any) {
unit.Log(unit.Context, TraceLevel, msg, fields...)
}
func (unit *Unit[T]) IsStopped() bool {
return unit.StopReason() != nil
}
func (unit *Unit[T]) StopReason() error {
return context.Cause(unit.Context)
}
func (unit *Unit[T]) Stop(err error) {
unit.Info("stop", "reason", err.Error())
unit.CancelCauseFunc(err)
}

View File

@@ -6,21 +6,21 @@ import (
"time"
)
type Promise[T any] struct {
type Promise struct {
context.Context
context.CancelCauseFunc
Value T
timer *time.Timer
}
func NewPromise[T any](v T) *Promise[T] {
p := &Promise[T]{Value: v}
p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background())
func NewPromise(ctx context.Context) *Promise {
p := &Promise{}
p.Context, p.CancelCauseFunc = context.WithCancelCause(ctx)
return p
}
func NewPromiseWithTimeout[T any](v T, timeout time.Duration) *Promise[T] {
p := &Promise[T]{Value: v}
p.Context, p.CancelCauseFunc = context.WithCancelCause(context.Background())
func NewPromiseWithTimeout(ctx context.Context, timeout time.Duration) *Promise {
p := &Promise{}
p.Context, p.CancelCauseFunc = context.WithCancelCause(ctx)
p.timer = time.AfterFunc(timeout, func() {
p.CancelCauseFunc(ErrTimeout)
})
@@ -30,27 +30,30 @@ func NewPromiseWithTimeout[T any](v T, timeout time.Duration) *Promise[T] {
var ErrResolve = errors.New("promise resolved")
var ErrTimeout = errors.New("promise timeout")
func (p *Promise[T]) Resolve(v T) {
p.Value = v
p.CancelCauseFunc(ErrResolve)
func (p *Promise) Resolve() {
p.Fulfill(nil)
}
func (p *Promise[T]) Await() (T, error) {
func (p *Promise) Reject(err error) {
p.Fulfill(err)
}
func (p *Promise) Await() (err error) {
<-p.Done()
err := context.Cause(p.Context)
err = context.Cause(p.Context)
if errors.Is(err, ErrResolve) {
err = nil
}
return p.Value, err
return
}
func (p *Promise[T]) Fulfill(err error) {
func (p *Promise) Fulfill(err error) {
if p.timer != nil {
p.timer.Stop()
}
p.CancelCauseFunc(Conditoinal(err == nil, ErrResolve, err))
}
func (p *Promise[T]) IsPending() bool {
func (p *Promise) IsPending() bool {
return context.Cause(p.Context) == nil
}

303
plugin.go
View File

@@ -2,9 +2,14 @@ package m7s
import (
"context"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
myip "github.com/husanpao/ip"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"gorm.io/gorm"
"log/slog"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/db"
"net"
"net/http"
@@ -13,13 +18,6 @@ import (
"reflect"
"runtime"
"strings"
gatewayRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
. "m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/util"
)
type DefaultYaml string
@@ -33,17 +31,17 @@ type PluginMeta struct {
RegisterGRPCHandler func(context.Context, *gatewayRuntime.ServeMux, *grpc.ClientConn) error
}
func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) {
func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) {
instance, ok := reflect.New(plugin.Type).Interface().(IPlugin)
if !ok {
panic("plugin must implement IPlugin")
}
p := reflect.ValueOf(instance).Elem().FieldByName("Plugin").Addr().Interface().(*Plugin)
p = reflect.ValueOf(instance).Elem().FieldByName("Plugin").Addr().Interface().(*Plugin)
p.handler = instance
p.Meta = plugin
p.Executor = instance
p.Server = s
p.Logger = s.Logger.With("plugin", plugin.Name)
p.Context, p.CancelCauseFunc = context.WithCancelCause(s.Context)
p.Task.Init(s.Context, s.Logger.With("plugin", plugin.Name))
upperName := strings.ToUpper(plugin.Name)
if os.Getenv(upperName+"_ENABLE") == "false" {
p.Disabled = true
@@ -93,29 +91,12 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) {
if err != nil {
s.Error("failed to connect database", "error", err, "dsn", s.config.DSN, "type", s.config.DBType)
p.Disabled = true
return
}
}
}
err = instance.OnInit()
if err != nil {
p.Error("init", "error", err)
p.Stop(err)
return
}
if plugin.ServiceDesc != nil && s.grpcServer != nil {
s.grpcServer.RegisterService(plugin.ServiceDesc, instance)
if plugin.RegisterGRPCHandler != nil {
if err = plugin.RegisterGRPCHandler(p.Context, s.config.HTTP.GetGRPCMux(), s.grpcClientConn); err != nil {
p.Error("init", "error", err)
p.Stop(err)
} else {
p.Info("grpc handler registered")
}
}
}
s.Plugins.Add(p)
p.Start()
return
}
type iPlugin interface {
@@ -123,8 +104,8 @@ type iPlugin interface {
}
type IPlugin interface {
TaskExecutor
OnInit() error
OnEvent(any)
OnExit()
}
@@ -133,16 +114,16 @@ type IRegisterHandler interface {
}
type IPullerPlugin interface {
NewPullHandler() PullHandler
DoPull(*PullContext) error
GetPullableList() []string
}
type IPusherPlugin interface {
NewPushHandler() PushHandler
DoPush(*PushContext) error
}
type IRecorderPlugin interface {
NewRecordHandler() RecordHandler
DoRecord(*RecordContext) error
}
type ITCPPlugin interface {
@@ -186,7 +167,7 @@ func InstallPlugin[C iPlugin](options ...any) error {
}
type Plugin struct {
Unit[int]
Task
Disabled bool
Meta *PluginMeta
config config.Common
@@ -252,13 +233,36 @@ func (p *Plugin) assign() {
p.registerHandler(handlerMap)
}
func (p *Plugin) Stop(err error) {
p.Unit.Stop(err)
func (p *Plugin) Start() (err error) {
s := p.Server
err = p.handler.OnInit()
if err != nil {
p.Error("init", "error", err)
return
}
if p.Meta.ServiceDesc != nil && s.grpcServer != nil {
s.grpcServer.RegisterService(p.Meta.ServiceDesc, p.handler)
if p.Meta.RegisterGRPCHandler != nil {
if err = p.Meta.RegisterGRPCHandler(p.Context, s.config.HTTP.GetGRPCMux(), s.grpcClientConn); err != nil {
p.Error("init", "error", err)
return
} else {
p.Info("grpc handler registered")
}
}
}
s.Plugins.Add(p)
p.listen()
return
}
func (p *Plugin) Dispose() {
p.Server.Plugins.Remove(p)
p.config.HTTP.StopListen()
p.config.TCP.StopListen()
}
func (p *Plugin) Start() {
func (p *Plugin) listen() {
httpConf := &p.config.HTTP
if httpConf.ListenAddrTLS != "" && (httpConf.ListenAddrTLS != p.Server.config.HTTP.ListenAddrTLS) {
p.Info("https listen at ", "addr", httpConf.ListenAddrTLS)
@@ -272,13 +276,9 @@ func (p *Plugin) Start() {
p.Stop(httpConf.Listen())
}()
}
if tcphandler, ok := p.handler.(ITCPPlugin); ok {
tcpConf := &p.config.TCP
tcphandler, ok := p.handler.(ITCPPlugin)
if !ok {
tcphandler = p
}
if tcpConf.ListenAddr != "" && tcpConf.AutoListen {
p.Info("listen tcp", "addr", tcpConf.ListenAddr)
go func() {
@@ -299,13 +299,10 @@ func (p *Plugin) Start() {
}
}()
}
udpConf := &p.config.UDP
udpHandler, ok := p.handler.(IUDPPlugin)
if !ok {
udpHandler = p
}
if udpHandler, ok := p.handler.(IUDPPlugin); ok {
udpConf := &p.config.UDP
if udpConf.ListenAddr != "" && udpConf.AutoListen {
p.Info("listen udp", "addr", udpConf.ListenAddr)
go func() {
@@ -317,6 +314,7 @@ func (p *Plugin) Start() {
}()
}
}
}
func (p *Plugin) OnInit() error {
@@ -327,140 +325,76 @@ func (p *Plugin) OnExit() {
}
func (p *Plugin) onEvent(event any) {
switch v := event.(type) {
case *Publisher:
if h, ok := p.handler.(interface{ OnPublish(*Publisher) }); ok {
h.OnPublish(v)
}
case *Puller:
if h, ok := p.handler.(interface{ OnPull(*Puller) }); ok {
h.OnPull(v)
}
}
p.handler.OnEvent(event)
}
func (p *Plugin) OnEvent(event any) {
}
func (p *Plugin) OnTCPConnect(conn *net.TCPConn) {
p.handler.OnEvent(conn)
}
func (p *Plugin) OnUDPConnect(conn *net.UDPConn) {
p.handler.OnEvent(conn)
}
func (p *Plugin) Publish(streamPath string, options ...any) (publisher *Publisher, err error) {
publisher = &Publisher{Publish: p.config.Publish}
publisher = createPublisher(p, streamPath, options...)
if p.config.EnableAuth {
if onAuthPub, ok := p.Server.OnAuthPubs[p.Meta.Name]; ok {
authPromise := util.NewPromise(publisher)
onAuthPub(authPromise)
if _, err = authPromise.Await(); err != nil {
if err = onAuthPub(publisher).Await(); err != nil {
p.Warn("auth failed", "error", err)
return
}
}
}
for _, option := range options {
switch v := option.(type) {
case func(*config.Publish):
v(&publisher.Publish)
}
}
publisher.Init(p, streamPath, &publisher.Publish, options...)
_, err = p.Server.Call(publisher)
return
}
func (p *Plugin) Pull(streamPath string, url string, options ...any) (puller *Puller, err error) {
puller = &Puller{Pull: p.config.Pull}
puller.Client.Proxy = p.config.Pull.Proxy
puller.Client.RemoteURL = url
puller.Client.PubSubBase = &puller.PubSubBase
puller.Publish = p.config.Publish
puller.PublishTimeout = 0
puller.StreamPath = streamPath
var pullHandler PullHandler
for _, option := range options {
switch v := option.(type) {
case PullHandler:
pullHandler = v
}
}
puller.Init(p, streamPath, &puller.Publish, options...)
if _, err = p.Server.Call(puller); err != nil {
return
}
if v, ok := p.handler.(IPullerPlugin); pullHandler == nil && ok {
pullHandler = v.NewPullHandler()
}
if pullHandler != nil {
err = puller.Start(pullHandler)
}
return
}
func (p *Plugin) Record(streamPath string, filePath string, options ...any) (recorder *Recorder, err error) {
recorder = &Recorder{
Record: p.config.Record,
}
if err = os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return
}
recorder.StreamPath = streamPath
recorder.Subscribe = p.config.Subscribe
if recorder.File, err = os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil {
return
}
defer func() {
err = recorder.File.Close()
if info, err := recorder.File.Stat(); err == nil && info.Size() == 0 {
os.Remove(recorder.File.Name())
}
}()
recorder.Init(p, streamPath, &recorder.Subscribe, options...)
if _, err = p.Server.Call(recorder); err != nil {
return
}
recorder.Publisher.WaitTrack()
var recordHandler RecordHandler
if v, ok := p.handler.(IRecorderPlugin); recordHandler == nil && ok {
recordHandler = v.NewRecordHandler()
}
if recordHandler != nil {
err = recorder.Start(recordHandler)
}
err = p.Server.streamTM.Start(&publisher.Task)
return
}
func (p *Plugin) Subscribe(streamPath string, options ...any) (subscriber *Subscriber, err error) {
subscriber = &Subscriber{Subscribe: p.config.Subscribe}
subscriber = createSubscriber(p, streamPath, options...)
if p.config.EnableAuth {
if onAuthSub, ok := p.Server.OnAuthSubs[p.Meta.Name]; ok {
authPromise := util.NewPromise(subscriber)
onAuthSub(authPromise)
if _, err = authPromise.Await(); err != nil {
if err = onAuthSub(subscriber).Await(); err != nil {
p.Warn("auth failed", "error", err)
return
}
}
}
for _, option := range options {
switch v := option.(type) {
case func(*config.Subscribe):
v(&subscriber.Subscribe)
err = p.Server.streamTM.Start(&subscriber.Task)
err = subscriber.Publisher.WaitTrack()
return
}
func (p *Plugin) Pull(streamPath string, url string, options ...any) (puller *PullContext, err error) {
puller = createPullContext(p, streamPath, url, options...)
if err = p.Server.pullTM.Start(&puller.Task); err != nil {
return
}
if pullPlugin, ok := p.handler.(IPullerPlugin); ok {
puller.Run(pullPlugin.DoPull)
}
subscriber.Init(p, streamPath, &subscriber.Subscribe, options...)
if subscriber.Subscribe.BufferTime > 0 {
subscriber.Subscribe.SubMode = SUBMODE_BUFFER
return
}
func (p *Plugin) Push(streamPath string, url string, options ...any) (pusher *PushContext, err error) {
pusher = createPushContext(p, streamPath, url, options...)
if err = p.Server.pushTM.Start(&pusher.Task); err != nil {
return
}
if pushPlugin, ok := p.handler.(IPusherPlugin); ok {
pusher.Run(pushPlugin.DoPush)
}
return
}
func (p *Plugin) Record(streamPath string, filePath string, options ...any) (recorder *RecordContext, err error) {
recorder = createRecoder(p, streamPath, filePath, options...)
dir := filePath
if filepath.Ext(filePath) != "" {
dir = filepath.Dir(filePath)
}
if err = os.MkdirAll(dir, 0755); err != nil {
return
}
recorder.Subscriber, err = p.Subscribe(streamPath, p.config.Subscribe)
if err != nil {
return
}
if err = p.Server.recordTM.Start(&recorder.Task); err != nil {
return
}
if recordPlugin, ok := p.handler.(IRecorderPlugin); ok {
recorder.Run(recordPlugin.DoRecord)
}
_, err = p.Server.Call(subscriber)
subscriber.Publisher.WaitTrack()
return
}
@@ -487,34 +421,6 @@ func (p *Plugin) registerHandler(handlers map[string]http.HandlerFunc) {
}
}
func (p *Plugin) Push(streamPath string, url string, options ...any) (pusher *Pusher, err error) {
pusher = &Pusher{Push: p.config.Push}
pusher.Client.PubSubBase = &pusher.PubSubBase
pusher.Client.Proxy = p.config.Push.Proxy
pusher.Client.RemoteURL = url
pusher.Subscribe = p.config.Subscribe
pusher.StreamPath = streamPath
var pushHandler PushHandler
for _, option := range options {
switch v := option.(type) {
case PushHandler:
pushHandler = v
}
}
pusher.Init(p, streamPath, &pusher.Subscribe, options...)
if _, err = p.Server.Call(pusher); err != nil {
return
}
pusher.Publisher.WaitTrack()
if v, ok := p.handler.(IPusherPlugin); pushHandler == nil && ok {
pushHandler = v.NewPushHandler()
}
if pushHandler != nil {
err = pusher.Start(pushHandler)
}
return
}
func (p *Plugin) logHandler(handler http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p.Debug("visit", "path", r.URL.String(), "remote", r.RemoteAddr)
@@ -546,28 +452,21 @@ func (p *Plugin) AddLogHandler(handler slog.Handler) {
p.Server.LogHandler.Add(handler)
}
func (p *Plugin) PostToServer(event any) {
if p.Server.eventChan == nil {
panic("eventChan is nil")
}
p.Server.PostMessage(event)
}
func (p *Plugin) SaveConfig() (err error) {
_, err = p.Server.Call(func() error {
p.Server.pluginTM.Call(func() {
if p.Modify == nil {
os.Remove(p.settingPath())
return nil
return
}
var file *os.File
if file, err = os.OpenFile(p.settingPath(), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil {
return
}
file, err := os.OpenFile(p.settingPath(), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666)
if err == nil {
defer file.Close()
err = yaml.NewEncoder(file).Encode(p.Modify)
}
})
if err == nil {
p.Info("config saved")
}
return err
})
return
}

View File

@@ -27,10 +27,14 @@ func (p *FLVPlugin) OnInit() error {
return nil
}
var _ = m7s.InstallPlugin[FLVPlugin](defaultConfig, NewPullHandler)
var _ = m7s.InstallPlugin[FLVPlugin](defaultConfig)
func (p *FLVPlugin) NewRecordHandler() m7s.RecordHandler {
return &Recorder{}
func (p *FLVPlugin) DoPull(pull *m7s.PullContext) error {
return PullFLV(pull)
}
func (p *FLVPlugin) DoRecord(ctx *m7s.RecordContext) error {
return RecordFlv(ctx)
}
func (p *FLVPlugin) WriteFlvHeader(sub *m7s.Subscriber) (flv net.Buffers) {
@@ -95,7 +99,6 @@ func (p *FLVPlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if hijacker, ok := w.(http.Hijacker); ok && wto > 0 {
conn, _, _ := hijacker.Hijack()
conn.SetWriteDeadline(time.Now().Add(wto))
sub.Closer = conn
gotFlvTag = func(flv net.Buffers) (err error) {
conn.SetWriteDeadline(time.Now().Add(wto))
_, err = flv.WriteTo(conn)

View File

@@ -1,8 +1,10 @@
package flv
import (
"bufio"
"io"
"m7s.live/m7s/v5/pkg/util"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
"net"
)
@@ -13,6 +15,8 @@ const (
FLV_TAG_TYPE_SCRIPT = 0x12
)
var FLVHead = []byte{'F', 'L', 'V', 0x01, 0x05, 0, 0, 0, 9, 0, 0, 0, 0}
func AVCC2FLV(t byte, ts uint32, avcc ...[]byte) (flv net.Buffers) {
b := util.Buffer(make([]byte, 0, 15))
b.WriteByte(t)
@@ -30,8 +34,33 @@ func WriteFLVTagHead(t uint8, ts, dataSize uint32, b []byte) {
b[4], b[5], b[6], b[7] = byte(ts>>16), byte(ts>>8), byte(ts), byte(ts>>24)
}
func WriteFLVTag(w io.Writer, t byte, timestamp uint32, payload []byte) (err error) {
buffers := AVCC2FLV(t, timestamp, payload)
func WriteFLVTag(w io.Writer, t byte, timestamp uint32, payload ...[]byte) (err error) {
buffers := AVCC2FLV(t, timestamp, payload...)
_, err = buffers.WriteTo(w)
return
}
func ReadMetaData(reader io.Reader) (metaData rtmp.EcmaArray, err error) {
r := bufio.NewReader(reader)
_, err = r.Discard(13)
tagHead := make(util.Buffer, 11)
_, err = io.ReadFull(r, tagHead)
if err != nil {
return
}
tmp := tagHead
t := tmp.ReadByte()
dataLen := tmp.ReadUint24()
_, err = r.Discard(4)
if t == FLV_TAG_TYPE_SCRIPT {
data := make([]byte, dataLen+4)
_, err = io.ReadFull(reader, data)
amf := &rtmp.AMF{
Buffer: util.Buffer(data[1+2+len("onMetaData") : len(data)-4]),
}
var obj any
obj, err = amf.Unmarshal()
metaData = obj.(rtmp.EcmaArray)
}
return
}

View File

@@ -13,29 +13,14 @@ import (
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
)
type FLVPuller struct {
*util.BufReader
*util.ScalableMemoryAllocator
hasAudio bool
hasVideo bool
absTS uint32 //绝对时间戳
}
func NewFLVPuller() *FLVPuller {
return &FLVPuller{
ScalableMemoryAllocator: util.NewScalableMemoryAllocator(1 << 10),
}
}
func NewPullHandler() m7s.PullHandler {
return NewFLVPuller()
}
func (puller *FLVPuller) Connect(p *m7s.Client) (err error) {
func PullFLV(p *m7s.PullContext) (err error) {
var reader *util.BufReader
var hasAudio, hasVideo bool
var absTS uint32
if strings.HasPrefix(p.RemoteURL, "http") {
var res *http.Response
client := http.DefaultClient
if proxyConf := p.Proxy; proxyConf != "" {
if proxyConf := p.ConnectProxy; proxyConf != "" {
proxy, err := url.Parse(proxyConf)
if err != nil {
return err
@@ -47,19 +32,19 @@ func (puller *FLVPuller) Connect(p *m7s.Client) (err error) {
if res.StatusCode != http.StatusOK {
return io.EOF
}
p.Closer = res.Body
puller.BufReader = util.NewBufReader(res.Body)
defer res.Body.Close()
reader = util.NewBufReader(res.Body)
}
} else {
var res *os.File
if res, err = os.Open(p.RemoteURL); err == nil {
p.Closer = res
puller.BufReader = util.NewBufReader(res)
defer res.Close()
reader = util.NewBufReader(res)
}
}
if err == nil {
var head util.Memory
head, err = puller.BufReader.ReadBytes(13)
head, err = reader.ReadBytes(13)
if err == nil {
var flvHead [3]byte
var version, flag byte
@@ -68,37 +53,35 @@ func (puller *FLVPuller) Connect(p *m7s.Client) (err error) {
if flvHead != [...]byte{'F', 'L', 'V'} {
err = errors.New("not flv file")
} else {
puller.hasAudio = flag&0x04 != 0
puller.hasVideo = flag&0x01 != 0
hasAudio = flag&0x04 != 0
hasVideo = flag&0x01 != 0
}
}
}
return
}
func (puller *FLVPuller) Pull(p *m7s.Puller) (err error) {
var startTs uint32
pubConf := p.GetPublishConfig()
if !puller.hasAudio {
pubConf := p.Publisher.GetPublishConfig()
if !hasAudio {
pubConf.PubAudio = false
}
if !puller.hasVideo {
if !hasVideo {
pubConf.PubVideo = false
}
for offsetTs := puller.absTS; err == nil; _, err = puller.ReadBE(4) {
t, err := puller.ReadByte()
allocator := util.NewScalableMemoryAllocator(1 << 10)
for offsetTs := absTS; err == nil; _, err = reader.ReadBE(4) {
t, err := reader.ReadByte()
if err != nil {
return err
}
dataSize, err := puller.ReadBE32(3)
dataSize, err := reader.ReadBE32(3)
if err != nil {
return err
}
timestamp, err := puller.ReadBE32(3)
timestamp, err := reader.ReadBE32(3)
if err != nil {
return err
}
h, err := puller.ReadByte()
h, err := reader.ReadByte()
if err != nil {
return err
}
@@ -106,28 +89,28 @@ func (puller *FLVPuller) Pull(p *m7s.Puller) (err error) {
if startTs == 0 {
startTs = timestamp
}
if _, err = puller.ReadBE(3); err != nil { // stream id always 0
if _, err = reader.ReadBE(3); err != nil { // stream id always 0
return err
}
var frame rtmp.RTMPData
switch ds := int(dataSize); t {
case FLV_TAG_TYPE_AUDIO, FLV_TAG_TYPE_VIDEO:
frame.SetAllocator(puller.ScalableMemoryAllocator)
err = puller.ReadNto(ds, frame.NextN(ds))
frame.SetAllocator(allocator)
err = reader.ReadNto(ds, frame.NextN(ds))
default:
err = puller.Skip(ds)
err = reader.Skip(ds)
}
if err != nil {
return err
}
puller.absTS = offsetTs + (timestamp - startTs)
frame.Timestamp = puller.absTS
absTS = offsetTs + (timestamp - startTs)
frame.Timestamp = absTS
//fmt.Println(t, offsetTs, timestamp, startTs, puller.absTS)
switch t {
case FLV_TAG_TYPE_AUDIO:
p.WriteAudio(frame.WrapAudio())
err = p.Publisher.WriteAudio(frame.WrapAudio())
case FLV_TAG_TYPE_VIDEO:
p.WriteVideo(frame.WrapVideo())
err = p.Publisher.WriteVideo(frame.WrapVideo())
case FLV_TAG_TYPE_SCRIPT:
p.Info("script")
}

View File

@@ -7,28 +7,29 @@ import (
"m7s.live/m7s/v5/pkg/util"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
"os"
"slices"
"time"
)
type Recorder struct {
*m7s.Subscriber
filepositions []uint64
times []float64
Offset int64
duration int64
}
func (r *Recorder) Record(recorder *m7s.Recorder) (err error) {
func RecordFlv(ctx *m7s.RecordContext) (err error) {
var file *os.File
var filepositions []uint64
var times []float64
var offset int64
var duration int64
if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR|util.Conditoinal(ctx.Append, os.O_APPEND, os.O_TRUNC), 0666); err != nil {
return
}
func (r *Recorder) Close() {
}
func (r *Recorder) writeMetaData(file util.ReadWriteSeekCloser, duration int64) {
defer file.Close()
at, vt := r.AudioReader, r.VideoReader
hasAudio, hasVideo := at != nil, vt != nil
}
suber := ctx.Subscriber
ar, vr := suber.AudioReader, suber.VideoReader
hasAudio, hasVideo := ar != nil, vr != nil
writeMetaTag := func() {
defer func() {
err = file.Close()
if info, err := file.Stat(); err == nil && info.Size() == 0 {
os.Remove(file.Name())
}
}()
var amf rtmp.AMF
metaData := rtmp.EcmaArray{
"MetaDataCreator": "m7s/" + m7s.Version,
@@ -37,12 +38,12 @@ func (r *Recorder) writeMetaData(file util.ReadWriteSeekCloser, duration int64)
"hasMatadata": true,
"canSeekToEnd": true,
"duration": float64(duration) / 1000,
"hasKeyFrames": len(r.filepositions) > 0,
"hasKeyFrames": len(filepositions) > 0,
"filesize": 0,
}
var flags byte
if hasAudio {
ctx := at.Track.ICodecCtx.GetBase().(pkg.IAudioCodecCtx)
ctx := ar.Track.ICodecCtx.GetBase().(pkg.IAudioCodecCtx)
flags |= (1 << 2)
metaData["audiocodecid"] = int(rtmp.ParseAudioCodec(ctx.FourCC()))
metaData["audiosamplerate"] = ctx.GetSampleRate()
@@ -50,47 +51,47 @@ func (r *Recorder) writeMetaData(file util.ReadWriteSeekCloser, duration int64)
metaData["stereo"] = ctx.GetChannels() == 2
}
if hasVideo {
ctx := vt.Track.ICodecCtx.GetBase().(pkg.IVideoCodecCtx)
ctx := vr.Track.ICodecCtx.GetBase().(pkg.IVideoCodecCtx)
flags |= 1
metaData["videocodecid"] = int(rtmp.ParseVideoCodec(ctx.FourCC()))
metaData["width"] = ctx.Width()
metaData["height"] = ctx.Height()
metaData["framerate"] = vt.Track.FPS
metaData["videodatarate"] = vt.Track.BPS
metaData["framerate"] = vr.Track.FPS
metaData["videodatarate"] = vr.Track.BPS
metaData["keyframes"] = map[string]any{
"filepositions": r.filepositions,
"times": r.times,
"filepositions": filepositions,
"times": times,
}
defer func() {
r.filepositions = []uint64{0}
r.times = []float64{0}
filepositions = []uint64{0}
times = []float64{0}
}()
}
amf.Marshals("onMetaData", metaData)
offset := amf.Len() + 13 + 15
if keyframesCount := len(r.filepositions); keyframesCount > 0 {
metaData["filesize"] = uint64(offset) + r.filepositions[keyframesCount-1]
for i := range r.filepositions {
r.filepositions[i] += uint64(offset)
if keyframesCount := len(filepositions); keyframesCount > 0 {
metaData["filesize"] = uint64(offset) + filepositions[keyframesCount-1]
for i := range filepositions {
filepositions[i] += uint64(offset)
}
metaData["keyframes"] = map[string]any{
"filepositions": r.filepositions,
"times": r.times,
"filepositions": filepositions,
"times": times,
}
}
if tempFile, err := os.CreateTemp("", "*.flv"); err != nil {
r.Error("create temp file failed", "err", err)
ctx.Error("create temp file failed", "err", err)
return
} else {
defer func() {
tempFile.Close()
os.Remove(tempFile.Name())
r.Info("writeMetaData success")
ctx.Info("writeMetaData success")
}()
_, err := tempFile.Write([]byte{'F', 'L', 'V', 0x01, flags, 0, 0, 0, 9, 0, 0, 0, 0})
if err != nil {
r.Error(err.Error())
ctx.Error(err.Error())
return
}
amf.Reset()
@@ -98,20 +99,87 @@ func (r *Recorder) writeMetaData(file util.ReadWriteSeekCloser, duration int64)
WriteFLVTag(tempFile, FLV_TAG_TYPE_SCRIPT, 0, marshals)
_, err = file.Seek(13, io.SeekStart)
if err != nil {
r.Error("writeMetaData Seek failed", "err", err)
ctx.Error("writeMetaData Seek failed", "err", err)
return
}
_, err = io.Copy(tempFile, file)
if err != nil {
r.Error("writeMetaData Copy failed", "err", err)
ctx.Error("writeMetaData Copy failed", "err", err)
return
}
_, err = tempFile.Seek(0, io.SeekStart)
_, err = file.Seek(0, io.SeekStart)
_, err = io.Copy(file, tempFile)
if err != nil {
r.Error("writeMetaData Copy failed", "err", err)
ctx.Error("writeMetaData Copy failed", "err", err)
return
}
}
}
if ctx.Append {
var metaData rtmp.EcmaArray
metaData, err = ReadMetaData(file)
keyframes := metaData["keyframes"].(map[string]any)
filepositions = slices.Collect(func(yield func(uint64) bool) {
for _, v := range keyframes["filepositions"].([]float64) {
yield(uint64(v))
}
})
times = keyframes["times"].([]float64)
if _, err = file.Seek(-4, io.SeekEnd); err != nil {
ctx.Error("seek file failed", "err", err)
file.Write(FLVHead)
} else {
tmp := make(util.Buffer, 4)
tmp2 := tmp
file.Read(tmp)
tagSize := tmp.ReadUint32()
tmp = tmp2
file.Seek(int64(tagSize), io.SeekEnd)
file.Read(tmp2)
ts := tmp2.ReadUint24() | (uint32(tmp[3]) << 24)
ctx.Info("append flv", "last tagSize", tagSize, "last ts", ts)
if hasVideo {
vr.StartTs = time.Duration(ts) * time.Millisecond
}
if hasAudio {
ar.StartTs = time.Duration(ts) * time.Millisecond
}
file.Seek(0, io.SeekEnd)
}
} else {
file.Write(FLVHead)
}
if ctx.Fragment == 0 {
defer writeMetaTag()
}
checkFragment := func(absTime uint32) {
if ctx.Fragment == 0 {
return
}
if duration = int64(absTime); time.Duration(duration)*time.Millisecond >= ctx.Fragment {
writeMetaTag()
offset = 0
if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_RDWR, 0666); err != nil {
return
}
file.Write(FLVHead)
if vr != nil {
vr.ResetAbsTime()
err = WriteFLVTag(file, FLV_TAG_TYPE_VIDEO, 0, vr.Track.SequenceFrame.(*rtmp.RTMPVideo).Buffers...)
}
}
}
return m7s.PlayBlock(ctx.Subscriber, func(audio *rtmp.RTMPAudio) (err error) {
if !hasVideo {
checkFragment(ar.AbsTime)
}
return WriteFLVTag(file, FLV_TAG_TYPE_AUDIO, vr.AbsTime, audio.Buffers...)
}, func(video *rtmp.RTMPVideo) (err error) {
if vr.Value.IDR {
filepositions = append(filepositions, uint64(offset))
times = append(times, float64(vr.AbsTime)/1000)
}
return WriteFLVTag(file, FLV_TAG_TYPE_VIDEO, vr.AbsTime, video.Buffers...)
})
}

View File

@@ -10,7 +10,8 @@ import (
type RecordRequest struct {
SN, SumNum int
*util.Promise[[]gb28181.Record]
Response []gb28181.Record
*util.Promise
}
func (r *RecordRequest) GetKey() int {

View File

@@ -3,6 +3,7 @@ package plugin_gb28181
import (
"github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip"
"log/slog"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/util"
@@ -23,7 +24,8 @@ const (
)
type Device struct {
pkg.Unit[string]
pkg.Task
ID string
Name string
Manufacturer string
Model string
@@ -43,6 +45,7 @@ type Device struct {
dialogClient *sipgo.DialogClient
contactHDR sip.ContactHeader
fromHDR sip.FromHeader
*slog.Logger
}
func (d *Device) GetKey() string {
@@ -63,7 +66,8 @@ func (d *Device) onMessage(req *sip.Request, tx sip.ServerTransaction, msg *gb28
case "RecordInfo":
if channel, ok := d.channels.Get(msg.DeviceID); ok {
if req, ok := channel.RecordReqs.Get(msg.SN); ok {
req.Resolve(msg.RecordList)
req.Response = msg.RecordList
req.Resolve()
}
}
case "DeviceInfo":

View File

@@ -24,7 +24,8 @@ func (d *Dialog) GetCallID() string {
return d.session.InviteRequest.CallID().Value()
}
func (d *Dialog) Connect(p *m7s.Client) (err error) {
func (d *Dialog) Pull(p *m7s.PullContext) (err error) {
sss := strings.Split(p.RemoteURL, "/")
deviceId, channelId := sss[0], sss[1]
if len(sss) == 2 {
@@ -41,11 +42,8 @@ func (d *Dialog) Connect(p *m7s.Client) (err error) {
var recordRange util.Range[int]
err = recordRange.Resolve(sss[2])
}
return
}
func (d *Dialog) Pull(p *m7s.Puller) (err error) {
d.Receiver = gb28181.NewReceiver(&p.Publisher)
d.Receiver = gb28181.NewReceiver(p.Publisher)
ssrc := d.CreateSSRC(d.gb.Serial)
d.gb.dialogs.Set(d)
defer d.gb.dialogs.Remove(d)

View File

@@ -1,7 +1,6 @@
package plugin_gb28181
import (
"context"
"errors"
"fmt"
"github.com/emiago/sipgo"
@@ -11,7 +10,6 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"m7s.live/m7s/v5/pkg/util"
"m7s.live/m7s/v5/plugin/gb28181/pb"
@@ -243,11 +241,7 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) {
port, _ := strconv.Atoi(portStr)
serverPort, _ := strconv.Atoi(sPortStr)
d = &Device{
Unit: pkg.Unit[string]{
ID: id,
StartTime: time.Now(),
Logger: gb.Logger.With("id", id),
},
UpdateTime: time.Now(),
Status: DeviceRegisterStatus,
Recipient: sip.Uri{
@@ -274,7 +268,7 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) {
Params: sip.NewParams(),
},
}
d.Context, d.CancelCauseFunc = context.WithCancelCause(gb.Context)
d.Init(gb.Context, gb.Logger.With("id", id))
d.fromHDR.Params.Add("tag", sip.GenerateTagN(16))
d.client, _ = sipgo.NewClient(gb.ua, sipgo.WithClientLogger(zerolog.New(os.Stdout)), sipgo.WithClientHostname(publicIP))
d.dialogClient = sipgo.NewDialogClient(d.client, d.contactHDR)
@@ -288,10 +282,11 @@ func (gb *GB28181Plugin) StoreDevice(id string, req *sip.Request) (d *Device) {
return
}
func (gb *GB28181Plugin) NewPullHandler() m7s.PullHandler {
return &Dialog{
func (gb *GB28181Plugin) DoPull(ctx *m7s.PullContext) error {
dialog := Dialog{
gb: gb,
}
return dialog.Pull(ctx)
}
func (gb *GB28181Plugin) GetPullableList() []string {

View File

@@ -54,6 +54,6 @@ func (h *LogRotatePlugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (l *LogRotatePlugin) API_tail(w http.ResponseWriter, r *http.Request) {
writer := util.NewSSE(w, r.Context())
h := console.NewHandler(writer, &console.HandlerOptions{NoColor: true})
l.PostToServer(h)
l.Server.AddLogHandler(h)
<-r.Context().Done()
}

View File

@@ -83,16 +83,16 @@ func (p *MP4Plugin) OnInit() error {
var _ = m7s.InstallPlugin[MP4Plugin](defaultConfig)
func (p *MP4Plugin) NewPullHandler() m7s.PullHandler {
return pkg.NewMP4Puller()
func (p *MP4Plugin) DoPull(ctx *m7s.PullContext) error {
return pkg.PullMP4(ctx)
}
func (p *MP4Plugin) GetPullableList() []string {
return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub))
}
func (p *MP4Plugin) NewRecordHandler() m7s.RecordHandler {
return &pkg.Recorder{}
func (p *MP4Plugin) DoRecord(ctx *m7s.RecordContext) error {
return pkg.RecordMP4(ctx)
}
func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -179,10 +179,8 @@ func (p *MP4Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
if hijacker, ok := w.(http.Hijacker); ok && ctx.wto > 0 {
sub.Conn, _, _ = hijacker.Hijack()
sub.Closer = sub.Conn
ctx.Writer = sub.Conn
ctx.conn = sub.Conn
ctx.conn, _, _ = hijacker.Hijack()
ctx.Writer = ctx.conn
} else {
ctx.Writer = w
w.(http.Flusher).Flush()

View File

@@ -14,22 +14,12 @@ import (
"strings"
)
type MP4Puller struct {
*util.ScalableMemoryAllocator
*box.MovDemuxer
}
func NewMP4Puller() *MP4Puller {
return &MP4Puller{
ScalableMemoryAllocator: util.NewScalableMemoryAllocator(1 << 10),
}
}
func (puller *MP4Puller) Connect(p *m7s.Client) (err error) {
if strings.HasPrefix(p.RemoteURL, "http") {
func PullMP4(ctx *m7s.PullContext) (err error) {
var demuxer *box.MovDemuxer
if strings.HasPrefix(ctx.RemoteURL, "http") {
var res *http.Response
client := http.DefaultClient
if proxyConf := p.Proxy; proxyConf != "" {
if proxyConf := ctx.ConnectProxy; proxyConf != "" {
proxy, err := url.Parse(proxyConf)
if err != nil {
return err
@@ -37,68 +27,66 @@ func (puller *MP4Puller) Connect(p *m7s.Client) (err error) {
transport := &http.Transport{Proxy: http.ProxyURL(proxy)}
client = &http.Client{Transport: transport}
}
if res, err = client.Get(p.RemoteURL); err == nil {
if res, err = client.Get(ctx.RemoteURL); err == nil {
if res.StatusCode != http.StatusOK {
return io.EOF
}
p.Closer = res.Body
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
return err
}
puller.MovDemuxer = box.CreateMp4Demuxer(strings.NewReader(string(content)))
demuxer = box.CreateMp4Demuxer(strings.NewReader(string(content)))
}
} else {
var res *os.File
if res, err = os.Open(p.RemoteURL); err == nil {
p.Closer = res
if res, err = os.Open(ctx.RemoteURL); err == nil {
defer res.Close()
}
puller.MovDemuxer = box.CreateMp4Demuxer(res)
demuxer = box.CreateMp4Demuxer(res)
}
return
}
func (puller *MP4Puller) Pull(p *m7s.Puller) (err error) {
var tracks []box.TrackInfo
if tracks, err = puller.ReadHead(); err != nil {
if tracks, err = demuxer.ReadHead(); err != nil {
return
}
publisher := ctx.Publisher
for _, track := range tracks {
switch track.Cid {
case box.MP4_CODEC_H264:
var sequece rtmp.RTMPVideo
sequece.Append([]byte{0x17, 0x00, 0x00, 0x00, 0x00}, track.ExtraData)
p.WriteVideo(&sequece)
err = publisher.WriteVideo(&sequece)
case box.MP4_CODEC_H265:
var sequece rtmp.RTMPVideo
sequece.Append([]byte{0b1001_0000 | rtmp.PacketTypeSequenceStart}, codec.FourCC_H265[:], track.ExtraData)
p.WriteVideo(&sequece)
err = publisher.WriteVideo(&sequece)
case box.MP4_CODEC_AAC:
var sequence rtmp.RTMPAudio
sequence.Append([]byte{0xaf, 0x00}, track.ExtraData)
p.WriteAudio(&sequence)
err = publisher.WriteAudio(&sequence)
}
}
allocator := util.NewScalableMemoryAllocator(1 << 10)
for {
pkg, err := puller.ReadPacket(puller.ScalableMemoryAllocator)
pkg, err := demuxer.ReadPacket(allocator)
if err != nil {
p.Error("Error reading MP4 packet", "err", err)
ctx.Error("Error reading MP4 packet", "err", err)
return err
}
switch track := tracks[pkg.TrackId-1]; track.Cid {
case box.MP4_CODEC_H264:
var videoFrame rtmp.RTMPVideo
videoFrame.SetAllocator(puller.ScalableMemoryAllocator)
videoFrame.SetAllocator(allocator)
videoFrame.CTS = uint32(pkg.Pts - pkg.Dts)
videoFrame.Timestamp = uint32(pkg.Dts)
keyFrame := codec.H264NALUType(pkg.Data[5]&0x1F) == codec.NALU_IDR_Picture
videoFrame.AppendOne([]byte{util.Conditoinal[byte](keyFrame, 0x17, 0x27), 0x01, byte(videoFrame.CTS >> 24), byte(videoFrame.CTS >> 8), byte(videoFrame.CTS)})
videoFrame.AddRecycleBytes(pkg.Data)
p.WriteVideo(&videoFrame)
err = publisher.WriteVideo(&videoFrame)
case box.MP4_CODEC_H265:
var videoFrame rtmp.RTMPVideo
videoFrame.SetAllocator(puller.ScalableMemoryAllocator)
videoFrame.SetAllocator(allocator)
videoFrame.CTS = uint32(pkg.Pts - pkg.Dts)
videoFrame.Timestamp = uint32(pkg.Dts)
var head []byte
@@ -122,14 +110,14 @@ func (puller *MP4Puller) Pull(p *m7s.Puller) (err error) {
}
copy(head[1:], codec.FourCC_H265[:])
videoFrame.AddRecycleBytes(pkg.Data)
p.WriteVideo(&videoFrame)
err = publisher.WriteVideo(&videoFrame)
case box.MP4_CODEC_AAC:
var audioFrame rtmp.RTMPAudio
audioFrame.SetAllocator(puller.ScalableMemoryAllocator)
audioFrame.SetAllocator(allocator)
audioFrame.Timestamp = uint32(pkg.Dts)
audioFrame.AppendOne([]byte{0xaf, 0x01})
audioFrame.AddRecycleBytes(pkg.Data)
p.WriteAudio(&audioFrame)
err = publisher.WriteAudio(&audioFrame)
}
}
}

View File

@@ -5,62 +5,56 @@ import (
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/codec"
"m7s.live/m7s/v5/plugin/mp4/pkg/box"
"os"
"time"
)
type Recorder struct {
*m7s.Subscriber
*box.Movmuxer
videoId uint32
audioId uint32
}
func (r *Recorder) Record(recorder *m7s.Recorder) (err error) {
r.Movmuxer, err = box.CreateMp4Muxer(recorder.File)
if recorder.Publisher.HasAudioTrack() {
audioTrack := recorder.Publisher.AudioTrack
func RecordMP4(ctx *m7s.RecordContext) (err error) {
var file *os.File
var muxer *box.Movmuxer
var audioId, videoId uint32
// TODO: fragment
if file, err = os.OpenFile(ctx.FilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil {
return
}
defer func() {
err = muxer.WriteTrailer()
if err != nil {
ctx.Error("write trailer", "err", err)
} else {
ctx.Info("write trailer")
}
err = file.Close()
}()
muxer, err = box.CreateMp4Muxer(file)
ar, vr := ctx.Subscriber.AudioReader, ctx.Subscriber.VideoReader
if ar != nil {
audioTrack := ar.Track
switch ctx := audioTrack.ICodecCtx.GetBase().(type) {
case *codec.AACCtx:
r.audioId = r.AddAudioTrack(box.MP4_CODEC_AAC, box.WithExtraData(ctx.ConfigBytes))
audioId = muxer.AddAudioTrack(box.MP4_CODEC_AAC, box.WithExtraData(ctx.ConfigBytes))
case *codec.PCMACtx:
r.audioId = r.AddAudioTrack(box.MP4_CODEC_G711A, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize)))
audioId = muxer.AddAudioTrack(box.MP4_CODEC_G711A, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize)))
case *codec.PCMUCtx:
r.audioId = r.AddAudioTrack(box.MP4_CODEC_G711U, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize)))
audioId = muxer.AddAudioTrack(box.MP4_CODEC_G711U, box.WithAudioSampleRate(uint32(ctx.SampleRate)), box.WithAudioChannelCount(uint8(ctx.Channels)), box.WithAudioSampleBits(uint8(ctx.SampleSize)))
}
}
if recorder.Publisher.HasVideoTrack() {
videoTrack := recorder.Publisher.VideoTrack
if vr != nil {
videoTrack := vr.Track
switch ctx := videoTrack.ICodecCtx.GetBase().(type) {
case *codec.H264Ctx:
r.videoId = r.AddVideoTrack(box.MP4_CODEC_H264, box.WithExtraData(ctx.Record))
videoId = muxer.AddVideoTrack(box.MP4_CODEC_H264, box.WithExtraData(ctx.Record))
case *codec.H265Ctx:
r.videoId = r.AddVideoTrack(box.MP4_CODEC_H265, box.WithExtraData(ctx.Record))
videoId = muxer.AddVideoTrack(box.MP4_CODEC_H265, box.WithExtraData(ctx.Record))
}
}
r.Subscriber = &recorder.Subscriber
return m7s.PlayBlock(&recorder.Subscriber, func(audio *pkg.RawAudio) error {
return r.WriteAudio(r.audioId, audio.ToBytes(), uint64(audio.Timestamp/time.Millisecond))
return m7s.PlayBlock(ctx.Subscriber, func(audio *pkg.RawAudio) error {
return muxer.WriteAudio(audioId, audio.ToBytes(), uint64(audio.Timestamp/time.Millisecond))
}, func(video *pkg.H26xFrame) error {
var nalus [][]byte
for _, nalu := range video.Nalus {
nalus = append(nalus, nalu.ToBytes())
}
return r.WriteVideo(r.videoId, nalus, uint64(video.Timestamp/time.Millisecond), uint64(video.CTS/time.Millisecond))
return muxer.WriteVideo(videoId, nalus, uint64(video.Timestamp/time.Millisecond), uint64(video.CTS/time.Millisecond))
})
}
func (r *Recorder) Close() {
//defer func() {
// if err := recover(); err != nil {
// r.Error("close", "err", err)
// } else {
// r.Info("close")
// }
//}()
err := r.WriteTrailer()
if err != nil {
r.Error("write trailer", "err", err)
} else {
r.Info("write trailer")
}
}

View File

@@ -4,10 +4,13 @@ import (
"context"
gpb "m7s.live/m7s/v5/pb"
"m7s.live/m7s/v5/plugin/rtmp/pb"
rtmp "m7s.live/m7s/v5/plugin/rtmp/pkg"
)
func (r *RTMPPlugin) PushOut(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) {
go r.Push(req.StreamPath, req.RemoteURL, &rtmp.Client{})
return &gpb.SuccessResponse{}, nil
if pushContext, err := r.Push(req.StreamPath, req.RemoteURL); err != nil {
return nil, err
} else {
go pushContext.Run(r.DoPush)
}
return &gpb.SuccessResponse{}, err
}

View File

@@ -14,6 +14,7 @@ import (
type RTMPPlugin struct {
pb.UnimplementedRtmpServer
Client
m7s.Plugin
ChunkSize int `default:"1024"`
KeepAlive bool
@@ -30,18 +31,10 @@ func (p *RTMPPlugin) OnInit() error {
return nil
}
func (p *RTMPPlugin) NewPullHandler() m7s.PullHandler {
return &Client{}
}
func (p *RTMPPlugin) GetPullableList() []string {
return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub))
}
func (p *RTMPPlugin) NewPushHandler() m7s.PushHandler {
return &Client{}
}
func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
logger := p.Logger.With("remote", conn.RemoteAddr().String())
receivers := make(map[uint32]*Receiver)
@@ -55,7 +48,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
}
if len(receivers) > 0 {
for _, receiver := range receivers {
receiver.Dispose(err)
receiver.Stop(err)
}
}
}()
@@ -165,7 +158,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
StreamID: cmd.StreamId,
},
}
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, conn, connectInfo)
receiver.Publisher, err = p.Publish(nc.AppName+"/"+cmd.PublishingName, receiver, connectInfo)
if err != nil {
delete(receivers, cmd.StreamId)
err = receiver.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error)
@@ -185,7 +178,7 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) {
}
var suber *m7s.Subscriber
// sender.ID = fmt.Sprintf("%s|%d", conn.RemoteAddr().String(), sender.StreamID)
suber, err = p.Subscribe(streamPath, conn, connectInfo)
suber, err = p.Subscribe(streamPath, &ns, connectInfo)
if err != nil {
err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error)
} else {

View File

@@ -11,29 +11,18 @@ import (
"m7s.live/m7s/v5"
)
type Client struct {
NetStream
ServerInfo map[string]any
}
type Client struct{}
func NewPushHandler() m7s.PushHandler {
return &Client{}
}
func NewPullHandler() m7s.PullHandler {
return &Client{}
}
func (client *Client) Connect(p *m7s.Client) (err error) {
func createClient(c *m7s.Connection) (*NetStream, error) {
chunkSize := 4096
addr := p.RemoteURL
addr := c.RemoteURL
u, err := url.Parse(addr)
if err != nil {
return err
return nil, err
}
ps := strings.Split(u.Path, "/")
if len(ps) < 3 {
return errors.New("illegal rtmp url")
return nil, errors.New("illegal rtmp url")
}
isRtmps := u.Scheme == "rtmps"
if strings.Count(u.Host, ":") == 0 {
@@ -52,101 +41,108 @@ func (client *Client) Connect(p *m7s.Client) (err error) {
conn, err = net.Dial("tcp", u.Host)
}
if err != nil {
return err
return nil, err
}
ns := &NetStream{}
ns.NetConnection = NewNetConnection(conn, c.Logger)
defer func() {
if err != nil {
conn.Close()
ns.disconnect()
}
}()
client.NetConnection = NewNetConnection(conn, p.Logger)
if err = client.ClientHandshake(); err != nil {
return err
if err = ns.ClientHandshake(); err != nil {
return ns, err
}
client.AppName = strings.Join(ps[1:len(ps)-1], "/")
err = client.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(chunkSize))
ns.AppName = strings.Join(ps[1:len(ps)-1], "/")
err = ns.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(chunkSize))
if err != nil {
return
return ns, err
}
client.WriteChunkSize = chunkSize
ns.WriteChunkSize = chunkSize
path := u.Path
if len(u.Query()) != 0 {
path += "?" + u.RawQuery
}
err = client.SendMessage(RTMP_MSG_AMF0_COMMAND, &CallMessage{
err = ns.SendMessage(RTMP_MSG_AMF0_COMMAND, &CallMessage{
CommandMessage{"connect", 1},
map[string]any{
"app": client.AppName,
"app": ns.AppName,
"flashVer": "monibuca/" + m7s.Version,
"swfUrl": addr,
"tcUrl": strings.TrimSuffix(addr, path) + "/" + client.AppName,
"tcUrl": strings.TrimSuffix(addr, path) + "/" + ns.AppName,
},
nil,
})
var msg *Chunk
for err != nil {
msg, err := client.RecvMessage()
msg, err = ns.RecvMessage()
if err != nil {
return err
return ns, err
}
switch msg.MessageTypeID {
case RTMP_MSG_AMF0_COMMAND:
cmd := msg.MsgData.(Commander).GetCommand()
switch cmd.CommandName {
case "_result":
client.ServerInfo = msg.MsgData.(*ResponseMessage).Properties
c.MetaData = msg.MsgData.(*ResponseMessage).Properties
response := msg.MsgData.(*ResponseMessage)
if response.Infomation["code"] == NetConnection_Connect_Success {
} else {
return err
return ns, err
}
default:
fmt.Println(cmd.CommandName)
}
}
}
client.Info("connect", "remoteURL", p.RemoteURL)
return
c.Info("connect", "remoteURL", c.RemoteURL)
return ns, nil
}
func (puller *Client) Pull(p *m7s.Puller) (err error) {
p.MetaData = puller.ServerInfo
func (Client) DoPull(p *m7s.PullContext) (err error) {
var connection *NetStream
if connection, err = createClient(&p.Connection); err != nil {
return
}
defer func() {
puller.Close()
connection.disconnect()
if p := recover(); p != nil {
err = p.(error)
}
p.Dispose(err)
}()
err = puller.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2})
err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2})
var msg *Chunk
for err == nil {
msg, err := puller.RecvMessage()
if err != nil {
if err = p.Publisher.Err(); err != nil {
return
}
if msg, err = connection.RecvMessage(); err != nil {
return err
}
switch msg.MessageTypeID {
case RTMP_MSG_AUDIO:
p.WriteAudio(msg.AVData.WrapAudio())
err = p.Publisher.WriteAudio(msg.AVData.WrapAudio())
case RTMP_MSG_VIDEO:
p.WriteVideo(msg.AVData.WrapVideo())
err = p.Publisher.WriteVideo(msg.AVData.WrapVideo())
case RTMP_MSG_AMF0_COMMAND:
cmd := msg.MsgData.(Commander).GetCommand()
switch cmd.CommandName {
case "_result":
if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok {
puller.StreamID = response.StreamId
connection.StreamID = response.StreamId
m := &PlayMessage{}
m.StreamId = response.StreamId
m.TransactionId = 4
m.CommandMessage.CommandName = "play"
URL, _ := url.Parse(p.Client.RemoteURL)
URL, _ := url.Parse(p.Connection.RemoteURL)
ps := strings.Split(URL.Path, "/")
p.Args = URL.Query()
args := URL.Query()
m.StreamName = ps[len(ps)-1]
if len(p.Args) > 0 {
m.StreamName += "?" + p.Args.Encode()
if len(args) > 0 {
m.StreamName += "?" + args.Encode()
}
puller.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
connection.SendMessage(RTMP_MSG_AMF0_COMMAND, m)
// if response, ok := msg.MsgData.(*ResponsePlayMessage); ok {
// if response.Object["code"] == "NetStream.Play.Start" {
@@ -163,11 +159,16 @@ func (puller *Client) Pull(p *m7s.Puller) (err error) {
return
}
func (pusher *Client) Push(p *m7s.Pusher) (err error) {
p.MetaData = pusher.ServerInfo
pusher.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2})
for {
msg, err := pusher.RecvMessage()
func (Client) DoPush(p *m7s.PushContext) (err error) {
var connection *NetStream
if connection, err = createClient(&p.Connection); err != nil {
return
}
defer connection.disconnect()
err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2})
var msg *Chunk
for err == nil {
msg, err = connection.RecvMessage()
if err != nil {
return err
}
@@ -177,15 +178,15 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) {
switch cmd.CommandName {
case Response_Result, Response_OnStatus:
if response, ok := msg.MsgData.(*ResponseCreateStreamMessage); ok {
pusher.StreamID = response.StreamId
URL, _ := url.Parse(p.Client.RemoteURL)
connection.StreamID = response.StreamId
URL, _ := url.Parse(p.Connection.RemoteURL)
_, streamPath, _ := strings.Cut(URL.Path, "/")
_, streamPath, _ = strings.Cut(streamPath, "/")
p.Args = URL.Query()
if len(p.Args) > 0 {
streamPath += "?" + p.Args.Encode()
args := URL.Query()
if len(args) > 0 {
streamPath += "?" + args.Encode()
}
pusher.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{
err = connection.SendMessage(RTMP_MSG_AMF0_COMMAND, &PublishMessage{
CURDStreamMessage{
CommandMessage{
"publish",
@@ -198,8 +199,15 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) {
})
} else if response, ok := msg.MsgData.(*ResponsePublishMessage); ok {
if response.Infomation["code"] == NetStream_Publish_Start {
audio, video := pusher.CreateSender(true)
go m7s.PlayBlock(&p.Subscriber, audio.HandleAudio, video.HandleVideo)
p.Connection.ReConnectCount = 0
audio, video := connection.CreateSender(true)
go func() {
for err == nil {
msg, err = connection.RecvMessage()
}
p.Subscriber.Stop(err)
}()
return m7s.PlayBlock(p.Subscriber, audio.HandleAudio, video.HandleVideo)
} else {
return errors.New(response.Infomation["code"].(string))
}
@@ -207,4 +215,5 @@ func (pusher *Client) Push(p *m7s.Pusher) (err error) {
}
}
}
return
}

View File

@@ -63,9 +63,9 @@ func (ns *NetStream) BeginPlay(tid uint64) (err error) {
return
}
func (ns *NetStream) Close() error {
if ns.NetConnection != nil {
func (ns *NetStream) disconnect() {
if ns != nil && ns.NetConnection != nil {
ns.NetConnection.Destroy()
}
return nil
return
}

View File

@@ -22,20 +22,13 @@ var _ = m7s.InstallPlugin[RTSPPlugin](defaultConfig)
type RTSPPlugin struct {
m7s.Plugin
}
func (p *RTSPPlugin) NewPullHandler() m7s.PullHandler {
return &Client{}
Client
}
func (p *RTSPPlugin) GetPullableList() []string {
return slices.Collect(maps.Keys(p.GetCommonConf().PullOnSub))
}
func (p *RTSPPlugin) NewPushHandler() m7s.PushHandler {
return &Client{}
}
func (p *RTSPPlugin) OnInit() error {
for streamPath, url := range p.GetCommonConf().PullOnStart {
go p.Pull(streamPath, url)
@@ -56,7 +49,7 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
logger.Error(err.Error(), "stack", string(debug.Stack()))
}
if receiver != nil {
receiver.Dispose(err)
receiver.Stop(err)
}
}()
var req *util.Request
@@ -114,7 +107,7 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
receiver = &Receiver{}
receiver.NetConnection = nc
if receiver.Publisher, err = p.Publish(strings.TrimPrefix(nc.URL.Path, "/")); err != nil {
if receiver.Publisher, err = p.Publish(strings.TrimPrefix(nc.URL.Path, "/"), receiver); err != nil {
receiver = nil
err = nc.WriteResponse(&util.Response{
StatusCode: 500, Status: err.Error(),
@@ -131,9 +124,9 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
case MethodDescribe:
sendMode = true
var subscriber *m7s.Subscriber
subscriber, err = p.Subscribe(strings.TrimPrefix(nc.URL.Path, "/"), conn)
sender = &Sender{}
sender.NetConnection = nc
sender.Subscriber, err = p.Subscribe(strings.TrimPrefix(nc.URL.Path, "/"), sender)
if err != nil {
res := &util.Response{
StatusCode: http.StatusBadRequest,
@@ -149,10 +142,6 @@ func (p *RTSPPlugin) OnTCPConnect(conn *net.TCPConn) {
},
Request: req,
}
sender = &Sender{
Subscriber: subscriber,
}
sender.NetConnection = nc
// convert tracks to real output medias
var medias []*Media
if medias, err = sender.GetMedia(); err != nil {

View File

@@ -2,35 +2,21 @@ package rtsp
import (
"crypto/tls"
"errors"
"fmt"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
"net"
"net/http"
"net/url"
"strconv"
"strings"
)
type Client struct {
Stream
}
type Client struct{}
func NewPushHandler() m7s.PushHandler {
return &Client{}
}
func NewPullHandler() m7s.PullHandler {
return &Client{}
}
func (c *Client) Connect(p *m7s.Client) (err error) {
func createClient(p *m7s.Connection) (s *Stream, err error) {
addr := p.RemoteURL
var rtspURL *url.URL
rtspURL, err = url.Parse(addr)
if err != nil {
return err
return
}
//ps := strings.Split(u.Path, "/")
//if len(ps) < 3 {
@@ -53,290 +39,73 @@ func (c *Client) Connect(p *m7s.Client) (err error) {
conn, err = net.Dial("tcp", rtspURL.Host)
}
if err != nil {
return err
return
}
defer func() {
s = &Stream{NetConnection: NewNetConnection(conn, p.Logger)}
s.URL = rtspURL
s.auth = util.NewAuth(s.URL.User)
s.Backchannel = true
err = s.Options()
if err != nil {
conn.Close()
s.disconnect()
return
}
}()
c.NetConnection = NewNetConnection(conn, p.Logger)
c.URL = rtspURL
c.auth = util.NewAuth(c.URL.User)
c.Backchannel = true
return c.Options()
return
}
func (c *Client) Pull(p *m7s.Puller) (err error) {
func (Client) DoPull(p *m7s.PullContext) (err error) {
var s *Stream
if s, err = createClient(&p.Connection); err != nil {
return
}
defer func() {
c.Close()
s.disconnect()
if p := recover(); p != nil {
err = p.(error)
}
p.Dispose(err)
}()
var media []*Media
if media, err = c.Describe(); err != nil {
if media, err = s.Describe(); err != nil {
return
}
receiver := &Receiver{Publisher: &p.Publisher, Stream: c.Stream}
receiver := &Receiver{Publisher: p.Publisher, Stream: s}
if err = receiver.SetMedia(media); err != nil {
return
}
if err = c.Play(); err != nil {
if err = s.Play(); err != nil {
return
}
p.Connection.ReConnectCount = 0
return receiver.Receive()
}
func (c *Client) Push(p *m7s.Pusher) (err error) {
defer c.Close()
sender := &Sender{Subscriber: &p.Subscriber, Stream: c.Stream}
func (Client) DoPush(ctx *m7s.PushContext) (err error) {
var s *Stream
if s, err = createClient(&ctx.Connection); err != nil {
return
}
defer s.disconnect()
sender := &Sender{Subscriber: ctx.Subscriber, Stream: s}
var medias []*Media
medias, err = sender.GetMedia()
err = c.Announce(medias)
err = s.Announce(medias)
if err != nil {
return
}
for i, media := range medias {
switch media.Kind {
case "audio", "video":
_, err = c.SetupMedia(media, i)
_, err = s.SetupMedia(media, i)
if err != nil {
return
}
default:
c.Warn("media kind not support", "kind", media.Kind)
ctx.Warn("media kind not support", "kind", media.Kind)
}
}
if err = c.Record(); err != nil {
if err = s.Record(); err != nil {
return
}
ctx.Connection.ReConnectCount = 0
return sender.Send()
}
func (c *Client) Do(req *util.Request) (*util.Response, error) {
if err := c.WriteRequest(req); err != nil {
return nil, err
}
res, err := c.ReadResponse()
if err != nil {
return nil, err
}
if res.StatusCode == http.StatusUnauthorized {
switch c.auth.Method {
case util.AuthNone:
if c.auth.ReadNone(res) {
return c.Do(req)
}
return nil, errors.New("user/pass not provided")
case util.AuthUnknown:
if c.auth.Read(res) {
return c.Do(req)
}
default:
return nil, errors.New("wrong user/pass")
}
}
if res.StatusCode != http.StatusOK {
return res, fmt.Errorf("wrong response on %s", req.Method)
}
return res, nil
}
func (c *Client) Options() error {
req := &util.Request{Method: MethodOptions, URL: c.URL}
res, err := c.Do(req)
if err != nil {
return err
}
if val := res.Header.Get("Content-Base"); val != "" {
c.URL, err = urlParse(val)
if err != nil {
return err
}
}
return nil
}
func (c *Client) Describe() (medias []*Media, err error) {
// 5.3 Back channel connection
// https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf
req := &util.Request{
Method: MethodDescribe,
URL: c.URL,
Header: map[string][]string{
"Accept": {"application/sdp"},
},
}
if c.Backchannel {
req.Header.Set("Require", "www.onvif.org/ver20/backchannel")
}
if c.UserAgent != "" {
// this camera will answer with 401 on DESCRIBE without User-Agent
// https://github.com/AlexxIT/go2rtc/issues/235
req.Header.Set("User-Agent", c.UserAgent)
}
var res *util.Response
res, err = c.Do(req)
if err != nil {
return
}
if val := res.Header.Get("Content-Base"); val != "" {
c.URL, err = urlParse(val)
if err != nil {
return
}
}
c.sdp = string(res.Body) // for info
medias, err = UnmarshalSDP(res.Body)
if err != nil {
return
}
if c.Media != "" {
clone := make([]*Media, 0, len(medias))
for _, media := range medias {
if strings.Contains(c.Media, media.Kind) {
clone = append(clone, media)
}
}
medias = clone
}
return
}
func (c *Client) Announce(medias []*Media) (err error) {
req := &util.Request{
Method: MethodAnnounce,
URL: c.URL,
Header: map[string][]string{
"Content-Type": {"application/sdp"},
},
}
req.Body, err = MarshalSDP(c.SessionName, medias)
if err != nil {
return err
}
_, err = c.Do(req)
return
}
func (c *Client) SetupMedia(media *Media, index int) (byte, error) {
var transport string
transport = fmt.Sprintf(
// i - RTP (data channel)
// i+1 - RTCP (control channel)
"RTP/AVP/TCP;unicast;interleaved=%d-%d", index*2, index*2+1,
)
if transport == "" {
return 0, fmt.Errorf("wrong media: %v", media)
}
rawURL := media.ID // control
if !strings.Contains(rawURL, "://") {
rawURL = c.URL.String()
if !strings.HasSuffix(rawURL, "/") {
rawURL += "/"
}
rawURL += media.ID
}
trackURL, err := urlParse(rawURL)
if err != nil {
return 0, err
}
req := &util.Request{
Method: MethodSetup,
URL: trackURL,
Header: map[string][]string{
"Transport": {transport},
},
}
res, err := c.Do(req)
if err != nil {
// some Dahua/Amcrest cameras fail here because two simultaneous
// backchannel connections
//if c.Backchannel {
// c.Backchannel = false
// if err = c.Connect(); err != nil {
// return 0, err
// }
// return c.SetupMedia(media)
//}
return 0, err
}
if c.Session == "" {
// Session: 7116520596809429228
// Session: 216525287999;timeout=60
if s := res.Header.Get("Session"); s != "" {
if i := strings.IndexByte(s, ';'); i > 0 {
c.Session = s[:i]
if i = strings.Index(s, "timeout="); i > 0 {
c.keepalive, _ = strconv.Atoi(s[i+8:])
}
} else {
c.Session = s
}
}
}
// we send our `interleaved`, but camera can answer with another
// Transport: RTP/AVP/TCP;unicast;interleaved=10-11;ssrc=10117CB7
// Transport: RTP/AVP/TCP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0
// Transport: RTP/AVP/TCP;ssrc=22345682;interleaved=0-1
transport = res.Header.Get("Transport")
if !strings.HasPrefix(transport, "RTP/AVP/TCP;") {
// Escam Q6 has a bug:
// Transport: RTP/AVP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0-1
if !strings.Contains(transport, ";interleaved=") {
return 0, fmt.Errorf("wrong transport: %s", transport)
}
}
channel := Between(transport, "interleaved=", "-")
i, err := strconv.Atoi(channel)
if err != nil {
return 0, err
}
return byte(i), nil
}
func (c *Client) Play() (err error) {
return c.WriteRequest(&util.Request{Method: MethodPlay, URL: c.URL})
}
func (c *Client) Record() (err error) {
return c.WriteRequest(&util.Request{Method: MethodRecord, URL: c.URL})
}
func (c *Client) Teardown() (err error) {
// allow TEARDOWN from any state (ex. ANNOUNCE > SETUP)
return c.WriteRequest(&util.Request{Method: MethodTeardown, URL: c.URL})
}
func (c *Client) Destroy() {
_ = c.Teardown()
c.NetConnection.Destroy()
}

View File

@@ -0,0 +1,243 @@
package rtsp
import (
"errors"
"fmt"
"m7s.live/m7s/v5/pkg/util"
"net/http"
"strconv"
"strings"
)
type Stream struct {
*NetConnection
AudioChannelID int
VideoChannelID int
}
func (c *Stream) Do(req *util.Request) (*util.Response, error) {
if err := c.WriteRequest(req); err != nil {
return nil, err
}
res, err := c.ReadResponse()
if err != nil {
return nil, err
}
if res.StatusCode == http.StatusUnauthorized {
switch c.auth.Method {
case util.AuthNone:
if c.auth.ReadNone(res) {
return c.Do(req)
}
return nil, errors.New("user/pass not provided")
case util.AuthUnknown:
if c.auth.Read(res) {
return c.Do(req)
}
default:
return nil, errors.New("wrong user/pass")
}
}
if res.StatusCode != http.StatusOK {
return res, fmt.Errorf("wrong response on %s", req.Method)
}
return res, nil
}
func (c *Stream) Options() error {
req := &util.Request{Method: MethodOptions, URL: c.URL}
res, err := c.Do(req)
if err != nil {
return err
}
if val := res.Header.Get("Content-Base"); val != "" {
c.URL, err = urlParse(val)
if err != nil {
return err
}
}
return nil
}
func (c *Stream) Describe() (medias []*Media, err error) {
// 5.3 Back channel connection
// https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf
req := &util.Request{
Method: MethodDescribe,
URL: c.URL,
Header: map[string][]string{
"Accept": {"application/sdp"},
},
}
if c.Backchannel {
req.Header.Set("Require", "www.onvif.org/ver20/backchannel")
}
if c.UserAgent != "" {
// this camera will answer with 401 on DESCRIBE without User-Agent
// https://github.com/AlexxIT/go2rtc/issues/235
req.Header.Set("User-Agent", c.UserAgent)
}
var res *util.Response
res, err = c.Do(req)
if err != nil {
return
}
if val := res.Header.Get("Content-Base"); val != "" {
c.URL, err = urlParse(val)
if err != nil {
return
}
}
c.sdp = string(res.Body) // for info
medias, err = UnmarshalSDP(res.Body)
if err != nil {
return
}
if c.Media != "" {
clone := make([]*Media, 0, len(medias))
for _, media := range medias {
if strings.Contains(c.Media, media.Kind) {
clone = append(clone, media)
}
}
medias = clone
}
return
}
func (c *Stream) Announce(medias []*Media) (err error) {
req := &util.Request{
Method: MethodAnnounce,
URL: c.URL,
Header: map[string][]string{
"Content-Type": {"application/sdp"},
},
}
req.Body, err = MarshalSDP(c.SessionName, medias)
if err != nil {
return err
}
_, err = c.Do(req)
return
}
func (c *Stream) SetupMedia(media *Media, index int) (byte, error) {
var transport string
transport = fmt.Sprintf(
// i - RTP (data channel)
// i+1 - RTCP (control channel)
"RTP/AVP/TCP;unicast;interleaved=%d-%d", index*2, index*2+1,
)
if transport == "" {
return 0, fmt.Errorf("wrong media: %v", media)
}
rawURL := media.ID // control
if !strings.Contains(rawURL, "://") {
rawURL = c.URL.String()
if !strings.HasSuffix(rawURL, "/") {
rawURL += "/"
}
rawURL += media.ID
}
trackURL, err := urlParse(rawURL)
if err != nil {
return 0, err
}
req := &util.Request{
Method: MethodSetup,
URL: trackURL,
Header: map[string][]string{
"Transport": {transport},
},
}
res, err := c.Do(req)
if err != nil {
// some Dahua/Amcrest cameras fail here because two simultaneous
// backchannel connections
//if c.Backchannel {
// c.Backchannel = false
// if err = c.Connect(); err != nil {
// return 0, err
// }
// return c.SetupMedia(media)
//}
return 0, err
}
if c.Session == "" {
// Session: 7116520596809429228
// Session: 216525287999;timeout=60
if s := res.Header.Get("Session"); s != "" {
if i := strings.IndexByte(s, ';'); i > 0 {
c.Session = s[:i]
if i = strings.Index(s, "timeout="); i > 0 {
c.keepalive, _ = strconv.Atoi(s[i+8:])
}
} else {
c.Session = s
}
}
}
// we send our `interleaved`, but camera can answer with another
// Transport: RTP/AVP/TCP;unicast;interleaved=10-11;ssrc=10117CB7
// Transport: RTP/AVP/TCP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0
// Transport: RTP/AVP/TCP;ssrc=22345682;interleaved=0-1
transport = res.Header.Get("Transport")
if !strings.HasPrefix(transport, "RTP/AVP/TCP;") {
// Escam Q6 has a bug:
// Transport: RTP/AVP;unicast;destination=192.168.1.111;source=192.168.1.222;interleaved=0-1
if !strings.Contains(transport, ";interleaved=") {
return 0, fmt.Errorf("wrong transport: %s", transport)
}
}
channel := Between(transport, "interleaved=", "-")
i, err := strconv.Atoi(channel)
if err != nil {
return 0, err
}
return byte(i), nil
}
func (c *Stream) Play() (err error) {
return c.WriteRequest(&util.Request{Method: MethodPlay, URL: c.URL})
}
func (c *Stream) Record() (err error) {
return c.WriteRequest(&util.Request{Method: MethodRecord, URL: c.URL})
}
func (c *Stream) Teardown() (err error) {
// allow TEARDOWN from any state (ex. ANNOUNCE > SETUP)
return c.WriteRequest(&util.Request{Method: MethodTeardown, URL: c.URL})
}
func (ns *Stream) disconnect() {
if ns != nil && ns.NetConnection != nil {
_ = ns.Teardown()
ns.NetConnection.Destroy()
}
}

View File

@@ -10,30 +10,18 @@ import (
"reflect"
)
type Stream struct {
*NetConnection
AudioChannelID int
VideoChannelID int
}
type Sender struct {
*m7s.Subscriber
Stream
*Stream
}
type Receiver struct {
*m7s.Publisher
Stream
*Stream
AudioCodecParameters *webrtc.RTPCodecParameters
VideoCodecParameters *webrtc.RTPCodecParameters
}
func (ns *Stream) Close() error {
if ns.NetConnection != nil {
ns.NetConnection.Destroy()
}
return nil
}
func (s *Sender) GetMedia() (medias []*Media, err error) {
if s.SubAudio && s.Publisher.PubAudio && s.Publisher.HasAudioTrack() {
audioTrack := s.Publisher.GetAudioTrack(reflect.TypeOf((*mrtp.RTPAudio)(nil)))
@@ -163,6 +151,7 @@ func (r *Receiver) Receive() (err error) {
var channelID byte
var buf []byte
for err == nil {
channelID, buf, err = r.NetConnection.Receive(false)
if err != nil {
return
@@ -184,7 +173,9 @@ func (r *Receiver) Receive() (err error) {
audioFrame.AddRecycleBytes(buf)
audioFrame.Packets = append(audioFrame.Packets, packet)
} else {
err = r.WriteAudio(audioFrame)
if err = r.WriteAudio(audioFrame); err != nil {
return
}
audioFrame = &mrtp.RTPAudio{}
audioFrame.AddRecycleBytes(buf)
audioFrame.Packets = []*rtp.Packet{packet}
@@ -204,7 +195,9 @@ func (r *Receiver) Receive() (err error) {
videoFrame.Packets = append(videoFrame.Packets, packet)
} else {
// t := time.Now()
err = r.WriteVideo(videoFrame)
if err = r.WriteVideo(videoFrame); err != nil {
return
}
// fmt.Println("write video", time.Since(t))
videoFrame = &mrtp.Video{}
videoFrame.AddRecycleBytes(buf)

View File

@@ -13,14 +13,14 @@ import (
"m7s.live/m7s/v5/plugin/stress/pb"
)
func (r *StressPlugin) pull(count int, format, url string, newFunc func() m7s.PullHandler) error {
func (r *StressPlugin) pull(count int, format, url string, puller m7s.Puller) error {
if i := r.pullers.Length; count > i {
for j := i; j < count; j++ {
puller, err := r.Pull(fmt.Sprintf("stress/%d", j), fmt.Sprintf(format, url))
ctx, err := r.Pull(fmt.Sprintf("stress/%d", j), fmt.Sprintf(format, url))
if err != nil {
return err
}
go r.startPull(puller, newFunc())
go r.startPull(ctx, puller)
}
} else if count < i {
for j := i; j > count; j-- {
@@ -31,14 +31,14 @@ func (r *StressPlugin) pull(count int, format, url string, newFunc func() m7s.Pu
return nil
}
func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, newFunc func() m7s.PushHandler) (err error) {
func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, pusher m7s.Pusher) (err error) {
if i := r.pushers.Length; count > i {
for j := i; j < count; j++ {
pusher, err := r.Push(streamPath, fmt.Sprintf(format, remoteHost, j))
ctx, err := r.Push(streamPath, fmt.Sprintf(format, remoteHost, j))
if err != nil {
return err
}
go r.startPush(pusher, newFunc())
go r.startPush(ctx, pusher)
}
} else if count < i {
for j := i; j > count; j-- {
@@ -50,34 +50,34 @@ func (r *StressPlugin) push(count int, streamPath, format, remoteHost string, ne
}
func (r *StressPlugin) PushRTMP(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) {
return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtmp://%s/stress/%d", req.RemoteHost, rtmp.NewPushHandler)
return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtmp://%s/stress/%d", req.RemoteHost, rtmp.Client{}.DoPush)
}
func (r *StressPlugin) PushRTSP(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) {
return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtsp://%s/stress/%d", req.RemoteHost, rtsp.NewPushHandler)
return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, "rtsp://%s/stress/%d", req.RemoteHost, rtsp.Client{}.DoPush)
}
func (r *StressPlugin) PullRTMP(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) {
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtmp://%s", req.RemoteURL, rtmp.NewPullHandler)
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtmp://%s", req.RemoteURL, rtmp.Client{}.DoPull)
}
func (r *StressPlugin) PullRTSP(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) {
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtsp://%s", req.RemoteURL, rtsp.NewPullHandler)
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "rtsp://%s", req.RemoteURL, rtsp.Client{}.DoPull)
}
func (r *StressPlugin) PullHDL(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) {
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.NewPullHandler)
return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), "http://%s", req.RemoteURL, hdl.PullFLV)
}
func (r *StressPlugin) startPush(pusher *m7s.Pusher, handler m7s.PushHandler) {
func (r *StressPlugin) startPush(pusher *m7s.PushContext, handler m7s.Pusher) {
r.pushers.AddUnique(pusher)
pusher.Start(handler)
pusher.Run(handler)
r.pushers.Remove(pusher)
}
func (r *StressPlugin) startPull(puller *m7s.Puller, handler m7s.PullHandler) {
func (r *StressPlugin) startPull(puller *m7s.PullContext, handler m7s.Puller) {
r.pullers.AddUnique(puller)
puller.Start(handler)
puller.Run(handler)
r.pullers.Remove(puller)
}

View File

@@ -10,8 +10,8 @@ import (
type StressPlugin struct {
pb.UnimplementedApiServer
m7s.Plugin
pushers util.Collection[string, *m7s.Pusher]
pullers util.Collection[string, *m7s.Puller]
pushers util.Collection[string, *m7s.PushContext]
pullers util.Collection[string, *m7s.PullContext]
}
var _ = m7s.InstallPlugin[StressPlugin](&pb.Api_ServiceDesc, pb.RegisterApiHandler)

View File

@@ -1,6 +1,7 @@
package m7s
import (
"context"
"math"
"os"
"path/filepath"
@@ -59,21 +60,22 @@ type AVTracks struct {
}
func (t *AVTracks) CreateSubTrack(dataType reflect.Type) (track *AVTrack) {
track = NewAVTrack(dataType, t.AVTrack, util.NewPromise(struct{}{}))
track = NewAVTrack(dataType, t.AVTrack, util.NewPromise(context.TODO()))
track.WrapIndex = t.Length
t.Add(track)
return
}
// createPublisher -> Start -> WriteAudio/WriteVideo -> Dispose
type Publisher struct {
PubSubBase
sync.RWMutex `json:"-" yaml:"-"`
sync.RWMutex
config.Publish
State PublisherState
AudioTrack, VideoTrack AVTracks
audioReady, videoReady *util.Promise[struct{}]
audioReady, videoReady *util.Promise
DataTrack *DataTrack
Subscribers util.Collection[int, *Subscriber] `json:"-" yaml:"-"`
Subscribers SubscriberCollection
GOP int
baseTs, lastTs time.Duration
dumpFile *os.File
@@ -87,6 +89,70 @@ func (p *Publisher) GetKey() string {
return p.StreamPath
}
func createPublisher(p *Plugin, streamPath string, options ...any) (publisher *Publisher) {
publisher = &Publisher{Publish: p.config.Publish}
publisher.ID = p.Server.streamTM.GetID()
publisher.Executor = publisher
publisher.Plugin = p
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
var opt = []any{p.Logger.With("streamPath", streamPath, "pId", publisher.ID)}
for _, option := range options {
switch v := option.(type) {
case func(*config.Publish):
v(&publisher.Publish)
default:
opt = append(opt, option)
}
}
publisher.Init(streamPath, &publisher.Publish, opt...)
return
}
func (p *Publisher) Start() (err error) {
s := p.Plugin.Server
if oldPublisher, ok := s.Streams.Get(p.StreamPath); ok {
if p.KickExist {
p.Warn("kick")
oldPublisher.Stop(ErrKick)
p.TakeOver(oldPublisher)
} else {
return ErrStreamExist
}
}
s.Streams.Set(p)
p.Info("publish")
p.audioReady = util.NewPromiseWithTimeout(p, time.Second*5)
p.videoReady = util.NewPromiseWithTimeout(p, time.Second*5)
if p.Dump {
f := filepath.Join("./dump", p.StreamPath)
os.MkdirAll(filepath.Dir(f), 0666)
p.dumpFile, _ = os.OpenFile(filepath.Join("./dump", p.StreamPath), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
}
if waiting, ok := s.Waiting.Get(p.StreamPath); ok {
p.TakeOver(waiting)
s.Waiting.Remove(waiting)
}
for plugin := range s.Plugins.Range {
if plugin.Disabled {
continue
}
if remoteURL := plugin.GetCommonConf().CheckPush(p.StreamPath); remoteURL != "" {
if _, ok := plugin.handler.(IPusherPlugin); ok {
go plugin.Push(p.StreamPath, remoteURL)
}
}
if filePath := plugin.GetCommonConf().CheckRecord(p.StreamPath); filePath != "" {
if _, ok := plugin.handler.(IRecorderPlugin); ok {
go plugin.Record(p.StreamPath, filePath)
}
}
//if h, ok := plugin.handler.(IOnPublishPlugin); ok {
// h.OnPublish(publisher)
//}
}
return
}
func (p *Publisher) timeout() (err error) {
switch p.State {
case PublisherStateInit:
@@ -179,17 +245,6 @@ func (p *Publisher) AddSubscriber(subscriber *Subscriber) {
}
}
func (p *Publisher) Start() {
p.Info("publish")
p.audioReady = util.NewPromiseWithTimeout(struct{}{}, time.Second*5)
p.videoReady = util.NewPromiseWithTimeout(struct{}{}, time.Second*5)
if p.Dump {
f := filepath.Join("./dump", p.StreamPath)
os.MkdirAll(filepath.Dir(f), 0666)
p.dumpFile, _ = os.OpenFile(filepath.Join("./dump", p.StreamPath), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
}
}
func (p *Publisher) writeAV(t *AVTrack, data IAVFrame) {
frame := &t.Value
frame.Wraps = append(frame.Wraps, data)
@@ -222,6 +277,9 @@ func (p *Publisher) WriteVideo(data IAVFrame) (err error) {
data.Recycle()
}
}()
if err = p.Err(); err != nil {
return
}
if p.dumpFile != nil {
data.Dump(1, p.dumpFile)
}
@@ -320,6 +378,9 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) {
data.Recycle()
}
}()
if err = p.Err(); err != nil {
return
}
if p.dumpFile != nil {
data.Dump(0, p.dumpFile)
}
@@ -394,6 +455,9 @@ func (p *Publisher) WriteAudio(data IAVFrame) (err error) {
}
func (p *Publisher) WriteData(data IDataFrame) (err error) {
if err = p.Err(); err != nil {
return
}
if p.DataTrack == nil {
p.DataTrack = &DataTrack{}
p.DataTrack.Logger = p.Logger.With("track", "data")
@@ -441,16 +505,28 @@ func (p *Publisher) HasVideoTrack() bool {
return p.VideoTrack.Length > 0
}
func (p *Publisher) Dispose(err error) {
func (p *Publisher) Dispose() {
s := p.Plugin.Server
s.Streams.Remove(p)
if p.Subscribers.Length > 0 {
s.Waiting.Add(p)
}
p.Info("unpublish", "remain", s.Streams.Length, "reason", p.StopReason())
for subscriber := range p.SubscriberRange {
waitCloseTimeout := p.WaitCloseTimeout
if waitCloseTimeout == 0 {
waitCloseTimeout = subscriber.WaitTimeout
}
subscriber.TimeoutTimer.Reset(waitCloseTimeout)
}
p.Lock()
defer p.Unlock()
if p.dumpFile != nil {
p.dumpFile.Close()
}
if p.State == PublisherStateDisposed {
return
panic("disposed")
}
if p.IsStopped() {
if p.HasAudioTrack() {
p.AudioTrack.Dispose()
}
@@ -458,9 +534,6 @@ func (p *Publisher) Dispose(err error) {
p.VideoTrack.Dispose()
}
p.State = PublisherStateDisposed
return
}
p.Stop(err)
}
func (p *Publisher) TakeOver(old *Publisher) {
@@ -469,18 +542,16 @@ func (p *Publisher) TakeOver(old *Publisher) {
for subscriber := range old.SubscriberRange {
p.AddSubscriber(subscriber)
}
if old.Plugin != nil {
old.Dispose(nil)
}
old.Subscribers = util.Collection[int, *Subscriber]{}
old.Stop(ErrKick)
old.Subscribers = SubscriberCollection{}
}
func (p *Publisher) WaitTrack() (err error) {
if p.PubVideo {
_, err = p.videoReady.Await()
err = p.videoReady.Await()
}
if p.PubAudio {
_, err = p.audioReady.Await()
err = p.audioReady.Await()
}
return
}

107
puller.go
View File

@@ -1,66 +1,101 @@
package m7s
import (
"io"
"time"
"context"
"m7s.live/m7s/v5/pkg"
"m7s.live/m7s/v5/pkg/config"
"time"
)
type Client struct {
*PubSubBase
type Connection struct {
pkg.Task
Plugin *Plugin
StreamPath string // 对应本地流
RemoteURL string // 远程服务器地址(用于推拉)
ReConnectCount int //重连次数
Proxy string // 代理地址
ConnectProxy string // 连接代理
MetaData any
}
func (client *Client) reconnect(count int) (ok bool) {
func (client *Connection) reconnect(count int) (ok bool) {
ok = count == -1 || client.ReConnectCount <= count
client.ReConnectCount++
return
}
type PullHandler interface {
Connect(*Client) error
// Disconnect()
Pull(*Puller) error
type Puller = func(*PullContext) error
func createPullContext(p *Plugin, streamPath string, url string, options ...any) (pullCtx *PullContext) {
pullCtx = &PullContext{Pull: p.config.Pull}
pullCtx.ID = p.Server.pullTM.GetID()
pullCtx.Plugin = p
pullCtx.Executor = pullCtx
pullCtx.ConnectProxy = p.config.Pull.Proxy
pullCtx.RemoteURL = url
publishConfig := p.config.Publish
publishConfig.PublishTimeout = 0
pullCtx.StreamPath = streamPath
pullCtx.PublishOptions = []any{publishConfig}
var ctx = p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
default:
pullCtx.PublishOptions = append(pullCtx.PublishOptions, option)
}
}
p.Init(ctx, p.Logger.With("pullURL", url, "streamPath", streamPath))
pullCtx.PublishOptions = append(pullCtx.PublishOptions, pullCtx.Context)
return
}
type Puller struct {
Client Client
Publisher
type PullContext struct {
Connection
Publisher *Publisher
PublishOptions []any
config.Pull
}
func (p *Puller) Start(handler PullHandler) (err error) {
badPuller := true
var startTime time.Time
for p.Info("start pull", "url", p.Client.RemoteURL); p.Client.reconnect(p.RePull); p.Warn("restart pull") {
if time.Since(startTime) < 5*time.Second {
func (p *PullContext) GetKey() string {
return p.StreamPath
}
func (p *PullContext) Run(puller Puller) {
var err error
defer p.Info("stop pull")
for p.Info("start pull", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePull); p.Warn("restart pull") {
if p.Publisher != nil && time.Since(p.Publisher.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second)
}
startTime = time.Now()
if err = handler.Connect(&p.Client); err != nil {
if err == io.EOF {
p.Info("pull complete")
return
if p.Publisher, err = p.Plugin.Publish(p.StreamPath, p.PublishOptions...); err != nil {
p.Error("pull publish failed", "error", err)
break
}
p.Error("pull connect", "error", err)
if badPuller {
err = puller(p)
p.Publisher.Stop(err)
if p.IsStopped() {
return
}
} else {
badPuller = false
p.Client.ReConnectCount = 0
if err = handler.Pull(p); err != nil && !p.IsStopped() {
p.Error("pull interrupt", "error", err)
}
}
if p.IsStopped() {
p.Info("stop pull")
return
if err == nil {
err = pkg.ErrRetryRunOut
}
// handler.Disconnect()
}
return nil
p.Stop(err)
}
func (p *PullContext) Start() (err error) {
s := p.Plugin.Server
if _, ok := s.Pulls.Get(p.GetKey()); ok {
return pkg.ErrStreamExist
}
s.Pulls.Add(p)
return
}
func (p *PullContext) Dispose() {
p.Plugin.Server.Pulls.Remove(p)
}

View File

@@ -1,57 +1,85 @@
package m7s
import (
"io"
"context"
"m7s.live/m7s/v5/pkg"
"time"
"m7s.live/m7s/v5/pkg/config"
)
type PushHandler interface {
Connect(*Client) error
// Disconnect()
Push(*Pusher) error
type Pusher = func(*PushContext) error
func createPushContext(p *Plugin, streamPath string, url string, options ...any) (pushCtx *PushContext) {
pushCtx = &PushContext{Push: p.config.Push}
pushCtx.ID = p.Server.pushTM.GetID()
pushCtx.Plugin = p
pushCtx.Executor = pushCtx
pushCtx.RemoteURL = url
pushCtx.StreamPath = streamPath
pushCtx.ConnectProxy = p.config.Push.Proxy
pushCtx.SubscribeOptions = []any{p.config.Subscribe}
var ctx = p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
default:
pushCtx.SubscribeOptions = append(pushCtx.SubscribeOptions, option)
}
}
pushCtx.Init(ctx, p.Logger.With("pushURL", url, "streamPath", streamPath))
pushCtx.SubscribeOptions = append(pushCtx.SubscribeOptions, pushCtx.Context)
return
}
type Pusher struct {
Client Client
Subscriber
type PushContext struct {
Connection
Subscriber *Subscriber
SubscribeOptions []any
config.Push
}
func (p *Pusher) GetKey() string {
return p.Client.RemoteURL
func (p *PushContext) GetKey() string {
return p.RemoteURL
}
func (p *Pusher) Start(handler PushHandler) (err error) {
badPuller := true
var startTime time.Time
for p.Info("start push", "url", p.Client.RemoteURL); p.Client.reconnect(p.RePush); p.Warn("restart push") {
if time.Since(startTime) < 5*time.Second {
func (p *PushContext) Run(pusher Pusher) {
p.StartTime = time.Now()
defer p.Info("stop push")
var err error
for p.Info("start push", "url", p.Connection.RemoteURL); p.Connection.reconnect(p.RePush); p.Warn("restart push") {
if p.Subscriber != nil && time.Since(p.Subscriber.StartTime) < 5*time.Second {
time.Sleep(5 * time.Second)
}
startTime = time.Now()
if err = handler.Connect(&p.Client); err != nil {
if err == io.EOF {
p.Info("push complete")
return
if p.Subscriber, err = p.Plugin.Subscribe(p.StreamPath, p.SubscribeOptions...); err != nil {
p.Error("push subscribe failed", "error", err)
break
}
p.Error("push connect", "error", err)
if badPuller {
err = pusher(p)
p.Subscriber.Stop(err)
if p.IsStopped() {
return
}
} else {
badPuller = false
p.Client.ReConnectCount = 0
if err = handler.Push(p); err != nil && !p.IsStopped() {
p.Error("push interrupt", "error", err)
}
}
if p.IsStopped() {
p.Info("stop push")
if err == nil {
err = pkg.ErrRetryRunOut
}
p.Stop(err)
return
}
// handler.Disconnect()
}
return nil
}
func (p *PushContext) Start() (err error) {
s := p.Plugin.Server
if _, ok := s.Pushs.Get(p.GetKey()); ok {
return pkg.ErrPushRemoteURLExist
}
s.Pushs.Add(p)
return
}
func (p *PushContext) Dispose() {
p.Plugin.Server.Pushs.Remove(p)
}

View File

@@ -1,26 +1,66 @@
package m7s
import (
"m7s.live/m7s/v5/pkg/config"
"os"
"context"
"m7s.live/m7s/v5/pkg"
"time"
)
type RecordHandler interface {
Close()
Record(*Recorder) error
type Recorder = func(*RecordContext) error
func createRecoder(p *Plugin, streamPath string, filePath string, options ...any) (recorder *RecordContext) {
recorder = &RecordContext{
Plugin: p,
Fragment: p.config.Record.Fragment,
Append: p.config.Record.Append,
FilePath: filePath,
}
recorder.ID = p.Server.recordTM.GetID()
recorder.Executor = recorder
recorder.FilePath = filePath
recorder.SubscribeOptions = []any{p.config.Subscribe}
var ctx = p.Context
for _, option := range options {
switch v := option.(type) {
case context.Context:
ctx = v
default:
recorder.SubscribeOptions = append(recorder.SubscribeOptions, option)
}
}
recorder.Init(ctx, p.Logger.With("filePath", filePath, "streamPath", streamPath))
recorder.SubscribeOptions = append(recorder.SubscribeOptions, recorder.Context)
return
}
type Recorder struct {
File *os.File
Subscriber
config.Record
type RecordContext struct {
pkg.Task
Plugin *Plugin
Subscriber *Subscriber
SubscribeOptions []any
Fragment time.Duration
Append bool
FilePath string
}
func (p *Recorder) GetKey() string {
return p.File.Name()
func (p *RecordContext) GetKey() string {
return p.FilePath
}
func (p *Recorder) Start(handler RecordHandler) (err error) {
defer handler.Close()
return handler.Record(p)
func (p *RecordContext) Run(recorder Recorder) {
err := recorder(p)
p.Stop(err)
}
func (p *RecordContext) Start() (err error) {
s := p.Plugin.Server
if _, ok := s.Records.Get(p.GetKey()); ok {
return pkg.ErrRecordSamePath
}
s.Records.Add(p)
return
}
func (p *RecordContext) Dispose() {
p.Plugin.Server.Records.Remove(p)
}

412
server.go
View File

@@ -20,8 +20,6 @@ import (
"os"
"os/signal"
"path/filepath"
"reflect"
"slices"
"strings"
"sync/atomic"
"syscall"
@@ -33,13 +31,13 @@ var (
MergeConfigs = []string{"Publish", "Subscribe", "HTTP", "PublicIP", "LogLevel", "EnableAuth", "DB"}
ExecPath = os.Args[0]
ExecDir = filepath.Dir(ExecPath)
serverIndexG atomic.Uint32
DefaultServer = NewServer()
serverMeta = PluginMeta{
Name: "Global",
Version: Version,
}
Servers = make([]*Server, 10)
Servers util.Collection[uint32, *Server]
serverIdG atomic.Uint32
Routes = map[string]string{}
defaultLogHandler = console.NewHandler(os.Stdout, &console.HandlerOptions{TimeFormat: "15:04:05.000000"})
)
@@ -56,31 +54,34 @@ type Server struct {
pb.UnimplementedGlobalServer
Plugin
ServerConfig
eventChan chan any
//eventChan chan any
Plugins util.Collection[string, *Plugin]
Streams, Waiting util.Collection[string, *Publisher]
Pulls util.Collection[string, *Puller]
Pushs util.Collection[string, *Pusher]
Records util.Collection[string, *Recorder]
Subscribers util.Collection[int, *Subscriber]
Pulls util.Collection[string, *PullContext]
Pushs util.Collection[string, *PushContext]
Records util.Collection[string, *RecordContext]
Subscribers SubscriberCollection
LogHandler MultiLogHandler
pidG, sidG int
apiList []string
grpcServer *grpc.Server
grpcClientConn *grpc.ClientConn
lastSummaryTime time.Time
lastSummary *pb.SummaryResponse
OnAuthPubs map[string]func(p *util.Promise[*Publisher])
OnAuthSubs map[string]func(p *util.Promise[*Subscriber])
OnAuthPubs map[string]func(*Publisher) *util.Promise
OnAuthSubs map[string]func(*Subscriber) *util.Promise
pluginTM, streamTM, pullTM, pushTM, recordTM *TaskManager
runOption struct {
ctx context.Context
conf any
}
}
func NewServer() (s *Server) {
s = &Server{}
s.ID = int(serverIndexG.Add(1))
s.ID = serverIdG.Add(1)
s.Meta = &serverMeta
s.OnAuthPubs = make(map[string]func(p *util.Promise[*Publisher]))
s.OnAuthSubs = make(map[string]func(p *util.Promise[*Subscriber]))
Servers[s.ID] = s
s.OnAuthPubs = make(map[string]func(*Publisher) *util.Promise)
s.OnAuthSubs = make(map[string]func(*Subscriber) *util.Promise)
return
}
@@ -100,21 +101,11 @@ func init() {
}
}
func (s *Server) Run(ctx context.Context, conf any) (err error) {
s.StartTime = time.Now()
for err = s.run(ctx, conf); err == ErrRestart; err = s.run(ctx, conf) {
var server Server
server.ID = s.ID
server.Meta = s.Meta
server.OnAuthPubs = s.OnAuthPubs
server.OnAuthSubs = s.OnAuthSubs
server.DB = s.DB
*s = server
}
return
func (s *Server) GetKey() uint32 {
return s.ID
}
func (s *Server) run(ctx context.Context, conf any) (err error) {
func (s *Server) Start() (err error) {
s.Server = s
s.handler = s
s.config.HTTP.ListenAddrTLS = ":8443"
@@ -122,18 +113,16 @@ func (s *Server) run(ctx context.Context, conf any) (err error) {
s.config.TCP.ListenAddr = ":50051"
s.LogHandler.SetLevel(slog.LevelInfo)
s.LogHandler.Add(defaultLogHandler)
s.Logger = slog.New(&s.LogHandler).With("Server", s.ID)
s.Task.Init(s.runOption.ctx, slog.New(&s.LogHandler).With("Server", s.ID))
s.Info("start")
httpConf, tcpConf := &s.config.HTTP, &s.config.TCP
mux := runtime.NewServeMux(runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) {
httpConf.GetHttpMux().ServeHTTP(w, r)
}))
httpConf.SetMux(mux)
s.Context, s.CancelCauseFunc = context.WithCancelCause(ctx)
s.Info("start")
var cg rawconfig
var configYaml []byte
switch v := conf.(type) {
switch v := s.runOption.conf.(type) {
case string:
if _, err = os.Stat(v); err != nil {
v = filepath.Join(ExecDir, v)
@@ -156,7 +145,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) {
if cg != nil {
s.Config.ParseUserFile(cg["global"])
}
s.eventChan = make(chan any, s.EventBusSize)
//s.eventChan = make(chan any, s.EventBusSize)
s.LogHandler.SetLevel(ParseLevel(s.config.LogLevel))
s.registerHandler(map[string]http.HandlerFunc{
"/api/config/json/{name}": s.api_Config_JSON_,
@@ -175,7 +164,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) {
if httpConf.ListenAddrTLS != "" {
s.Info("https listen at ", "addr", httpConf.ListenAddrTLS)
go func(addr string) {
if err := httpConf.ListenTLS(); err != http.ErrServerClosed {
if err = httpConf.ListenTLS(); err != http.ErrServerClosed {
s.Stop(err)
}
s.Info("https stop listen at ", "addr", addr)
@@ -184,7 +173,7 @@ func (s *Server) run(ctx context.Context, conf any) (err error) {
if httpConf.ListenAddr != "" {
s.Info("http listen at ", "addr", httpConf.ListenAddr)
go func(addr string) {
if err := httpConf.Listen(); err != http.ErrServerClosed {
if err = httpConf.Listen(); err != http.ErrServerClosed {
s.Stop(err)
}
s.Info("http stop listen at ", "addr", addr)
@@ -196,107 +185,62 @@ func (s *Server) run(ctx context.Context, conf any) (err error) {
s.grpcServer = grpc.NewServer(opts...)
pb.RegisterGlobalServer(s.grpcServer, s)
s.grpcClientConn, err = grpc.DialContext(ctx, tcpConf.ListenAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
s.grpcClientConn, err = grpc.DialContext(s.Context, tcpConf.ListenAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
s.Error("failed to dial", "error", err)
return err
return
}
defer s.grpcClientConn.Close()
if err = pb.RegisterGlobalHandler(ctx, mux, s.grpcClientConn); err != nil {
if err = pb.RegisterGlobalHandler(s.Context, mux, s.grpcClientConn); err != nil {
s.Error("register handler faild", "error", err)
return err
return
}
tcplis, err = net.Listen("tcp", tcpConf.ListenAddr)
if err != nil {
s.Error("failed to listen", "error", err)
return err
return
}
defer tcplis.Close()
}
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
s.pluginTM = NewTaskManager()
go s.pluginTM.Run(signalChan, func() {
for plugin := range s.Plugins.Range {
plugin.handler.OnExit()
}
})
for _, plugin := range plugins {
plugin.Init(s, cg[strings.ToLower(plugin.Name)])
if p := plugin.Init(s, cg[strings.ToLower(plugin.Name)]); !p.Disabled {
s.pluginTM.Add(&p.Task)
}
}
if tcplis != nil {
go func(addr string) {
if err := s.grpcServer.Serve(tcplis); err != nil {
if err = s.grpcServer.Serve(tcplis); err != nil {
s.Stop(err)
}
s.Info("grpc stop listen at ", "addr", addr)
}(tcpConf.ListenAddr)
}
s.eventLoop()
err = context.Cause(s)
s.Warn("Server is done", "reason", err)
for publisher := range s.Streams.Range {
publisher.Stop(err)
}
for subscriber := range s.Subscribers.Range {
subscriber.Stop(err)
}
for p := range s.Plugins.Range {
p.Stop(err)
}
httpConf.StopListen()
return
}
type DoneChan = <-chan struct{}
func (s *Server) doneEventLoop(input chan DoneChan, output chan int) {
cases := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(input)}}
for {
switch chosen, rev, ok := reflect.Select(cases); chosen {
case 0:
if !ok {
return
}
cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: rev})
default:
output <- chosen - 1
cases = slices.Delete(cases, chosen, chosen+1)
}
}
}
// eventLoop powerful grateful graceful beautiful
func (s *Server) eventLoop() {
pulse := time.NewTicker(s.PulseInterval)
defer pulse.Stop()
pubChan := make(chan DoneChan, 10)
pubDoneChan := make(chan int, 10)
subChan := make(chan DoneChan, 10)
subDoneChan := make(chan int, 10)
defer close(pubChan)
defer close(subChan)
go s.doneEventLoop(pubChan, pubDoneChan)
go s.doneEventLoop(subChan, subDoneChan)
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(signalChan)
for {
select {
case <-signalChan:
for plugin := range s.Plugins.Range {
if plugin.Disabled {
continue
}
plugin.handler.OnExit()
}
case <-s.Done():
return
case <-pulse.C:
s.streamTM = NewTaskManager()
s.pullTM = NewTaskManager()
s.pushTM = NewTaskManager()
s.recordTM = NewTaskManager()
go s.streamTM.Run(time.NewTicker(s.PulseInterval).C, func(time.Time) {
for publisher := range s.Streams.Range {
if err := publisher.checkTimeout(); err != nil {
publisher.Stop(err)
}
}
for publisher := range s.Waiting.Range {
if publisher.Plugin != nil {
if err := publisher.checkTimeout(); err != nil {
publisher.Dispose(err)
s.createWait(publisher.StreamPath)
}
}
// TODO: ?
//if publisher.Plugin != nil {
// if err := publisher.checkTimeout(); err != nil {
// publisher.Stop(err)
// s.createWait(publisher.StreamPath)
// }
//}
for sub := range publisher.SubscriberRange {
select {
case <-sub.TimeoutTimer.C:
@@ -305,207 +249,61 @@ func (s *Server) eventLoop() {
}
}
}
case pubDone := <-pubDoneChan:
s.onUnpublish(s.Streams.Items[pubDone])
case subDone := <-subDoneChan:
s.onUnsubscribe(s.Subscribers.Items[subDone])
case event := <-s.eventChan:
switch v := event.(type) {
case *util.Promise[any]:
switch vv := v.Value.(type) {
case func():
vv()
v.Fulfill(nil)
continue
case func() error:
v.Fulfill(vv())
continue
case *Publisher:
err := s.OnPublish(vv)
if v.Fulfill(err); err != nil {
continue
}
event = vv
pubChan <- vv.Done()
case *Subscriber:
err := s.OnSubscribe(vv)
if v.Fulfill(err); err != nil {
continue
}
subChan <- vv.Done()
if !s.EnableSubEvent {
continue
}
event = v.Value
case *Puller:
if _, ok := s.Pulls.Get(vv.GetKey()); ok {
v.Fulfill(ErrStreamExist)
continue
} else {
err := s.OnPublish(&vv.Publisher)
v.Fulfill(err)
if err != nil {
continue
}
s.Pulls.Add(vv)
pubChan <- vv.Done()
event = v.Value
}
case *Pusher:
if _, ok := s.Pushs.Get(vv.GetKey()); ok {
v.Fulfill(ErrStreamExist)
continue
} else {
err := s.OnSubscribe(&vv.Subscriber)
v.Fulfill(err)
if err != nil {
continue
}
subChan <- vv.Done()
s.Pushs.Add(vv)
event = v.Value
}
case *Recorder:
if _, ok := s.Records.Get(vv.GetKey()); ok {
v.Fulfill(ErrStreamExist)
continue
} else {
err := s.OnSubscribe(&vv.Subscriber)
v.Fulfill(err)
if err != nil {
continue
}
subChan <- vv.Done()
s.Records.Add(vv)
event = v.Value
}
}
case slog.Handler:
s.LogHandler.Add(v)
}
for plugin := range s.Plugins.Range {
if plugin.Disabled {
continue
}
plugin.onEvent(event)
}
}
}
})
go s.pullTM.Run()
go s.pushTM.Run()
go s.recordTM.Run()
Servers.Add(s)
return
}
func (s *Server) onUnsubscribe(subscriber *Subscriber) {
s.Subscribers.Remove(subscriber)
s.Info("unsubscribe", "streamPath", subscriber.StreamPath, "reason", subscriber.StopReason())
if subscriber.Closer != nil {
subscriber.Close()
}
for pusher := range s.Pushs.Range {
if &pusher.Subscriber == subscriber {
s.Pushs.Remove(pusher)
break
}
}
if subscriber.Publisher != nil {
subscriber.Publisher.RemoveSubscriber(subscriber)
}
func (s *Server) Call(callback func()) {
s.streamTM.Call(callback)
}
func (s *Server) onUnpublish(publisher *Publisher) {
s.Streams.Remove(publisher)
if publisher.Subscribers.Length > 0 {
s.Waiting.Add(publisher)
}
s.Info("unpublish", "streamPath", publisher.StreamPath, "count", s.Streams.Length, "reason", publisher.StopReason())
for subscriber := range publisher.SubscriberRange {
waitCloseTimeout := publisher.WaitCloseTimeout
if waitCloseTimeout == 0 {
waitCloseTimeout = subscriber.WaitTimeout
}
subscriber.TimeoutTimer.Reset(waitCloseTimeout)
}
if publisher.Closer != nil {
_ = publisher.Close()
}
s.Pulls.RemoveByKey(publisher.StreamPath)
func (s *Server) Dispose() {
Servers.Remove(s)
s.config.HTTP.StopListen()
err := context.Cause(s)
s.streamTM.ShutDown(err)
s.pullTM.ShutDown(err)
s.pushTM.ShutDown(err)
s.recordTM.ShutDown(err)
s.pluginTM.ShutDown(err)
s.Warn("Server is done", "reason", err)
}
func (s *Server) OnPublish(publisher *Publisher) error {
if oldPublisher, ok := s.Streams.Get(publisher.StreamPath); ok {
if publisher.KickExist {
publisher.Warn("kick")
oldPublisher.Stop(ErrKick)
publisher.TakeOver(oldPublisher)
} else {
return ErrStreamExist
func (s *Server) Run(ctx context.Context, conf any) (err error) {
for {
s.runOption.ctx = ctx
s.runOption.conf = conf
if err = s.Start(); err != nil {
return
}
<-s.Done()
s.Dispose()
if err = context.Cause(s); err != ErrRestart {
return
}
s.Streams.Set(publisher)
s.pidG++
p := publisher.Plugin
publisher.ID = s.pidG
publisher.Logger = p.With("streamPath", publisher.StreamPath, "pubID", publisher.ID)
publisher.TimeoutTimer = time.NewTimer(p.config.PublishTimeout)
publisher.Start()
if waiting, ok := s.Waiting.Get(publisher.StreamPath); ok {
publisher.TakeOver(waiting)
s.Waiting.Remove(waiting)
var server Server
server.ID = s.ID
server.Meta = s.Meta
server.OnAuthPubs = s.OnAuthPubs
server.OnAuthSubs = s.OnAuthSubs
server.DB = s.DB
*s = server
}
for plugin := range s.Plugins.Range {
if plugin.Disabled {
continue
}
if remoteURL := plugin.GetCommonConf().CheckPush(publisher.StreamPath); remoteURL != "" {
if _, ok := plugin.handler.(IPusherPlugin); ok {
go plugin.Push(publisher.StreamPath, remoteURL)
}
}
if filePath := plugin.GetCommonConf().CheckRecord(publisher.StreamPath); filePath != "" {
if _, ok := plugin.handler.(IRecorderPlugin); ok {
go plugin.Record(publisher.StreamPath, filePath)
}
}
}
return nil
}
func (s *Server) createWait(streamPath string) *Publisher {
newPublisher := &Publisher{}
s.pidG++
newPublisher.ID = s.pidG
newPublisher.Logger = s.Logger.With("pubID", newPublisher.ID, "streamPath", streamPath)
newPublisher.Logger = s.Logger.With("streamPath", streamPath)
s.Info("createWait")
newPublisher.StreamPath = streamPath
s.Waiting.Set(newPublisher)
return newPublisher
}
func (s *Server) OnSubscribe(subscriber *Subscriber) error {
s.sidG++
subscriber.ID = s.sidG
subscriber.Logger = subscriber.Plugin.With("streamPath", subscriber.StreamPath, "subID", subscriber.ID)
subscriber.TimeoutTimer = time.NewTimer(subscriber.Plugin.config.Subscribe.WaitTimeout)
s.Subscribers.Add(subscriber)
subscriber.Info("subscribe")
if publisher, ok := s.Streams.Get(subscriber.StreamPath); ok {
publisher.AddSubscriber(subscriber)
} else if publisher, ok = s.Waiting.Get(subscriber.StreamPath); ok {
publisher.AddSubscriber(subscriber)
} else {
s.createWait(subscriber.StreamPath).AddSubscriber(subscriber)
for plugin := range s.Plugins.Range {
if plugin.Disabled {
continue
}
if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(subscriber.StreamPath); remoteURL != "" {
if _, ok := plugin.handler.(IPullerPlugin); ok {
go plugin.Pull(subscriber.StreamPath, remoteURL)
}
}
}
}
return nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/favicon.ico" {
http.ServeFile(w, r, "favicon.ico")
@@ -520,17 +318,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) Call(arg any) (result any, err error) {
promise := util.NewPromise(arg)
s.eventChan <- promise
<-promise.Done()
result = promise.Value
if err = context.Cause(promise.Context); err == util.ErrResolve {
err = nil
}
return
}
func (s *Server) PostMessage(msg any) {
s.eventChan <- msg
}
//func (s *Server) Call(arg any) (result any, err error) {
// promise := util.NewPromise(arg)
// s.eventChan <- promise
// <-promise.Done()
// result = promise.Value
// if err = context.Cause(promise.Context); err == util.ErrResolve {
// err = nil
// }
// return
//}
//
//func (s *Server) PostMessage(msg any) {
// s.eventChan <- msg
//}

View File

@@ -3,10 +3,8 @@ package m7s
import (
"context"
"errors"
"io"
"net"
"log/slog"
"net/url"
"os"
"reflect"
"runtime"
"strings"
@@ -19,46 +17,29 @@ import (
var AVFrameType = reflect.TypeOf((*AVFrame)(nil))
type Owner struct {
Conn net.Conn
File *os.File
MetaData any
io.Closer
}
type PubSubBase struct {
Unit[int]
Owner
Task
Plugin *Plugin
StreamPath string
Args url.Values
TimeoutTimer *time.Timer
MetaData any
}
func (p *PubSubBase) GetKey() int {
return p.ID
}
func (ps *PubSubBase) Init(p *Plugin, streamPath string, conf any, options ...any) {
ps.Plugin = p
ctx := p.Context
func (ps *PubSubBase) Init(streamPath string, conf any, options ...any) {
ctx := ps.Plugin.Context
var logger *slog.Logger
for _, option := range options {
switch v := option.(type) {
case *slog.Logger:
logger = v
case context.Context:
ctx = v
case net.Conn:
ps.Conn = v
ps.Closer = v
case *os.File:
ps.File = v
ps.Closer = v
case io.Closer:
ps.Closer = v
default:
ps.MetaData = v
}
}
ps.Context, ps.CancelCauseFunc = context.WithCancelCause(ctx)
ps.Task.Init(ctx, logger)
if u, err := url.Parse(streamPath); err == nil {
ps.StreamPath, ps.Args = u.Path, u.Query()
}
@@ -80,8 +61,11 @@ func (ps *PubSubBase) Init(p *Plugin, streamPath string, conf any, options ...an
c.ParseModifyFile(cc)
}
ps.StartTime = time.Now()
}
type SubscriberCollection = util.Collection[uint32, *Subscriber]
type Subscriber struct {
PubSubBase
config.Subscribe
@@ -90,6 +74,57 @@ type Subscriber struct {
VideoReader *AVRingReader
}
func createSubscriber(p *Plugin, streamPath string, options ...any) *Subscriber {
subscriber := &Subscriber{Subscribe: p.config.Subscribe}
subscriber.ID = p.Server.streamTM.GetID()
subscriber.Plugin = p
subscriber.Executor = subscriber
subscriber.TimeoutTimer = time.NewTimer(subscriber.WaitTimeout)
var opt = []any{p.Logger.With("streamPath", streamPath, "sId", subscriber.ID)}
for _, option := range options {
switch v := option.(type) {
case func(*config.Subscribe):
v(&subscriber.Subscribe)
default:
opt = append(opt, option)
}
}
subscriber.Init(streamPath, &subscriber.Subscribe, opt...)
if subscriber.Subscribe.BufferTime > 0 {
subscriber.Subscribe.SubMode = SUBMODE_BUFFER
}
return subscriber
}
func (s *Subscriber) Start() (err error) {
server := s.Plugin.Server
server.Subscribers.Add(s)
s.Info("subscribe")
if publisher, ok := server.Streams.Get(s.StreamPath); ok {
publisher.AddSubscriber(s)
} else if publisher, ok = server.Waiting.Get(s.StreamPath); ok {
publisher.AddSubscriber(s)
} else {
server.createWait(s.StreamPath).AddSubscriber(s)
for plugin := range server.Plugins.Range {
if remoteURL := plugin.GetCommonConf().Pull.CheckPullOnSub(s.StreamPath); remoteURL != "" {
if _, ok := plugin.handler.(IPullerPlugin); ok {
go plugin.Pull(s.StreamPath, remoteURL)
}
}
}
}
return
}
func (s *Subscriber) Dispose() {
s.Plugin.Server.Subscribers.Remove(s)
s.Info("unsubscribe", "reason", s.StopReason())
if s.Publisher != nil {
s.Publisher.RemoveSubscriber(s)
}
}
func (s *Subscriber) createAudioReader(dataType reflect.Type, startAudioTs time.Duration) (awi int) {
if s.Publisher == nil || dataType == nil {
return
@@ -173,6 +208,7 @@ func PlayBlock0[A any, V any](s *Subscriber, handler SubscribeHandler[A, V]) (er
awi := s.createAudioReader(a1, startAudioTs)
vwi := s.createVideoReader(v1, startVideoTs)
defer func() {
s.Stop(err)
if s.AudioReader != nil {
s.AudioReader.StopRead()
}