mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 12:51:31 +08:00
Backend: Add security-focused tests, harden WebDAV and use safe.Download
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
16
AGENTS.md
16
AGENTS.md
@@ -22,6 +22,9 @@ Learn more: https://agents.md/
|
||||
- Whenever the Change Management instructions for a document require it, publish changes as a new file with an incremented version suffix (e.g., `*-v3.md`) rather than overwriting the original file.
|
||||
- Older spec versions remain in the repo for historical reference but are not linked from the main TOC. Do not base new work on superseded files (e.g., `*-v1.md` when `*-v2.md` exists).
|
||||
|
||||
Note on specs repository availability
|
||||
- The `specs/` repository may be private and is not guaranteed to be present in every clone or environment. Do not add Makefile targets in the main project that depend on `specs/` paths. When `specs/` is available, run its tools directly (e.g., `bash specs/scripts/lint-status.sh`).
|
||||
|
||||
## Project Structure & Languages
|
||||
|
||||
- Backend: Go (`internal/`, `pkg/`, `cmd/`) + MariaDB/SQLite
|
||||
@@ -191,6 +194,19 @@ Note: Across our public documentation, official images, and in production, the c
|
||||
- Examples assume a Linux/Unix shell. For Windows specifics, see the Developer Guide FAQ:
|
||||
https://docs.photoprism.app/developer-guide/faq/#can-your-development-environment-be-used-under-windows
|
||||
|
||||
### HTTP Download — Security Checklist
|
||||
|
||||
- Use the shared safe HTTP helper instead of ad‑hoc `net/http` code:
|
||||
- Package: `pkg/service/http/safe` → `safe.Download(destPath, url, *safe.Options)`.
|
||||
- Default policy in this repo: allow only `http/https`, enforce timeouts and max size, write to a `0600` temp file then rename.
|
||||
- SSRF protection (mandatory unless explicitly needed for tests):
|
||||
- Set `AllowPrivate=false` to block private/loopback/multicast/link‑local ranges.
|
||||
- All redirect targets are validated; the final connected peer IP is also checked.
|
||||
- Prefer an image‑focused `Accept` header for image downloads: `"image/jpeg, image/png, */*;q=0.1"`.
|
||||
- Avatars and small images: use the thin wrapper in `internal/thumb/avatar.SafeDownload` which applies stricter defaults (15s timeout, 10 MiB, `AllowPrivate=false`).
|
||||
- Tests using `httptest.Server` on 127.0.0.1 must pass `AllowPrivate=true` explicitly to succeed.
|
||||
- Keep per‑resource size budgets small; rely on `io.LimitReader` + `Content-Length` prechecks.
|
||||
|
||||
If anything in this file conflicts with the `Makefile` or the Developer Guide, the `Makefile` and the documentation win. When unsure, **ask** for clarification before proceeding.
|
||||
|
||||
## Agent Quick Tips (Do This)
|
||||
|
@@ -153,6 +153,12 @@ Security & Hot Spots (Where to Look)
|
||||
- Pipeline: `internal/thumb/vips.go` (VipsInit, VipsRotate, export params).
|
||||
- Sizes & names: `internal/thumb/sizes.go`, `internal/thumb/names.go`, `internal/thumb/filter.go`.
|
||||
|
||||
- Safe HTTP downloader:
|
||||
- Shared utility: `pkg/service/http/safe` (`Download`, `Options`).
|
||||
- Protections: scheme allow‑list (http/https), pre‑DNS + per‑redirect hostname/IP validation, final peer IP check, size and timeout enforcement, temp file `0600` + rename.
|
||||
- Avatars: wrapper `internal/thumb/avatar.SafeDownload` applies stricter defaults (15s, 10 MiB, `AllowPrivate=false`, image‑focused `Accept`).
|
||||
- Tests: `go test ./pkg/service/http/safe -count=1` (includes redirect SSRF cases); avatars: `go test ./internal/thumb/avatar -count=1`.
|
||||
|
||||
Performance & Limits
|
||||
- Prefer existing caches/workers/batching as per Makefile and code.
|
||||
- When adding list endpoints, default `count=100` (max `1000`); set `Cache-Control: no-store` for secrets.
|
||||
|
2
Makefile
2
Makefile
@@ -115,6 +115,8 @@ swag: swag-json
|
||||
swag-json:
|
||||
@echo "Generating ./internal/api/swagger.json..."
|
||||
swag init --ot json --parseDependency --parseDepth 1 --dir internal/api -g api.go -o ./internal/api
|
||||
@echo "Fixing unstable time.Duration enums in swagger.json..."
|
||||
@GO111MODULE=on go run scripts/tools/swaggerfix/main.go internal/api/swagger.json || { echo "swaggerfix failed"; exit 1; }
|
||||
swag-yaml:
|
||||
@echo "Generating ./internal/api/swagger.yaml..."
|
||||
swag init --ot yaml --parseDependency --parseDepth 1 --dir internal/api -g api.go -o ./internal/api
|
||||
|
@@ -16,7 +16,13 @@ func UpdateClientConfig() {
|
||||
|
||||
// GetClientConfig returns the client configuration values as JSON.
|
||||
//
|
||||
// GET /api/v1/config
|
||||
// @Summary get client configuration
|
||||
// @Id GetClientConfig
|
||||
// @Tags Config
|
||||
// @Produce json
|
||||
// @Success 200 {object} gin.H
|
||||
// @Failure 401 {object} i18n.Response
|
||||
// @Router /api/v1/config [get]
|
||||
func GetClientConfig(router *gin.RouterGroup) {
|
||||
router.GET("/config", func(c *gin.Context) {
|
||||
sess := Session(ClientIP(c), AuthToken(c))
|
||||
|
@@ -176,7 +176,7 @@ func BatchPhotosRestore(router *gin.RouterGroup) {
|
||||
// @Param photos body form.Selection true "Photo Selection"
|
||||
// @Router /api/v1/batch/photos/approve [post]
|
||||
func BatchPhotosApprove(router *gin.RouterGroup) {
|
||||
router.POST("batch/photos/approve", func(c *gin.Context) {
|
||||
router.POST("/batch/photos/approve", func(c *gin.Context) {
|
||||
s := Auth(c, acl.ResourcePhotos, acl.ActionUpdate)
|
||||
|
||||
if s.Abort(c) {
|
||||
|
@@ -15,7 +15,16 @@ import (
|
||||
|
||||
// Connect confirms external service accounts using a token.
|
||||
//
|
||||
// PUT /api/v1/connect/:name
|
||||
// @Summary confirm external service accounts using a token
|
||||
// @Id ConnectService
|
||||
// @Tags Config
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param name path string true "service name (e.g., hub)"
|
||||
// @Param connect body form.Connect true "connection token"
|
||||
// @Success 200 {object} gin.H
|
||||
// @Failure 400,401,403 {object} i18n.Response
|
||||
// @Router /api/v1/connect/{name} [put]
|
||||
func Connect(router *gin.RouterGroup) {
|
||||
router.PUT("/connect/:name", func(c *gin.Context) {
|
||||
name := clean.ID(c.Param("name"))
|
||||
|
13
internal/api/doc_overrides.go
Normal file
13
internal/api/doc_overrides.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package api
|
||||
|
||||
import "time"
|
||||
|
||||
// Schema Overrides for Swagger generation.
|
||||
|
||||
// Override the generated schema for time.Duration to avoid unstable enums
|
||||
// from the standard library constants (Nanosecond, Minute, etc.). Using
|
||||
// a simple integer schema is accurate (nanoseconds) and deterministic.
|
||||
//
|
||||
// @name time.Duration
|
||||
// @description Duration in nanoseconds (int64). Examples: 1000000000 (1s), 60000000000 (1m).
|
||||
type SwaggerTimeDuration = time.Duration
|
@@ -28,7 +28,13 @@ type FoldersResponse struct {
|
||||
|
||||
// SearchFoldersOriginals returns folders in originals as JSON.
|
||||
//
|
||||
// GET /api/v1/folders/originals
|
||||
// @Summary list folders in originals
|
||||
// @Id SearchFoldersOriginals
|
||||
// @Tags Folders
|
||||
// @Produce json
|
||||
// @Success 200 {object} api.FoldersResponse
|
||||
// @Failure 401,403 {object} i18n.Response
|
||||
// @Router /api/v1/folders/originals [get]
|
||||
func SearchFoldersOriginals(router *gin.RouterGroup) {
|
||||
conf := get.Config()
|
||||
SearchFolders(router, "originals", entity.RootOriginals, conf.OriginalsPath())
|
||||
@@ -36,7 +42,13 @@ func SearchFoldersOriginals(router *gin.RouterGroup) {
|
||||
|
||||
// SearchFoldersImport returns import folders as JSON.
|
||||
//
|
||||
// GET /api/v1/folders/import
|
||||
// @Summary list folders in import
|
||||
// @Id SearchFoldersImport
|
||||
// @Tags Folders
|
||||
// @Produce json
|
||||
// @Success 200 {object} api.FoldersResponse
|
||||
// @Failure 401,403 {object} i18n.Response
|
||||
// @Router /api/v1/folders/import [get]
|
||||
func SearchFoldersImport(router *gin.RouterGroup) {
|
||||
conf := get.Config()
|
||||
SearchFolders(router, "import", entity.RootImport, conf.ImportPath())
|
||||
|
@@ -87,10 +87,10 @@ func DeleteLink(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, link)
|
||||
}
|
||||
|
||||
// CreateLink adds a new share link and return it as JSON.
|
||||
//
|
||||
// @Tags Links
|
||||
// @Router /api/v1/{entity}/{uid}/links [post]
|
||||
// CreateLink adds a new share link and returns it as JSON.
|
||||
// Note: Internal helper used by resource-specific endpoints (e.g., albums, photos).
|
||||
// Swagger annotations are defined on those public handlers to avoid generating
|
||||
// undocumented generic paths like "/api/v1/{entity}/{uid}/links".
|
||||
func CreateLink(c *gin.Context) {
|
||||
s := Auth(c, acl.ResourceShares, acl.ActionCreate)
|
||||
|
||||
|
153
internal/api/oauth_token_ratelimit_test.go
Normal file
153
internal/api/oauth_token_ratelimit_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/server/limiter"
|
||||
"github.com/photoprism/photoprism/pkg/authn"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
)
|
||||
|
||||
func TestOAuthToken_RateLimit_ClientCredentials(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
OAuthToken(router)
|
||||
|
||||
// Tighten rate limits
|
||||
oldLogin, oldAuth := limiter.Login, limiter.Auth
|
||||
defer func() { limiter.Login, limiter.Auth = oldLogin, oldAuth }()
|
||||
limiter.Login = limiter.NewLimit(rate.Every(24*time.Hour), 3) // burst 3
|
||||
limiter.Auth = limiter.NewLimit(rate.Every(24*time.Hour), 3)
|
||||
|
||||
// Invalid client secret repeatedly (from UnknownIP: no headers set)
|
||||
path := "/api/v1/oauth/token"
|
||||
for i := 0; i < 3; i++ {
|
||||
data := url.Values{
|
||||
"grant_type": {authn.GrantClientCredentials.String()},
|
||||
"client_id": {"cs5cpu17n6gj2qo5"},
|
||||
"client_secret": {"INVALID"},
|
||||
"scope": {"metrics"},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, path, strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
// Next call should be rate limited
|
||||
data := url.Values{
|
||||
"grant_type": {authn.GrantClientCredentials.String()},
|
||||
"client_id": {"cs5cpu17n6gj2qo5"},
|
||||
"client_secret": {"INVALID"},
|
||||
"scope": {"metrics"},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, path, strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
req.Header.Set("X-Forwarded-For", "198.51.100.99")
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
|
||||
func TestOAuthToken_ResponseFields_ClientSuccess(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
OAuthToken(router)
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {authn.GrantClientCredentials.String()},
|
||||
"client_id": {"cs5cpu17n6gj2qo5"},
|
||||
"client_secret": {"xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e"},
|
||||
"scope": {"metrics"},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
body := w.Body.String()
|
||||
assert.NotEmpty(t, gjson.Get(body, "access_token").String())
|
||||
tokType := gjson.Get(body, "token_type").String()
|
||||
assert.True(t, strings.EqualFold(tokType, "bearer"))
|
||||
assert.GreaterOrEqual(t, gjson.Get(body, "expires_in").Int(), int64(0))
|
||||
assert.Equal(t, "metrics", gjson.Get(body, "scope").String())
|
||||
}
|
||||
|
||||
func TestOAuthToken_ResponseFields_UserSuccess(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
sess := AuthenticateUser(app, router, "alice", "Alice123!")
|
||||
OAuthToken(router)
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {authn.GrantPassword.String()},
|
||||
"client_name": {"TestApp"},
|
||||
"username": {"alice"},
|
||||
"password": {"Alice123!"},
|
||||
"scope": {"*"},
|
||||
}
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
req.Header.Set(header.XAuthToken, sess)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
body := w.Body.String()
|
||||
assert.NotEmpty(t, gjson.Get(body, "access_token").String())
|
||||
tokType := gjson.Get(body, "token_type").String()
|
||||
assert.True(t, strings.EqualFold(tokType, "bearer"))
|
||||
assert.GreaterOrEqual(t, gjson.Get(body, "expires_in").Int(), int64(0))
|
||||
assert.Equal(t, "*", gjson.Get(body, "scope").String())
|
||||
}
|
||||
|
||||
func TestOAuthToken_BadRequestsAndErrors(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
OAuthToken(router)
|
||||
|
||||
// Missing grant_type & creds -> invalid credentials
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", nil)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
// Unknown grant type
|
||||
data := url.Values{
|
||||
"grant_type": {"unknown"},
|
||||
}
|
||||
req, _ = http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
w = httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
|
||||
// Password grant with wrong password
|
||||
sess := AuthenticateUser(app, router, "alice", "Alice123!")
|
||||
data = url.Values{
|
||||
"grant_type": {authn.GrantPassword.String()},
|
||||
"client_name": {"AppPasswordAlice"},
|
||||
"username": {"alice"},
|
||||
"password": {"WrongPassword!"},
|
||||
"scope": {"*"},
|
||||
}
|
||||
req, _ = http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
|
||||
req.Header.Set(header.ContentType, header.ContentTypeForm)
|
||||
req.Header.Set(header.XAuthToken, sess)
|
||||
w = httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
44
internal/api/session_ratelimit_test.go
Normal file
44
internal/api/session_ratelimit_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/server/limiter"
|
||||
)
|
||||
|
||||
func TestCreateSession_RateLimitExceeded(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
CreateSession(router)
|
||||
|
||||
// Tighten rate limits and do repeated bad logins from UnknownIP
|
||||
oldLogin, oldAuth := limiter.Login, limiter.Auth
|
||||
defer func() { limiter.Login, limiter.Auth = oldLogin, oldAuth }()
|
||||
limiter.Login = limiter.NewLimit(rate.Every(24*time.Hour), 3)
|
||||
limiter.Auth = limiter.NewLimit(rate.Every(24*time.Hour), 3)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{"username": "admin", "password": "wrong"}`)
|
||||
assert.Equal(t, http.StatusUnauthorized, r.Code)
|
||||
}
|
||||
// Next attempt should be 429
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{"username": "admin", "password": "wrong"}`)
|
||||
assert.Equal(t, http.StatusTooManyRequests, r.Code)
|
||||
}
|
||||
|
||||
func TestCreateSession_MissingFields(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.SetAuthMode(config.AuthModePasswd)
|
||||
defer conf.SetAuthMode(config.AuthModePublic)
|
||||
CreateSession(router)
|
||||
// Empty object -> unauthorized (invalid credentials)
|
||||
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{}`)
|
||||
assert.Equal(t, http.StatusUnauthorized, r.Code)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,7 @@ import (
|
||||
// @Param files formData file true "one or more files to upload (repeat the field for multiple files)"
|
||||
// @Success 200 {object} i18n.Response
|
||||
// @Failure 400,401,403,413,429,507 {object} i18n.Response
|
||||
// @Router /users/{uid}/upload/{token} [post]
|
||||
// @Router /api/v1/users/{uid}/upload/{token} [post]
|
||||
func UploadUserFiles(router *gin.RouterGroup) {
|
||||
router.POST("/users/:uid/upload/:token", func(c *gin.Context) {
|
||||
conf := get.Config()
|
||||
@@ -273,7 +273,7 @@ func UploadCheckFile(destName string, rejectRaw bool, totalSizeLimit int64) (rem
|
||||
// @Param options body form.UploadOptions true "processing options"
|
||||
// @Success 200 {object} i18n.Response
|
||||
// @Failure 400,401,403,404,409,429 {object} i18n.Response
|
||||
// @Router /users/{uid}/upload/{token} [put]
|
||||
// @Router /api/v1/users/{uid}/upload/{token} [put]
|
||||
func ProcessUserUpload(router *gin.RouterGroup) {
|
||||
router.PUT("/users/:uid/upload/:token", func(c *gin.Context) {
|
||||
s := AuthAny(c, acl.ResourceFiles, acl.Permissions{acl.ActionManage, acl.ActionUpload})
|
||||
|
677
internal/api/users_upload_multipart_test.go
Normal file
677
internal/api/users_upload_multipart_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
)
|
||||
|
||||
// buildMultipart builds a multipart form with one field name "files" and provided files.
|
||||
func buildMultipart(files map[string][]byte) (body *bytes.Buffer, contentType string, err error) {
|
||||
body = &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
for name, data := range files {
|
||||
fw, cerr := mw.CreateFormFile("files", name)
|
||||
if cerr != nil {
|
||||
return nil, "", cerr
|
||||
}
|
||||
if _, werr := fw.Write(data); werr != nil {
|
||||
return nil, "", werr
|
||||
}
|
||||
}
|
||||
cerr := mw.Close()
|
||||
return body, mw.FormDataContentType(), cerr
|
||||
}
|
||||
|
||||
// buildMultipartTwo builds a multipart form with exactly two files (same field name: "files").
|
||||
func buildMultipartTwo(name1 string, data1 []byte, name2 string, data2 []byte) (body *bytes.Buffer, contentType string, err error) {
|
||||
body = &bytes.Buffer{}
|
||||
mw := multipart.NewWriter(body)
|
||||
for _, it := range [][2]interface{}{{name1, data1}, {name2, data2}} {
|
||||
fw, cerr := mw.CreateFormFile("files", it[0].(string))
|
||||
if cerr != nil {
|
||||
return nil, "", cerr
|
||||
}
|
||||
if _, werr := fw.Write(it[1].([]byte)); werr != nil {
|
||||
return nil, "", werr
|
||||
}
|
||||
}
|
||||
cerr := mw.Close()
|
||||
return body, mw.FormDataContentType(), cerr
|
||||
}
|
||||
|
||||
// buildZipWithDirsAndFiles creates a zip archive bytes with explicit directory entries and files.
|
||||
func buildZipWithDirsAndFiles(dirs []string, files map[string][]byte) []byte {
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
// Directories (ensure trailing slash)
|
||||
for _, d := range dirs {
|
||||
name := d
|
||||
if !strings.HasSuffix(name, "/") {
|
||||
name += "/"
|
||||
}
|
||||
_, _ = zw.Create(name)
|
||||
}
|
||||
// Files
|
||||
for name, data := range files {
|
||||
f, _ := zw.Create(name)
|
||||
_, _ = f.Write(data)
|
||||
}
|
||||
_ = zw.Close()
|
||||
return zbuf.Bytes()
|
||||
}
|
||||
|
||||
func findUploadedFiles(t *testing.T, base string) []string {
|
||||
t.Helper()
|
||||
var out []string
|
||||
_ = filepath.Walk(base, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil && !info.IsDir() {
|
||||
out = append(out, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// findUploadedFilesForToken lists files only under upload subfolders whose name ends with token suffix.
|
||||
func findUploadedFilesForToken(t *testing.T, base string, tokenSuffix string) []string {
|
||||
t.Helper()
|
||||
var out []string
|
||||
entries, _ := os.ReadDir(base)
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, tokenSuffix) {
|
||||
continue
|
||||
}
|
||||
dir := filepath.Join(base, name)
|
||||
_ = filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
if err == nil && !info.IsDir() {
|
||||
out = append(out, p)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// removeUploadDirsForToken removes upload subdirectories whose name ends with tokenSuffix.
|
||||
func removeUploadDirsForToken(t *testing.T, base string, tokenSuffix string) {
|
||||
t.Helper()
|
||||
entries, _ := os.ReadDir(base)
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if strings.HasSuffix(name, tokenSuffix) {
|
||||
_ = os.RemoveAll(filepath.Join(base, name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_SingleJPEG(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
// Limit allowed upload extensions to ensure text files get rejected in tests
|
||||
conf.Options().UploadAllow = "jpg"
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
// Cleanup: remove token-specific upload dir after test
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "abc123")
|
||||
// Load a real tiny JPEG from testdata
|
||||
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
|
||||
data, err := os.ReadFile(jpgPath)
|
||||
if err != nil {
|
||||
t.Skipf("missing example.jpg: %v", err)
|
||||
}
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"example.jpg": data})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqUrl := "/api/v1/users/" + adminUid + "/upload/abc123"
|
||||
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
|
||||
|
||||
// Verify file written somewhere under users/<uid>/upload/*
|
||||
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, uploadBase, "abc123")
|
||||
// At least one file written
|
||||
assert.NotEmpty(t, files)
|
||||
// Expect the filename to appear somewhere
|
||||
var found bool
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f, "example.jpg") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "uploaded JPEG not found")
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipExtract(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
// Allow archives and restrict allowed extensions to images
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,png,zip"
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
|
||||
// Cleanup after test
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "ziptok")
|
||||
// Create an in-memory zip with one JPEG (valid) and one TXT (rejected)
|
||||
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
|
||||
jpg, err := os.ReadFile(jpgPath)
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
// add jpeg
|
||||
jf, _ := zw.Create("a.jpg")
|
||||
_, _ = jf.Write(jpg)
|
||||
// add txt
|
||||
tf, _ := zw.Create("note.txt")
|
||||
_, _ = io.WriteString(tf, "hello")
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"upload.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqUrl := "/api/v1/users/" + adminUid + "/upload/zipoff"
|
||||
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
|
||||
|
||||
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, uploadBase, "zipoff")
|
||||
// Expect extracted jpeg present and txt absent
|
||||
var jpgFound, txtFound bool
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f, "a.jpg") {
|
||||
jpgFound = true
|
||||
}
|
||||
if strings.HasSuffix(f, "note.txt") {
|
||||
txtFound = true
|
||||
}
|
||||
}
|
||||
assert.True(t, jpgFound, "extracted jpeg not found")
|
||||
assert.False(t, txtFound, "text file should be rejected")
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ArchivesDisabled(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
// disallow archives while allowing the .zip extension in filter
|
||||
conf.Options().UploadArchives = false
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
|
||||
// Cleanup after test
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipoff")
|
||||
// zip with one jpeg inside
|
||||
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
|
||||
jpg, err := os.ReadFile(jpgPath)
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
jf, _ := zw.Create("a.jpg")
|
||||
_, _ = jf.Write(jpg)
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"upload.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqUrl := "/api/v1/users/" + adminUid + "/upload/ziptok"
|
||||
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
// server returns 200 even if rejected internally; nothing extracted/saved
|
||||
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
|
||||
|
||||
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, uploadBase, "ziptok")
|
||||
assert.Empty(t, files, "no files should remain when archives disabled")
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_PerFileLimitExceeded(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadAllow = "jpg"
|
||||
conf.Options().OriginalsLimit = 1 // 1 MiB per-file
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "size1")
|
||||
|
||||
// Build a 2MiB dummy payload (not a real JPEG; that's fine for pre-save size check)
|
||||
big := bytes.Repeat([]byte("A"), 2*1024*1024)
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"big.jpg": big})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/size1", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
// Ensure nothing saved
|
||||
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "size1")
|
||||
assert.Empty(t, files)
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_TotalLimitExceeded(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadAllow = "jpg"
|
||||
conf.Options().UploadLimit = 1 // 1 MiB total
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "total")
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
// build multipart with two images so sum > 1 MiB (2*~63KiB = ~126KiB) -> still <1MiB, so use 16 copies
|
||||
// build two bigger bodies by concatenation
|
||||
times := 9
|
||||
big1 := bytes.Repeat(data, times)
|
||||
big2 := bytes.Repeat(data, times)
|
||||
body, ctype, err := buildMultipartTwo("a.jpg", big1, "b.jpg", big2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/total", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Expect at most one file saved (second should be rejected by total limit)
|
||||
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "total")
|
||||
assert.LessOrEqual(t, len(files), 1)
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipPartialExtraction(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadLimit = 1 // 1 MiB total
|
||||
conf.Options().OriginalsLimit = 50 // 50 MiB per file
|
||||
conf.Options().UploadNSFW = true // skip nsfw scanning to speed up test
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "partial")
|
||||
|
||||
// Build a zip containing multiple JPEG entries so that total extracted size > 1 MiB
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
for i := 0; i < 20; i++ { // ~20 * 63 KiB ≈ 1.2 MiB
|
||||
f, _ := zw.Create(fmt.Sprintf("pic%02d.jpg", i+1))
|
||||
_, _ = f.Write(data)
|
||||
}
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"multi.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/partial", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "partial")
|
||||
// At least one extracted, but not all 20 due to total limit
|
||||
var countJPG int
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f, ".jpg") {
|
||||
countJPG++
|
||||
}
|
||||
}
|
||||
assert.GreaterOrEqual(t, countJPG, 1)
|
||||
assert.Less(t, countJPG, 20)
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipDeepNestingStress(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadNSFW = true
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipdeep")
|
||||
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
// Build a deeply nested path (20 levels)
|
||||
deep := ""
|
||||
for i := 0; i < 20; i++ {
|
||||
if i == 0 {
|
||||
deep = "deep"
|
||||
} else {
|
||||
deep = filepath.Join(deep, fmt.Sprintf("lvl%02d", i))
|
||||
}
|
||||
}
|
||||
name := filepath.Join(deep, "deep.jpg")
|
||||
zbytes := buildZipWithDirsAndFiles(nil, map[string][]byte{name: data})
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"deepnest.zip": zbytes})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipdeep", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, base, "zipdeep")
|
||||
// Only one file expected, deep path created
|
||||
assert.Equal(t, 1, len(files))
|
||||
assert.True(t, strings.Contains(files[0], filepath.Join("deep", "lvl01")))
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipRejectsHiddenAndTraversal(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadNSFW = true // skip scanning
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "rejects")
|
||||
|
||||
// Prepare a valid jpg payload
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
// Hidden file
|
||||
f1, _ := zw.Create(".hidden.jpg")
|
||||
_, _ = f1.Write(data)
|
||||
// @ file
|
||||
f2, _ := zw.Create("@meta.jpg")
|
||||
_, _ = f2.Write(data)
|
||||
// Traversal path (will be skipped by safe join in unzip)
|
||||
f3, _ := zw.Create("dir/../traverse.jpg")
|
||||
_, _ = f3.Write(data)
|
||||
// Valid file
|
||||
f4, _ := zw.Create("ok.jpg")
|
||||
_, _ = f4.Write(data)
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"test.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/rejects", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "rejects")
|
||||
var hasOk, hasHidden, hasAt, hasTraverse bool
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f, "ok.jpg") {
|
||||
hasOk = true
|
||||
}
|
||||
if strings.HasSuffix(f, ".hidden.jpg") {
|
||||
hasHidden = true
|
||||
}
|
||||
if strings.HasSuffix(f, "@meta.jpg") {
|
||||
hasAt = true
|
||||
}
|
||||
if strings.HasSuffix(f, "traverse.jpg") {
|
||||
hasTraverse = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasOk)
|
||||
assert.False(t, hasHidden)
|
||||
assert.False(t, hasAt)
|
||||
assert.False(t, hasTraverse)
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipNestedDirectories(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadNSFW = true
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipnest")
|
||||
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
// Create nested dirs and files
|
||||
dirs := []string{"nested", "nested/sub"}
|
||||
files := map[string][]byte{
|
||||
"nested/a.jpg": data,
|
||||
"nested/sub/b.jpg": data,
|
||||
}
|
||||
zbytes := buildZipWithDirsAndFiles(dirs, files)
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"nested.zip": zbytes})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipnest", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
filesOut := findUploadedFilesForToken(t, base, "zipnest")
|
||||
var haveA, haveB bool
|
||||
for _, f := range filesOut {
|
||||
if strings.HasSuffix(f, filepath.Join("nested", "a.jpg")) {
|
||||
haveA = true
|
||||
}
|
||||
if strings.HasSuffix(f, filepath.Join("nested", "sub", "b.jpg")) {
|
||||
haveB = true
|
||||
}
|
||||
}
|
||||
assert.True(t, haveA)
|
||||
assert.True(t, haveB)
|
||||
// Directories exist
|
||||
// Locate token dir
|
||||
entries, _ := os.ReadDir(base)
|
||||
var tokenDir string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() && strings.HasSuffix(e.Name(), "zipnest") {
|
||||
tokenDir = filepath.Join(base, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenDir != "" {
|
||||
_, errA := os.Stat(filepath.Join(tokenDir, "nested"))
|
||||
_, errB := os.Stat(filepath.Join(tokenDir, "nested", "sub"))
|
||||
assert.NoError(t, errA)
|
||||
assert.NoError(t, errB)
|
||||
} else {
|
||||
t.Fatalf("token dir not found under %s", base)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipImplicitDirectories(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadNSFW = true
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipimpl")
|
||||
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
// Create zip containing only files with nested paths (no explicit directory entries)
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
f1, _ := zw.Create(filepath.Join("nested", "a.jpg"))
|
||||
_, _ = f1.Write(data)
|
||||
f2, _ := zw.Create(filepath.Join("nested", "sub", "b.jpg"))
|
||||
_, _ = f2.Write(data)
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"nested-files-only.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipimpl", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, base, "zipimpl")
|
||||
var haveA, haveB bool
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f, filepath.Join("nested", "a.jpg")) {
|
||||
haveA = true
|
||||
}
|
||||
if strings.HasSuffix(f, filepath.Join("nested", "sub", "b.jpg")) {
|
||||
haveB = true
|
||||
}
|
||||
}
|
||||
assert.True(t, haveA)
|
||||
assert.True(t, haveB)
|
||||
// Confirm directories were implicitly created
|
||||
entries, _ := os.ReadDir(base)
|
||||
var tokenDir string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() && strings.HasSuffix(e.Name(), "zipimpl") {
|
||||
tokenDir = filepath.Join(base, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenDir == "" {
|
||||
t.Fatalf("token dir not found under %s", base)
|
||||
}
|
||||
_, errA := os.Stat(filepath.Join(tokenDir, "nested"))
|
||||
_, errB := os.Stat(filepath.Join(tokenDir, "nested", "sub"))
|
||||
assert.NoError(t, errA)
|
||||
assert.NoError(t, errB)
|
||||
}
|
||||
|
||||
func TestUploadUserFiles_Multipart_ZipAbsolutePathRejected(t *testing.T) {
|
||||
app, router, conf := NewApiTest()
|
||||
conf.Options().UploadArchives = true
|
||||
conf.Options().UploadAllow = "jpg,zip"
|
||||
conf.Options().UploadNSFW = true
|
||||
UploadUserFiles(router)
|
||||
token := AuthenticateAdmin(app, router)
|
||||
|
||||
adminUid := entity.Admin.UserUID
|
||||
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipabs")
|
||||
|
||||
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
|
||||
if err != nil {
|
||||
t.Skip("missing example.jpg")
|
||||
}
|
||||
|
||||
// Zip with an absolute path entry
|
||||
var zbuf bytes.Buffer
|
||||
zw := zip.NewWriter(&zbuf)
|
||||
f, _ := zw.Create("/abs.jpg")
|
||||
_, _ = f.Write(data)
|
||||
_ = zw.Close()
|
||||
|
||||
body, ctype, err := buildMultipart(map[string][]byte{"abs.zip": zbuf.Bytes()})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipabs", body)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
header.SetAuthorization(req, token)
|
||||
w := httptest.NewRecorder()
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// No files should be extracted/saved for this token
|
||||
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
|
||||
files := findUploadedFilesForToken(t, base, "zipabs")
|
||||
assert.Empty(t, files)
|
||||
}
|
@@ -3,6 +3,8 @@ package api
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -43,3 +45,64 @@ func TestUploadUserFiles(t *testing.T) {
|
||||
config.Options().FilesQuota = 0
|
||||
})
|
||||
}
|
||||
|
||||
func TestUploadCheckFile_AcceptsAndReducesLimit(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Copy a small known-good JPEG test file from pkg/fs/testdata
|
||||
src := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
|
||||
dst := filepath.Join(dir, "example.jpg")
|
||||
b, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
t.Skipf("skip if test asset not present: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dst, b, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
orig := int64(len(b))
|
||||
rem, err := UploadCheckFile(dst, false, orig+100)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), rem)
|
||||
// file remains
|
||||
assert.FileExists(t, dst)
|
||||
}
|
||||
|
||||
func TestUploadCheckFile_TotalLimitReachedDeletes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Make a tiny file
|
||||
dst := filepath.Join(dir, "tiny.txt")
|
||||
assert.NoError(t, os.WriteFile(dst, []byte("hello"), 0o600))
|
||||
// Very small total limit (0) → should remove file and error
|
||||
_, err := UploadCheckFile(dst, false, 0)
|
||||
assert.Error(t, err)
|
||||
_, statErr := os.Stat(dst)
|
||||
assert.True(t, os.IsNotExist(statErr), "file should be removed when limit reached")
|
||||
}
|
||||
|
||||
func TestUploadCheckFile_UnsupportedTypeDeletes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Create a file with an unknown extension; should be rejected
|
||||
dst := filepath.Join(dir, "unknown.xyz")
|
||||
assert.NoError(t, os.WriteFile(dst, []byte("not-an-image"), 0o600))
|
||||
_, err := UploadCheckFile(dst, false, 1<<20)
|
||||
assert.Error(t, err)
|
||||
_, statErr := os.Stat(dst)
|
||||
assert.True(t, os.IsNotExist(statErr), "unsupported file should be removed")
|
||||
}
|
||||
|
||||
func TestUploadCheckFile_SizeAccounting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Use known-good JPEG
|
||||
src := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
t.Skip("asset missing; skip")
|
||||
}
|
||||
f := filepath.Join(dir, "a.jpg")
|
||||
assert.NoError(t, os.WriteFile(f, data, 0o600))
|
||||
size := int64(len(data))
|
||||
// Set remaining limit to size+1 so it does not hit the removal branch (which triggers on <=0)
|
||||
rem, err := UploadCheckFile(f, false, size+1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), rem)
|
||||
}
|
||||
|
@@ -152,9 +152,24 @@ func WebDAVFileName(request *http.Request, router *gin.RouterGroup, conf *config
|
||||
// Determine the absolute file path based on the request URL and the configuration.
|
||||
switch basePath {
|
||||
case conf.BaseUri(WebDAVOriginals):
|
||||
fileName = filepath.Join(conf.OriginalsPath(), strings.TrimPrefix(request.URL.Path, basePath))
|
||||
// Resolve the requested path safely under OriginalsPath.
|
||||
rel := strings.TrimPrefix(request.URL.Path, basePath)
|
||||
// Make relative if a leading slash remains after trimming the base.
|
||||
rel = strings.TrimLeft(rel, "/\\")
|
||||
if name, err := joinUnderBase(conf.OriginalsPath(), rel); err == nil {
|
||||
fileName = name
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
case conf.BaseUri(WebDAVImport):
|
||||
fileName = filepath.Join(conf.ImportPath(), strings.TrimPrefix(request.URL.Path, basePath))
|
||||
// Resolve the requested path safely under ImportPath.
|
||||
rel := strings.TrimPrefix(request.URL.Path, basePath)
|
||||
rel = strings.TrimLeft(rel, "/\\")
|
||||
if name, err := joinUnderBase(conf.ImportPath(), rel); err == nil {
|
||||
fileName = name
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
@@ -167,6 +182,27 @@ func WebDAVFileName(request *http.Request, router *gin.RouterGroup, conf *config
|
||||
return fileName
|
||||
}
|
||||
|
||||
// joinUnderBase joins a base directory with a relative name and ensures
|
||||
// that the resulting path stays within the base directory. Absolute
|
||||
// paths and Windows-style volume names are rejected.
|
||||
func joinUnderBase(baseDir, rel string) (string, error) {
|
||||
if rel == "" {
|
||||
return "", fmt.Errorf("invalid path")
|
||||
}
|
||||
// Reject absolute or volume paths.
|
||||
if filepath.IsAbs(rel) || filepath.VolumeName(rel) != "" {
|
||||
return "", fmt.Errorf("invalid path: absolute or volume path not allowed")
|
||||
}
|
||||
cleaned := filepath.Clean(rel)
|
||||
// Compose destination and verify it stays inside base.
|
||||
dest := filepath.Join(baseDir, cleaned)
|
||||
base := filepath.Clean(baseDir)
|
||||
if dest != base && !strings.HasPrefix(dest, base+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("invalid path: outside base directory")
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
// WebDAVSetFavoriteFlag adds the favorite flag to files uploaded via WebDAV.
|
||||
func WebDAVSetFavoriteFlag(fileName string) {
|
||||
yamlName := fs.AbsPrefix(fileName, false) + fs.ExtYml
|
||||
|
37
internal/server/webdav_actions_test.go
Normal file
37
internal/server/webdav_actions_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWebDAVSetFavoriteFlag_CreatesYamlOnce(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "img.jpg")
|
||||
assert.NoError(t, os.WriteFile(file, []byte("x"), 0o600))
|
||||
// First call creates YAML
|
||||
WebDAVSetFavoriteFlag(file)
|
||||
// YAML is written next to file without the media extension (AbsPrefix)
|
||||
yml := filepath.Join(filepath.Dir(file), "img.yml")
|
||||
assert.FileExists(t, yml)
|
||||
// Write a marker and ensure second call doesn't overwrite content
|
||||
orig, _ := os.ReadFile(yml)
|
||||
WebDAVSetFavoriteFlag(file)
|
||||
now, _ := os.ReadFile(yml)
|
||||
assert.Equal(t, string(orig), string(now))
|
||||
}
|
||||
|
||||
func TestWebDAVSetFileMtime_NoFuture(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "a.txt")
|
||||
assert.NoError(t, os.WriteFile(file, []byte("x"), 0o600))
|
||||
// Set a past mtime
|
||||
WebDAVSetFileMtime(file, 946684800) // 2000-01-01 UTC
|
||||
after, _ := os.Stat(file)
|
||||
// Compare seconds to avoid platform-specific rounding
|
||||
got := after.ModTime().Unix()
|
||||
assert.Equal(t, int64(946684800), got)
|
||||
}
|
90
internal/server/webdav_secure_test.go
Normal file
90
internal/server/webdav_secure_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestJoinUnderBase(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
// Normal join
|
||||
out, err := joinUnderBase(base, "a/b/c.txt")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(base, "a/b/c.txt"), out)
|
||||
// Absolute rejected
|
||||
_, err = joinUnderBase(base, "/etc/passwd")
|
||||
assert.Error(t, err)
|
||||
// Parent traversal rejected
|
||||
_, err = joinUnderBase(base, "../../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWebDAVFileName_PathTraversalRejected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Create a legitimate file inside base to ensure happy-path works later.
|
||||
insideFile := filepath.Join(dir, "ok.txt")
|
||||
assert.NoError(t, fs.WriteString(insideFile, "ok"))
|
||||
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
conf.Options().OriginalsPath = dir
|
||||
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVOriginals))
|
||||
|
||||
// Attempt traversal to outside path.
|
||||
req := &http.Request{Method: http.MethodPut}
|
||||
req.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/../../etc/passwd"}
|
||||
got := WebDAVFileName(req, grp, conf)
|
||||
assert.Equal(t, "", got, "should reject traversal")
|
||||
|
||||
// Happy path: file under base resolves and exists.
|
||||
req2 := &http.Request{Method: http.MethodPut}
|
||||
req2.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/ok.txt"}
|
||||
got = WebDAVFileName(req2, grp, conf)
|
||||
assert.Equal(t, insideFile, got)
|
||||
}
|
||||
|
||||
func TestWebDAVFileName_MethodNotPut(t *testing.T) {
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVOriginals))
|
||||
req := &http.Request{Method: http.MethodGet}
|
||||
req.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/anything.jpg"}
|
||||
got := WebDAVFileName(req, grp, conf)
|
||||
assert.Equal(t, "", got)
|
||||
}
|
||||
|
||||
func TestWebDAVFileName_ImportBasePath(t *testing.T) {
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVImport))
|
||||
// create a real file under import
|
||||
file := filepath.Join(conf.ImportPath(), "in.jpg")
|
||||
assert.NoError(t, fs.MkdirAll(filepath.Dir(file)))
|
||||
assert.NoError(t, fs.WriteString(file, "x"))
|
||||
req := &http.Request{Method: http.MethodPut}
|
||||
req.URL = &url.URL{Path: conf.BaseUri(WebDAVImport) + "/in.jpg"}
|
||||
got := WebDAVFileName(req, grp, conf)
|
||||
assert.Equal(t, file, got)
|
||||
}
|
||||
|
||||
func TestWebDAVSetFileMtime_FutureIgnored(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "a.txt")
|
||||
assert.NoError(t, fs.WriteString(file, "x"))
|
||||
before, _ := os.Stat(file)
|
||||
future := time.Now().Add(2 * time.Hour).Unix()
|
||||
WebDAVSetFileMtime(file, future)
|
||||
after, _ := os.Stat(file)
|
||||
assert.Equal(t, before.ModTime().Unix(), after.ModTime().Unix())
|
||||
}
|
298
internal/server/webdav_write_test.go
Normal file
298
internal/server/webdav_write_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
)
|
||||
|
||||
func setupWebDAVRouter(conf *config.Config) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVOriginals), WebDAVAuth(conf))
|
||||
WebDAV(conf.OriginalsPath(), grp, conf)
|
||||
return r
|
||||
}
|
||||
|
||||
func authBearer(req *http.Request) {
|
||||
sess := entity.SessionFixtures.Get("alice_token_webdav")
|
||||
header.SetAuthorization(req, sess.AuthToken())
|
||||
}
|
||||
|
||||
func authBasic(req *http.Request) {
|
||||
sess := entity.SessionFixtures.Get("alice_token_webdav")
|
||||
basic := []byte(fmt.Sprintf("alice:%s", sess.AuthToken()))
|
||||
req.Header.Set(header.Auth, fmt.Sprintf("%s %s", header.AuthBasic, base64.StdEncoding.EncodeToString(basic)))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_MKCOL_PUT(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
|
||||
// MKCOL
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/wdvdir", nil)
|
||||
authBearer(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1) // Created
|
||||
// PUT file
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/wdvdir/hello.txt", bytes.NewBufferString("hello"))
|
||||
authBearer(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
// file exists
|
||||
path := filepath.Join(conf.OriginalsPath(), "wdvdir", "hello.txt")
|
||||
b, err := os.ReadFile(path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hello", string(b))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_MOVE_COPY(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
|
||||
// Ensure source and destination directories via MKCOL
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/src", nil)
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/dst", nil)
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
// Create source file via PUT
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/src/a.txt", bytes.NewBufferString("A"))
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
|
||||
// MOVE /originals/src/a.txt -> /originals/dst/b.txt
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/a.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/b.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
// Verify moved
|
||||
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "src", "a.txt"))
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "b.txt"))
|
||||
|
||||
// COPY /originals/dst/b.txt -> /originals/dst/c.txt
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/dst/b.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/c.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
// Verify copy
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "b.txt"))
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "c.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_OverwriteSemantics(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
|
||||
// Prepare src and dst
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "src"), 0o700)
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "dst"), 0o700)
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "src", "f.txt"), []byte("NEW"), 0o600)
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"), []byte("OLD"), 0o600)
|
||||
|
||||
// COPY with Overwrite: F -> should not overwrite existing
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/src/f.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/f.txt")
|
||||
req.Header.Set("Overwrite", "F")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Expect not successful (commonly 412 Precondition Failed)
|
||||
if w.Code == 201 || w.Code == 204 {
|
||||
t.Fatalf("expected failure when Overwrite=F, got %d", w.Code)
|
||||
}
|
||||
// Content remains OLD
|
||||
b, _ := os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"))
|
||||
assert.Equal(t, "OLD", string(b))
|
||||
|
||||
// COPY with Overwrite: T -> must overwrite
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/src/f.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/f.txt")
|
||||
req.Header.Set("Overwrite", "T")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Success (201/204 acceptable)
|
||||
if !(w.Code == 201 || w.Code == 204) {
|
||||
t.Fatalf("expected success for Overwrite=T, got %d", w.Code)
|
||||
}
|
||||
b, _ = os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"))
|
||||
assert.Equal(t, "NEW", string(b))
|
||||
|
||||
// MOVE with Overwrite: F to existing file -> expect failure
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "src", "g.txt"), []byte("GNEW"), 0o600)
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "dst", "g.txt"), []byte("GOLD"), 0o600)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/g.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/g.txt")
|
||||
req.Header.Set("Overwrite", "F")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code == 201 || w.Code == 204 {
|
||||
t.Fatalf("expected failure when Overwrite=F for MOVE, got %d", w.Code)
|
||||
}
|
||||
// MOVE with Overwrite: T -> overwrites and removes source
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/g.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/g.txt")
|
||||
req.Header.Set("Overwrite", "T")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
if !(w.Code == 201 || w.Code == 204) {
|
||||
t.Fatalf("expected success for MOVE Overwrite=T, got %d", w.Code)
|
||||
}
|
||||
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "src", "g.txt"))
|
||||
gb, _ := os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "g.txt"))
|
||||
assert.Equal(t, "GNEW", string(gb))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_MoveMissingDestination(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
// Ensure src exists
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "mv"), 0o700)
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "mv", "file.txt"), []byte("X"), 0o600)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/mv/file.txt", nil)
|
||||
// no Destination header
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Expect failure (not 201/204)
|
||||
if w.Code == 201 || w.Code == 204 {
|
||||
t.Fatalf("expected failure when Destination header missing, got %d", w.Code)
|
||||
}
|
||||
// Source remains
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "mv", "file.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_CopyInvalidDestinationPrefix(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
// Ensure src exists
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "cp"), 0o700)
|
||||
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "cp", "a.txt"), []byte("A"), 0o600)
|
||||
|
||||
// COPY to a destination outside the handler prefix
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/cp/a.txt", nil)
|
||||
req.Header.Set("Destination", "/notwebdav/d.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Expect failure
|
||||
if w.Code == 201 || w.Code == 204 {
|
||||
t.Fatalf("expected failure for invalid Destination prefix, got %d", w.Code)
|
||||
}
|
||||
// Destination not created
|
||||
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "notwebdav", "d.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_MoveNonExistentSource(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
// Ensure destination dir exists
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "dst2"), 0o700)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/nosuch/file.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst2/file.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Expect failure (e.g., 404)
|
||||
if w.Code == 201 || w.Code == 204 {
|
||||
t.Fatalf("expected failure moving non-existent source, got %d", w.Code)
|
||||
}
|
||||
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "dst2", "file.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_CopyTraversalDestination(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
|
||||
// Create source file via PUT
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "travsrc"), 0o700)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/travsrc/a.txt", bytes.NewBufferString("A"))
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
|
||||
// Attempt COPY with traversal in Destination
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/travsrc/a.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/../evil.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
// Expect success with sanitized destination inside base
|
||||
if !(w.Code == 201 || w.Code == 204) {
|
||||
t.Fatalf("expected success (sanitized), got %d", w.Code)
|
||||
}
|
||||
// Not created above originals; created as /originals/evil.txt
|
||||
parent := filepath.Dir(conf.OriginalsPath())
|
||||
assert.NoFileExists(t, filepath.Join(parent, "evil.txt"))
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "evil.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_MoveTraversalDestination(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
r := setupWebDAVRouter(conf)
|
||||
|
||||
// Create source file via PUT
|
||||
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "travsrc2"), 0o700)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/travsrc2/a.txt", bytes.NewBufferString("A"))
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.InDelta(t, 201, w.Code, 1)
|
||||
|
||||
// Attempt MOVE with traversal in Destination
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/travsrc2/a.txt", nil)
|
||||
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/../evil2.txt")
|
||||
authBasic(req)
|
||||
r.ServeHTTP(w, req)
|
||||
if !(w.Code == 201 || w.Code == 204) {
|
||||
t.Fatalf("expected success (sanitized) for MOVE, got %d", w.Code)
|
||||
}
|
||||
// Source removed; destination created inside base, not outside
|
||||
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "travsrc2", "a.txt"))
|
||||
parent := filepath.Dir(conf.OriginalsPath())
|
||||
assert.NoFileExists(t, filepath.Join(parent, "evil2.txt"))
|
||||
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "evil2.txt"))
|
||||
}
|
||||
|
||||
func TestWebDAVWrite_ReadOnlyForbidden(t *testing.T) {
|
||||
conf := config.TestConfig()
|
||||
conf.Options().ReadOnly = true
|
||||
r := setupWebDAVRouter(conf)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/ro", nil)
|
||||
authBearer(req)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
42
internal/thumb/avatar/download.go
Normal file
42
internal/thumb/avatar/download.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/service/http/safe"
|
||||
)
|
||||
|
||||
var (
|
||||
// Stricter defaults for avatar images than the generic HTTP safe defaults.
|
||||
defaultTimeout = 15 * time.Second
|
||||
defaultMaxSize int64 = 10 << 20 // 10 MiB for avatar images
|
||||
)
|
||||
|
||||
// SafeDownload delegates avatar image downloads to the shared HTTP safe downloader
|
||||
// with hardened defaults suitable for small image files.
|
||||
// Callers may pass a partially filled safe.Options to override defaults.
|
||||
func SafeDownload(destPath, rawURL string, opt *safe.Options) error {
|
||||
// Start with strict avatar defaults.
|
||||
o := &safe.Options{
|
||||
Timeout: defaultTimeout,
|
||||
MaxSizeBytes: defaultMaxSize,
|
||||
AllowPrivate: false, // block private/loopback by default
|
||||
// Prefer images but allow others at low priority; MIME is validated later.
|
||||
Accept: "image/jpeg, image/png, */*;q=0.1",
|
||||
}
|
||||
if opt != nil {
|
||||
if opt.Timeout > 0 {
|
||||
o.Timeout = opt.Timeout
|
||||
}
|
||||
if opt.MaxSizeBytes > 0 {
|
||||
o.MaxSizeBytes = opt.MaxSizeBytes
|
||||
}
|
||||
// Bool has no sentinel; just copy the value.
|
||||
o.AllowPrivate = opt.AllowPrivate
|
||||
if strings.TrimSpace(opt.Accept) != "" {
|
||||
o.Accept = opt.Accept
|
||||
}
|
||||
}
|
||||
return safe.Download(destPath, rawURL, o)
|
||||
}
|
73
internal/thumb/avatar/download_test.go
Normal file
73
internal/thumb/avatar/download_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/service/http/safe"
|
||||
)
|
||||
|
||||
func TestSafeDownload_InvalidScheme(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "x")
|
||||
if err := SafeDownload(dest, "file:///etc/passwd", nil); err == nil {
|
||||
t.Fatal("expected error for invalid scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeDownload_PrivateIPBlocked(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "x")
|
||||
if err := SafeDownload(dest, "http://127.0.0.1/test.png", nil); err == nil {
|
||||
t.Fatal("expected SSRF private IP block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeDownload_MaxSizeExceeded(t *testing.T) {
|
||||
// Local server; allow private for test.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
// 2MB body
|
||||
w.WriteHeader(http.StatusOK)
|
||||
buf := make([]byte, 2<<20)
|
||||
_, _ = w.Write(buf)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "big")
|
||||
err := SafeDownload(dest, ts.URL, &safe.Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true})
|
||||
if err == nil {
|
||||
t.Fatal("expected size exceeded error")
|
||||
}
|
||||
if _, statErr := os.Stat(dest); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected no output file on error, got stat err=%v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeDownload_Succeeds(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "ok")
|
||||
if err := SafeDownload(dest, ts.URL, &safe.Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if string(b) != "ok" {
|
||||
t.Fatalf("unexpected content: %q", string(b))
|
||||
}
|
||||
}
|
@@ -32,7 +32,8 @@ func SetUserImageURL(m *entity.User, imageUrl, imageSrc, thumbPath string) error
|
||||
|
||||
tmpName := filepath.Join(os.TempDir(), rnd.Base36(64))
|
||||
|
||||
if err = fs.Download(tmpName, u.String()); err != nil {
|
||||
// Hardened remote fetch with SSRF and size limits.
|
||||
if err = SafeDownload(tmpName, u.String(), nil); err != nil {
|
||||
return fmt.Errorf("failed to download avatar image (%w)", err)
|
||||
}
|
||||
|
||||
|
40
pkg/fs/fs.go
40
pkg/fs/fs.go
@@ -27,11 +27,12 @@ package fs
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/service/http/safe"
|
||||
)
|
||||
|
||||
var ignoreCase bool
|
||||
@@ -206,40 +207,9 @@ func Abs(name string) string {
|
||||
|
||||
// Download downloads a file from a URL.
|
||||
func Download(fileName string, url string) error {
|
||||
if dir := filepath.Dir(fileName); dir == "" || dir == "/" || dir == "." || dir == ".." {
|
||||
return fmt.Errorf("invalid path")
|
||||
} else if err := MkdirAll(dir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the file
|
||||
out, err := os.Create(fileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer out.Close()
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check server response
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Writer the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// Preserve existing semantics but with safer network behavior.
|
||||
// Allow private IPs by default to avoid breaking intended internal downloads.
|
||||
return safe.Download(fileName, url, &safe.Options{AllowPrivate: true})
|
||||
}
|
||||
|
||||
// DirIsEmpty returns true if a directory is empty.
|
||||
|
218
pkg/service/http/safe/download.go
Normal file
218
pkg/service/http/safe/download.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package safe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Download fetches a URL to a destination file with timeouts, size limits, and optional SSRF protection.
|
||||
func Download(destPath, rawURL string, opt *Options) error {
|
||||
if destPath == "" {
|
||||
return errors.New("invalid destination path")
|
||||
}
|
||||
// Prepare destination directory.
|
||||
if dir := filepath.Dir(destPath); dir == "" || dir == "/" || dir == "." || dir == ".." {
|
||||
return errors.New("invalid destination directory")
|
||||
} else if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
|
||||
return ErrSchemeNotAllowed
|
||||
}
|
||||
|
||||
// Defaults w/ env overrides
|
||||
maxSize := defaultMaxSize
|
||||
if n := envInt64("PHOTOPRISM_HTTP_MAX_DOWNLOAD"); n > 0 {
|
||||
maxSize = n
|
||||
}
|
||||
timeout := defaultTimeout
|
||||
if d := envDuration("PHOTOPRISM_HTTP_TIMEOUT"); d > 0 {
|
||||
timeout = d
|
||||
}
|
||||
|
||||
o := Options{Timeout: timeout, MaxSizeBytes: maxSize, AllowPrivate: true, Accept: "*/*"}
|
||||
if opt != nil {
|
||||
if opt.Timeout > 0 {
|
||||
o.Timeout = opt.Timeout
|
||||
}
|
||||
if opt.MaxSizeBytes > 0 {
|
||||
o.MaxSizeBytes = opt.MaxSizeBytes
|
||||
}
|
||||
o.AllowPrivate = opt.AllowPrivate
|
||||
if strings.TrimSpace(opt.Accept) != "" {
|
||||
o.Accept = opt.Accept
|
||||
}
|
||||
}
|
||||
|
||||
// Optional SSRF block
|
||||
if !o.AllowPrivate {
|
||||
if ip := net.ParseIP(u.Hostname()); ip != nil {
|
||||
if isPrivateOrDisallowedIP(ip) {
|
||||
return ErrPrivateIP
|
||||
}
|
||||
} else {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
addrs, lookErr := net.DefaultResolver.LookupIPAddr(ctx, u.Hostname())
|
||||
if lookErr != nil {
|
||||
return lookErr
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if isPrivateOrDisallowedIP(a.IP) {
|
||||
return ErrPrivateIP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce redirect validation when private networks are disallowed.
|
||||
client := &http.Client{
|
||||
Timeout: o.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if !o.AllowPrivate {
|
||||
h := req.URL.Hostname()
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
if isPrivateOrDisallowedIP(ip) {
|
||||
return ErrPrivateIP
|
||||
}
|
||||
} else {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
addrs, lookErr := net.DefaultResolver.LookupIPAddr(ctx, h)
|
||||
if lookErr != nil {
|
||||
return lookErr
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if isPrivateOrDisallowedIP(a.IP) {
|
||||
return ErrPrivateIP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Propagate Accept header from the first request.
|
||||
if len(via) > 0 {
|
||||
if v := via[0].Header.Get("Accept"); v != "" {
|
||||
req.Header.Set("Accept", v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if o.Accept != "" {
|
||||
req.Header.Set("Accept", o.Accept)
|
||||
}
|
||||
// Capture the final remote IP used for the connection.
|
||||
var finalIP net.IP
|
||||
trace := &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
if addr := info.Conn.RemoteAddr(); addr != nil {
|
||||
host, _, _ := net.SplitHostPort(addr.String())
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
finalIP = ip
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
|
||||
// Validate the connected peer address when private ranges are disallowed.
|
||||
if !o.AllowPrivate && finalIP != nil && isPrivateOrDisallowedIP(finalIP) {
|
||||
return ErrPrivateIP
|
||||
}
|
||||
|
||||
if resp.ContentLength > 0 && o.MaxSizeBytes > 0 && resp.ContentLength > o.MaxSizeBytes {
|
||||
return ErrSizeExceeded
|
||||
}
|
||||
|
||||
tmp := destPath + ".part"
|
||||
f, err := os.OpenFile(tmp, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
f.Close()
|
||||
if err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
}
|
||||
}()
|
||||
|
||||
var r io.Reader = resp.Body
|
||||
if o.MaxSizeBytes > 0 {
|
||||
r = io.LimitReader(resp.Body, o.MaxSizeBytes+1)
|
||||
}
|
||||
n, copyErr := io.Copy(f, r)
|
||||
if copyErr != nil {
|
||||
err = copyErr
|
||||
return err
|
||||
}
|
||||
if o.MaxSizeBytes > 0 && n > o.MaxSizeBytes {
|
||||
err = ErrSizeExceeded
|
||||
return err
|
||||
}
|
||||
if err = f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = os.Rename(tmp, destPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isPrivateOrDisallowedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
if v4[0] == 10 {
|
||||
return true
|
||||
}
|
||||
if v4[0] == 172 && v4[1] >= 16 && v4[1] <= 31 {
|
||||
return true
|
||||
}
|
||||
if v4[0] == 192 && v4[1] == 168 {
|
||||
return true
|
||||
}
|
||||
if v4[0] == 169 && v4[1] == 254 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
// IPv6 ULA fc00::/7
|
||||
if ip.To16() != nil {
|
||||
if ip[0]&0xFE == 0xFC {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
57
pkg/service/http/safe/download_redirect_test.go
Normal file
57
pkg/service/http/safe/download_redirect_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package safe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Redirect to a private IP must be blocked when AllowPrivate=false.
|
||||
func TestDownload_BlockRedirectToPrivate(t *testing.T) {
|
||||
// Public-looking server that redirects to 127.0.0.1
|
||||
redirectTarget := "http://127.0.0.1:65535/secret"
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, redirectTarget, http.StatusFound)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "out")
|
||||
err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: false})
|
||||
if err == nil {
|
||||
t.Fatalf("expected redirect SSRF to be blocked")
|
||||
}
|
||||
if _, statErr := os.Stat(dest); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected no output file on error, got stat err=%v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
// With AllowPrivate=true, redirects to a local httptest server should succeed.
|
||||
func TestDownload_AllowRedirectToPrivate(t *testing.T) {
|
||||
// Local private target that serves content.
|
||||
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer target.Close()
|
||||
|
||||
// Public-looking server that redirects to the private target.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, target.URL, http.StatusFound)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "ok")
|
||||
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := os.ReadFile(dest)
|
||||
if err != nil || string(b) != "ok" {
|
||||
t.Fatalf("unexpected content: %v %q", err, string(b))
|
||||
}
|
||||
}
|
42
pkg/service/http/safe/download_test.go
Normal file
42
pkg/service/http/safe/download_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package safe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSafeDownload_OK(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "hello")
|
||||
}))
|
||||
defer ts.Close()
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "ok.txt")
|
||||
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1024, AllowPrivate: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := os.ReadFile(dest)
|
||||
if err != nil || string(b) != "hello" {
|
||||
t.Fatalf("unexpected content: %v %q", err, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeDownload_TooLarge(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// 2KiB
|
||||
_, _ = w.Write(make([]byte, 2048))
|
||||
}))
|
||||
defer ts.Close()
|
||||
dir := t.TempDir()
|
||||
dest := filepath.Join(dir, "big.bin")
|
||||
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1024, AllowPrivate: true}); err == nil {
|
||||
t.Fatalf("expected ErrSizeExceeded")
|
||||
}
|
||||
}
|
47
pkg/service/http/safe/options.go
Normal file
47
pkg/service/http/safe/options.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package safe
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Options controls Download behavior.
|
||||
type Options struct {
|
||||
Timeout time.Duration
|
||||
MaxSizeBytes int64
|
||||
AllowPrivate bool
|
||||
Accept string
|
||||
}
|
||||
|
||||
var (
|
||||
// Defaults are tuned for general downloads (not just avatars).
|
||||
defaultTimeout = 30 * time.Second
|
||||
defaultMaxSize = int64(200 * 1024 * 1024) // 200 MiB
|
||||
|
||||
ErrSchemeNotAllowed = errors.New("invalid scheme (only http/https allowed)")
|
||||
ErrSizeExceeded = errors.New("response exceeds maximum allowed size")
|
||||
ErrPrivateIP = errors.New("connection to private or loopback address not allowed")
|
||||
)
|
||||
|
||||
// envInt64 returns an int64 from env or -1 if unset/invalid.
|
||||
func envInt64(key string) int64 {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// envDuration returns a duration from env seconds or 0 if unset/invalid.
|
||||
func envDuration(key string) time.Duration {
|
||||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||||
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return time.Duration(n) * time.Second
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
25
pkg/service/http/safe/safe.go
Normal file
25
pkg/service/http/safe/safe.go
Normal file
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
Package safe provides a secure HTTP downloader with customizable settings.
|
||||
|
||||
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under Version 3 of the GNU Affero General Public License (the "AGPL"):
|
||||
<https://docs.photoprism.app/license/agpl>
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
The AGPL is supplemented by our Trademark and Brand Guidelines,
|
||||
which describe how our Brand Assets may be used:
|
||||
<https://www.photoprism.app/trademark>
|
||||
|
||||
Feel free to send an email to hello@photoprism.app if you have questions,
|
||||
want to support our work, or just want to say hello.
|
||||
|
||||
Additional information can be found in our Developer Guide:
|
||||
<https://docs.photoprism.app/developer-guide/>
|
||||
*/
|
||||
package safe
|
52
scripts/tools/swaggerfix/main.go
Normal file
52
scripts/tools/swaggerfix/main.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "usage: swaggerfix <swagger.json>")
|
||||
os.Exit(2)
|
||||
}
|
||||
path := os.Args[1]
|
||||
b, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "read:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
var doc map[string]interface{}
|
||||
if err := json.Unmarshal(b, &doc); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "parse:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Traverse to definitions.time.Duration
|
||||
defs, _ := doc["definitions"].(map[string]interface{})
|
||||
if defs == nil {
|
||||
fmt.Fprintln(os.Stderr, "no definitions in swagger file")
|
||||
os.Exit(1)
|
||||
}
|
||||
td, _ := defs["time.Duration"].(map[string]interface{})
|
||||
if td == nil {
|
||||
fmt.Fprintln(os.Stderr, "no time.Duration schema found; nothing to do")
|
||||
os.Exit(0)
|
||||
}
|
||||
// Remove unstable enums and varnames to ensure deterministic output.
|
||||
delete(td, "enum")
|
||||
delete(td, "x-enum-varnames")
|
||||
defs["time.Duration"] = td
|
||||
doc["definitions"] = defs
|
||||
out, err := json.MarshalIndent(doc, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "marshal:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := ioutil.WriteFile(path, out, 0644); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "write:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user