Backend: Add security-focused tests, harden WebDAV and use safe.Download

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-09-22 10:42:53 +02:00
parent a22babe3d1
commit 9ea5f0596c
29 changed files with 9905 additions and 7695 deletions

View File

@@ -22,6 +22,9 @@ Learn more: https://agents.md/
- Whenever the Change Management instructions for a document require it, publish changes as a new file with an incremented version suffix (e.g., `*-v3.md`) rather than overwriting the original file.
- Older spec versions remain in the repo for historical reference but are not linked from the main TOC. Do not base new work on superseded files (e.g., `*-v1.md` when `*-v2.md` exists).
Note on specs repository availability
- The `specs/` repository may be private and is not guaranteed to be present in every clone or environment. Do not add Makefile targets in the main project that depend on `specs/` paths. When `specs/` is available, run its tools directly (e.g., `bash specs/scripts/lint-status.sh`).
## Project Structure & Languages
- Backend: Go (`internal/`, `pkg/`, `cmd/`) + MariaDB/SQLite
@@ -191,6 +194,19 @@ Note: Across our public documentation, official images, and in production, the c
- Examples assume a Linux/Unix shell. For Windows specifics, see the Developer Guide FAQ:
https://docs.photoprism.app/developer-guide/faq/#can-your-development-environment-be-used-under-windows
### HTTP Download — Security Checklist
- Use the shared safe HTTP helper instead of adhoc `net/http` code:
- Package: `pkg/service/http/safe``safe.Download(destPath, url, *safe.Options)`.
- Default policy in this repo: allow only `http/https`, enforce timeouts and max size, write to a `0600` temp file then rename.
- SSRF protection (mandatory unless explicitly needed for tests):
- Set `AllowPrivate=false` to block private/loopback/multicast/linklocal ranges.
- All redirect targets are validated; the final connected peer IP is also checked.
- Prefer an imagefocused `Accept` header for image downloads: `"image/jpeg, image/png, */*;q=0.1"`.
- Avatars and small images: use the thin wrapper in `internal/thumb/avatar.SafeDownload` which applies stricter defaults (15s timeout, 10 MiB, `AllowPrivate=false`).
- Tests using `httptest.Server` on 127.0.0.1 must pass `AllowPrivate=true` explicitly to succeed.
- Keep perresource size budgets small; rely on `io.LimitReader` + `Content-Length` prechecks.
If anything in this file conflicts with the `Makefile` or the Developer Guide, the `Makefile` and the documentation win. When unsure, **ask** for clarification before proceeding.
## Agent Quick Tips (Do This)

View File

@@ -153,6 +153,12 @@ Security & Hot Spots (Where to Look)
- Pipeline: `internal/thumb/vips.go` (VipsInit, VipsRotate, export params).
- Sizes & names: `internal/thumb/sizes.go`, `internal/thumb/names.go`, `internal/thumb/filter.go`.
- Safe HTTP downloader:
- Shared utility: `pkg/service/http/safe` (`Download`, `Options`).
- Protections: scheme allowlist (http/https), preDNS + perredirect hostname/IP validation, final peer IP check, size and timeout enforcement, temp file `0600` + rename.
- Avatars: wrapper `internal/thumb/avatar.SafeDownload` applies stricter defaults (15s, 10MiB, `AllowPrivate=false`, imagefocused `Accept`).
- Tests: `go test ./pkg/service/http/safe -count=1` (includes redirect SSRF cases); avatars: `go test ./internal/thumb/avatar -count=1`.
Performance & Limits
- Prefer existing caches/workers/batching as per Makefile and code.
- When adding list endpoints, default `count=100` (max `1000`); set `Cache-Control: no-store` for secrets.

View File

@@ -115,6 +115,8 @@ swag: swag-json
swag-json:
@echo "Generating ./internal/api/swagger.json..."
swag init --ot json --parseDependency --parseDepth 1 --dir internal/api -g api.go -o ./internal/api
@echo "Fixing unstable time.Duration enums in swagger.json..."
@GO111MODULE=on go run scripts/tools/swaggerfix/main.go internal/api/swagger.json || { echo "swaggerfix failed"; exit 1; }
swag-yaml:
@echo "Generating ./internal/api/swagger.yaml..."
swag init --ot yaml --parseDependency --parseDepth 1 --dir internal/api -g api.go -o ./internal/api

View File

