diff --git a/config/remote.go b/config/remote.go index e4de17e..754ab90 100644 --- a/config/remote.go +++ b/config/remote.go @@ -1,10 +1,10 @@ package config import ( + "bufio" "context" "crypto/tls" "encoding/json" - "errors" "io" "net/http" "strings" @@ -18,6 +18,10 @@ type myResponseWriter2 struct { myResponseWriter } +func (w *myResponseWriter2) Flush() { + +} + func (cfg *Engine) Remote(ctx context.Context) error { tlsConf := &tls.Config{ InsecureSkipVerify: true, @@ -65,21 +69,33 @@ func (cfg *Engine) Remote(ctx context.Context) error { func (cfg *Engine) ReceiveRequest(s quic.Stream) error { defer s.Close() wr := &myResponseWriter2{Stream: s} - reqStr, err := io.ReadAll(s) + reader := bufio.NewReader(s) var req *http.Request + url, _, err := reader.ReadLine() if err == nil { - if b, a, f := strings.Cut(string(reqStr), "\n"); f { - if len(a) > 0 { - req, err = http.NewRequest("POST", b, strings.NewReader(a)) + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + req, err = http.NewRequestWithContext(ctx, "GET", string(url), s) + for err == nil { + var h []byte + h, _, err = reader.ReadLine() + if len(h) > 0 { + b, a, f := strings.Cut(string(h), ": ") + if f { + req.Header.Set(b, a) + } } else { - req, err = http.NewRequest("GET", b, nil) + break } - if err == nil { - h, _ := cfg.mux.Handler(req) + } + if err == nil { + h, _ := cfg.mux.Handler(req) + if req.Header.Get("Accept") == "text/event-stream" { + go h.ServeHTTP(wr, req) + } else { h.ServeHTTP(wr, req) } - } else { - err = errors.New("theres no \\r") + io.ReadAll(s) } } if err != nil { diff --git a/config/types.go b/config/types.go index f5d2274..26a7405 100755 --- a/config/types.go +++ b/config/types.go @@ -79,11 +79,11 @@ func (p *Push) GetPushConfig() *Push { return p } -func (p *Push) AddPush(streamPath string, url string) { +func (p *Push) AddPush(url string, streamPath string) { if p.PushList == nil { p.PushList = make(map[string]string) } - p.PushList[streamPath] = url + p.PushList[url] = streamPath } type Console struct { @@ -106,10 +106,10 @@ type Engine struct { type myResponseWriter struct { } -func (w *myResponseWriter) Header() http.Header { +func (*myResponseWriter) Header() http.Header { return make(http.Header) } -func (w *myResponseWriter) WriteHeader(statusCode int) { +func (*myResponseWriter) WriteHeader(statusCode int) { } type myWsWriter struct { diff --git a/http.go b/http.go index 99520fd..3b6d216 100644 --- a/http.go +++ b/http.go @@ -27,7 +27,20 @@ func (conf *GlobalConfig) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func (conf *GlobalConfig) API_summary(rw http.ResponseWriter, r *http.Request) { - util.ReturnJson(summary.collect, time.Second, rw, r) + if r.Header.Get("Accept") == "text/event-stream" { + summary.Add() + defer summary.Done() + util.ReturnJson(func() *Summary { + return &summary + }, time.Second, rw, r) + } else { + if !summary.Running() { + summary.collect() + } + if err := json.NewEncoder(rw).Encode(&summary); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } + } } func (conf *GlobalConfig) API_plugins(rw http.ResponseWriter, r *http.Request) { @@ -39,9 +52,7 @@ func (conf *GlobalConfig) API_plugins(rw http.ResponseWriter, r *http.Request) { func (conf *GlobalConfig) API_stream(rw http.ResponseWriter, r *http.Request) { if streamPath := r.URL.Query().Get("streamPath"); streamPath != "" { if s := Streams.Get(streamPath); s != nil { - if err := json.NewEncoder(rw).Encode(s); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } + util.ReturnJson(func() *Stream { return s }, time.Second, rw, r) } else { http.Error(rw, NO_SUCH_STREAM, http.StatusNotFound) } @@ -145,23 +156,32 @@ func (conf *GlobalConfig) API_updateConfig(w http.ResponseWriter, r *http.Reques } func (conf *GlobalConfig) API_list_pull(w http.ResponseWriter, r *http.Request) { - result := []any{} - Pullers.Range(func(key, value any) bool { - result = append(result, key) - return true - }) - if err := json.NewEncoder(w).Encode(result); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } + util.ReturnJson(func() (result []any) { + Pullers.Range(func(key, value any) bool { + result = append(result, key) + return true + }) + return + }, time.Second, w, r) } func (conf *GlobalConfig) API_list_push(w http.ResponseWriter, r *http.Request) { - result := []any{} - Pushers.Range(func(key, value any) bool { - result = append(result, key) - return true - }) - if err := json.NewEncoder(w).Encode(result); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.ReturnJson(func() (result []any) { + Pushers.Range(func(key, value any) bool { + result = append(result, value) + return true + }) + return + }, time.Second, w, r) +} + +func (conf *GlobalConfig) API_stopPush(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + pusher, ok := Pushers.Load(q.Get("url")) + if ok { + pusher.(IPusher).Stop() + w.Write([]byte("ok")) + } else { + http.Error(w, "no such pusher", http.StatusNotFound) } } diff --git a/plugin.go b/plugin.go index 7396f9e..4189876 100644 --- a/plugin.go +++ b/plugin.go @@ -314,10 +314,10 @@ func (opt *Plugin) Push(streamPath string, url string, pusher IPusher, save bool if err = opt.Subscribe(streamPath, pusher); err != nil { return } - Pushers.Store(pusher, url) + Pushers.Store(url, pusher) go func() { defer opt.Info("push finished", zp, zu) - defer Pushers.Delete(pusher) + defer Pushers.Delete(url) for pusher.Reconnect() { opt.Info("start push", zp, zu) if err = pusher.Push(); !pusher.IsClosed() { @@ -337,7 +337,7 @@ func (opt *Plugin) Push(streamPath string, url string, pusher IPusher, save bool }() if save { - pushConfig.AddPush(streamPath, url) + pushConfig.AddPush(url, streamPath) if opt.Modified == nil { opt.Modified = make(config.Config) } diff --git a/summary.go b/summary.go index 5fcea89..182ef85 100644 --- a/summary.go +++ b/summary.go @@ -70,16 +70,20 @@ func (s *Summary) Running() bool { // Add 增加订阅者 func (s *Summary) Add() { - if atomic.AddInt32(&s.ref, 1) == 1 { + if count := atomic.AddInt32(&s.ref, 1); count == 1 { log.Info("start report summary") + } else { + log.Info("summary count", count) } } // Done 删除订阅者 func (s *Summary) Done() { - if atomic.AddInt32(&s.ref, -1) == 0 { + if count := atomic.AddInt32(&s.ref, -1); count == 0 { log.Info("stop report summary") s.lastNetWork = nil + } else { + log.Info("summary count", count) } } diff --git a/util/socket.go b/util/socket.go index 41cf5a5..cb7dee6 100644 --- a/util/socket.go +++ b/util/socket.go @@ -10,19 +10,17 @@ import ( ) func ReturnJson[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWriter, r *http.Request) { - if r.URL.Query().Get("sse") == "" { - if err := json.NewEncoder(rw).Encode(fetch()); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } - return - } - sse := NewSSE(rw, r.Context()) - tick := time.NewTicker(tickDur) - for range tick.C { - if sse.WriteJSON(fetch()) != nil { - tick.Stop() - break + if r.Header.Get("Accept") == "text/event-stream" { + sse := NewSSE(rw, r.Context()) + tick := time.NewTicker(tickDur) + defer tick.Stop() + for range tick.C { + if sse.WriteJSON(fetch()) != nil { + return + } } + } else if err := json.NewEncoder(rw).Encode(fetch()); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) } } @@ -62,6 +60,7 @@ func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) { } return conn, err } + // CORS 加入跨域策略头包含CORP func CORS(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/util/sse.go b/util/sse.go index 1c2e213..5a29fd7 100644 --- a/util/sse.go +++ b/util/sse.go @@ -3,6 +3,7 @@ package util import ( "context" "encoding/json" + "net" "net/http" "os/exec" ) @@ -22,24 +23,23 @@ func (sse *SSE) Write(data []byte) (n int, err error) { if err = sse.Err(); err != nil { return } - _, err = sse.ResponseWriter.Write(sseBegin) - n, err = sse.ResponseWriter.Write(data) - _, err = sse.ResponseWriter.Write(sseEnd) - if err != nil { - return + buffers := net.Buffers{sseBegin, data, sseEnd} + nn, err := buffers.WriteTo(sse.ResponseWriter) + if err == nil { + sse.ResponseWriter.(http.Flusher).Flush() } - sse.ResponseWriter.(http.Flusher).Flush() - return + return int(nn), err } func (sse *SSE) WriteEvent(event string, data []byte) (err error) { if err = sse.Err(); err != nil { return } - _, err = sse.ResponseWriter.Write(sseEent) - _, err = sse.ResponseWriter.Write([]byte(event)) - _, err = sse.ResponseWriter.Write([]byte("\n")) - _, err = sse.Write(data) + buffers := net.Buffers{sseEent, []byte(event + "\n"), sseBegin, data, sseEnd} + _, err = buffers.WriteTo(sse.ResponseWriter) + if err == nil { + sse.ResponseWriter.(http.Flusher).Flush() + } return } @@ -51,8 +51,8 @@ func NewSSE(w http.ResponseWriter, ctx context.Context) *SSE { header.Set("X-Accel-Buffering", "no") header.Set("Access-Control-Allow-Origin", "*") return &SSE{ - w, - ctx, + ResponseWriter: w, + Context: ctx, } }