Files
monibuca/plugin/cascade/server.go
2024-08-20 12:41:20 +08:00

121 lines
3.3 KiB
Go

package plugin_cascade
import (
"bufio"
"m7s.live/m7s/v5"
"m7s.live/m7s/v5/pkg/util"
"net/http"
"strconv"
"strings"
"github.com/quic-go/quic-go"
"m7s.live/m7s/v5/plugin/cascade/pkg"
)
type CascadeServerPlugin struct {
m7s.Plugin
AutoRegister bool `default:"true" desc:"下级自动注册"`
RelayAPI cascade.RelayAPIConfig `desc:"访问控制"`
}
var _ = m7s.InstallPlugin[CascadeServerPlugin]()
type CascadeServer struct {
util.MarcoLongTask
quic.Connection
conf *CascadeServerPlugin
}
func (c *CascadeServerPlugin) OnQUICConnect(conn quic.Connection) util.ITask {
ret := &CascadeServer{
Connection: conn,
conf: c,
}
ret.Logger = c.Logger.With("remoteAddr", conn.RemoteAddr().String())
return ret
}
func (task *CascadeServer) Go() {
remoteAddr := task.Connection.RemoteAddr().String()
stream, err := task.AcceptStream(task)
if err != nil {
task.Error("AcceptStream", "err", err)
return
}
var secret string
r := bufio.NewReader(stream)
if secret, err = r.ReadString(0); err != nil {
task.Error("read secret", "err", err)
return
}
secret = secret[:len(secret)-1] // 去掉msg末尾的0
var instance cascade.Instance
child := &instance
err = task.conf.DB.AutoMigrate(child)
tx := task.conf.DB.First(child, "secret = ?", secret)
if tx.Error == nil {
cascade.SubordinateMap.Set(child)
} else if task.conf.AutoRegister {
child.Secret = secret
child.IP = remoteAddr
tx = task.conf.DB.First(child, "ip = ?", remoteAddr)
if tx.Error != nil {
task.conf.DB.Create(child)
}
cascade.SubordinateMap.Set(child)
} else {
task.Error("invalid secret:", "secret", secret)
_, err = stream.Write([]byte{1, 0})
return
}
child.IP = remoteAddr
child.Online = true
if child.Name == "" {
child.Name = remoteAddr
}
task.conf.DB.Updates(child)
child.Connection = task.Connection
_, err = stream.Write([]byte{0, 0})
err = stream.Close()
task.Info("client register:", "remoteAddr", remoteAddr)
for err == nil {
var receiveRequestTask cascade.ReceiveRequestTask
receiveRequestTask.Connection = task.Connection
receiveRequestTask.Plugin = &task.conf.Plugin
receiveRequestTask.Handler = task.conf.GetGlobalCommonConf().GetHandler()
if receiveRequestTask.Stream, err = task.AcceptStream(task); err == nil {
task.AddTask(&receiveRequestTask)
}
}
}
// API_relay_ 用于转发请求, api/relay/:instanceId/*
func (c *CascadeServerPlugin) API_relay_(w http.ResponseWriter, r *http.Request) {
paths := strings.Split(r.URL.Path, "/")
instanceId, err := strconv.ParseUint(paths[3], 10, 32)
instance, ok := cascade.SubordinateMap.Get(uint(instanceId))
if err != nil || !ok {
//util.ReturnError(util.APIErrorNotFound, "instance not found", w, r)
return
}
relayURL := "/" + strings.Join(paths[4:], "/")
r.URL.Path = relayURL
if r.URL.RawQuery != "" {
relayURL += "?" + r.URL.RawQuery
}
c.Debug("relayQuic", "relayURL", relayURL)
var relayer cascade.Http2Quic
relayer.Connection = instance.Connection
relayer.Stream, err = instance.OpenStream()
if err != nil {
//util.ReturnError(util.APIErrorInternal, err.Error(), w, r)
}
c.AddTask(&relayer)
relayer.ServeHTTP(w, r)
}
// API_list 用于获取所有下级, api/list
func (c *CascadeServerPlugin) API_list(w http.ResponseWriter, r *http.Request) {
//util.ReturnFetchList(SubordinateMap.ToList, w, r)
}