@@ -16,7 +16,13 @@ func UpdateClientConfig() {
// GetClientConfig returns the client configuration values as JSON.
//
// GET /api/v1/config
// @Summary get client configuration
// @Id GetClientConfig
// @Tags Config
// @Produce json
// @Success 200 {object} gin.H
// @Failure 401 {object} i18n.Response
// @Router /api/v1/config [get]
func GetClientConfig(router *gin.RouterGroup) {
router.GET("/config", func(c *gin.Context) {
sess := Session(ClientIP(c), AuthToken(c))

View File

@@ -176,7 +176,7 @@ func BatchPhotosRestore(router *gin.RouterGroup) {
// @Param photos body form.Selection true "Photo Selection"
// @Router /api/v1/batch/photos/approve [post]
func BatchPhotosApprove(router *gin.RouterGroup) {
router.POST("batch/photos/approve", func(c *gin.Context) {
router.POST("/batch/photos/approve", func(c *gin.Context) {
s := Auth(c, acl.ResourcePhotos, acl.ActionUpdate)
if s.Abort(c) {

View File

@@ -15,7 +15,16 @@ import (
// Connect confirms external service accounts using a token.
//
// PUT /api/v1/connect/:name
// @Summary confirm external service accounts using a token
// @Id ConnectService
// @Tags Config
// @Accept json
// @Produce json
// @Param name path string true "service name (e.g., hub)"
// @Param connect body form.Connect true "connection token"
// @Success 200 {object} gin.H
// @Failure 400,401,403 {object} i18n.Response
// @Router /api/v1/connect/{name} [put]
func Connect(router *gin.RouterGroup) {
router.PUT("/connect/:name", func(c *gin.Context) {
name := clean.ID(c.Param("name"))

View File

@@ -0,0 +1,13 @@
package api
import "time"
// Schema Overrides for Swagger generation.
// Override the generated schema for time.Duration to avoid unstable enums
// from the standard library constants (Nanosecond, Minute, etc.). Using
// a simple integer schema is accurate (nanoseconds) and deterministic.
//
// @name time.Duration
// @description Duration in nanoseconds (int64). Examples: 1000000000 (1s), 60000000000 (1m).
type SwaggerTimeDuration = time.Duration

View File

@@ -28,7 +28,13 @@ type FoldersResponse struct {
// SearchFoldersOriginals returns folders in originals as JSON.
//
// GET /api/v1/folders/originals
// @Summary list folders in originals
// @Id SearchFoldersOriginals
// @Tags Folders
// @Produce json
// @Success 200 {object} api.FoldersResponse
// @Failure 401,403 {object} i18n.Response
// @Router /api/v1/folders/originals [get]
func SearchFoldersOriginals(router *gin.RouterGroup) {
conf := get.Config()
SearchFolders(router, "originals", entity.RootOriginals, conf.OriginalsPath())
@@ -36,7 +42,13 @@ func SearchFoldersOriginals(router *gin.RouterGroup) {
// SearchFoldersImport returns import folders as JSON.
//
// GET /api/v1/folders/import
// @Summary list folders in import
// @Id SearchFoldersImport
// @Tags Folders
// @Produce json
// @Success 200 {object} api.FoldersResponse
// @Failure 401,403 {object} i18n.Response
// @Router /api/v1/folders/import [get]
func SearchFoldersImport(router *gin.RouterGroup) {
conf := get.Config()
SearchFolders(router, "import", entity.RootImport, conf.ImportPath())

View File

@@ -87,10 +87,10 @@ func DeleteLink(c *gin.Context) {
c.JSON(http.StatusOK, link)
}
// CreateLink adds a new share link and return it as JSON.
//
// @Tags Links
// @Router /api/v1/{entity}/{uid}/links [post]
// CreateLink adds a new share link and returns it as JSON.
// Note: Internal helper used by resource-specific endpoints (e.g., albums, photos).
// Swagger annotations are defined on those public handlers to avoid generating
// undocumented generic paths like "/api/v1/{entity}/{uid}/links".
func CreateLink(c *gin.Context) {
s := Auth(c, acl.ResourceShares, acl.ActionCreate)

View File

@@ -0,0 +1,153 @@
package api
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
"golang.org/x/time/rate"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/server/limiter"
"github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/service/http/header"
)
func TestOAuthToken_RateLimit_ClientCredentials(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
OAuthToken(router)
// Tighten rate limits
oldLogin, oldAuth := limiter.Login, limiter.Auth
defer func() { limiter.Login, limiter.Auth = oldLogin, oldAuth }()
limiter.Login = limiter.NewLimit(rate.Every(24*time.Hour), 3) // burst 3
limiter.Auth = limiter.NewLimit(rate.Every(24*time.Hour), 3)
// Invalid client secret repeatedly (from UnknownIP: no headers set)
path := "/api/v1/oauth/token"
for i := 0; i < 3; i++ {
data := url.Values{
"grant_type": {authn.GrantClientCredentials.String()},
"client_id": {"cs5cpu17n6gj2qo5"},
"client_secret": {"INVALID"},
"scope": {"metrics"},
}
req, _ := http.NewRequest(http.MethodPost, path, strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
// Next call should be rate limited
data := url.Values{
"grant_type": {authn.GrantClientCredentials.String()},
"client_id": {"cs5cpu17n6gj2qo5"},
"client_secret": {"INVALID"},
"scope": {"metrics"},
}
req, _ := http.NewRequest(http.MethodPost, path, strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
req.Header.Set("X-Forwarded-For", "198.51.100.99")
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
func TestOAuthToken_ResponseFields_ClientSuccess(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
OAuthToken(router)
data := url.Values{
"grant_type": {authn.GrantClientCredentials.String()},
"client_id": {"cs5cpu17n6gj2qo5"},
"client_secret": {"xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e"},
"scope": {"metrics"},
}
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
body := w.Body.String()
assert.NotEmpty(t, gjson.Get(body, "access_token").String())
tokType := gjson.Get(body, "token_type").String()
assert.True(t, strings.EqualFold(tokType, "bearer"))
assert.GreaterOrEqual(t, gjson.Get(body, "expires_in").Int(), int64(0))
assert.Equal(t, "metrics", gjson.Get(body, "scope").String())
}
func TestOAuthToken_ResponseFields_UserSuccess(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
sess := AuthenticateUser(app, router, "alice", "Alice123!")
OAuthToken(router)
data := url.Values{
"grant_type": {authn.GrantPassword.String()},
"client_name": {"TestApp"},
"username": {"alice"},
"password": {"Alice123!"},
"scope": {"*"},
}
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
req.Header.Set(header.XAuthToken, sess)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
body := w.Body.String()
assert.NotEmpty(t, gjson.Get(body, "access_token").String())
tokType := gjson.Get(body, "token_type").String()
assert.True(t, strings.EqualFold(tokType, "bearer"))
assert.GreaterOrEqual(t, gjson.Get(body, "expires_in").Int(), int64(0))
assert.Equal(t, "*", gjson.Get(body, "scope").String())
}
func TestOAuthToken_BadRequestsAndErrors(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
OAuthToken(router)
// Missing grant_type & creds -> invalid credentials
req, _ := http.NewRequest(http.MethodPost, "/api/v1/oauth/token", nil)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
// Unknown grant type
data := url.Values{
"grant_type": {"unknown"},
}
req, _ = http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
w = httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
// Password grant with wrong password
sess := AuthenticateUser(app, router, "alice", "Alice123!")
data = url.Values{
"grant_type": {authn.GrantPassword.String()},
"client_name": {"AppPasswordAlice"},
"username": {"alice"},
"password": {"WrongPassword!"},
"scope": {"*"},
}
req, _ = http.NewRequest(http.MethodPost, "/api/v1/oauth/token", strings.NewReader(data.Encode()))
req.Header.Set(header.ContentType, header.ContentTypeForm)
req.Header.Set(header.XAuthToken, sess)
w = httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}

View File

@@ -0,0 +1,44 @@
package api
import (
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/server/limiter"
)
func TestCreateSession_RateLimitExceeded(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
CreateSession(router)
// Tighten rate limits and do repeated bad logins from UnknownIP
oldLogin, oldAuth := limiter.Login, limiter.Auth
defer func() { limiter.Login, limiter.Auth = oldLogin, oldAuth }()
limiter.Login = limiter.NewLimit(rate.Every(24*time.Hour), 3)
limiter.Auth = limiter.NewLimit(rate.Every(24*time.Hour), 3)
for i := 0; i < 3; i++ {
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{"username": "admin", "password": "wrong"}`)
assert.Equal(t, http.StatusUnauthorized, r.Code)
}
// Next attempt should be 429
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{"username": "admin", "password": "wrong"}`)
assert.Equal(t, http.StatusTooManyRequests, r.Code)
}
func TestCreateSession_MissingFields(t *testing.T) {
app, router, conf := NewApiTest()
conf.SetAuthMode(config.AuthModePasswd)
defer conf.SetAuthMode(config.AuthModePublic)
CreateSession(router)
// Empty object -> unauthorized (invalid credentials)
r := PerformRequestWithBody(app, http.MethodPost, "/api/v1/session", `{}`)
assert.Equal(t, http.StatusUnauthorized, r.Code)
}

File diff suppressed because it is too large Load Diff

View File

@@ -37,7 +37,7 @@ import (
// @Param files formData file true "one or more files to upload (repeat the field for multiple files)"
// @Success 200 {object} i18n.Response
// @Failure 400,401,403,413,429,507 {object} i18n.Response
// @Router /users/{uid}/upload/{token} [post]
// @Router /api/v1/users/{uid}/upload/{token} [post]
func UploadUserFiles(router *gin.RouterGroup) {
router.POST("/users/:uid/upload/:token", func(c *gin.Context) {
conf := get.Config()
@@ -273,7 +273,7 @@ func UploadCheckFile(destName string, rejectRaw bool, totalSizeLimit int64) (rem
// @Param options body form.UploadOptions true "processing options"
// @Success 200 {object} i18n.Response
// @Failure 400,401,403,404,409,429 {object} i18n.Response
// @Router /users/{uid}/upload/{token} [put]
// @Router /api/v1/users/{uid}/upload/{token} [put]
func ProcessUserUpload(router *gin.RouterGroup) {
router.PUT("/users/:uid/upload/:token", func(c *gin.Context) {
s := AuthAny(c, acl.ResourceFiles, acl.Permissions{acl.ActionManage, acl.ActionUpload})

View File

@@ -0,0 +1,677 @@
package api
import (
"archive/zip"
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/service/http/header"
)
// buildMultipart builds a multipart form with one field name "files" and provided files.
func buildMultipart(files map[string][]byte) (body *bytes.Buffer, contentType string, err error) {
body = &bytes.Buffer{}
mw := multipart.NewWriter(body)
for name, data := range files {
fw, cerr := mw.CreateFormFile("files", name)
if cerr != nil {
return nil, "", cerr
}
if _, werr := fw.Write(data); werr != nil {
return nil, "", werr
}
}
cerr := mw.Close()
return body, mw.FormDataContentType(), cerr
}
// buildMultipartTwo builds a multipart form with exactly two files (same field name: "files").
func buildMultipartTwo(name1 string, data1 []byte, name2 string, data2 []byte) (body *bytes.Buffer, contentType string, err error) {
body = &bytes.Buffer{}
mw := multipart.NewWriter(body)
for _, it := range [][2]interface{}{{name1, data1}, {name2, data2}} {
fw, cerr := mw.CreateFormFile("files", it[0].(string))
if cerr != nil {
return nil, "", cerr
}
if _, werr := fw.Write(it[1].([]byte)); werr != nil {
return nil, "", werr
}
}
cerr := mw.Close()
return body, mw.FormDataContentType(), cerr
}
// buildZipWithDirsAndFiles creates a zip archive bytes with explicit directory entries and files.
func buildZipWithDirsAndFiles(dirs []string, files map[string][]byte) []byte {
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
// Directories (ensure trailing slash)
for _, d := range dirs {
name := d
if !strings.HasSuffix(name, "/") {
name += "/"
}
_, _ = zw.Create(name)
}
// Files
for name, data := range files {
f, _ := zw.Create(name)
_, _ = f.Write(data)
}
_ = zw.Close()
return zbuf.Bytes()
}
func findUploadedFiles(t *testing.T, base string) []string {
t.Helper()
var out []string
_ = filepath.Walk(base, func(path string, info os.FileInfo, err error) error {
if err == nil && !info.IsDir() {
out = append(out, path)
}
return nil
})
return out
}
// findUploadedFilesForToken lists files only under upload subfolders whose name ends with token suffix.
func findUploadedFilesForToken(t *testing.T, base string, tokenSuffix string) []string {
t.Helper()
var out []string
entries, _ := os.ReadDir(base)
for _, e := range entries {
if !e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(name, tokenSuffix) {
continue
}
dir := filepath.Join(base, name)
_ = filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
if err == nil && !info.IsDir() {
out = append(out, p)
}
return nil
})
}
return out
}
// removeUploadDirsForToken removes upload subdirectories whose name ends with tokenSuffix.
func removeUploadDirsForToken(t *testing.T, base string, tokenSuffix string) {
t.Helper()
entries, _ := os.ReadDir(base)
for _, e := range entries {
if !e.IsDir() {
continue
}
name := e.Name()
if strings.HasSuffix(name, tokenSuffix) {
_ = os.RemoveAll(filepath.Join(base, name))
}
}
}
func TestUploadUserFiles_Multipart_SingleJPEG(t *testing.T) {
app, router, conf := NewApiTest()
// Limit allowed upload extensions to ensure text files get rejected in tests
conf.Options().UploadAllow = "jpg"
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
// Cleanup: remove token-specific upload dir after test
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "abc123")
// Load a real tiny JPEG from testdata
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
data, err := os.ReadFile(jpgPath)
if err != nil {
t.Skipf("missing example.jpg: %v", err)
}
body, ctype, err := buildMultipart(map[string][]byte{"example.jpg": data})
if err != nil {
t.Fatal(err)
}
reqUrl := "/api/v1/users/" + adminUid + "/upload/abc123"
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
// Verify file written somewhere under users/<uid>/upload/*
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, uploadBase, "abc123")
// At least one file written
assert.NotEmpty(t, files)
// Expect the filename to appear somewhere
var found bool
for _, f := range files {
if strings.HasSuffix(f, "example.jpg") {
found = true
break
}
}
assert.True(t, found, "uploaded JPEG not found")
}
func TestUploadUserFiles_Multipart_ZipExtract(t *testing.T) {
app, router, conf := NewApiTest()
// Allow archives and restrict allowed extensions to images
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,png,zip"
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
// Cleanup after test
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "ziptok")
// Create an in-memory zip with one JPEG (valid) and one TXT (rejected)
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
jpg, err := os.ReadFile(jpgPath)
if err != nil {
t.Skip("missing example.jpg")
}
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
// add jpeg
jf, _ := zw.Create("a.jpg")
_, _ = jf.Write(jpg)
// add txt
tf, _ := zw.Create("note.txt")
_, _ = io.WriteString(tf, "hello")
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"upload.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
reqUrl := "/api/v1/users/" + adminUid + "/upload/zipoff"
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, uploadBase, "zipoff")
// Expect extracted jpeg present and txt absent
var jpgFound, txtFound bool
for _, f := range files {
if strings.HasSuffix(f, "a.jpg") {
jpgFound = true
}
if strings.HasSuffix(f, "note.txt") {
txtFound = true
}
}
assert.True(t, jpgFound, "extracted jpeg not found")
assert.False(t, txtFound, "text file should be rejected")
}
func TestUploadUserFiles_Multipart_ArchivesDisabled(t *testing.T) {
app, router, conf := NewApiTest()
// disallow archives while allowing the .zip extension in filter
conf.Options().UploadArchives = false
conf.Options().UploadAllow = "jpg,zip"
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
// Cleanup after test
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipoff")
// zip with one jpeg inside
jpgPath := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
jpg, err := os.ReadFile(jpgPath)
if err != nil {
t.Skip("missing example.jpg")
}
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
jf, _ := zw.Create("a.jpg")
_, _ = jf.Write(jpg)
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"upload.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
reqUrl := "/api/v1/users/" + adminUid + "/upload/ziptok"
req := httptest.NewRequest(http.MethodPost, reqUrl, body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
// server returns 200 even if rejected internally; nothing extracted/saved
assert.Equal(t, http.StatusOK, w.Code, w.Body.String())
uploadBase := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, uploadBase, "ziptok")
assert.Empty(t, files, "no files should remain when archives disabled")
}
func TestUploadUserFiles_Multipart_PerFileLimitExceeded(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadAllow = "jpg"
conf.Options().OriginalsLimit = 1 // 1 MiB per-file
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "size1")
// Build a 2MiB dummy payload (not a real JPEG; that's fine for pre-save size check)
big := bytes.Repeat([]byte("A"), 2*1024*1024)
body, ctype, err := buildMultipart(map[string][]byte{"big.jpg": big})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/size1", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Ensure nothing saved
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "size1")
assert.Empty(t, files)
}
func TestUploadUserFiles_Multipart_TotalLimitExceeded(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadAllow = "jpg"
conf.Options().UploadLimit = 1 // 1 MiB total
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "total")
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
// build multipart with two images so sum > 1 MiB (2*~63KiB = ~126KiB) -> still <1MiB, so use 16 copies
// build two bigger bodies by concatenation
times := 9
big1 := bytes.Repeat(data, times)
big2 := bytes.Repeat(data, times)
body, ctype, err := buildMultipartTwo("a.jpg", big1, "b.jpg", big2)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/total", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Expect at most one file saved (second should be rejected by total limit)
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "total")
assert.LessOrEqual(t, len(files), 1)
}
func TestUploadUserFiles_Multipart_ZipPartialExtraction(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadLimit = 1 // 1 MiB total
conf.Options().OriginalsLimit = 50 // 50 MiB per file
conf.Options().UploadNSFW = true // skip nsfw scanning to speed up test
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "partial")
// Build a zip containing multiple JPEG entries so that total extracted size > 1 MiB
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
for i := 0; i < 20; i++ { // ~20 * 63 KiB ≈ 1.2 MiB
f, _ := zw.Create(fmt.Sprintf("pic%02d.jpg", i+1))
_, _ = f.Write(data)
}
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"multi.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/partial", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "partial")
// At least one extracted, but not all 20 due to total limit
var countJPG int
for _, f := range files {
if strings.HasSuffix(f, ".jpg") {
countJPG++
}
}
assert.GreaterOrEqual(t, countJPG, 1)
assert.Less(t, countJPG, 20)
}
func TestUploadUserFiles_Multipart_ZipDeepNestingStress(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadNSFW = true
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipdeep")
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
// Build a deeply nested path (20 levels)
deep := ""
for i := 0; i < 20; i++ {
if i == 0 {
deep = "deep"
} else {
deep = filepath.Join(deep, fmt.Sprintf("lvl%02d", i))
}
}
name := filepath.Join(deep, "deep.jpg")
zbytes := buildZipWithDirsAndFiles(nil, map[string][]byte{name: data})
body, ctype, err := buildMultipart(map[string][]byte{"deepnest.zip": zbytes})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipdeep", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, base, "zipdeep")
// Only one file expected, deep path created
assert.Equal(t, 1, len(files))
assert.True(t, strings.Contains(files[0], filepath.Join("deep", "lvl01")))
}
func TestUploadUserFiles_Multipart_ZipRejectsHiddenAndTraversal(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadNSFW = true // skip scanning
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "rejects")
// Prepare a valid jpg payload
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
// Hidden file
f1, _ := zw.Create(".hidden.jpg")
_, _ = f1.Write(data)
// @ file
f2, _ := zw.Create("@meta.jpg")
_, _ = f2.Write(data)
// Traversal path (will be skipped by safe join in unzip)
f3, _ := zw.Create("dir/../traverse.jpg")
_, _ = f3.Write(data)
// Valid file
f4, _ := zw.Create("ok.jpg")
_, _ = f4.Write(data)
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"test.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/rejects", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
files := findUploadedFilesForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "rejects")
var hasOk, hasHidden, hasAt, hasTraverse bool
for _, f := range files {
if strings.HasSuffix(f, "ok.jpg") {
hasOk = true
}
if strings.HasSuffix(f, ".hidden.jpg") {
hasHidden = true
}
if strings.HasSuffix(f, "@meta.jpg") {
hasAt = true
}
if strings.HasSuffix(f, "traverse.jpg") {
hasTraverse = true
}
}
assert.True(t, hasOk)
assert.False(t, hasHidden)
assert.False(t, hasAt)
assert.False(t, hasTraverse)
}
func TestUploadUserFiles_Multipart_ZipNestedDirectories(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadNSFW = true
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipnest")
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
// Create nested dirs and files
dirs := []string{"nested", "nested/sub"}
files := map[string][]byte{
"nested/a.jpg": data,
"nested/sub/b.jpg": data,
}
zbytes := buildZipWithDirsAndFiles(dirs, files)
body, ctype, err := buildMultipart(map[string][]byte{"nested.zip": zbytes})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipnest", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
filesOut := findUploadedFilesForToken(t, base, "zipnest")
var haveA, haveB bool
for _, f := range filesOut {
if strings.HasSuffix(f, filepath.Join("nested", "a.jpg")) {
haveA = true
}
if strings.HasSuffix(f, filepath.Join("nested", "sub", "b.jpg")) {
haveB = true
}
}
assert.True(t, haveA)
assert.True(t, haveB)
// Directories exist
// Locate token dir
entries, _ := os.ReadDir(base)
var tokenDir string
for _, e := range entries {
if e.IsDir() && strings.HasSuffix(e.Name(), "zipnest") {
tokenDir = filepath.Join(base, e.Name())
break
}
}
if tokenDir != "" {
_, errA := os.Stat(filepath.Join(tokenDir, "nested"))
_, errB := os.Stat(filepath.Join(tokenDir, "nested", "sub"))
assert.NoError(t, errA)
assert.NoError(t, errB)
} else {
t.Fatalf("token dir not found under %s", base)
}
}
func TestUploadUserFiles_Multipart_ZipImplicitDirectories(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadNSFW = true
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipimpl")
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
// Create zip containing only files with nested paths (no explicit directory entries)
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
f1, _ := zw.Create(filepath.Join("nested", "a.jpg"))
_, _ = f1.Write(data)
f2, _ := zw.Create(filepath.Join("nested", "sub", "b.jpg"))
_, _ = f2.Write(data)
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"nested-files-only.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipimpl", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, base, "zipimpl")
var haveA, haveB bool
for _, f := range files {
if strings.HasSuffix(f, filepath.Join("nested", "a.jpg")) {
haveA = true
}
if strings.HasSuffix(f, filepath.Join("nested", "sub", "b.jpg")) {
haveB = true
}
}
assert.True(t, haveA)
assert.True(t, haveB)
// Confirm directories were implicitly created
entries, _ := os.ReadDir(base)
var tokenDir string
for _, e := range entries {
if e.IsDir() && strings.HasSuffix(e.Name(), "zipimpl") {
tokenDir = filepath.Join(base, e.Name())
break
}
}
if tokenDir == "" {
t.Fatalf("token dir not found under %s", base)
}
_, errA := os.Stat(filepath.Join(tokenDir, "nested"))
_, errB := os.Stat(filepath.Join(tokenDir, "nested", "sub"))
assert.NoError(t, errA)
assert.NoError(t, errB)
}
func TestUploadUserFiles_Multipart_ZipAbsolutePathRejected(t *testing.T) {
app, router, conf := NewApiTest()
conf.Options().UploadArchives = true
conf.Options().UploadAllow = "jpg,zip"
conf.Options().UploadNSFW = true
UploadUserFiles(router)
token := AuthenticateAdmin(app, router)
adminUid := entity.Admin.UserUID
defer removeUploadDirsForToken(t, filepath.Join(conf.UserStoragePath(adminUid), "upload"), "zipabs")
data, err := os.ReadFile(filepath.Clean("../../pkg/fs/testdata/directory/example.jpg"))
if err != nil {
t.Skip("missing example.jpg")
}
// Zip with an absolute path entry
var zbuf bytes.Buffer
zw := zip.NewWriter(&zbuf)
f, _ := zw.Create("/abs.jpg")
_, _ = f.Write(data)
_ = zw.Close()
body, ctype, err := buildMultipart(map[string][]byte{"abs.zip": zbuf.Bytes()})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/users/"+adminUid+"/upload/zipabs", body)
req.Header.Set("Content-Type", ctype)
header.SetAuthorization(req, token)
w := httptest.NewRecorder()
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// No files should be extracted/saved for this token
base := filepath.Join(conf.UserStoragePath(adminUid), "upload")
files := findUploadedFilesForToken(t, base, "zipabs")
assert.Empty(t, files)
}

View File

@@ -3,6 +3,8 @@ package api
import (
"fmt"
"net/http"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@@ -43,3 +45,64 @@ func TestUploadUserFiles(t *testing.T) {
config.Options().FilesQuota = 0
})
}
func TestUploadCheckFile_AcceptsAndReducesLimit(t *testing.T) {
dir := t.TempDir()
// Copy a small known-good JPEG test file from pkg/fs/testdata
src := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
dst := filepath.Join(dir, "example.jpg")
b, err := os.ReadFile(src)
if err != nil {
t.Skipf("skip if test asset not present: %v", err)
}
if err := os.WriteFile(dst, b, 0o600); err != nil {
t.Fatal(err)
}
orig := int64(len(b))
rem, err := UploadCheckFile(dst, false, orig+100)
assert.NoError(t, err)
assert.Equal(t, int64(100), rem)
// file remains
assert.FileExists(t, dst)
}
func TestUploadCheckFile_TotalLimitReachedDeletes(t *testing.T) {
dir := t.TempDir()
// Make a tiny file
dst := filepath.Join(dir, "tiny.txt")
assert.NoError(t, os.WriteFile(dst, []byte("hello"), 0o600))
// Very small total limit (0) → should remove file and error
_, err := UploadCheckFile(dst, false, 0)
assert.Error(t, err)
_, statErr := os.Stat(dst)
assert.True(t, os.IsNotExist(statErr), "file should be removed when limit reached")
}
func TestUploadCheckFile_UnsupportedTypeDeletes(t *testing.T) {
dir := t.TempDir()
// Create a file with an unknown extension; should be rejected
dst := filepath.Join(dir, "unknown.xyz")
assert.NoError(t, os.WriteFile(dst, []byte("not-an-image"), 0o600))
_, err := UploadCheckFile(dst, false, 1<<20)
assert.Error(t, err)
_, statErr := os.Stat(dst)
assert.True(t, os.IsNotExist(statErr), "unsupported file should be removed")
}
func TestUploadCheckFile_SizeAccounting(t *testing.T) {
dir := t.TempDir()
// Use known-good JPEG
src := filepath.Clean("../../pkg/fs/testdata/directory/example.jpg")
data, err := os.ReadFile(src)
if err != nil {
t.Skip("asset missing; skip")
}
f := filepath.Join(dir, "a.jpg")
assert.NoError(t, os.WriteFile(f, data, 0o600))
size := int64(len(data))
// Set remaining limit to size+1 so it does not hit the removal branch (which triggers on <=0)
rem, err := UploadCheckFile(f, false, size+1)
assert.NoError(t, err)
assert.Equal(t, int64(1), rem)
}

View File

@@ -152,9 +152,24 @@ func WebDAVFileName(request *http.Request, router *gin.RouterGroup, conf *config
// Determine the absolute file path based on the request URL and the configuration.
switch basePath {
case conf.BaseUri(WebDAVOriginals):
fileName = filepath.Join(conf.OriginalsPath(), strings.TrimPrefix(request.URL.Path, basePath))
// Resolve the requested path safely under OriginalsPath.
rel := strings.TrimPrefix(request.URL.Path, basePath)
// Make relative if a leading slash remains after trimming the base.
rel = strings.TrimLeft(rel, "/\\")
if name, err := joinUnderBase(conf.OriginalsPath(), rel); err == nil {
fileName = name
} else {
return ""
}
case conf.BaseUri(WebDAVImport):
fileName = filepath.Join(conf.ImportPath(), strings.TrimPrefix(request.URL.Path, basePath))
// Resolve the requested path safely under ImportPath.
rel := strings.TrimPrefix(request.URL.Path, basePath)
rel = strings.TrimLeft(rel, "/\\")
if name, err := joinUnderBase(conf.ImportPath(), rel); err == nil {
fileName = name
} else {
return ""
}
default:
return ""
}
@@ -167,6 +182,27 @@ func WebDAVFileName(request *http.Request, router *gin.RouterGroup, conf *config
return fileName
}
// joinUnderBase joins a base directory with a relative name and ensures
// that the resulting path stays within the base directory. Absolute
// paths and Windows-style volume names are rejected.
func joinUnderBase(baseDir, rel string) (string, error) {
if rel == "" {
return "", fmt.Errorf("invalid path")
}
// Reject absolute or volume paths.
if filepath.IsAbs(rel) || filepath.VolumeName(rel) != "" {
return "", fmt.Errorf("invalid path: absolute or volume path not allowed")
}
cleaned := filepath.Clean(rel)
// Compose destination and verify it stays inside base.
dest := filepath.Join(baseDir, cleaned)
base := filepath.Clean(baseDir)
if dest != base && !strings.HasPrefix(dest, base+string(os.PathSeparator)) {
return "", fmt.Errorf("invalid path: outside base directory")
}
return dest, nil
}
// WebDAVSetFavoriteFlag adds the favorite flag to files uploaded via WebDAV.
func WebDAVSetFavoriteFlag(fileName string) {
yamlName := fs.AbsPrefix(fileName, false) + fs.ExtYml

View File

@@ -0,0 +1,37 @@
package server
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWebDAVSetFavoriteFlag_CreatesYamlOnce(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "img.jpg")
assert.NoError(t, os.WriteFile(file, []byte("x"), 0o600))
// First call creates YAML
WebDAVSetFavoriteFlag(file)
// YAML is written next to file without the media extension (AbsPrefix)
yml := filepath.Join(filepath.Dir(file), "img.yml")
assert.FileExists(t, yml)
// Write a marker and ensure second call doesn't overwrite content
orig, _ := os.ReadFile(yml)
WebDAVSetFavoriteFlag(file)
now, _ := os.ReadFile(yml)
assert.Equal(t, string(orig), string(now))
}
func TestWebDAVSetFileMtime_NoFuture(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "a.txt")
assert.NoError(t, os.WriteFile(file, []byte("x"), 0o600))
// Set a past mtime
WebDAVSetFileMtime(file, 946684800) // 2000-01-01 UTC
after, _ := os.Stat(file)
// Compare seconds to avoid platform-specific rounding
got := after.ModTime().Unix()
assert.Equal(t, int64(946684800), got)
}

View File

@@ -0,0 +1,90 @@
package server
import (
"net/http"
"net/url"
"os"
"path/filepath"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/pkg/fs"
)
func TestJoinUnderBase(t *testing.T) {
base := t.TempDir()
// Normal join
out, err := joinUnderBase(base, "a/b/c.txt")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(base, "a/b/c.txt"), out)
// Absolute rejected
_, err = joinUnderBase(base, "/etc/passwd")
assert.Error(t, err)
// Parent traversal rejected
_, err = joinUnderBase(base, "../../etc/passwd")
assert.Error(t, err)
}
func TestWebDAVFileName_PathTraversalRejected(t *testing.T) {
dir := t.TempDir()
// Create a legitimate file inside base to ensure happy-path works later.
insideFile := filepath.Join(dir, "ok.txt")
assert.NoError(t, fs.WriteString(insideFile, "ok"))
conf := config.NewTestConfig("server-webdav")
conf.Options().OriginalsPath = dir
r := gin.New()
grp := r.Group(conf.BaseUri(WebDAVOriginals))
// Attempt traversal to outside path.
req := &http.Request{Method: http.MethodPut}
req.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/../../etc/passwd"}
got := WebDAVFileName(req, grp, conf)
assert.Equal(t, "", got, "should reject traversal")
// Happy path: file under base resolves and exists.
req2 := &http.Request{Method: http.MethodPut}
req2.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/ok.txt"}
got = WebDAVFileName(req2, grp, conf)
assert.Equal(t, insideFile, got)
}
func TestWebDAVFileName_MethodNotPut(t *testing.T) {
conf := config.NewTestConfig("server-webdav")
r := gin.New()
grp := r.Group(conf.BaseUri(WebDAVOriginals))
req := &http.Request{Method: http.MethodGet}
req.URL = &url.URL{Path: conf.BaseUri(WebDAVOriginals) + "/anything.jpg"}
got := WebDAVFileName(req, grp, conf)
assert.Equal(t, "", got)
}
func TestWebDAVFileName_ImportBasePath(t *testing.T) {
conf := config.NewTestConfig("server-webdav")
r := gin.New()
grp := r.Group(conf.BaseUri(WebDAVImport))
// create a real file under import
file := filepath.Join(conf.ImportPath(), "in.jpg")
assert.NoError(t, fs.MkdirAll(filepath.Dir(file)))
assert.NoError(t, fs.WriteString(file, "x"))
req := &http.Request{Method: http.MethodPut}
req.URL = &url.URL{Path: conf.BaseUri(WebDAVImport) + "/in.jpg"}
got := WebDAVFileName(req, grp, conf)
assert.Equal(t, file, got)
}
func TestWebDAVSetFileMtime_FutureIgnored(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "a.txt")
assert.NoError(t, fs.WriteString(file, "x"))
before, _ := os.Stat(file)
future := time.Now().Add(2 * time.Hour).Unix()
WebDAVSetFileMtime(file, future)
after, _ := os.Stat(file)
assert.Equal(t, before.ModTime().Unix(), after.ModTime().Unix())
}

View File

@@ -0,0 +1,298 @@
package server
import (
"bytes"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/service/http/header"
)
func setupWebDAVRouter(conf *config.Config) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
grp := r.Group(conf.BaseUri(WebDAVOriginals), WebDAVAuth(conf))
WebDAV(conf.OriginalsPath(), grp, conf)
return r
}
func authBearer(req *http.Request) {
sess := entity.SessionFixtures.Get("alice_token_webdav")
header.SetAuthorization(req, sess.AuthToken())
}
func authBasic(req *http.Request) {
sess := entity.SessionFixtures.Get("alice_token_webdav")
basic := []byte(fmt.Sprintf("alice:%s", sess.AuthToken()))
req.Header.Set(header.Auth, fmt.Sprintf("%s %s", header.AuthBasic, base64.StdEncoding.EncodeToString(basic)))
}
func TestWebDAVWrite_MKCOL_PUT(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// MKCOL
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/wdvdir", nil)
authBearer(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1) // Created
// PUT file
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/wdvdir/hello.txt", bytes.NewBufferString("hello"))
authBearer(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// file exists
path := filepath.Join(conf.OriginalsPath(), "wdvdir", "hello.txt")
b, err := os.ReadFile(path)
assert.NoError(t, err)
assert.Equal(t, "hello", string(b))
}
func TestWebDAVWrite_MOVE_COPY(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Ensure source and destination directories via MKCOL
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/src", nil)
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/dst", nil)
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// Create source file via PUT
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/src/a.txt", bytes.NewBufferString("A"))
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// MOVE /originals/src/a.txt -> /originals/dst/b.txt
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/a.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/b.txt")
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// Verify moved
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "src", "a.txt"))
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "b.txt"))
// COPY /originals/dst/b.txt -> /originals/dst/c.txt
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/dst/b.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/c.txt")
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// Verify copy
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "b.txt"))
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "dst", "c.txt"))
}
func TestWebDAVWrite_OverwriteSemantics(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Prepare src and dst
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "src"), 0o700)
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "dst"), 0o700)
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "src", "f.txt"), []byte("NEW"), 0o600)
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"), []byte("OLD"), 0o600)
// COPY with Overwrite: F -> should not overwrite existing
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/src/f.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/f.txt")
req.Header.Set("Overwrite", "F")
authBasic(req)
r.ServeHTTP(w, req)
// Expect not successful (commonly 412 Precondition Failed)
if w.Code == 201 || w.Code == 204 {
t.Fatalf("expected failure when Overwrite=F, got %d", w.Code)
}
// Content remains OLD
b, _ := os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"))
assert.Equal(t, "OLD", string(b))
// COPY with Overwrite: T -> must overwrite
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/src/f.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/f.txt")
req.Header.Set("Overwrite", "T")
authBasic(req)
r.ServeHTTP(w, req)
// Success (201/204 acceptable)
if !(w.Code == 201 || w.Code == 204) {
t.Fatalf("expected success for Overwrite=T, got %d", w.Code)
}
b, _ = os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "f.txt"))
assert.Equal(t, "NEW", string(b))
// MOVE with Overwrite: F to existing file -> expect failure
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "src", "g.txt"), []byte("GNEW"), 0o600)
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "dst", "g.txt"), []byte("GOLD"), 0o600)
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/g.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/g.txt")
req.Header.Set("Overwrite", "F")
authBasic(req)
r.ServeHTTP(w, req)
if w.Code == 201 || w.Code == 204 {
t.Fatalf("expected failure when Overwrite=F for MOVE, got %d", w.Code)
}
// MOVE with Overwrite: T -> overwrites and removes source
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/src/g.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst/g.txt")
req.Header.Set("Overwrite", "T")
authBasic(req)
r.ServeHTTP(w, req)
if !(w.Code == 201 || w.Code == 204) {
t.Fatalf("expected success for MOVE Overwrite=T, got %d", w.Code)
}
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "src", "g.txt"))
gb, _ := os.ReadFile(filepath.Join(conf.OriginalsPath(), "dst", "g.txt"))
assert.Equal(t, "GNEW", string(gb))
}
func TestWebDAVWrite_MoveMissingDestination(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Ensure src exists
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "mv"), 0o700)
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "mv", "file.txt"), []byte("X"), 0o600)
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/mv/file.txt", nil)
// no Destination header
authBasic(req)
r.ServeHTTP(w, req)
// Expect failure (not 201/204)
if w.Code == 201 || w.Code == 204 {
t.Fatalf("expected failure when Destination header missing, got %d", w.Code)
}
// Source remains
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "mv", "file.txt"))
}
func TestWebDAVWrite_CopyInvalidDestinationPrefix(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Ensure src exists
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "cp"), 0o700)
_ = os.WriteFile(filepath.Join(conf.OriginalsPath(), "cp", "a.txt"), []byte("A"), 0o600)
// COPY to a destination outside the handler prefix
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/cp/a.txt", nil)
req.Header.Set("Destination", "/notwebdav/d.txt")
authBasic(req)
r.ServeHTTP(w, req)
// Expect failure
if w.Code == 201 || w.Code == 204 {
t.Fatalf("expected failure for invalid Destination prefix, got %d", w.Code)
}
// Destination not created
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "notwebdav", "d.txt"))
}
func TestWebDAVWrite_MoveNonExistentSource(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Ensure destination dir exists
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "dst2"), 0o700)
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/nosuch/file.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/dst2/file.txt")
authBasic(req)
r.ServeHTTP(w, req)
// Expect failure (e.g., 404)
if w.Code == 201 || w.Code == 204 {
t.Fatalf("expected failure moving non-existent source, got %d", w.Code)
}
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "dst2", "file.txt"))
}
func TestWebDAVWrite_CopyTraversalDestination(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Create source file via PUT
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "travsrc"), 0o700)
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/travsrc/a.txt", bytes.NewBufferString("A"))
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// Attempt COPY with traversal in Destination
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodCopy, conf.BaseUri(WebDAVOriginals)+"/travsrc/a.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/../evil.txt")
authBasic(req)
r.ServeHTTP(w, req)
// Expect success with sanitized destination inside base
if !(w.Code == 201 || w.Code == 204) {
t.Fatalf("expected success (sanitized), got %d", w.Code)
}
// Not created above originals; created as /originals/evil.txt
parent := filepath.Dir(conf.OriginalsPath())
assert.NoFileExists(t, filepath.Join(parent, "evil.txt"))
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "evil.txt"))
}
func TestWebDAVWrite_MoveTraversalDestination(t *testing.T) {
conf := config.TestConfig()
r := setupWebDAVRouter(conf)
// Create source file via PUT
_ = os.MkdirAll(filepath.Join(conf.OriginalsPath(), "travsrc2"), 0o700)
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodPut, conf.BaseUri(WebDAVOriginals)+"/travsrc2/a.txt", bytes.NewBufferString("A"))
authBasic(req)
r.ServeHTTP(w, req)
assert.InDelta(t, 201, w.Code, 1)
// Attempt MOVE with traversal in Destination
w = httptest.NewRecorder()
req = httptest.NewRequest(MethodMove, conf.BaseUri(WebDAVOriginals)+"/travsrc2/a.txt", nil)
req.Header.Set("Destination", conf.BaseUri(WebDAVOriginals)+"/../evil2.txt")
authBasic(req)
r.ServeHTTP(w, req)
if !(w.Code == 201 || w.Code == 204) {
t.Fatalf("expected success (sanitized) for MOVE, got %d", w.Code)
}
// Source removed; destination created inside base, not outside
assert.NoFileExists(t, filepath.Join(conf.OriginalsPath(), "travsrc2", "a.txt"))
parent := filepath.Dir(conf.OriginalsPath())
assert.NoFileExists(t, filepath.Join(parent, "evil2.txt"))
assert.FileExists(t, filepath.Join(conf.OriginalsPath(), "evil2.txt"))
}
func TestWebDAVWrite_ReadOnlyForbidden(t *testing.T) {
conf := config.TestConfig()
conf.Options().ReadOnly = true
r := setupWebDAVRouter(conf)
w := httptest.NewRecorder()
req := httptest.NewRequest(MethodMkcol, conf.BaseUri(WebDAVOriginals)+"/ro", nil)
authBearer(req)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}

