diff --git a/http/middleware/cache/cache.go b/http/middleware/cache/cache.go index 25344e19..a32e35e7 100644 --- a/http/middleware/cache/cache.go +++ b/http/middleware/cache/cache.go @@ -57,31 +57,18 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { if req.Method != "GET" { res.Header().Set("X-Cache", "SKIP ONLYGET") - - if err := next(c); err != nil { - c.Error(err) - } - - return nil + return next(c) } - res.Header().Set("Cache-Control", fmt.Sprintf("max-age=%.0f", config.Cache.TTL().Seconds())) - key := strings.TrimPrefix(req.URL.Path, config.Prefix) if !config.Cache.IsExtensionCacheable(path.Ext(req.URL.Path)) { res.Header().Set("X-Cache", "SKIP EXT") - - if err := next(c); err != nil { - c.Error(err) - } - - return nil + return next(c) } if obj, expireIn, _ := config.Cache.Get(key); obj == nil { // cache miss - writer := res.Writer w := &cacheWriter{ @@ -105,6 +92,7 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { if res.Status != 200 { res.Header().Set("X-Cache", "SKIP NOTOK") + res.Writer.WriteHeader(res.Status) return nil } @@ -112,6 +100,7 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { if !config.Cache.IsSizeCacheable(size) { res.Header().Set("X-Cache", "SKIP TOOBIG") + res.Writer.WriteHeader(res.Status) return nil } @@ -123,11 +112,13 @@ func NewWithConfig(config Config) echo.MiddlewareFunc { if err := config.Cache.Put(key, o, size); err != nil { res.Header().Set("X-Cache", "SKIP TOOBIG") + res.Writer.WriteHeader(res.Status) return nil } res.Header().Set("Cache-Control", fmt.Sprintf("max-age=%.0f", expireIn.Seconds())) res.Header().Set("X-Cache", "MISS") + res.Writer.WriteHeader(res.Status) } else { // cache hit o := obj.(*cacheObject) @@ -190,7 +181,5 @@ func (w *cacheWriter) WriteHeader(code int) { } func (w *cacheWriter) Write(body []byte) (int, error) { - n, err := w.body.Write(body) - - return n, err + return w.body.Write(body) } diff --git a/http/middleware/cache/cache_test.go b/http/middleware/cache/cache_test.go new file mode 100644 index 00000000..7748a970 --- /dev/null +++ b/http/middleware/cache/cache_test.go @@ -0,0 +1,100 @@ +package cache + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/datarhei/core/v16/http/cache" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + c, err := cache.NewLRUCache(cache.LRUConfig{ + TTL: 300 * time.Second, + MaxSize: 0, + MaxFileSize: 16, + AllowExtensions: []string{".js"}, + BlockExtensions: []string{".ts"}, + Logger: nil, + }) + + require.NoError(t, err) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/found.js", nil) + rec := httptest.NewRecorder() + ctx := e.NewContext(req, rec) + + handler := NewWithConfig(Config{ + Cache: c, + })(func(c echo.Context) error { + if c.Request().URL.Path == "/found.js" { + c.Response().Write([]byte("test")) + } else if c.Request().URL.Path == "/toobig.js" { + c.Response().Write([]byte("testtesttesttesttest")) + } else if c.Request().URL.Path == "/blocked.ts" { + c.Response().Write([]byte("blocked")) + } + + c.Response().WriteHeader(http.StatusNotFound) + return nil + }) + + handler(ctx) + + require.Equal(t, "test", rec.Body.String()) + require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, "MISS", rec.Result().Header.Get("x-cache")) + + rec = httptest.NewRecorder() + ctx = e.NewContext(req, rec) + + handler(ctx) + + require.Equal(t, "test", rec.Body.String()) + require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, "HIT", rec.Result().Header.Get("x-cache")[:3]) + + req = httptest.NewRequest(http.MethodGet, "/notfound.js", nil) + rec = httptest.NewRecorder() + ctx = e.NewContext(req, rec) + + handler(ctx) + + require.Equal(t, 404, rec.Result().StatusCode) + require.Equal(t, "SKIP NOTOK", rec.Result().Header.Get("x-cache")) + + req = httptest.NewRequest(http.MethodGet, "/toobig.js", nil) + rec = httptest.NewRecorder() + ctx = e.NewContext(req, rec) + + handler(ctx) + + require.Equal(t, "testtesttesttesttest", rec.Body.String()) + require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, "SKIP TOOBIG", rec.Result().Header.Get("x-cache")) + + req = httptest.NewRequest(http.MethodGet, "/blocked.ts", nil) + rec = httptest.NewRecorder() + ctx = e.NewContext(req, rec) + + handler(ctx) + + require.Equal(t, "blocked", rec.Body.String()) + require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, "SKIP EXT", rec.Result().Header.Get("x-cache")) + + req = httptest.NewRequest(http.MethodPost, "/found.js", nil) + rec = httptest.NewRecorder() + ctx = e.NewContext(req, rec) + + handler(ctx) + + require.Equal(t, "test", rec.Body.String()) + require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, "SKIP ONLYGET", rec.Result().Header.Get("x-cache")) +}