mirror of
https://github.com/Monibuca/engine.git
synced 2025-10-28 02:21:56 +08:00
feat: 增加自定义中间件功能,增加设置http的几个超时设定
This commit is contained in:
32
README.md
32
README.md
@@ -37,6 +37,9 @@ global:
|
||||
cors: true # 是否自动添加cors头
|
||||
username: "" # 用户名和密码,用于API访问时的基本身份认证
|
||||
password: ""
|
||||
readtimeout: 0 # 读取超时时间,单位秒,0为不限制
|
||||
writetimeout: 0 # 写入超时时间,单位秒,0为不限制
|
||||
idletimeout: 0 # 空闲超时时间,单位秒,0为不限制
|
||||
publish:
|
||||
pubaudio: true # 是否发布音频流
|
||||
pubvideo: true # 是否发布视频流
|
||||
@@ -113,3 +116,32 @@ var OnAuthPub func(p *util.Promise[IPublisher]) error
|
||||
```
|
||||
** 注意:如果单独鉴权和全局鉴权同时存在,优先使用单独鉴权 **
|
||||
** 全局鉴权函数可以被多次覆盖,所以需要自己实现鉴权逻辑的合并 **
|
||||
|
||||
# Http中间件
|
||||
在HTTPConfig接口中增加了AddMiddleware方法,可以通过该方法添加中间件,中间件的定义如下
|
||||
```go
|
||||
type Middleware func(string, http.Handler) http.Handler
|
||||
type HTTPConfig interface {
|
||||
GetHTTPConfig() *HTTP
|
||||
Listen(ctx context.Context) error
|
||||
Handle(string, http.Handler)
|
||||
AddMiddleware(Middleware)
|
||||
}
|
||||
|
||||
```
|
||||
中间件的添加必须在FirstConfig之前,也就是在Listen之前
|
||||
例如:
|
||||
```go
|
||||
type MyMiddlewareConfig struct {
|
||||
config.HTTP
|
||||
}
|
||||
var myMiddlewareConfig = &MyMiddlewareConfig{}
|
||||
func init(){
|
||||
myMiddlewareConfig.AddMiddleware(func(pattern string, handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// do something
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
}
|
||||
```
|
||||
@@ -103,7 +103,11 @@ func (config Config) Unmarshal(s any) {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fv.SetInt(value.Int())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if value.CanFloat() {
|
||||
fv.SetFloat(value.Float())
|
||||
} else {
|
||||
fv.SetFloat(float64(value.Int()))
|
||||
}
|
||||
case reflect.Slice:
|
||||
var s reflect.Value
|
||||
if value.Kind() == reflect.Slice {
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
. "github.com/logrusorgru/aurora"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
|
||||
var _ HTTPConfig = (*HTTP)(nil)
|
||||
|
||||
type Middleware func(string, http.Handler) http.Handler
|
||||
type HTTP struct {
|
||||
ListenAddr string
|
||||
ListenAddrTLS string
|
||||
@@ -20,15 +22,24 @@ type HTTP struct {
|
||||
CORS bool //是否自动添加CORS头
|
||||
UserName string
|
||||
Password string
|
||||
ReadTimeout float64
|
||||
WriteTimeout float64
|
||||
IdleTimeout float64
|
||||
mux *http.ServeMux
|
||||
middlewares []Middleware
|
||||
}
|
||||
type HTTPConfig interface {
|
||||
GetHTTPConfig() *HTTP
|
||||
Listen(ctx context.Context) error
|
||||
HandleFunc(string, func(http.ResponseWriter, *http.Request))
|
||||
Handle(string, http.Handler)
|
||||
AddMiddleware(Middleware)
|
||||
}
|
||||
|
||||
func (config *HTTP) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) {
|
||||
func (config *HTTP) AddMiddleware(middleware Middleware) {
|
||||
config.middlewares = append(config.middlewares, middleware)
|
||||
}
|
||||
|
||||
func (config *HTTP) Handle(path string, f http.Handler) {
|
||||
if config.mux == nil {
|
||||
config.mux = http.NewServeMux()
|
||||
}
|
||||
@@ -38,7 +49,10 @@ func (config *HTTP) HandleFunc(path string, f func(http.ResponseWriter, *http.Re
|
||||
if config.UserName != "" && config.Password != "" {
|
||||
f = util.BasicAuth(config.UserName, config.Password, f)
|
||||
}
|
||||
config.mux.HandleFunc(path, f)
|
||||
for _, middleware := range config.middlewares {
|
||||
f = middleware(path, f)
|
||||
}
|
||||
config.mux.Handle(path, f)
|
||||
}
|
||||
|
||||
func (config *HTTP) GetHTTPConfig() *HTTP {
|
||||
@@ -54,13 +68,27 @@ func (config *HTTP) Listen(ctx context.Context) error {
|
||||
if config.ListenAddrTLS != "" && (config == &Global.HTTP || config.ListenAddrTLS != Global.ListenAddrTLS) {
|
||||
g.Go(func() error {
|
||||
log.Info("🌐 https listen at ", Blink(config.ListenAddrTLS))
|
||||
return http.ListenAndServeTLS(config.ListenAddrTLS, config.CertFile, config.KeyFile, config.mux)
|
||||
var server = http.Server{
|
||||
Addr: config.ListenAddrTLS,
|
||||
ReadTimeout: time.Duration(config.ReadTimeout) * time.Second,
|
||||
WriteTimeout: time.Duration(config.WriteTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(config.IdleTimeout) * time.Second,
|
||||
Handler: config.mux,
|
||||
}
|
||||
return server.ListenAndServeTLS(config.CertFile, config.KeyFile)
|
||||
})
|
||||
}
|
||||
if config.ListenAddr != "" && (config == &Global.HTTP || config.ListenAddr != Global.ListenAddr) {
|
||||
g.Go(func() error {
|
||||
log.Info("🌐 http listen at ", Blink(config.ListenAddr))
|
||||
return http.ListenAndServe(config.ListenAddr, config.mux)
|
||||
var server = http.Server{
|
||||
Addr: config.ListenAddr,
|
||||
ReadTimeout: time.Duration(config.ReadTimeout) * time.Second,
|
||||
WriteTimeout: time.Duration(config.WriteTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(config.IdleTimeout) * time.Second,
|
||||
Handler: config.mux,
|
||||
}
|
||||
return server.ListenAndServe()
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
|
||||
16
plugin.go
16
plugin.go
@@ -64,13 +64,13 @@ type Plugin struct {
|
||||
saveTimer *time.Timer //用于保存的时候的延迟,防抖
|
||||
}
|
||||
|
||||
func (opt *Plugin) logHandler(pattern string, handler func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, r *http.Request) {
|
||||
func (opt *Plugin) logHandler(pattern string, handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
opt.Debug("visit", zap.String("path", r.URL.String()), zap.String("remote", r.RemoteAddr))
|
||||
handler(rw, r)
|
||||
handler.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
func (opt *Plugin) handleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
func (opt *Plugin) handle(pattern string, handler http.Handler) {
|
||||
if opt == nil {
|
||||
return
|
||||
}
|
||||
@@ -80,12 +80,12 @@ func (opt *Plugin) handleFunc(pattern string, handler func(http.ResponseWriter,
|
||||
}
|
||||
if ok {
|
||||
opt.Debug("http handle added:" + pattern)
|
||||
conf.HandleFunc(pattern, opt.logHandler(pattern, handler))
|
||||
conf.Handle(pattern, opt.logHandler(pattern, handler))
|
||||
}
|
||||
if opt != Engine {
|
||||
pattern = "/" + strings.ToLower(opt.Name) + pattern
|
||||
opt.Debug("http handle added to engine:" + pattern)
|
||||
EngineConfig.HandleFunc(pattern, opt.logHandler(pattern, handler))
|
||||
EngineConfig.Handle(pattern, opt.logHandler(pattern, handler))
|
||||
}
|
||||
apiList = append(apiList, pattern)
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func (opt *Plugin) registerHandler() {
|
||||
if name != "ServeHTTP" {
|
||||
patten = strings.ToLower(strings.ReplaceAll(name, "_", "/"))
|
||||
}
|
||||
opt.handleFunc(patten, handler)
|
||||
opt.handle(patten, http.HandlerFunc(handler))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
|
||||
}
|
||||
|
||||
// CORS 加入跨域策略头包含CORP
|
||||
func CORS(next http.HandlerFunc) http.HandlerFunc {
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := w.Header()
|
||||
header.Set("Access-Control-Allow-Credentials", "true")
|
||||
@@ -80,7 +80,7 @@ func CORS(next http.HandlerFunc) http.HandlerFunc {
|
||||
})
|
||||
}
|
||||
|
||||
func BasicAuth(u, p string, next http.HandlerFunc) http.HandlerFunc {
|
||||
func BasicAuth(u, p string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract the username and password from the request
|
||||
// Authorization header. If no Authentication header is present
|
||||
|
||||
Reference in New Issue
Block a user