View File

@@ -0,0 +1,42 @@
package avatar
import (
"strings"
"time"
"github.com/photoprism/photoprism/pkg/service/http/safe"
)
var (
// Stricter defaults for avatar images than the generic HTTP safe defaults.
defaultTimeout = 15 * time.Second
defaultMaxSize int64 = 10 << 20 // 10 MiB for avatar images
)
// SafeDownload delegates avatar image downloads to the shared HTTP safe downloader
// with hardened defaults suitable for small image files.
// Callers may pass a partially filled safe.Options to override defaults.
func SafeDownload(destPath, rawURL string, opt *safe.Options) error {
// Start with strict avatar defaults.
o := &safe.Options{
Timeout: defaultTimeout,
MaxSizeBytes: defaultMaxSize,
AllowPrivate: false, // block private/loopback by default
// Prefer images but allow others at low priority; MIME is validated later.
Accept: "image/jpeg, image/png, */*;q=0.1",
}
if opt != nil {
if opt.Timeout > 0 {
o.Timeout = opt.Timeout
}
if opt.MaxSizeBytes > 0 {
o.MaxSizeBytes = opt.MaxSizeBytes
}
// Bool has no sentinel; just copy the value.
o.AllowPrivate = opt.AllowPrivate
if strings.TrimSpace(opt.Accept) != "" {
o.Accept = opt.Accept
}
}
return safe.Download(destPath, rawURL, o)
}

