fix: cascade plugin

This commit is contained in:
langhuihui
2024-08-17 20:38:10 +08:00
parent 263c60ad9d
commit 5cc3fdecf9
10 changed files with 129 additions and 126 deletions

View File

@@ -5,5 +5,11 @@ global:
listenaddrtls: :8555 listenaddrtls: :8555
tcp: tcp:
listenaddr: :50052 listenaddr: :50052
console: cascadeclient:
secret: de2c0bb9fd47684adc07a426e139239b server: localhost:44944
pull:
enableregexp: true
pullonsub:
.*: m7s://$0
#console:
# secret: de2c0bb9fd47684adc07a426e139239b

View File

@@ -1,16 +1,13 @@
global: global:
loglevel: debug loglevel: debug
tcp: disableall: true
listenaddr: :50051 #console:
console: # secret: 00aea3af031f134d6307618b05ec4899
secret: 00aea3af031f134d6307618b05ec4899 cascadeserver:
rtmp: enable: true
enable: false quic:
rtsp: listenaddr: :44944
enable: false #flv:
webrtc: # pull:
enable: false # pullonstart:
flv: # live/test: /Users/dexter/Movies/jb-demo.flv
pull:
pullonstart:
live/test: /Users/dexter/Movies/jb-demo.flv

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"m7s.live/m7s/v5" "m7s.live/m7s/v5"
_ "m7s.live/m7s/v5/plugin/console" _ "m7s.live/m7s/v5/plugin/cascade"
_ "m7s.live/m7s/v5/plugin/debug" _ "m7s.live/m7s/v5/plugin/debug"
_ "m7s.live/m7s/v5/plugin/flv" _ "m7s.live/m7s/v5/plugin/flv"
_ "m7s.live/m7s/v5/plugin/logrotate" _ "m7s.live/m7s/v5/plugin/logrotate"

View File

@@ -2,5 +2,5 @@ package config
type DB struct { type DB struct {
DBType string `default:"sqlite" desc:"数据库类型"` DBType string `default:"sqlite" desc:"数据库类型"`
DSN string `default:"cascade.db" desc:"数据库文件路径"` DSN string `default:"m7s.db" desc:"数据库文件路径"`
} }

View File

@@ -2,10 +2,13 @@
package db package db
import "github.com/glebarez/sqlite" import (
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
func init() { func init() {
Factory["sqlite"] = func(dsn string) gorm.Dialector { Factory["sqlite"] = func(dsn string) gorm.Dialector {
return gorm.Open(sqlite.Open(dsn), &gorm.Config{}) return sqlite.Open(dsn)
} }
} }

View File

@@ -23,13 +23,13 @@ import (
"m7s.live/m7s/v5/pkg/util" "m7s.live/m7s/v5/pkg/util"
) )
type DefaultYaml string type (
DefaultYaml string
OnExitHandler func()
AuthPublisher = func(*Publisher) *util.Promise
AuthSubscriber = func(*Subscriber) *util.Promise
type OnExitHandler func() PluginMeta struct {
type AuthPublisher = func(*Publisher) *util.Promise
type AuthSubscriber = func(*Subscriber) *util.Promise
type PluginMeta struct {
Name string Name string
Version string //插件版本 Version string //插件版本
Type reflect.Type Type reflect.Type
@@ -42,7 +42,41 @@ type PluginMeta struct {
OnExit OnExitHandler OnExit OnExitHandler
OnAuthPub AuthPublisher OnAuthPub AuthPublisher
OnAuthSub AuthSubscriber OnAuthSub AuthSubscriber
} }
iPlugin interface {
nothing()
}
IPlugin interface {
util.ITask
OnInit() error
OnStop()
Pull(path string, url string)
}
IRegisterHandler interface {
RegisterHandler() map[string]http.HandlerFunc
}
IPullerPlugin interface {
GetPullableList() []string
}
ITCPPlugin interface {
OnTCPConnect(*net.TCPConn)
}
IUDPPlugin interface {
OnUDPConnect(*net.UDPConn)
}
IQUICPlugin interface {
OnQUICConnect(quic.Connection)
}
)
var plugins []PluginMeta
func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) { func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin) {
instance, ok := reflect.New(plugin.Type).Interface().(IPlugin) instance, ok := reflect.New(plugin.Type).Interface().(IPlugin)
@@ -111,39 +145,6 @@ func (plugin *PluginMeta) Init(s *Server, userConfig map[string]any) (p *Plugin)
return return
} }
type iPlugin interface {
nothing()
}
type IPlugin interface {
util.ITask
OnInit() error
OnStop()
Pull(path string, url string)
}
type IRegisterHandler interface {
RegisterHandler() map[string]http.HandlerFunc
}
type IPullerPlugin interface {
GetPullableList() []string
}
type ITCPPlugin interface {
OnTCPConnect(*net.TCPConn)
}
type IUDPPlugin interface {
OnUDPConnect(*net.UDPConn)
}
type IQUICPlugin interface {
OnQUICConnect(quic.Connection)
}
var plugins []PluginMeta
// InstallPlugin 安装插件 // InstallPlugin 安装插件
func InstallPlugin[C iPlugin](options ...any) error { func InstallPlugin[C iPlugin](options ...any) error {
var c *C var c *C
@@ -347,11 +348,13 @@ func (p *Plugin) listen() (err error) {
quicConf := &p.config.Quic quicConf := &p.config.Quic
if quicConf.ListenAddr != "" && quicConf.AutoListen { if quicConf.ListenAddr != "" && quicConf.AutoListen {
p.Info("listen quic", "addr", quicConf.ListenAddr) p.Info("listen quic", "addr", quicConf.ListenAddr)
err = quicConf.ListenQuic(p, quicHandler.OnQUICConnect) go func() {
if err != nil { p.Stop(quicConf.ListenQuic(p, quicHandler.OnQUICConnect))
p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err) }()
return //if err != nil {
} // p.Error("listen quic", "addr", quicConf.ListenAddr, "error", err)
// return
//}
} }
} }
return return

