feat: 增加自定义中间件功能,增加设置http的几个超时设定

This commit is contained in:
dexter
2022-12-19 13:48:38 +08:00
parent 1e1a86cd4b
commit 75b150a3cc
5 changed files with 81 additions and 17 deletions

View File

@@ -37,6 +37,9 @@ global:
cors: true # 是否自动添加cors头 cors: true # 是否自动添加cors头
username: "" # 用户名和密码用于API访问时的基本身份认证 username: "" # 用户名和密码用于API访问时的基本身份认证
password: "" password: ""
readtimeout: 0 # 读取超时时间单位秒0为不限制
writetimeout: 0 # 写入超时时间单位秒0为不限制
idletimeout: 0 # 空闲超时时间单位秒0为不限制
publish: publish:
pubaudio: true # 是否发布音频流 pubaudio: true # 是否发布音频流
pubvideo: true # 是否发布视频流 pubvideo: true # 是否发布视频流
@@ -112,4 +115,33 @@ var OnAuthSub func(p *util.Promise[ISubscriber]) error
var OnAuthPub func(p *util.Promise[IPublisher]) error var OnAuthPub func(p *util.Promise[IPublisher]) error
``` ```
** 注意:如果单独鉴权和全局鉴权同时存在,优先使用单独鉴权 ** ** 注意:如果单独鉴权和全局鉴权同时存在,优先使用单独鉴权 **
** 全局鉴权函数可以被多次覆盖,所以需要自己实现鉴权逻辑的合并 ** ** 全局鉴权函数可以被多次覆盖,所以需要自己实现鉴权逻辑的合并 **
# Http中间件
在HTTPConfig接口中增加了AddMiddleware方法可以通过该方法添加中间件中间件的定义如下
```go
type Middleware func(string, http.Handler) http.Handler
type HTTPConfig interface {
GetHTTPConfig() *HTTP
Listen(ctx context.Context) error
Handle(string, http.Handler)
AddMiddleware(Middleware)
}
```
中间件的添加必须在FirstConfig之前也就是在Listen之前
例如:
```go
type MyMiddlewareConfig struct {
config.HTTP
}
var myMiddlewareConfig = &MyMiddlewareConfig{}
func init(){
myMiddlewareConfig.AddMiddleware(func(pattern string, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do something
handler.ServeHTTP(w, r)
})
})
}
```

View File

