Files
monibuca/plugin/test/api.go
2025-09-24 12:14:49 +08:00

301 lines
8.1 KiB
Go

package plugin_test
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/mcuadros/go-defaults"
"golang.org/x/exp/slices"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"m7s.live/v5"
pb "m7s.live/v5/pb"
"m7s.live/v5/pkg/config"
"m7s.live/v5/pkg/task"
"m7s.live/v5/pkg/util"
flv "m7s.live/v5/plugin/flv/pkg"
hls "m7s.live/v5/plugin/hls/pkg"
mp4 "m7s.live/v5/plugin/mp4/pkg"
rtmp "m7s.live/v5/plugin/rtmp/pkg"
rtsp "m7s.live/v5/plugin/rtsp/pkg"
srt "m7s.live/v5/plugin/srt/pkg"
testpb "m7s.live/v5/plugin/test/pb"
webrtc "m7s.live/v5/plugin/webrtc/pkg"
)
// ========== Protobuf 转换函数 ========== //
// ToPBTestCase 转换为 protobuf TestCase
func ToPBTestCase(tc *TestCase) *testpb.TestCase {
if tc == nil {
return nil
}
return &testpb.TestCase{
Name: tc.Name,
Description: tc.Description,
Timeout: durationpb.New(tc.Timeout),
Tasks: ToPBTestTasks(tc.Tasks),
Status: string(tc.Status),
StartTime: timestamppb.New(time.Unix(tc.StartTime, 0)),
EndTime: timestamppb.New(time.Unix(tc.EndTime, 0)),
Duration: tc.Duration,
VideoCodec: tc.VideoCodec,
AudioCodec: tc.AudioCodec,
VideoOnly: tc.VideoOnly,
AudioOnly: tc.AudioOnly,
ErrorMsg: tc.ErrorMsg,
Logs: tc.Logs,
Tags: tc.Tags,
}
}
func ToPBTestTasks(tasks []TestTaskConfig) []*testpb.TestTask {
pbTasks := make([]*testpb.TestTask, 0, len(tasks))
for _, task := range tasks {
pbTasks = append(pbTasks, &testpb.TestTask{
Action: task.Action,
Delay: durationpb.New(task.Delay),
Format: task.Format,
})
}
return pbTasks
}
// ========== Protobuf Gateway API 实现 ========== //
// ListTestCases 获取测试用例列表
func (p *TestPlugin) ListTestCases(ctx context.Context, req *testpb.ListTestCasesRequest) (*testpb.ListTestCasesResponse, error) {
// 构建过滤器
filter := TestCaseFilter{
Tags: req.Tags,
Status: TestCaseStatus(req.Status),
}
// 从缓存获取测试用例
allCases := p.GetTestCasesFromCache(filter)
// 转换为 protobuf 格式
pbCases := make([]*testpb.TestCase, 0, len(allCases))
for _, tc := range allCases {
pbCases = append(pbCases, ToPBTestCase(tc))
}
return &testpb.ListTestCasesResponse{
Code: 0, Message: "success", Data: pbCases,
}, nil
}
func (p *TestPlugin) ExecuteTestCase(ctx context.Context, req *testpb.ExecuteTestCaseRequest) (*pb.SuccessResponse, error) {
for _, name := range req.Names {
tc, exists := p.GetTestCaseFromCache(name)
if !exists || tc.Status == TestCaseStatusRunning || tc.Status == TestCaseStatusStarting {
continue
}
tc.Job = &task.Job{}
tc.ErrorMsg = ""
tc.Logs = ""
p.AddTask(tc)
}
return &pb.SuccessResponse{Code: 0, Message: "success"}, nil
}
func (p *TestPlugin) GetTestCaseSSE(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
var filter TestCaseFilter
tags := query.Get("tags")
if tags != "" {
filter.Tags = strings.Split(tags, ",")
}
filter.Status = TestCaseStatus(query.Get("status"))
util.NewSSE(w, r.Context(), func(sse *util.SSE) {
flush := func() error {
return sse.WriteJSON(p.GetTestCasesFromCache(filter))
}
if err := flush(); err != nil {
return
}
// 创建定时器,定期发送状态更新
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-sse.Context.Done():
return
case <-p.flushSSE:
if err := flush(); err != nil {
return
}
case <-ticker.C:
if err := flush(); err != nil {
return
}
}
}
})
}
// ========== Stress 测试相关 API 实现 ========== //
func (p *TestPlugin) pull(count int, url string, testMode int32, puller m7s.PullerFactory) (err error) {
hasPlaceholder := strings.Contains(url, "%d")
if i := p.pullers.Length; count > i {
for j := i; j < count; j++ {
conf := config.Pull{}
defaults.SetDefaults(&conf)
conf.TestMode = int(testMode)
if hasPlaceholder {
conf.URL = fmt.Sprintf(url, j)
} else {
conf.URL = url
}
puller := puller(conf)
ctx := puller.GetPullJob().Init(puller, &p.Plugin, fmt.Sprintf("stress/%d", j), conf, nil)
if err = ctx.WaitStarted(); err != nil {
return
}
if p.pullers.AddUnique(ctx) {
ctx.OnDispose(func() {
p.pullers.Remove(ctx)
})
} else {
ctx.Stop(task.ErrExist)
}
}
} else if count < i {
clone := slices.Clone(p.pullers.Items)
for j := i; j > count; j-- {
clone[j-1].Stop(task.ErrStopByUser)
}
}
return
}
func (p *TestPlugin) push(count int, streamPath, url string, pusher m7s.PusherFactory) (err error) {
if i := p.pushers.Length; count > i {
for j := i; j < count; j++ {
pusher := pusher()
conf := config.Push{URL: fmt.Sprintf(url, j)}
defaults.SetDefaults(&conf)
ctx := pusher.GetPushJob().Init(pusher, &p.Plugin, streamPath, conf, nil)
if err = ctx.WaitStarted(); err != nil {
return
}
if p.pushers.AddUnique(ctx) {
ctx.OnDispose(func() {
p.pushers.Remove(ctx)
})
} else {
ctx.Stop(task.ErrExist)
}
}
} else if count < i {
clone := slices.Clone(p.pushers.Items)
for j := i; j > count; j-- {
clone[j-1].Stop(task.ErrStopByUser)
}
}
return
}
func (p *TestPlugin) StartPush(ctx context.Context, req *testpb.PushRequest) (res *pb.SuccessResponse, err error) {
var pusher m7s.PusherFactory
if req.Protocol == "" {
if strings.HasPrefix(req.RemoteURL, "http") {
req.Protocol = "webrtc"
} else if strings.HasPrefix(req.RemoteURL, "srt") {
req.Protocol = "srt"
} else if strings.HasPrefix(req.RemoteURL, "rtsp") {
req.Protocol = "rtsp"
} else if strings.HasPrefix(req.RemoteURL, "rtmp") {
req.Protocol = "rtmp"
} else {
return nil, fmt.Errorf("unsupport protocol %s", req.RemoteURL)
}
}
switch req.Protocol {
case "rtmp":
pusher = rtmp.NewPusher
case "rtsp":
pusher = rtsp.NewPusher
case "srt":
pusher = srt.NewPusher
case "webrtc":
pusher = webrtc.NewPusher
default:
return nil, fmt.Errorf("unsupport protocol %s", req.Protocol)
}
return &pb.SuccessResponse{}, p.push(int(req.PushCount), req.StreamPath, req.RemoteURL, pusher)
}
func (p *TestPlugin) StartPull(ctx context.Context, req *testpb.PullRequest) (res *pb.SuccessResponse, err error) {
var puller m7s.PullerFactory
if req.Protocol == "" {
u, err := url.Parse(req.RemoteURL)
if err != nil {
return nil, fmt.Errorf("parse remote url failed: %w", err)
}
if strings.HasSuffix(u.Path, ".m3u8") {
req.Protocol = "hls"
} else if strings.HasSuffix(u.Path, ".flv") {
req.Protocol = "flv"
} else if strings.HasSuffix(u.Path, ".mp4") {
req.Protocol = "mp4"
} else if strings.HasPrefix(req.RemoteURL, "srt") {
req.Protocol = "srt"
} else if strings.HasPrefix(req.RemoteURL, "rtsp") {
req.Protocol = "rtsp"
} else if strings.HasPrefix(req.RemoteURL, "rtmp") {
req.Protocol = "rtmp"
} else {
req.Protocol = "webrtc"
}
}
switch req.Protocol {
case "rtmp":
puller = rtmp.NewPuller
case "rtsp":
puller = rtsp.NewPuller
case "srt":
puller = srt.NewPuller
case "flv":
puller = flv.NewPuller
case "mp4":
puller = mp4.NewPuller
case "webrtc":
puller = webrtc.NewPuller
case "hls":
puller = hls.NewPuller
default:
return nil, fmt.Errorf("unsupport protocol %s", req.Protocol)
}
return &pb.SuccessResponse{}, p.pull(int(req.PullCount), req.RemoteURL, req.TestMode, puller)
}
func (p *TestPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *pb.SuccessResponse, err error) {
for _, pusher := range slices.Clone(p.pushers.Items) {
pusher.Stop(task.ErrStopByUser)
}
return &pb.SuccessResponse{}, nil
}
func (p *TestPlugin) StopPull(ctx context.Context, req *emptypb.Empty) (res *pb.SuccessResponse, err error) {
for _, puller := range slices.Clone(p.pullers.Items) {
puller.Stop(task.ErrStopByUser)
}
return &pb.SuccessResponse{}, nil
}
func (p *TestPlugin) GetCount(ctx context.Context, req *emptypb.Empty) (res *testpb.CountResponse, err error) {
return &testpb.CountResponse{
Data: &testpb.CountResponseData{
PullCount: uint32(p.pullers.Length),
PushCount: uint32(p.pushers.Length),
},
}, nil
}