View File

@@ -11,7 +11,7 @@ import (
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
type CascadeClientConfig struct { type CascadeClientPlugin struct {
m7s.Plugin m7s.Plugin
RelayAPI cascade.RelayAPIConfig `desc:"访问控制"` RelayAPI cascade.RelayAPIConfig `desc:"访问控制"`
AutoPush bool `desc:"自动推流到上级"` //自动推流到上级 AutoPush bool `desc:"自动推流到上级"` //自动推流到上级
@@ -20,16 +20,11 @@ type CascadeClientConfig struct {
conn quic.Connection conn quic.Connection
} }
var _ = m7s.InstallPlugin[CascadeClientConfig](m7s.DefaultYaml(` var _ = m7s.InstallPlugin[CascadeClientPlugin](cascade.NewCascadePuller)
cascadeclient:
relayapi:
allow:
- /
`), cascade.NewCascadePuller)
type ConnectServerTask struct { type ConnectServerTask struct {
util.Task util.Task
cfg *CascadeClientConfig cfg *CascadeClientPlugin
quic.Connection quic.Connection
} }
@@ -43,7 +38,10 @@ func (task *ConnectServerTask) Start() (err error) {
KeepAlivePeriod: time.Second * 10, KeepAlivePeriod: time.Second * 10,
EnableDatagrams: true, EnableDatagrams: true,
}) })
if stream := quic.Stream(nil); err == nil { if err != nil {
return
}
var stream quic.Stream
if stream, err = task.OpenStreamSync(task.cfg); err == nil { if stream, err = task.OpenStreamSync(task.cfg); err == nil {
res := []byte{0} res := []byte{0}
fmt.Fprintf(stream, "%s", task.cfg.Secret) fmt.Fprintf(stream, "%s", task.cfg.Secret)
@@ -61,7 +59,6 @@ func (task *ConnectServerTask) Start() (err error) {
return nil return nil
} }
} }
}
return return
} }
@@ -80,8 +77,8 @@ func (task *ConnectServerTask) Run() (err error) {
return return
} }
func (c *CascadeClientConfig) OnInit() (err error) { func (c *CascadeClientPlugin) OnInit() (err error) {
if c.Secret == "" || c.Server == "" { if c.Secret == "" && c.Server == "" {
return nil return nil
} }
connectTask := ConnectServerTask{ connectTask := ConnectServerTask{
@@ -92,14 +89,14 @@ func (c *CascadeClientConfig) OnInit() (err error) {
return return
} }
func (c *CascadeClientConfig) Pull(streamPath, url string) { func (c *CascadeClientPlugin) Pull(streamPath, url string) {
puller := cascade.NewCascadePuller().(*cascade.Puller) puller := &cascade.Puller{
puller.Connection = c.conn Connection: c.conn,
puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url) }
c.Plugin.Server.AddPullTask(puller) c.Plugin.Server.AddPullTask(puller.GetPullContext().Init(puller, &c.Plugin, streamPath, url))
} }
//func (c *CascadeClientConfig) Start() { //func (c *CascadeClientPlugin) Start() {
// retryDelay := [...]int{2, 3, 5, 8, 13} // retryDelay := [...]int{2, 3, 5, 8, 13}
// for i := 0; c.Err() == nil; i++ { // for i := 0; c.Err() == nil; i++ {
// connected, err := c.Remote() // connected, err := c.Remote()
@@ -117,7 +114,7 @@ func (c *CascadeClientConfig) Pull(streamPath, url string) {
// } // }
//} //}
//func (c *CascadeClientConfig) Remote() (wasConnected bool, err error) { //func (c *CascadeClientPlugin) Remote() (wasConnected bool, err error) {
// tlsConf := &tls.Config{ // tlsConf := &tls.Config{
// InsecureSkipVerify: true, // InsecureSkipVerify: true,
// NextProtos: []string{"monibuca"}, // NextProtos: []string{"monibuca"},

View File

@@ -11,19 +11,18 @@ import (
"m7s.live/m7s/v5/plugin/cascade/pkg" "m7s.live/m7s/v5/plugin/cascade/pkg"
) )
type CascadeServerConfig struct { type CascadeServerPlugin struct {
m7s.Plugin m7s.Plugin
AutoRegister bool `default:"true" desc:"下级自动注册"` AutoRegister bool `default:"true" desc:"下级自动注册"`
RelayAPI cascade.RelayAPIConfig `desc:"访问控制"` RelayAPI cascade.RelayAPIConfig `desc:"访问控制"`
} }
var _ = m7s.InstallPlugin[CascadeServerConfig]() var _ = m7s.InstallPlugin[CascadeServerPlugin]()
func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) { func (c *CascadeServerPlugin) OnQUICConnect(conn quic.Connection) {
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
c.Info("client connected:", "remoteAddr", remoteAddr) c.Info("client connected:", "remoteAddr", remoteAddr)
var stream quic.Stream stream, err := conn.AcceptStream(c)
stream, err = conn.AcceptStream(c)
if err != nil { if err != nil {
c.Error("AcceptStream", "err", err) c.Error("AcceptStream", "err", err)
return return
@@ -76,11 +75,10 @@ func (c *CascadeServerConfig) OnQUICConnect(conn quic.Connection) (err error) {
c.AddTask(&receiveRequestTask) c.AddTask(&receiveRequestTask)
} }
} }
return
} }
// API_relay_ 用于转发请求, api/relay/:instanceId/* // API_relay_ 用于转发请求, api/relay/:instanceId/*
func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request) { func (c *CascadeServerPlugin) API_relay_(w http.ResponseWriter, r *http.Request) {
paths := strings.Split(r.URL.Path, "/") paths := strings.Split(r.URL.Path, "/")
instanceId, err := strconv.ParseUint(paths[3], 10, 32) instanceId, err := strconv.ParseUint(paths[3], 10, 32)
instance, ok := cascade.SubordinateMap.Get(uint(instanceId)) instance, ok := cascade.SubordinateMap.Get(uint(instanceId))
@@ -105,6 +103,6 @@ func (c *CascadeServerConfig) API_relay_(w http.ResponseWriter, r *http.Request)
} }
// API_list 用于获取所有下级, api/list // API_list 用于获取所有下级, api/list
func (c *CascadeServerConfig) API_list(w http.ResponseWriter, r *http.Request) { func (c *CascadeServerPlugin) API_list(w http.ResponseWriter, r *http.Request) {
//util.ReturnFetchList(SubordinateMap.ToList, w, r) //util.ReturnFetchList(SubordinateMap.ToList, w, r)
} }