@@ -103,7 +103,11 @@ func (config Config) Unmarshal(s any) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fv.SetInt(value.Int()) fv.SetInt(value.Int())
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
fv.SetFloat(value.Float()) if value.CanFloat() {
fv.SetFloat(value.Float())
} else {
fv.SetFloat(float64(value.Int()))
}
case reflect.Slice: case reflect.Slice:
var s reflect.Value var s reflect.Value
if value.Kind() == reflect.Slice { if value.Kind() == reflect.Slice {

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"context" "context"
"net/http" "net/http"
"time"
. "github.com/logrusorgru/aurora" . "github.com/logrusorgru/aurora"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@@ -12,6 +13,7 @@ import (
var _ HTTPConfig = (*HTTP)(nil) var _ HTTPConfig = (*HTTP)(nil)
type Middleware func(string, http.Handler) http.Handler
type HTTP struct { type HTTP struct {
ListenAddr string ListenAddr string
ListenAddrTLS string ListenAddrTLS string
@@ -20,15 +22,24 @@ type HTTP struct {
CORS bool //是否自动添加CORS头 CORS bool //是否自动添加CORS头
UserName string UserName string
Password string Password string
ReadTimeout float64
WriteTimeout float64
IdleTimeout float64
mux *http.ServeMux mux *http.ServeMux
middlewares []Middleware
} }
type HTTPConfig interface { type HTTPConfig interface {
GetHTTPConfig() *HTTP GetHTTPConfig() *HTTP
Listen(ctx context.Context) error Listen(ctx context.Context) error
HandleFunc(string, func(http.ResponseWriter, *http.Request)) Handle(string, http.Handler)
AddMiddleware(Middleware)
} }
func (config *HTTP) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) { func (config *HTTP) AddMiddleware(middleware Middleware) {
config.middlewares = append(config.middlewares, middleware)
}
func (config *HTTP) Handle(path string, f http.Handler) {
if config.mux == nil { if config.mux == nil {
config.mux = http.NewServeMux() config.mux = http.NewServeMux()
} }
@@ -38,7 +49,10 @@ func (config *HTTP) HandleFunc(path string, f func(http.ResponseWriter, *http.Re
if config.UserName != "" && config.Password != "" { if config.UserName != "" && config.Password != "" {
f = util.BasicAuth(config.UserName, config.Password, f) f = util.BasicAuth(config.UserName, config.Password, f)
} }
config.mux.HandleFunc(path, f) for _, middleware := range config.middlewares {
f = middleware(path, f)
}
config.mux.Handle(path, f)
} }
func (config *HTTP) GetHTTPConfig() *HTTP { func (config *HTTP) GetHTTPConfig() *HTTP {
@@ -54,13 +68,27 @@ func (config *HTTP) Listen(ctx context.Context) error {
if config.ListenAddrTLS != "" && (config == &Global.HTTP || config.ListenAddrTLS != Global.ListenAddrTLS) { if config.ListenAddrTLS != "" && (config == &Global.HTTP || config.ListenAddrTLS != Global.ListenAddrTLS) {
g.Go(func() error { g.Go(func() error {
log.Info("🌐 https listen at ", Blink(config.ListenAddrTLS)) log.Info("🌐 https listen at ", Blink(config.ListenAddrTLS))
return http.ListenAndServeTLS(config.ListenAddrTLS, config.CertFile, config.KeyFile, config.mux) var server = http.Server{
Addr: config.ListenAddrTLS,
ReadTimeout: time.Duration(config.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(config.WriteTimeout) * time.Second,
IdleTimeout: time.Duration(config.IdleTimeout) * time.Second,
Handler: config.mux,
}
return server.ListenAndServeTLS(config.CertFile, config.KeyFile)
}) })
} }
if config.ListenAddr != "" && (config == &Global.HTTP || config.ListenAddr != Global.ListenAddr) { if config.ListenAddr != "" && (config == &Global.HTTP || config.ListenAddr != Global.ListenAddr) {
g.Go(func() error { g.Go(func() error {
log.Info("🌐 http listen at ", Blink(config.ListenAddr)) log.Info("🌐 http listen at ", Blink(config.ListenAddr))
return http.ListenAndServe(config.ListenAddr, config.mux) var server = http.Server{
Addr: config.ListenAddr,
ReadTimeout: time.Duration(config.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(config.WriteTimeout) * time.Second,
IdleTimeout: time.Duration(config.IdleTimeout) * time.Second,
Handler: config.mux,
}
return server.ListenAndServe()
}) })
} }
g.Go(func() error { g.Go(func() error {

View File

@@ -64,13 +64,13 @@ type Plugin struct {
saveTimer *time.Timer //用于保存的时候的延迟,防抖 saveTimer *time.Timer //用于保存的时候的延迟,防抖
} }
func (opt *Plugin) logHandler(pattern string, handler func(http.ResponseWriter, *http.Request)) http.HandlerFunc { func (opt *Plugin) logHandler(pattern string, handler http.Handler) http.Handler {
return func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
opt.Debug("visit", zap.String("path", r.URL.String()), zap.String("remote", r.RemoteAddr)) opt.Debug("visit", zap.String("path", r.URL.String()), zap.String("remote", r.RemoteAddr))
handler(rw, r) handler.ServeHTTP(rw, r)
} })
} }
func (opt *Plugin) handleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { func (opt *Plugin) handle(pattern string, handler http.Handler) {
if opt == nil { if opt == nil {
return return
} }
@@ -80,12 +80,12 @@ func (opt *Plugin) handleFunc(pattern string, handler func(http.ResponseWriter,
} }
if ok { if ok {
opt.Debug("http handle added:" + pattern) opt.Debug("http handle added:" + pattern)
conf.HandleFunc(pattern, opt.logHandler(pattern, handler)) conf.Handle(pattern, opt.logHandler(pattern, handler))
} }
if opt != Engine { if opt != Engine {
pattern = "/" + strings.ToLower(opt.Name) + pattern pattern = "/" + strings.ToLower(opt.Name) + pattern
opt.Debug("http handle added to engine:" + pattern) opt.Debug("http handle added to engine:" + pattern)
EngineConfig.HandleFunc(pattern, opt.logHandler(pattern, handler)) EngineConfig.Handle(pattern, opt.logHandler(pattern, handler))
} }
apiList = append(apiList, pattern) apiList = append(apiList, pattern)
} }
@@ -174,7 +174,7 @@ func (opt *Plugin) registerHandler() {
if name != "ServeHTTP" { if name != "ServeHTTP" {
patten = strings.ToLower(strings.ReplaceAll(name, "_", "/")) patten = strings.ToLower(strings.ReplaceAll(name, "_", "/"))
} }
opt.handleFunc(patten, handler) opt.handle(patten, http.HandlerFunc(handler))
} }
} }
} }

View File

@@ -62,7 +62,7 @@ func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
} }
// CORS 加入跨域策略头包含CORP // CORS 加入跨域策略头包含CORP
func CORS(next http.HandlerFunc) http.HandlerFunc { func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := w.Header() header := w.Header()
header.Set("Access-Control-Allow-Credentials", "true") header.Set("Access-Control-Allow-Credentials", "true")
@@ -80,7 +80,7 @@ func CORS(next http.HandlerFunc) http.HandlerFunc {
}) })
} }
func BasicAuth(u, p string, next http.HandlerFunc) http.HandlerFunc { func BasicAuth(u, p string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract the username and password from the request // Extract the username and password from the request
// Authorization header. If no Authentication header is present // Authorization header. If no Authentication header is present