View File

@@ -0,0 +1,73 @@
package avatar
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/photoprism/photoprism/pkg/service/http/safe"
)
func TestSafeDownload_InvalidScheme(t *testing.T) {
dir := t.TempDir()
dest := filepath.Join(dir, "x")
if err := SafeDownload(dest, "file:///etc/passwd", nil); err == nil {
t.Fatal("expected error for invalid scheme")
}
}
func TestSafeDownload_PrivateIPBlocked(t *testing.T) {
dir := t.TempDir()
dest := filepath.Join(dir, "x")
if err := SafeDownload(dest, "http://127.0.0.1/test.png", nil); err == nil {
t.Fatal("expected SSRF private IP block")
}
}
func TestSafeDownload_MaxSizeExceeded(t *testing.T) {
// Local server; allow private for test.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
// 2MB body
w.WriteHeader(http.StatusOK)
buf := make([]byte, 2<<20)
_, _ = w.Write(buf)
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "big")
err := SafeDownload(dest, ts.URL, &safe.Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true})
if err == nil {
t.Fatal("expected size exceeded error")
}
if _, statErr := os.Stat(dest); !os.IsNotExist(statErr) {
t.Fatalf("expected no output file on error, got stat err=%v", statErr)
}
}
func TestSafeDownload_Succeeds(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "ok")
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "ok")
if err := SafeDownload(dest, ts.URL, &safe.Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true}); err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("read: %v", err)
}
if string(b) != "ok" {
t.Fatalf("unexpected content: %q", string(b))
}
}

