diff --git a/http/middleware/session/HLS.go b/http/middleware/session/HLS.go index fbb714b6..c76dd4d2 100644 --- a/http/middleware/session/HLS.go +++ b/http/middleware/session/HLS.go @@ -37,8 +37,9 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int // Read out the path of the .ts files and look them up in the ts-map. // Add it as ingress for the respective "sessionId". The "sessionId" is the .m3u8 file name. reader := req.Body - r := &bodyReader{ + r := &segmentReader{ reader: req.Body, + buffer: h.bufferPool.Get(), } req.Body = r @@ -46,6 +47,7 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int req.Body = reader if r.size == 0 { + h.bufferPool.Put(r.buffer) return } @@ -58,8 +60,10 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int h.hlsIngressCollector.Extra(path, data) } - h.hlsIngressCollector.Ingress(path, headerSize(req.Header)) + buffer := h.bufferPool.Get() + h.hlsIngressCollector.Ingress(path, headerSize(req.Header, buffer)) h.hlsIngressCollector.Ingress(path, r.size) + h.bufferPool.Put(buffer) segments := r.getSegments(urlpath.Dir(path)) @@ -74,6 +78,8 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int } h.lock.Unlock() } + + h.bufferPool.Put(r.buffer) }() } else if strings.HasSuffix(path, ".ts") { // Get the size of the .ts file and store it in the ts-map for later use. @@ -87,9 +93,11 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int req.Body = reader if r.size != 0 { + buffer := h.bufferPool.Get() h.lock.Lock() - h.rxsegments[path] = r.size + headerSize(req.Header) + h.rxsegments[path] = r.size + headerSize(req.Header, buffer) h.lock.Unlock() + h.bufferPool.Put(buffer) } }() } @@ -171,6 +179,7 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte // the data that we need to rewrite. rewriter = &sessionRewriter{ ResponseWriter: res.Writer, + buffer: h.bufferPool.Get(), } res.Writer = rewriter @@ -188,21 +197,29 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte if rewrite { if res.Status < 200 || res.Status >= 300 { res.Write(rewriter.buffer.Bytes()) + h.bufferPool.Put(rewriter.buffer) return nil } + buffer := h.bufferPool.Get() + // Rewrite the data befor sending it to the client - rewriter.rewriteHLS(sessionID, c.Request().URL) + rewriter.rewriteHLS(sessionID, c.Request().URL, buffer) res.Header().Set("Cache-Control", "private") - res.Write(rewriter.buffer.Bytes()) + res.Write(buffer.Bytes()) + + h.bufferPool.Put(buffer) + h.bufferPool.Put(rewriter.buffer) } if isM3U8 || isTS { if res.Status >= 200 && res.Status < 300 { // Collect how many bytes we've written in this session - h.hlsEgressCollector.Egress(sessionID, headerSize(res.Header())) + buffer := h.bufferPool.Get() + h.hlsEgressCollector.Egress(sessionID, headerSize(res.Header(), buffer)) h.hlsEgressCollector.Egress(sessionID, res.Size) + h.bufferPool.Put(buffer) if isTS { // Activate the session. If the session is already active, this is a noop @@ -214,13 +231,13 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte return nil } -type bodyReader struct { +type segmentReader struct { reader io.ReadCloser - buffer bytes.Buffer + buffer *bytes.Buffer size int64 } -func (r *bodyReader) Read(b []byte) (int, error) { +func (r *segmentReader) Read(b []byte) (int, error) { n, err := r.reader.Read(b) if n > 0 { r.buffer.Write(b[:n]) @@ -230,15 +247,15 @@ func (r *bodyReader) Read(b []byte) (int, error) { return n, err } -func (r *bodyReader) Close() error { +func (r *segmentReader) Close() error { return r.reader.Close() } -func (r *bodyReader) getSegments(dir string) []string { +func (r *segmentReader) getSegments(dir string) []string { segments := []string{} // Find all segment URLs in the .m3u8 - scanner := bufio.NewScanner(&r.buffer) + scanner := bufio.NewScanner(r.buffer) for scanner.Scan() { line := scanner.Text() @@ -280,65 +297,49 @@ func (r *bodyReader) getSegments(dir string) []string { return segments } -type bodysizeReader struct { - reader io.ReadCloser - size int64 -} - -func (r *bodysizeReader) Read(b []byte) (int, error) { - n, err := r.reader.Read(b) - r.size += int64(n) - - return n, err -} - -func (r *bodysizeReader) Close() error { - return r.reader.Close() -} - type sessionRewriter struct { http.ResponseWriter - buffer bytes.Buffer + buffer *bytes.Buffer } func (g *sessionRewriter) Write(data []byte) (int, error) { // Write the data into internal buffer for later rewrite - w, err := g.buffer.Write(data) - - return w, err + return g.buffer.Write(data) } -func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) { - var buffer bytes.Buffer - +func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL, buffer *bytes.Buffer) { isMaster := false // Find all URLS in the .m3u8 and add the session ID to the query string - scanner := bufio.NewScanner(&g.buffer) + scanner := bufio.NewScanner(g.buffer) for scanner.Scan() { - line := scanner.Text() + byteline := scanner.Bytes() // Write empty lines unmodified - if len(line) == 0 { - buffer.WriteString(line + "\n") + if len(byteline) == 0 { + buffer.Write(byteline) + buffer.WriteByte('\n') continue } // Write comments unmodified - if strings.HasPrefix(line, "#") { - buffer.WriteString(line + "\n") + if byteline[0] == '#' { + buffer.Write(byteline) + buffer.WriteByte('\n') continue } - u, err := url.Parse(line) + u, err := url.Parse(string(byteline)) if err != nil { - buffer.WriteString(line + "\n") + buffer.Write(byteline) + buffer.WriteByte('\n') continue } // Write anything that doesn't end in .m3u8 or .ts unmodified if !strings.HasSuffix(u.Path, ".m3u8") && !strings.HasSuffix(u.Path, ".ts") { - buffer.WriteString(line + "\n") + buffer.Write(byteline) + buffer.WriteByte('\n') continue } @@ -407,6 +408,4 @@ func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) { buffer.WriteString(urlpath.Base(requestURL.Path) + "?" + q.Encode()) } - - g.buffer = buffer } diff --git a/http/middleware/session/HLS_test.go b/http/middleware/session/HLS_test.go new file mode 100644 index 00000000..e5cede1a --- /dev/null +++ b/http/middleware/session/HLS_test.go @@ -0,0 +1,112 @@ +package session + +import ( + "bytes" + "io" + "net/url" + "os" + "testing" + + "github.com/datarhei/core/v16/mem" + "github.com/stretchr/testify/require" +) + +func TestHLSSegmentReader(t *testing.T) { + data, err := os.ReadFile("./fixtures/segments.txt") + require.NoError(t, err) + + r := bytes.NewReader(data) + + br := &segmentReader{ + reader: io.NopCloser(r), + buffer: &bytes.Buffer{}, + } + + _, err = io.ReadAll(br) + require.NoError(t, err) + + segments := br.getSegments("/foobar") + require.Equal(t, []string{ + "/foobar/test_0_0_0303.ts", + "/foobar/test_0_0_0304.ts", + "/foobar/test_0_0_0305.ts", + "/foobar/test_0_0_0306.ts", + "/foobar/test_0_0_0307.ts", + "/foobar/test_0_0_0308.ts", + "/foobar/test_0_0_0309.ts", + "/foobar/test_0_0_0310.ts", + }, segments) +} + +func BenchmarkHLSSegmentReader(b *testing.B) { + pool := mem.NewBufferPool() + + data, err := os.ReadFile("./fixtures/segments.txt") + require.NoError(b, err) + + rd := bytes.NewReader(data) + r := io.NopCloser(rd) + + for i := 0; i < b.N; i++ { + rd.Reset(data) + br := &segmentReader{ + reader: io.NopCloser(r), + buffer: pool.Get(), + } + + _, err := io.ReadAll(br) + require.NoError(b, err) + + pool.Put(br.buffer) + } +} + +func TestHLSRewrite(t *testing.T) { + data, err := os.ReadFile("./fixtures/segments.txt") + require.NoError(t, err) + + br := &sessionRewriter{ + buffer: &bytes.Buffer{}, + } + + _, err = br.Write(data) + require.NoError(t, err) + + u, err := url.Parse("http://example.com/test.m3u8") + require.NoError(t, err) + + buffer := &bytes.Buffer{} + + br.rewriteHLS("oT5GV8eWBbRAh4aib5egoK", u, buffer) + + data, err = os.ReadFile("./fixtures/segments_with_session.txt") + require.NoError(t, err) + + require.Equal(t, data, buffer.Bytes()) +} + +func BenchmarkHLSRewrite(b *testing.B) { + pool := mem.NewBufferPool() + + data, err := os.ReadFile("./fixtures/segments.txt") + require.NoError(b, err) + + u, err := url.Parse("http://example.com/test.m3u8") + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + br := &sessionRewriter{ + buffer: pool.Get(), + } + + _, err = br.Write(data) + require.NoError(b, err) + + buffer := pool.Get() + + br.rewriteHLS("oT5GV8eWBbRAh4aib5egoK", u, buffer) + + pool.Put(br.buffer) + pool.Put(buffer) + } +} diff --git a/http/middleware/session/HTTP.go b/http/middleware/session/HTTP.go index 615b2058..4913171f 100644 --- a/http/middleware/session/HTTP.go +++ b/http/middleware/session/HTTP.go @@ -7,7 +7,7 @@ import ( "github.com/lithammer/shortuuid/v4" ) -func (h *handler) handleHTTP(c echo.Context, ctxuser string, data map[string]interface{}, next echo.HandlerFunc) error { +func (h *handler) handleHTTP(c echo.Context, _ string, data map[string]interface{}, next echo.HandlerFunc) error { req := c.Request() res := c.Response() @@ -30,13 +30,13 @@ func (h *handler) handleHTTP(c echo.Context, ctxuser string, data map[string]int id := shortuuid.New() reader := req.Body - r := &fakeReader{ + r := &bodysizeReader{ reader: req.Body, } req.Body = r writer := res.Writer - w := &fakeWriter{ + w := &bodysizeWriter{ ResponseWriter: res.Writer, } res.Writer = w @@ -44,19 +44,21 @@ func (h *handler) handleHTTP(c echo.Context, ctxuser string, data map[string]int h.httpCollector.RegisterAndActivate(id, "", location, referrer) h.httpCollector.Extra(id, data) - defer h.httpCollector.Close(id) - defer func() { + buffer := h.bufferPool.Get() + req.Body = reader - h.httpCollector.Ingress(id, r.size+headerSize(req.Header)) - }() + h.httpCollector.Ingress(id, r.size+headerSize(req.Header, buffer)) - defer func() { res.Writer = writer - h.httpCollector.Egress(id, w.size+headerSize(res.Header())) + h.httpCollector.Egress(id, w.size+headerSize(res.Header(), buffer)) data["code"] = res.Status h.httpCollector.Extra(id, data) + + h.httpCollector.Close(id) + + h.bufferPool.Put(buffer) }() return next(c) diff --git a/http/middleware/session/fixtures/segments.txt b/http/middleware/session/fixtures/segments.txt new file mode 100644 index 00000000..a4e2348c --- /dev/null +++ b/http/middleware/session/fixtures/segments.txt @@ -0,0 +1,29 @@ +#EXTM3U +#EXT-X-VERSION:6 +#EXT-X-TARGETDURATION:2 +#EXT-X-MEDIA-SEQUENCE:303 +#EXT-X-INDEPENDENT-SEGMENTS +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:35.019+0200 +test_0_0_0303.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:37.019+0200 +test_0_0_0304.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:39.019+0200 +test_0_0_0305.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:41.019+0200 +test_0_0_0306.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:43.019+0200 +test_0_0_0307.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:45.019+0200 +test_0_0_0308.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:47.019+0200 +test_0_0_0309.ts +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:49.019+0200 +test_0_0_0310.ts diff --git a/http/middleware/session/fixtures/segments_with_session.txt b/http/middleware/session/fixtures/segments_with_session.txt new file mode 100644 index 00000000..f59ed305 --- /dev/null +++ b/http/middleware/session/fixtures/segments_with_session.txt @@ -0,0 +1,29 @@ +#EXTM3U +#EXT-X-VERSION:6 +#EXT-X-TARGETDURATION:2 +#EXT-X-MEDIA-SEQUENCE:303 +#EXT-X-INDEPENDENT-SEGMENTS +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:35.019+0200 +test_0_0_0303.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:37.019+0200 +test_0_0_0304.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:39.019+0200 +test_0_0_0305.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:41.019+0200 +test_0_0_0306.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:43.019+0200 +test_0_0_0307.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:45.019+0200 +test_0_0_0308.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:47.019+0200 +test_0_0_0309.ts?session=oT5GV8eWBbRAh4aib5egoK +#EXTINF:2.000000, +#EXT-X-PROGRAM-DATE-TIME:2024-10-09T12:56:49.019+0200 +test_0_0_0310.ts?session=oT5GV8eWBbRAh4aib5egoK diff --git a/http/middleware/session/session.go b/http/middleware/session/session.go index 684aa4c6..3f12f405 100644 --- a/http/middleware/session/session.go +++ b/http/middleware/session/session.go @@ -13,6 +13,7 @@ import ( "github.com/datarhei/core/v16/glob" "github.com/datarhei/core/v16/http/api" "github.com/datarhei/core/v16/http/handler/util" + "github.com/datarhei/core/v16/mem" "github.com/datarhei/core/v16/net" "github.com/datarhei/core/v16/session" "github.com/lithammer/shortuuid/v4" @@ -44,6 +45,8 @@ type handler struct { rxsegments map[string]int64 lock sync.Mutex + + bufferPool *mem.BufferPool } // New returns a new session middleware with default config @@ -75,6 +78,7 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { hlsIngressCollector: config.HLSIngressCollector, reSessionID: regexp.MustCompile(`^[` + regexp.QuoteMeta(shortuuid.DefaultAlphabet) + `]{22}$`), rxsegments: make(map[string]int64), + bufferPool: mem.NewBufferPool(), } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -173,43 +177,42 @@ func verifySession(raw interface{}, path, referrer string) (map[string]interface return data, nil } -func headerSize(header http.Header) int64 { - var buffer bytes.Buffer - - header.Write(&buffer) +func headerSize(header http.Header, buffer *bytes.Buffer) int64 { + buffer.Reset() + header.Write(buffer) return int64(buffer.Len()) } -type fakeReader struct { +type bodysizeReader struct { reader io.ReadCloser size int64 } -func (r *fakeReader) Read(b []byte) (int, error) { +func (r *bodysizeReader) Read(b []byte) (int, error) { n, err := r.reader.Read(b) r.size += int64(n) return n, err } -func (r *fakeReader) Close() error { +func (r *bodysizeReader) Close() error { return r.reader.Close() } -type fakeWriter struct { +type bodysizeWriter struct { http.ResponseWriter size int64 code int } -func (w *fakeWriter) WriteHeader(statusCode int) { +func (w *bodysizeWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) w.code = statusCode } -func (w *fakeWriter) Write(body []byte) (int, error) { +func (w *bodysizeWriter) Write(body []byte) (int, error) { n, err := w.ResponseWriter.Write(body) w.size += int64(n) @@ -217,7 +220,7 @@ func (w *fakeWriter) Write(body []byte) (int, error) { return n, err } -func (w *fakeWriter) Flush() { +func (w *bodysizeWriter) Flush() { flusher, ok := w.ResponseWriter.(http.Flusher) if ok { flusher.Flush() diff --git a/http/middleware/session/session_test.go b/http/middleware/session/session_test.go index 77ef3839..b8633b19 100644 --- a/http/middleware/session/session_test.go +++ b/http/middleware/session/session_test.go @@ -1,6 +1,8 @@ package session import ( + "bytes" + "net/http" "testing" "github.com/datarhei/core/v16/encoding/json" @@ -134,3 +136,29 @@ func TestVerifySessionMultipleRemote(t *testing.T) { _, err = verifySession(rawdata, "/memfs/6faad99a-c440-4df1-9344-963869718d8d/main.m3u8", "http://bar.example.com") require.Error(t, err) } + +func TestHeaderSize(t *testing.T) { + header := http.Header{} + + header.Add("Content-Type", "application/json") + header.Add("Content-Encoding", "gzip") + + buffer := &bytes.Buffer{} + size := headerSize(header, buffer) + + require.Equal(t, "Content-Encoding: gzip\r\nContent-Type: application/json\r\n", buffer.String()) + require.Equal(t, int64(56), size) +} + +func BenchmarkHeaderSize(b *testing.B) { + header := http.Header{} + + header.Add("Content-Type", "application/json") + header.Add("Content-Encoding", "gzip") + + buffer := &bytes.Buffer{} + + for i := 0; i < b.N; i++ { + headerSize(header, buffer) + } +}