fix: BasicAuth for grpc-gateway

This commit is contained in:
langhuihui
2025-09-16 14:00:12 +08:00
parent 0ae3422759
commit 825328118a
14 changed files with 35 additions and 38 deletions

4
.gitignore vendored
View File

@@ -13,13 +13,15 @@ bin
*.flv *.flv
pullcf.yaml pullcf.yaml
*.zip *.zip
*.mp4
!plugin/hls/hls.js.zip !plugin/hls/hls.js.zip
__debug* __debug*
.cursorrules .cursorrules
example/default/* example/default/*
!example/default/main.go !example/default/main.go
!example/default/config.yaml !example/default/config.yaml
!example/default/test.flv
!example/default/test.mp4
shutdown.sh shutdown.sh
!example/test/test.db !example/test/test.db
*.mp4
shutdown.bat shutdown.bat

View File

@@ -1,7 +1,7 @@
global: global:
location: location:
"^/hdl/(.*)": "/flv/$1" # 兼容 v4 "^/hdl/(.*)": "/flv/$1" # 兼容 v4
"^/stress/(.*)": "/test/$1" # 5.0.x "^/stress/api/(.*)": "/test/api/stress/$1" # 5.0.x
"^/monitor/(.*)": "/debug/$1" # 5.0.x "^/monitor/(.*)": "/debug/$1" # 5.0.x
loglevel: debug loglevel: debug
admin: admin:

View File

@@ -16,6 +16,7 @@ import (
_ "m7s.live/v5/plugin/onvif" _ "m7s.live/v5/plugin/onvif"
_ "m7s.live/v5/plugin/preview" _ "m7s.live/v5/plugin/preview"
_ "m7s.live/v5/plugin/rtmp" _ "m7s.live/v5/plugin/rtmp"
_ "m7s.live/v5/plugin/rtp"
_ "m7s.live/v5/plugin/rtsp" _ "m7s.live/v5/plugin/rtsp"
_ "m7s.live/v5/plugin/sei" _ "m7s.live/v5/plugin/sei"
_ "m7s.live/v5/plugin/snap" _ "m7s.live/v5/plugin/snap"

BIN
example/default/test.flv Normal file

Binary file not shown.

BIN
example/default/test.mp4 Normal file

Binary file not shown.

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"log/slog"
"net/http" "net/http"
"m7s.live/v5/pkg/util" "m7s.live/v5/pkg/util"
@@ -10,8 +11,6 @@ import (
"time" "time"
) )
var _ HTTPConfig = (*HTTP)(nil)
type Middleware func(string, http.Handler) http.Handler type Middleware func(string, http.Handler) http.Handler
type HTTP struct { type HTTP struct {
ListenAddr string `desc:"监听地址"` ListenAddr string `desc:"监听地址"`
@@ -28,16 +27,27 @@ type HTTP struct {
grpcMux *runtime.ServeMux grpcMux *runtime.ServeMux
middlewares []Middleware middlewares []Middleware
} }
type HTTPConfig interface {
GetHTTPConfig() *HTTP func (config *HTTP) logHandler(logger *slog.Logger, handler http.Handler) http.Handler {
// Handle(string, http.Handler) return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Handler(*http.Request) (http.Handler, string) logger.Debug("visit", "path", r.URL.String(), "remote", r.RemoteAddr)
// AddMiddleware(Middleware) handler.ServeHTTP(rw, r)
})
} }
func (config *HTTP) GetHandler() http.Handler { func (config *HTTP) GetHandler(logger *slog.Logger) (h http.Handler) {
if config.grpcMux != nil { if config.grpcMux != nil {
return config.grpcMux h = config.grpcMux
if logger != nil {
h = config.logHandler(logger, h)
}
if config.CORS {
h = util.CORS(h)
}
if config.UserName != "" && config.Password != "" {
h = util.BasicAuth(config.UserName, config.Password, h)
}
return
} }
return config.mux return config.mux
} }
@@ -79,11 +89,3 @@ func (config *HTTP) Handle(path string, f http.Handler, last bool) {
} }
config.mux.Handle(path, f) config.mux.Handle(path, f)
} }
func (config *HTTP) GetHTTPConfig() *HTTP {
return config
}
// func (config *HTTP) Handler(r *http.Request) (h http.Handler, pattern string) {
// return config.mux.Handler(r)
// }

View File

@@ -35,7 +35,7 @@ func (task *ListenHTTPWork) Start() (err error) {
ReadTimeout: task.HTTP.ReadTimeout, ReadTimeout: task.HTTP.ReadTimeout,
WriteTimeout: task.HTTP.WriteTimeout, WriteTimeout: task.HTTP.WriteTimeout,
IdleTimeout: task.HTTP.IdleTimeout, IdleTimeout: task.HTTP.IdleTimeout,
Handler: task.GetHandler(), Handler: task.GetHandler(task.Logger),
} }
return return
} }
@@ -61,7 +61,7 @@ func (task *ListenHTTPSWork) Start() (err error) {
ReadTimeout: task.HTTP.ReadTimeout, ReadTimeout: task.HTTP.ReadTimeout,
WriteTimeout: task.HTTP.WriteTimeout, WriteTimeout: task.HTTP.WriteTimeout,
IdleTimeout: task.HTTP.IdleTimeout, IdleTimeout: task.HTTP.IdleTimeout,
Handler: task.HTTP.GetHandler(), Handler: task.HTTP.GetHandler(task.Logger),
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cer}, Certificates: []tls.Certificate{cer},
CipherSuites: []uint16{ CipherSuites: []uint16{

View File

@@ -220,6 +220,7 @@ func CORS(next http.Handler) http.Handler {
header.Set("Access-Control-Allow-Credentials", "true") header.Set("Access-Control-Allow-Credentials", "true")
header.Set("Cross-Origin-Resource-Policy", "cross-origin") header.Set("Cross-Origin-Resource-Policy", "cross-origin")
header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization") header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization")
header.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
header.Set("Access-Control-Allow-Private-Network", "true") header.Set("Access-Control-Allow-Private-Network", "true")
origin := r.Header["Origin"] origin := r.Header["Origin"]
if len(origin) == 0 { if len(origin) == 0 {

View File

@@ -72,7 +72,7 @@ func (task *CascadeClient) Run() (err error) {
if s, err = task.AcceptStream(task.Task.Context); err == nil { if s, err = task.AcceptStream(task.Task.Context); err == nil {
task.AddTask(&cascade.ReceiveRequestTask{ task.AddTask(&cascade.ReceiveRequestTask{
Stream: s, Stream: s,
Handler: task.cfg.GetGlobalCommonConf().GetHandler(), Handler: task.cfg.GetGlobalCommonConf().GetHandler(task.Logger),
Connection: task.Connection, Connection: task.Connection,
Plugin: &task.cfg.Plugin, Plugin: &task.cfg.Plugin,
}) })

View File

@@ -125,7 +125,7 @@ func (task *CascadeServer) Go() (err error) {
var receiveRequestTask cascade.ReceiveRequestTask var receiveRequestTask cascade.ReceiveRequestTask
receiveRequestTask.Connection = task.Connection receiveRequestTask.Connection = task.Connection
receiveRequestTask.Plugin = &task.conf.Plugin receiveRequestTask.Plugin = &task.conf.Plugin
receiveRequestTask.Handler = task.conf.GetGlobalCommonConf().GetHandler() receiveRequestTask.Handler = task.conf.GetGlobalCommonConf().GetHandler(task.Logger)
if receiveRequestTask.Stream, err = task.AcceptStream(task); err == nil { if receiveRequestTask.Stream, err = task.AcceptStream(task); err == nil {
task.AddTask(&receiveRequestTask) task.AddTask(&receiveRequestTask)
} }

View File

@@ -29,7 +29,7 @@ func (r *RTPUDPReader) Read(packet *rtp.Packet) error {
if ordered != nil { if ordered != nil {
break break
} }
var buf [MTUSize]byte var buf [ReceiveMTU]byte
var pack rtp.Packet var pack rtp.Packet
n, err := r.Reader.Read(buf[:]) n, err := r.Reader.Read(buf[:])
if err != nil { if err != nil {

View File

@@ -59,6 +59,7 @@ const (
startBit = 1 << 7 startBit = 1 << 7
endBit = 1 << 6 endBit = 1 << 6
MTUSize = 1460 MTUSize = 1460
ReceiveMTU = 1500
) )
func (r *VideoFrame) Recycle() { func (r *VideoFrame) Recycle() {

View File

@@ -137,7 +137,7 @@ func (IO *MultipleConnection) Receive() {
} }
packet := frame.Packets.GetNextPointer() packet := frame.Packets.GetNextPointer()
for { for {
buf := mem.Malloc(mrtp.MTUSize) buf := mem.Malloc(mrtp.ReceiveMTU)
if n, _, err = track.Read(buf); err == nil { if n, _, err = track.Read(buf); err == nil {
mem.FreeRest(&buf, n) mem.FreeRest(&buf, n)
err = packet.Unmarshal(buf) err = packet.Unmarshal(buf)
@@ -200,7 +200,7 @@ func (IO *MultipleConnection) Receive() {
lastPLISent = time.Now() lastPLISent = time.Now()
} }
buf := mem.Malloc(mrtp.MTUSize) buf := mem.Malloc(mrtp.ReceiveMTU)
if n, _, err = track.Read(buf); err == nil { if n, _, err = track.Read(buf); err == nil {
mem.FreeRest(&buf, n) mem.FreeRest(&buf, n)
err = packet.Unmarshal(buf) err = packet.Unmarshal(buf)
@@ -212,6 +212,7 @@ func (IO *MultipleConnection) Receive() {
mem.Free(buf) mem.Free(buf)
continue continue
} }
if packet.Timestamp == writer.VideoFrame.Packets[0].Timestamp { if packet.Timestamp == writer.VideoFrame.Packets[0].Timestamp {
writer.VideoFrame.AddRecycleBytes(buf) writer.VideoFrame.AddRecycleBytes(buf)
packet = writer.VideoFrame.Packets.GetNextPointer() packet = writer.VideoFrame.Packets.GetNextPointer()

View File

@@ -17,7 +17,6 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/shirou/gopsutil/v4/cpu" "github.com/shirou/gopsutil/v4/cpu"
"google.golang.org/protobuf/proto"
"m7s.live/v5/pkg/config" "m7s.live/v5/pkg/config"
"m7s.live/v5/pkg/task" "m7s.live/v5/pkg/task"
@@ -234,16 +233,6 @@ func (s *Server) Start() (err error) {
var httpMux http.Handler = httpConf.CreateHttpMux() var httpMux http.Handler = httpConf.CreateHttpMux()
mux := runtime.NewServeMux( mux := runtime.NewServeMux(
runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}),
runtime.WithForwardResponseOption(func(ctx context.Context, w http.ResponseWriter, m proto.Message) error {
header := w.Header()
header.Set("Access-Control-Allow-Credentials", "true")
header.Set("Cross-Origin-Resource-Policy", "cross-origin")
header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization")
header.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
header.Set("Access-Control-Allow-Private-Network", "true")
header.Set("Access-Control-Allow-Origin", "*")
return nil
}),
runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) { runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) {
httpMux.ServeHTTP(w, r) httpMux.ServeHTTP(w, r)
}), }),
@@ -658,7 +647,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Rewrite the URL path and handle locally // Rewrite the URL path and handle locally
r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target) r.URL.Path = pattern.ReplaceAllString(r.URL.Path, target)
// Forward to local handler // Forward to local handler
s.config.HTTP.GetHandler().ServeHTTP(w, r) s.config.HTTP.GetHandler(s.Logger).ServeHTTP(w, r)
return return
} }
} }