fix: cascade plugin

This commit is contained in:
langhuihui
2024-08-17 20:38:10 +08:00
parent 263c60ad9d
commit 5cc3fdecf9
10 changed files with 129 additions and 126 deletions

View File

@@ -5,5 +5,11 @@ global:
listenaddrtls: :8555
tcp:
listenaddr: :50052
console:
secret: de2c0bb9fd47684adc07a426e139239b
cascadeclient:
server: localhost:44944
pull:
enableregexp: true
pullonsub:
.*: m7s://$0
#console:
# secret: de2c0bb9fd47684adc07a426e139239b

View File

@@ -1,16 +1,13 @@
global:
loglevel: debug
tcp:
listenaddr: :50051
console:
secret: 00aea3af031f134d6307618b05ec4899
rtmp:
enable: false
rtsp:
enable: false
webrtc:
enable: false
flv:
pull:
pullonstart:
live/test: /Users/dexter/Movies/jb-demo.flv
disableall: true
#console:
# secret: 00aea3af031f134d6307618b05ec4899
cascadeserver:
enable: true
quic:
listenaddr: :44944
#flv:
# pull:
# pullonstart:
# live/test: /Users/dexter/Movies/jb-demo.flv

View File

@@ -3,7 +3,7 @@ package main
import (
"context"
"m7s.live/m7s/v5"
_ "m7s.live/m7s/v5/plugin/console"
_ "m7s.live/m7s/v5/plugin/cascade"
_ "m7s.live/m7s/v5/plugin/debug"
_ "m7s.live/m7s/v5/plugin/flv"
_ "m7s.live/m7s/v5/plugin/logrotate"

View File

@@ -2,5 +2,5 @@ package config
type DB struct {
DBType string `default:"sqlite" desc:"数据库类型"`
DSN string `default:"cascade.db" desc:"数据库文件路径"`
DSN string `default:"m7s.db" desc:"数据库文件路径"`
}

View File

@@ -2,10 +2,13 @@
package db
import "github.com/glebarez/sqlite"
import (
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
func init() {
Factory["sqlite"] = func(dsn string) gorm.Dialector {
return gorm.Open(sqlite.Open(dsn), &gorm.Config{})
return sqlite.Open(dsn)
}
}

View File

@@ -23,13 +23,13 @@ import (
"m7s.live/m7s/v5/pkg/util"
)
type DefaultYaml string
type (
DefaultYaml string
OnExitHandler func()
AuthPublisher = func(*Publisher) *util.Promise
AuthSubscriber = func(*Subscriber) *util.Promise
type OnExitHandler func()
type AuthPublisher = func(*Publisher) *util.Promise
type AuthSubscriber = func(*Subscriber) *util.Promise
type PluginMeta struct {
PluginMeta struct {
Name string
Version string //插件版本
Type reflect.Type
@@ -42,7 +42,41 @@ type PluginMeta struct {
OnExit OnExitHandler
OnAuthPub AuthPublisher
OnAuthSub AuthSubscriber
}
}
iPlugin interface {
nothing()
}
IPlugin interface {
util.ITask
OnInit() error
OnStop()
Pull(path string, url string)
}
IRegisterHandler interface {
RegisterHandler() map[string]http.HandlerFunc
}
IPullerPlugin interface {
GetPullableList() []string
}
ITCPPlugin interface {
OnTCPConnect(*net.TCPConn)
}
IUDPPlugin interface {
OnUDPConnect(*net.UDPConn)
}
IQUICPlugin interface {
OnQUICConnect(quic.Connection)
}
)
var plugins []PluginMeta
func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) {
instance, ok := reflect.New(plugin.Type).Interface().(IPlugin)
@@ -111,39 +145,6 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
return
}
type iPlugin interface {
nothing()
}
type IPlugin interface {
util.ITask
OnInit() error
OnStop()
Pull(path string, url string)
}
type IRegisterHandler interface {
RegisterHandler() map[string]http.HandlerFunc
}
type IPullerPlugin interface {
GetPullableList() []string
}
type ITCPPlugin interface {
OnTCPConnect(*net.TCPConn)
}
type IUDPPlugin interface {
OnUDPConnect(*net.UDPConn)
}
type IQUICPlugin interface {
OnQUICConnect(quic.Connection)
}
var plugins []PluginMeta
// InstallPlugin 安装插件
func InstallPlugin[C iPlugin](options ...any) error {
var c *C
@@ -347,11 +348,13 @@ func (p *Plugin) listen() (err error) {
quicConf := &p.config.Quic
if quicConf.ListenAddr != "" && quicConf.AutoListen {
p.Info("listen quic", "addr", quicConf.ListenAddr)
err = quicConf.ListenQuic(p, quicHandler.OnQUICConnect)
if err != nil {
p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err)
return
}
go func() {
p.Stop(quicConf.ListenQuic(p, quicHandler.OnQUICConnect))
}()
//if err != nil {
// p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err)
// return
//}
}
}
return

View File

@@ -11,7 +11,7 @@ import (
"github.com/quic-go/quic-go"
)
type CascadeClientConfig struct {
type CascadeClientPlugin struct {
m7s.Plugin
RelayAPI cascade.RelayAPIConfig `desc:"访问控制"`
AutoPush bool `desc:"自动推流到上级"` //自动推流到上级
@@ -20,16 +20,11 @@ type CascadeClientConfig struct {
conn quic.Connection
}
var _ = m7s.InstallPlugin[CascadeClientConfig](m7s.DefaultYaml(`
cascadeclient:
relayapi:
allow:
- /
`), cascade.NewCascadePuller)
var _ = m7s.InstallPlugin[CascadeClientPlugin](cascade.NewCascadePuller)
type ConnectServerTask struct {
util.Task
cfg *CascadeClientConfig
cfg *CascadeClientPlugin
quic.Connection
}
@@ -43,7 +38,10 @@ func (task *ConnectServerTask) Start() (err error) {
KeepAlivePeriod: time.Second * 10,
EnableDatagrams: true,
})
if stream := quic.Stream(nil); err == nil {
if err != nil {
return
}
var stream quic.Stream
if stream, err = task.OpenStreamSync(task.cfg); err == nil {
res := []byte{0}
fmt.Fprintf(stream, "%s", task.cfg.Secret)
@@ -61,7 +59,6 @@ func (task *ConnectServerTask) Start() (err error) {
return nil
}
}
}
return
}
@@ -80,8 +77,8 @@ func (task *ConnectServerTask) Run() (err error) {
return
}
func (c *CascadeClientConfig) OnInit() (err error) {
if c.Secret == "" || c.Server == "" {
func (c *CascadeClientPlugin) OnInit() (err error) {
if c.Secret == "" && c.Server == "" {
return nil
}
connectTask := ConnectServerTask{
@@ -92,14 +89,14 @@ func (c *CascadeClientConfig) OnInit() (err error) {
return
}
func (c *CascadeClientConfig) Pull(streamPath, url string) {
puller := cascade.NewCascadePuller().(*cascade.Puller)
puller.Connection = c.conn
puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url)
c.Plugin.Server.AddPullTask(puller)
func (c *CascadeClientPlugin) Pull(streamPath, url string) {
puller := &cascade.Puller{
Connection: c.conn,
}
c.Plugin.Server.AddPullTask(puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url))
}
//func (c *CascadeClientConfig) Start() {
//func (c *CascadeClientPlugin) Start() {
// retryDelay := [...]int{2, 3, 5, 8, 13}
// for i := 0; c.Err() == nil; i++ {
// connected, err := c.Remote()
@@ -117,7 +114,7 @@ func (c *CascadeClientConfig) Pull(streamPath, url string) {
// }
//}
//func (c *CascadeClientConfig) Remote() (wasConnected bool, err error) {
//func (c *CascadeClientPlugin) Remote() (wasConnected bool, err error) {
// tlsConf := &tls.Config{
// InsecureSkipVerify: true,
// NextProtos: []string{"monibuca"},

View File

@@ -11,19 +11,18 @@ import (
"m7s.live/m7s/v5/plugin/cascade/pkg"
)
type CascadeServerConfig struct {
type CascadeServerPlugin struct {
m7s.Plugin
AutoRegister bool `default:"true" desc:"下级自动注册"`
RelayAPI cascade.RelayAPIConfig `desc:"访问控制"`
}
var _ = m7s.InstallPlugin[CascadeServerConfig]()
var _ = m7s.InstallPlugin[CascadeServerPlugin]()
func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) {
func (c *CascadeServerPlugin) OnQUICConnect(conn quic.Connection) {
remoteAddr := conn.RemoteAddr().String()
c.Info("client connected:", "remoteAddr", remoteAddr)
var stream quic.Stream
stream, err = conn.AcceptStream(c)
stream, err := conn.AcceptStream(c)
if err != nil {
c.Error("AcceptStream", "err", err)
return
@@ -76,11 +75,10 @@ func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) {
c.AddTask(&receiveRequestTask)
}
}
return
}
// API_relay_ 用于转发请求, api/relay/:instanceId/*
func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request) {
func (c *CascadeServerPlugin) API_relay_(w http.ResponseWriter, r *http.Request) {
paths := strings.Split(r.URL.Path, "/")
instanceId, err := strconv.ParseUint(paths[3], 10, 32)
instance, ok := cascade.SubordinateMap.Get(uint(instanceId))
@@ -105,6 +103,6 @@ func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request)
}
// API_list 用于获取所有下级, api/list
func (c *CascadeServerConfig) API_list(w http.ResponseWriter, r *http.Request) {
func (c *CascadeServerPlugin) API_list(w http.ResponseWriter, r *http.Request) {
//util.ReturnFetchList(SubordinateMap.ToList, w, r)
}

View File

@@ -2,6 +2,7 @@ package m7s
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
@@ -90,12 +91,10 @@ func NewServer(conf any) (s *Server) {
return
}
func Run(ctx context.Context, conf any) error {
for {
if err := util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped(); err != ErrRestart {
return err
}
func Run(ctx context.Context, conf any) (err error) {
for err = ErrRestart; errors.Is(err, ErrRestart); err = util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped() {
}
return
}
func AddRootTask[T util.ITask](task T) T {
@@ -108,7 +107,7 @@ func AddRootTaskWithContext[T util.ITask](ctx context.Context, task T) T {
return task
}
type rawconfig = map[string]map[string]any
type RawConfig = map[string]map[string]any
func init() {
signalChan := make(chan os.Signal, 1)
@@ -151,7 +150,7 @@ func (s *Server) Start() (err error) {
httpConf.GetHttpMux().ServeHTTP(w, r)
}))
httpConf.SetMux(mux)
var cg rawconfig
var cg RawConfig
var configYaml []byte
switch v := s.conf.(type) {
case string:
@@ -163,7 +162,7 @@ func (s *Server) Start() (err error) {
}
case []byte:
configYaml = v
case rawconfig:
case RawConfig:
cg = v
}
if configYaml != nil {

View File

@@ -8,7 +8,7 @@ import (
)
func TestRestart(b *testing.T) {
conf := map[string]map[string]any{"global": {"loglevel": "debug"}}
conf := m7s.RawConfig{"global": {"loglevel": "debug"}}
var server *m7s.Server
go func() {
time.Sleep(time.Second * 2)