View File

@@ -32,7 +32,8 @@ func SetUserImageURL(m *entity.User, imageUrl, imageSrc, thumbPath string) error
tmpName := filepath.Join(os.TempDir(), rnd.Base36(64))
if err = fs.Download(tmpName, u.String()); err != nil {
// Hardened remote fetch with SSRF and size limits.
if err = SafeDownload(tmpName, u.String(), nil); err != nil {
return fmt.Errorf("failed to download avatar image (%w)", err)
}

View File

@@ -27,11 +27,12 @@ package fs
import (
"fmt"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"syscall"
"github.com/photoprism/photoprism/pkg/service/http/safe"
)
var ignoreCase bool
@@ -206,40 +207,9 @@ func Abs(name string) string {
// Download downloads a file from a URL.
func Download(fileName string, url string) error {
if dir := filepath.Dir(fileName); dir == "" || dir == "/" || dir == "." || dir == ".." {
return fmt.Errorf("invalid path")
} else if err := MkdirAll(dir); err != nil {
return err
}
// Create the file
out, err := os.Create(fileName)
if err != nil {
return err
}
defer out.Close()
// Get the data
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
// Check server response
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
}
// Writer the body to file
_, err = io.Copy(out, resp.Body)
if err != nil {
return err
}
return nil
// Preserve existing semantics but with safer network behavior.
// Allow private IPs by default to avoid breaking intended internal downloads.
return safe.Download(fileName, url, &safe.Options{AllowPrivate: true})
}
// DirIsEmpty returns true if a directory is empty.

View File

@@ -0,0 +1,218 @@
package safe
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"os"
"path/filepath"
"strings"
"time"
)
// Download fetches a URL to a destination file with timeouts, size limits, and optional SSRF protection.
func Download(destPath, rawURL string, opt *Options) error {
if destPath == "" {
return errors.New("invalid destination path")
}
// Prepare destination directory.
if dir := filepath.Dir(destPath); dir == "" || dir == "/" || dir == "." || dir == ".." {
return errors.New("invalid destination directory")
} else if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
u, err := url.Parse(rawURL)
if err != nil {
return err
}
if !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
return ErrSchemeNotAllowed
}
// Defaults w/ env overrides
maxSize := defaultMaxSize
if n := envInt64("PHOTOPRISM_HTTP_MAX_DOWNLOAD"); n > 0 {
maxSize = n
}
timeout := defaultTimeout
if d := envDuration("PHOTOPRISM_HTTP_TIMEOUT"); d > 0 {
timeout = d
}
o := Options{Timeout: timeout, MaxSizeBytes: maxSize, AllowPrivate: true, Accept: "*/*"}
if opt != nil {
if opt.Timeout > 0 {
o.Timeout = opt.Timeout
}
if opt.MaxSizeBytes > 0 {
o.MaxSizeBytes = opt.MaxSizeBytes
}
o.AllowPrivate = opt.AllowPrivate
if strings.TrimSpace(opt.Accept) != "" {
o.Accept = opt.Accept
}
}
// Optional SSRF block
if !o.AllowPrivate {
if ip := net.ParseIP(u.Hostname()); ip != nil {
if isPrivateOrDisallowedIP(ip) {
return ErrPrivateIP
}
} else {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
addrs, lookErr := net.DefaultResolver.LookupIPAddr(ctx, u.Hostname())
if lookErr != nil {
return lookErr
}
for _, a := range addrs {
if isPrivateOrDisallowedIP(a.IP) {
return ErrPrivateIP
}
}
}
}
// Enforce redirect validation when private networks are disallowed.
client := &http.Client{
Timeout: o.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if !o.AllowPrivate {
h := req.URL.Hostname()
if ip := net.ParseIP(h); ip != nil {
if isPrivateOrDisallowedIP(ip) {
return ErrPrivateIP
}
} else {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
addrs, lookErr := net.DefaultResolver.LookupIPAddr(ctx, h)
if lookErr != nil {
return lookErr
}
for _, a := range addrs {
if isPrivateOrDisallowedIP(a.IP) {
return ErrPrivateIP
}
}
}
}
// Propagate Accept header from the first request.
if len(via) > 0 {
if v := via[0].Header.Get("Accept"); v != "" {
req.Header.Set("Accept", v)
}
}
return nil
},
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return err
}
if o.Accept != "" {
req.Header.Set("Accept", o.Accept)
}
// Capture the final remote IP used for the connection.
var finalIP net.IP
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if addr := info.Conn.RemoteAddr(); addr != nil {
host, _, _ := net.SplitHostPort(addr.String())
if ip := net.ParseIP(host); ip != nil {
finalIP = ip
}
}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
// Validate the connected peer address when private ranges are disallowed.
if !o.AllowPrivate && finalIP != nil && isPrivateOrDisallowedIP(finalIP) {
return ErrPrivateIP
}
if resp.ContentLength > 0 && o.MaxSizeBytes > 0 && resp.ContentLength > o.MaxSizeBytes {
return ErrSizeExceeded
}
tmp := destPath + ".part"
f, err := os.OpenFile(tmp, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
if err != nil {
return err
}
defer func() {
f.Close()
if err != nil {
_ = os.Remove(tmp)
}
}()
var r io.Reader = resp.Body
if o.MaxSizeBytes > 0 {
r = io.LimitReader(resp.Body, o.MaxSizeBytes+1)
}
n, copyErr := io.Copy(f, r)
if copyErr != nil {
err = copyErr
return err
}
if o.MaxSizeBytes > 0 && n > o.MaxSizeBytes {
err = ErrSizeExceeded
return err
}
if err = f.Close(); err != nil {
return err
}
if err = os.Rename(tmp, destPath); err != nil {
return err
}
return nil
}
func isPrivateOrDisallowedIP(ip net.IP) bool {
if ip == nil {
return true
}
if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
if v4 := ip.To4(); v4 != nil {
if v4[0] == 10 {
return true
}
if v4[0] == 172 && v4[1] >= 16 && v4[1] <= 31 {
return true
}
if v4[0] == 192 && v4[1] == 168 {
return true
}
if v4[0] == 169 && v4[1] == 254 {
return true
}
return false
}
// IPv6 ULA fc00::/7
if ip.To16() != nil {
if ip[0]&0xFE == 0xFC {
return true
}
}
return false
}

