mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-09-27 09:52:06 +08:00
301 lines
8.1 KiB
Go
301 lines
8.1 KiB
Go
package plugin_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mcuadros/go-defaults"
|
|
"golang.org/x/exp/slices"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"m7s.live/v5"
|
|
pb "m7s.live/v5/pb"
|
|
"m7s.live/v5/pkg/config"
|
|
"m7s.live/v5/pkg/task"
|
|
"m7s.live/v5/pkg/util"
|
|
flv "m7s.live/v5/plugin/flv/pkg"
|
|
hls "m7s.live/v5/plugin/hls/pkg"
|
|
mp4 "m7s.live/v5/plugin/mp4/pkg"
|
|
rtmp "m7s.live/v5/plugin/rtmp/pkg"
|
|
rtsp "m7s.live/v5/plugin/rtsp/pkg"
|
|
srt "m7s.live/v5/plugin/srt/pkg"
|
|
testpb "m7s.live/v5/plugin/test/pb"
|
|
webrtc "m7s.live/v5/plugin/webrtc/pkg"
|
|
)
|
|
|
|
// ========== Protobuf 转换函数 ========== //
|
|
|
|
// ToPBTestCase 转换为 protobuf TestCase
|
|
func ToPBTestCase(tc *TestCase) *testpb.TestCase {
|
|
if tc == nil {
|
|
return nil
|
|
}
|
|
return &testpb.TestCase{
|
|
Name: tc.Name,
|
|
Description: tc.Description,
|
|
Timeout: durationpb.New(tc.Timeout),
|
|
Tasks: ToPBTestTasks(tc.Tasks),
|
|
Status: string(tc.Status),
|
|
StartTime: timestamppb.New(time.Unix(tc.StartTime, 0)),
|
|
EndTime: timestamppb.New(time.Unix(tc.EndTime, 0)),
|
|
Duration: tc.Duration,
|
|
VideoCodec: tc.VideoCodec,
|
|
AudioCodec: tc.AudioCodec,
|
|
VideoOnly: tc.VideoOnly,
|
|
AudioOnly: tc.AudioOnly,
|
|
ErrorMsg: tc.ErrorMsg,
|
|
Logs: tc.Logs,
|
|
Tags: tc.Tags,
|
|
}
|
|
}
|
|
|
|
func ToPBTestTasks(tasks []TestTaskConfig) []*testpb.TestTask {
|
|
pbTasks := make([]*testpb.TestTask, 0, len(tasks))
|
|
for _, task := range tasks {
|
|
pbTasks = append(pbTasks, &testpb.TestTask{
|
|
Action: task.Action,
|
|
Delay: durationpb.New(task.Delay),
|
|
Format: task.Format,
|
|
})
|
|
}
|
|
return pbTasks
|
|
}
|
|
|
|
// ========== Protobuf Gateway API 实现 ========== //
|
|
|
|
// ListTestCases 获取测试用例列表
|
|
func (p *TestPlugin) ListTestCases(ctx context.Context, req *testpb.ListTestCasesRequest) (*testpb.ListTestCasesResponse, error) {
|
|
// 构建过滤器
|
|
filter := TestCaseFilter{
|
|
Tags: req.Tags,
|
|
Status: TestCaseStatus(req.Status),
|
|
}
|
|
// 从缓存获取测试用例
|
|
allCases := p.GetTestCasesFromCache(filter)
|
|
|
|
// 转换为 protobuf 格式
|
|
pbCases := make([]*testpb.TestCase, 0, len(allCases))
|
|
for _, tc := range allCases {
|
|
pbCases = append(pbCases, ToPBTestCase(tc))
|
|
}
|
|
|
|
return &testpb.ListTestCasesResponse{
|
|
Code: 0, Message: "success", Data: pbCases,
|
|
}, nil
|
|
}
|
|
|
|
func (p *TestPlugin) ExecuteTestCase(ctx context.Context, req *testpb.ExecuteTestCaseRequest) (*pb.SuccessResponse, error) {
|
|
for _, name := range req.Names {
|
|
tc, exists := p.GetTestCaseFromCache(name)
|
|
if !exists || tc.Status == TestCaseStatusRunning || tc.Status == TestCaseStatusStarting {
|
|
continue
|
|
}
|
|
tc.Job = &task.Job{}
|
|
tc.ErrorMsg = ""
|
|
tc.Logs = ""
|
|
p.AddTask(tc)
|
|
}
|
|
return &pb.SuccessResponse{Code: 0, Message: "success"}, nil
|
|
}
|
|
|
|
func (p *TestPlugin) GetTestCaseSSE(w http.ResponseWriter, r *http.Request) {
|
|
query := r.URL.Query()
|
|
var filter TestCaseFilter
|
|
tags := query.Get("tags")
|
|
if tags != "" {
|
|
filter.Tags = strings.Split(tags, ",")
|
|
}
|
|
filter.Status = TestCaseStatus(query.Get("status"))
|
|
util.NewSSE(w, r.Context(), func(sse *util.SSE) {
|
|
flush := func() error {
|
|
return sse.WriteJSON(p.GetTestCasesFromCache(filter))
|
|
}
|
|
if err := flush(); err != nil {
|
|
return
|
|
}
|
|
// 创建定时器,定期发送状态更新
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-sse.Context.Done():
|
|
return
|
|
case <-p.flushSSE:
|
|
if err := flush(); err != nil {
|
|
return
|
|
}
|
|
case <-ticker.C:
|
|
if err := flush(); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// ========== Stress 测试相关 API 实现 ========== //
|
|
|
|
func (p *TestPlugin) pull(count int, url string, testMode int32, puller m7s.PullerFactory) (err error) {
|
|
hasPlaceholder := strings.Contains(url, "%d")
|
|
if i := p.pullers.Length; count > i {
|
|
for j := i; j < count; j++ {
|
|
conf := config.Pull{}
|
|
defaults.SetDefaults(&conf)
|
|
conf.TestMode = int(testMode)
|
|
if hasPlaceholder {
|
|
conf.URL = fmt.Sprintf(url, j)
|
|
} else {
|
|
conf.URL = url
|
|
}
|
|
puller := puller(conf)
|
|
ctx := puller.GetPullJob().Init(puller, &p.Plugin, fmt.Sprintf("stress/%d", j), conf, nil)
|
|
if err = ctx.WaitStarted(); err != nil {
|
|
return
|
|
}
|
|
if p.pullers.AddUnique(ctx) {
|
|
ctx.OnDispose(func() {
|
|
p.pullers.Remove(ctx)
|
|
})
|
|
} else {
|
|
ctx.Stop(task.ErrExist)
|
|
}
|
|
}
|
|
} else if count < i {
|
|
clone := slices.Clone(p.pullers.Items)
|
|
for j := i; j > count; j-- {
|
|
clone[j-1].Stop(task.ErrStopByUser)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (p *TestPlugin) push(count int, streamPath, url string, pusher m7s.PusherFactory) (err error) {
|
|
if i := p.pushers.Length; count > i {
|
|
for j := i; j < count; j++ {
|
|
pusher := pusher()
|
|
conf := config.Push{URL: fmt.Sprintf(url, j)}
|
|
defaults.SetDefaults(&conf)
|
|
ctx := pusher.GetPushJob().Init(pusher, &p.Plugin, streamPath, conf, nil)
|
|
if err = ctx.WaitStarted(); err != nil {
|
|
return
|
|
}
|
|
if p.pushers.AddUnique(ctx) {
|
|
ctx.OnDispose(func() {
|
|
p.pushers.Remove(ctx)
|
|
})
|
|
} else {
|
|
ctx.Stop(task.ErrExist)
|
|
}
|
|
}
|
|
} else if count < i {
|
|
clone := slices.Clone(p.pushers.Items)
|
|
for j := i; j > count; j-- {
|
|
clone[j-1].Stop(task.ErrStopByUser)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (p *TestPlugin) StartPush(ctx context.Context, req *testpb.PushRequest) (res *pb.SuccessResponse, err error) {
|
|
var pusher m7s.PusherFactory
|
|
if req.Protocol == "" {
|
|
if strings.HasPrefix(req.RemoteURL, "http") {
|
|
req.Protocol = "webrtc"
|
|
} else if strings.HasPrefix(req.RemoteURL, "srt") {
|
|
req.Protocol = "srt"
|
|
} else if strings.HasPrefix(req.RemoteURL, "rtsp") {
|
|
req.Protocol = "rtsp"
|
|
} else if strings.HasPrefix(req.RemoteURL, "rtmp") {
|
|
req.Protocol = "rtmp"
|
|
} else {
|
|
return nil, fmt.Errorf("unsupport protocol %s", req.RemoteURL)
|
|
}
|
|
}
|
|
switch req.Protocol {
|
|
case "rtmp":
|
|
pusher = rtmp.NewPusher
|
|
case "rtsp":
|
|
pusher = rtsp.NewPusher
|
|
case "srt":
|
|
pusher = srt.NewPusher
|
|
case "webrtc":
|
|
pusher = webrtc.NewPusher
|
|
default:
|
|
return nil, fmt.Errorf("unsupport protocol %s", req.Protocol)
|
|
}
|
|
return &pb.SuccessResponse{}, p.push(int(req.PushCount), req.StreamPath, req.RemoteURL, pusher)
|
|
}
|
|
|
|
func (p *TestPlugin) StartPull(ctx context.Context, req *testpb.PullRequest) (res *pb.SuccessResponse, err error) {
|
|
var puller m7s.PullerFactory
|
|
if req.Protocol == "" {
|
|
u, err := url.Parse(req.RemoteURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse remote url failed: %w", err)
|
|
}
|
|
if strings.HasSuffix(u.Path, ".m3u8") {
|
|
req.Protocol = "hls"
|
|
} else if strings.HasSuffix(u.Path, ".flv") {
|
|
req.Protocol = "flv"
|
|
} else if strings.HasSuffix(u.Path, ".mp4") {
|
|
req.Protocol = "mp4"
|
|
} else if strings.HasPrefix(req.RemoteURL, "srt") {
|
|
req.Protocol = "srt"
|
|
} else if strings.HasPrefix(req.RemoteURL, "rtsp") {
|
|
req.Protocol = "rtsp"
|
|
} else if strings.HasPrefix(req.RemoteURL, "rtmp") {
|
|
req.Protocol = "rtmp"
|
|
} else {
|
|
req.Protocol = "webrtc"
|
|
}
|
|
}
|
|
switch req.Protocol {
|
|
case "rtmp":
|
|
puller = rtmp.NewPuller
|
|
case "rtsp":
|
|
puller = rtsp.NewPuller
|
|
case "srt":
|
|
puller = srt.NewPuller
|
|
case "flv":
|
|
puller = flv.NewPuller
|
|
case "mp4":
|
|
puller = mp4.NewPuller
|
|
case "webrtc":
|
|
puller = webrtc.NewPuller
|
|
case "hls":
|
|
puller = hls.NewPuller
|
|
default:
|
|
return nil, fmt.Errorf("unsupport protocol %s", req.Protocol)
|
|
}
|
|
return &pb.SuccessResponse{}, p.pull(int(req.PullCount), req.RemoteURL, req.TestMode, puller)
|
|
}
|
|
|
|
func (p *TestPlugin) StopPush(ctx context.Context, req *emptypb.Empty) (res *pb.SuccessResponse, err error) {
|
|
for _, pusher := range slices.Clone(p.pushers.Items) {
|
|
pusher.Stop(task.ErrStopByUser)
|
|
}
|
|
return &pb.SuccessResponse{}, nil
|
|
}
|
|
|
|
func (p *TestPlugin) StopPull(ctx context.Context, req *emptypb.Empty) (res *pb.SuccessResponse, err error) {
|
|
for _, puller := range slices.Clone(p.pullers.Items) {
|
|
puller.Stop(task.ErrStopByUser)
|
|
}
|
|
return &pb.SuccessResponse{}, nil
|
|
}
|
|
|
|
func (p *TestPlugin) GetCount(ctx context.Context, req *emptypb.Empty) (res *testpb.CountResponse, err error) {
|
|
return &testpb.CountResponse{
|
|
Data: &testpb.CountResponseData{
|
|
PullCount: uint32(p.pullers.Length),
|
|
PushCount: uint32(p.pushers.Length),
|
|
},
|
|
}, nil
|
|
}
|