package plugin_stress import ( "context" "fmt" "slices" "strings" "github.com/mcuadros/go-defaults" "m7s.live/v5/pkg/config" "m7s.live/v5/pkg/task" "google.golang.org/protobuf/types/known/emptypb" "m7s.live/v5" gpb "m7s.live/v5/pb" flv "m7s.live/v5/plugin/flv/pkg" mp4 "m7s.live/v5/plugin/mp4/pkg" rtmp "m7s.live/v5/plugin/rtmp/pkg" rtsp "m7s.live/v5/plugin/rtsp/pkg" srt "m7s.live/v5/plugin/srt/pkg" "m7s.live/v5/plugin/stress/pb" ) func (r *StressPlugin) pull(count int, url string, testMode int32, puller m7s.PullerFactory) (err error) { hasPlaceholder := strings.Contains(url, "%d") if i := r.pullers.Length; count > i { for j := i; j < count; j++ { conf := config.Pull{} defaults.SetDefaults(&conf) conf.TestMode = int(testMode) if hasPlaceholder { conf.URL = fmt.Sprintf(url, j) } else { conf.URL = url } p := puller(conf) ctx := p.GetPullJob().Init(p, &r.Plugin, fmt.Sprintf("stress/%d", j), conf, nil) if err = ctx.WaitStarted(); err != nil { return } if r.pullers.AddUnique(ctx) { ctx.OnDispose(func() { r.pullers.Remove(ctx) }) } else { ctx.Stop(task.ErrExist) } } } else if count < i { clone := slices.Clone(r.pullers.Items) for j := i; j > count; j-- { clone[j-1].Stop(task.ErrStopByUser) } } return } func (r *StressPlugin) push(count int, streamPath, url string, pusher m7s.PusherFactory) (err error) { if i := r.pushers.Length; count > i { for j := i; j < count; j++ { p := pusher() conf := config.Push{URL: fmt.Sprintf(url, j)} defaults.SetDefaults(&conf) ctx := p.GetPushJob().Init(p, &r.Plugin, streamPath, conf, nil) if err = ctx.WaitStarted(); err != nil { return } if r.pushers.AddUnique(ctx) { ctx.OnDispose(func() { r.pushers.Remove(ctx) }) } else { ctx.Stop(task.ErrExist) } } } else if count < i { clone := slices.Clone(r.pushers.Items) for j := i; j > count; j-- { clone[j-1].Stop(task.ErrStopByUser) } } return } func (r *StressPlugin) StartPush(ctx context.Context, req *pb.PushRequest) (res *gpb.SuccessResponse, err error) { var pusher m7s.PusherFactory switch req.Protocol { case "rtmp": pusher = rtmp.NewPusher case "rtsp": pusher = rtsp.NewPusher case "srt": pusher = srt.NewPusher default: return nil, fmt.Errorf("unsupport protocol %s", req.Protocol) } return &gpb.SuccessResponse{}, r.push(int(req.PushCount), req.StreamPath, req.RemoteURL, pusher) } func (r *StressPlugin) StartPull(ctx context.Context, req *pb.PullRequest) (res *gpb.SuccessResponse, err error) { var puller m7s.PullerFactory switch req.Protocol { case "rtmp": puller = rtmp.NewPuller case "rtsp": puller = rtsp.NewPuller case "srt": puller = srt.NewPuller case "flv": puller = flv.NewPuller case "mp4": puller = mp4.NewPuller default: return nil, fmt.Errorf("unsupport protocol %s", req.Protocol) } return &gpb.SuccessResponse{}, r.pull(int(req.PullCount), req.RemoteURL, req.TestMode, puller) } func (r *StressPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *gpb.SuccessResponse, err error) { for _, pusher := range slices.Clone(r.pushers.Items) { pusher.Stop(task.ErrStopByUser) } return &gpb.SuccessResponse{}, nil } func (r *StressPlugin) StopPull(ctx context.Context, req *emptypb.Empty) (res *gpb.SuccessResponse, err error) { for _, puller := range slices.Clone(r.pullers.Items) { puller.Stop(task.ErrStopByUser) } return &gpb.SuccessResponse{}, nil } func (r *StressPlugin) GetCount(ctx context.Context, req *emptypb.Empty) (res *pb.CountResponse, err error) { return &pb.CountResponse{ Data: &pb.CountResponseData{ PullCount: uint32(r.pullers.Length), PushCount: uint32(r.pushers.Length), }, }, nil }