View File

@@ -2,6 +2,7 @@ package m7s
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
@@ -90,12 +91,10 @@ func NewServer(conf any) (s *Server) {
return return
} }
func Run(ctx context.Context, conf any) error { func Run(ctx context.Context, conf any) (err error) {
for { for err = ErrRestart; errors.Is(err, ErrRestart); err = util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped() {
if err := util.RootTask.AddTaskWithContext(ctx, NewServer(conf)).WaitStopped(); err != ErrRestart {
return err
}
} }
return
} }
func AddRootTask[T util.ITask](task T) T { func AddRootTask[T util.ITask](task T) T {
@@ -108,7 +107,7 @@ func AddRootTaskWithContext[T util.ITask](ctx context.Context, task T) T {
return task return task
} }
type rawconfig = map[string]map[string]any type RawConfig = map[string]map[string]any
func init() { func init() {
signalChan := make(chan os.Signal, 1) signalChan := make(chan os.Signal, 1)
@@ -151,7 +150,7 @@ func (s *Server) Start() (err error) {
httpConf.GetHttpMux().ServeHTTP(w, r) httpConf.GetHttpMux().ServeHTTP(w, r)
})) }))
httpConf.SetMux(mux) httpConf.SetMux(mux)
var cg rawconfig var cg RawConfig
var configYaml []byte var configYaml []byte
switch v := s.conf.(type) { switch v := s.conf.(type) {
case string: case string:
@@ -163,7 +162,7 @@ func (s *Server) Start() (err error) {
} }
case []byte: case []byte:
configYaml = v configYaml = v
case rawconfig: case RawConfig:
cg = v cg = v
} }
if configYaml != nil { if configYaml != nil {

View File

@@ -8,7 +8,7 @@ import (
) )
func TestRestart(b *testing.T) { func TestRestart(b *testing.T) {
conf := map[string]map[string]any{"global": {"loglevel": "debug"}} conf := m7s.RawConfig{"global": {"loglevel": "debug"}}
var server *m7s.Server var server *m7s.Server
go func() { go func() {
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)