From 3b0a19e18a2f3a31a36b40fe43f6e1fdbc9739ac Mon Sep 17 00:00:00 2001 From: Ingo Oppermann Date: Thu, 8 Sep 2022 19:16:44 +0200 Subject: [PATCH] Allow to only compress responses that have a minimum length --- http/middleware/gzip/gzip.go | 108 ++++++++------ http/middleware/gzip/gzip_test.go | 240 ++++++++++++++++++++++++++++++ http/server.go | 18 +-- 3 files changed, 315 insertions(+), 51 deletions(-) create mode 100644 http/middleware/gzip/gzip_test.go diff --git a/http/middleware/gzip/gzip.go b/http/middleware/gzip/gzip.go index 22c7a8c0..659c73d8 100644 --- a/http/middleware/gzip/gzip.go +++ b/http/middleware/gzip/gzip.go @@ -2,6 +2,7 @@ package gzip import ( "bufio" + "bytes" "compress/gzip" "io" "net" @@ -25,15 +26,16 @@ type Config struct { // Length threshold before gzip compression // is used. Optional. Default value 0 MinLength int - - // Content-Types to compress. Empty for all - // files. Optional. Default value "text/plain" and "text/html" - ContentTypes []string } type gzipResponseWriter struct { io.Writer http.ResponseWriter + wroteBody bool + minLength int + minLengthExceeded bool + buffer bytes.Buffer + code int } const gzipScheme = "gzip" @@ -47,10 +49,32 @@ const ( // DefaultConfig is the default Gzip middleware config. var DefaultConfig = Config{ - Skipper: middleware.DefaultSkipper, - Level: -1, - MinLength: 0, - ContentTypes: []string{"text/plain", "text/html"}, + Skipper: middleware.DefaultSkipper, + Level: DefaultCompression, + MinLength: 0, +} + +// ContentTypesSkipper returns a Skipper based on the list of content types +// that should be compressed. If the list is empty, all responses will be +// compressed. +func ContentTypeSkipper(contentTypes []string) middleware.Skipper { + return func(c echo.Context) bool { + // If no allowed content types are given, compress all + if len(contentTypes) == 0 { + return false + } + + // Iterate through the allowed content types and don't skip if the content type matches + responseContentType := c.Response().Header().Get(echo.HeaderContentType) + + for _, contentType := range contentTypes { + if strings.Contains(responseContentType, contentType) { + return false + } + } + + return true + } } // New returns a middleware which compresses HTTP response using gzip compression @@ -75,10 +99,6 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { config.MinLength = DefaultConfig.MinLength } - if config.ContentTypes == nil { - config.ContentTypes = DefaultConfig.ContentTypes - } - pool := gzipPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -89,8 +109,8 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) - if shouldCompress(c, config.ContentTypes) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + + if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { @@ -98,8 +118,10 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { } rw := res.Writer w.Reset(rw) + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength} + defer func() { - if res.Size == 0 { + if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } @@ -108,49 +130,33 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { // See issue #424, #407. res.Writer = rw w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + res.Writer = rw + grw.ResponseWriter.WriteHeader(grw.code) + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } w.Close() pool.Put(w) }() - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + res.Writer = grw } + return next(c) } } } -func shouldCompress(c echo.Context, contentTypes []string) bool { - if !strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) || - strings.Contains(c.Request().Header.Get("Connection"), "Upgrade") || - strings.Contains(c.Request().Header.Get(echo.HeaderContentType), "text/event-stream") { - - return false - } - - // If no allowed content types are given, compress all - if len(contentTypes) == 0 { - return true - } - - // Iterate through the allowed content types and return true if the content type matches - responseContentType := c.Response().Header().Get(echo.HeaderContentType) - - for _, contentType := range contentTypes { - if strings.Contains(responseContentType, contentType) { - return true - } - } - - return false -} - func (w *gzipResponseWriter) WriteHeader(code int) { if code == http.StatusNoContent { // Issue #489 w.ResponseWriter.Header().Del(echo.HeaderContentEncoding) } w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { @@ -158,6 +164,24 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } + w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + w.ResponseWriter.WriteHeader(w.code) + + return w.Writer.Write(w.buffer.Bytes()) + } else { + return n, err + } + } + return w.Writer.Write(b) } diff --git a/http/middleware/gzip/gzip_test.go b/http/middleware/gzip/gzip_test.go new file mode 100644 index 00000000..a0ebc539 --- /dev/null +++ b/http/middleware/gzip/gzip_test.go @@ -0,0 +1,240 @@ +package gzip + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestGzip(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Accept-Encoding header + h := New()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + + assert.Equal("test", rec.Body.String()) + + // Gzip + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("test", buf.String()) + } + + chunkBuf := make([]byte, 5) + + // Gzip chunked + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec = httptest.NewRecorder() + + c = e.NewContext(req, rec) + New()(func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(rec.Flushed) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r.Reset(rec.Body) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + })(c) + + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("test", buf.String()) +} + +func TestGzipWithMinLength(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(NewWithConfig(Config{MinLength: 5})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + e.GET("/foobar", func(c echo.Context) error { + c.Response().Write([]byte("foobar")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Body.String(), "test") + + req = httptest.NewRequest(http.MethodGet, "/foobar", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "foobar", buf.String()) + } +} + +func TestGzipNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := New()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestGzipEmpty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := New()(func(c echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + +func TestGzipErrorReturned(t *testing.T) { + e := echo.New() + e.Use(New()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(NewWithConfig(Config{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "gzip") +} + +// Issue #806 +func TestGzipWithStatic(t *testing.T) { + e := echo.New() + e.Use(New()) + e.Static("/test", "./") + req := httptest.NewRequest(http.MethodGet, "/test/gzip.go", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + // Data is written out in chunks when Content-Length == "", so only + // validate the content length if it's not set. + if cl := rec.Header().Get("Content-Length"); cl != "" { + assert.Equal(t, cl, rec.Body.Len()) + } + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + defer r.Close() + want, err := os.ReadFile("./gzip.go") + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + buf.ReadFrom(r) + assert.Equal(t, want, buf.Bytes()) + } + } +} + +func BenchmarkGzip(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + + h := New()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Gzip + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} diff --git a/http/server.go b/http/server.go index 00ce8896..0c84a404 100644 --- a/http/server.go +++ b/http/server.go @@ -409,9 +409,9 @@ func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *server) setRoutes() { gzipMiddleware := mwgzip.NewWithConfig(mwgzip.Config{ - Level: mwgzip.BestSpeed, - MinLength: 1000, - ContentTypes: []string{""}, + Level: mwgzip.BestSpeed, + MinLength: 1000, + Skipper: mwgzip.ContentTypeSkipper(nil), }) // API router grouo @@ -444,9 +444,9 @@ func (s *server) setRoutes() { DefaultContentType: "text/html", })) fs.Use(mwgzip.NewWithConfig(mwgzip.Config{ - Level: mwgzip.BestSpeed, - MinLength: 1000, - ContentTypes: s.gzip.mimetypes, + Level: mwgzip.BestSpeed, + MinLength: 1000, + Skipper: mwgzip.ContentTypeSkipper(s.gzip.mimetypes), })) if s.middleware.cache != nil { fs.Use(s.middleware.cache) @@ -467,9 +467,9 @@ func (s *server) setRoutes() { DefaultContentType: "application/data", })) memfs.Use(mwgzip.NewWithConfig(mwgzip.Config{ - Level: mwgzip.BestSpeed, - MinLength: 1000, - ContentTypes: s.gzip.mimetypes, + Level: mwgzip.BestSpeed, + MinLength: 1000, + Skipper: mwgzip.ContentTypeSkipper(s.gzip.mimetypes), })) if s.middleware.session != nil { memfs.Use(s.middleware.session)