View File

@@ -0,0 +1,57 @@
package safe
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
)
// Redirect to a private IP must be blocked when AllowPrivate=false.
func TestDownload_BlockRedirectToPrivate(t *testing.T) {
// Public-looking server that redirects to 127.0.0.1
redirectTarget := "http://127.0.0.1:65535/secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectTarget, http.StatusFound)
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "out")
err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: false})
if err == nil {
t.Fatalf("expected redirect SSRF to be blocked")
}
if _, statErr := os.Stat(dest); !os.IsNotExist(statErr) {
t.Fatalf("expected no output file on error, got stat err=%v", statErr)
}
}
// With AllowPrivate=true, redirects to a local httptest server should succeed.
func TestDownload_AllowRedirectToPrivate(t *testing.T) {
// Local private target that serves content.
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "ok")
}))
defer target.Close()
// Public-looking server that redirects to the private target.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target.URL, http.StatusFound)
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "ok")
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1 << 20, AllowPrivate: true}); err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := os.ReadFile(dest)
if err != nil || string(b) != "ok" {
t.Fatalf("unexpected content: %v %q", err, string(b))
}
}

View File

@@ -0,0 +1,42 @@
package safe
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
)
func TestSafeDownload_OK(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "hello")
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "ok.txt")
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1024, AllowPrivate: true}); err != nil {
t.Fatal(err)
}
b, err := os.ReadFile(dest)
if err != nil || string(b) != "hello" {
t.Fatalf("unexpected content: %v %q", err, string(b))
}
}
func TestSafeDownload_TooLarge(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// 2KiB
_, _ = w.Write(make([]byte, 2048))
}))
defer ts.Close()
dir := t.TempDir()
dest := filepath.Join(dir, "big.bin")
if err := Download(dest, ts.URL, &Options{Timeout: 5 * time.Second, MaxSizeBytes: 1024, AllowPrivate: true}); err == nil {
t.Fatalf("expected ErrSizeExceeded")
}
}

View File

@@ -0,0 +1,47 @@
package safe
import (
"errors"
"os"
"strconv"
"strings"
"time"
)
// Options controls Download behavior.
type Options struct {
Timeout time.Duration
MaxSizeBytes int64
AllowPrivate bool
Accept string
}
var (
// Defaults are tuned for general downloads (not just avatars).
defaultTimeout = 30 * time.Second
defaultMaxSize = int64(200 * 1024 * 1024) // 200 MiB
ErrSchemeNotAllowed = errors.New("invalid scheme (only http/https allowed)")
ErrSizeExceeded = errors.New("response exceeds maximum allowed size")
ErrPrivateIP = errors.New("connection to private or loopback address not allowed")
)
// envInt64 returns an int64 from env or -1 if unset/invalid.
func envInt64(key string) int64 {
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
return n
}
}
return -1
}
// envDuration returns a duration from env seconds or 0 if unset/invalid.
func envDuration(key string) time.Duration {
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
return time.Duration(n) * time.Second
}
}
return 0
}

View File

@@ -0,0 +1,25 @@
/*
Package safe provides a secure HTTP downloader with customizable settings.
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
This program is free software: you can redistribute it and/or modify
it under Version 3 of the GNU Affero General Public License (the "AGPL"):
<https://docs.photoprism.app/license/agpl>
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
The AGPL is supplemented by our Trademark and Brand Guidelines,
which describe how our Brand Assets may be used:
<https://www.photoprism.app/trademark>
Feel free to send an email to hello@photoprism.app if you have questions,
want to support our work, or just want to say hello.
Additional information can be found in our Developer Guide:
<https://docs.photoprism.app/developer-guide/>
*/
package safe

View File

@@ -0,0 +1,52 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "usage: swaggerfix <swagger.json>")
os.Exit(2)
}
path := os.Args[1]
b, err := ioutil.ReadFile(path)
if err != nil {
fmt.Fprintln(os.Stderr, "read:", err)
os.Exit(1)
}
var doc map[string]interface{}
if err := json.Unmarshal(b, &doc); err != nil {
fmt.Fprintln(os.Stderr, "parse:", err)
os.Exit(1)
}
// Traverse to definitions.time.Duration
defs, _ := doc["definitions"].(map[string]interface{})
if defs == nil {
fmt.Fprintln(os.Stderr, "no definitions in swagger file")
os.Exit(1)
}
td, _ := defs["time.Duration"].(map[string]interface{})
if td == nil {
fmt.Fprintln(os.Stderr, "no time.Duration schema found; nothing to do")
os.Exit(0)
}
// Remove unstable enums and varnames to ensure deterministic output.
delete(td, "enum")
delete(td, "x-enum-varnames")
defs["time.Duration"] = td
doc["definitions"] = defs
out, err := json.MarshalIndent(doc, "", " ")
if err != nil {
fmt.Fprintln(os.Stderr, "marshal:", err)
os.Exit(1)
}
if err := ioutil.WriteFile(path, out, 0644); err != nil {
fmt.Fprintln(os.Stderr, "write:", err)
os.Exit(1)
}
}