Config: Change default vision model assets path to assets/models/ #127

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-08-08 19:06:56 +02:00
parent 2b48fe20dd
commit ab0bd1c732
27 changed files with 186 additions and 82 deletions

View File

@@ -11,6 +11,7 @@
/assets/imagenet /assets/imagenet
/assets/resnet /assets/resnet
/assets/vision /assets/vision
/assets/models
/storage /storage
/build /build
/photoprism /photoprism

View File

@@ -73,6 +73,7 @@ pull: docker-pull
test: test-js test-go test: test-js test-go
test-go: reset-sqlite run-test-go test-go: reset-sqlite run-test-go
test-pkg: reset-sqlite run-test-pkg test-pkg: reset-sqlite run-test-pkg
test-ai: reset-sqlite run-test-ai
test-api: reset-sqlite run-test-api test-api: reset-sqlite run-test-api
test-video: reset-sqlite run-test-video test-video: reset-sqlite run-test-video
test-entity: reset-sqlite run-test-entity test-entity: reset-sqlite run-test-entity
@@ -395,6 +396,9 @@ run-test-mariadb:
run-test-pkg: run-test-pkg:
$(info Running all Go tests in "/pkg"...) $(info Running all Go tests in "/pkg"...)
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./pkg/... $(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./pkg/...
run-test-ai:
$(info Running all AI tests...)
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/ai/...
run-test-api: run-test-api:
$(info Running all API tests...) $(info Running all API tests...)
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/api/... $(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/api/...

View File

@@ -1,8 +1,8 @@
examples examples
efficientnet models/_*
imagenet models/dev*
resnet models/local*
vision models/test*
README.md README.md
docs docs
.* .*

2
assets/models/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

View File

@@ -24,8 +24,8 @@ import (
// Model represents a TensorFlow classification model. // Model represents a TensorFlow classification model.
type Model struct { type Model struct {
model *tf.SavedModel model *tf.SavedModel
modelPath string name string
assetsPath string modelsPath string
defaultLabelsPath string defaultLabelsPath string
labels []string labels []string
disabled bool disabled bool
@@ -34,14 +34,14 @@ type Model struct {
} }
// NewModel returns new TensorFlow classification model instance. // NewModel returns new TensorFlow classification model instance.
func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model { func NewModel(modelsPath, name, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
if meta == nil { if meta == nil {
meta = new(tensorflow.ModelInfo) meta = new(tensorflow.ModelInfo)
} }
return &Model{ return &Model{
modelPath: modelPath, name: name,
assetsPath: assetsPath, modelsPath: modelsPath,
defaultLabelsPath: defaultLabelsPath, defaultLabelsPath: defaultLabelsPath,
meta: meta, meta: meta,
disabled: disabled, disabled: disabled,
@@ -49,8 +49,8 @@ func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.
} }
// NewNasnet returns new Nasnet TensorFlow classification model instance. // NewNasnet returns new Nasnet TensorFlow classification model instance.
func NewNasnet(assetsPath string, disabled bool) *Model { func NewNasnet(modelsPath string, disabled bool) *Model {
return NewModel(assetsPath, "nasnet", "", &tensorflow.ModelInfo{ return NewModel(modelsPath, "nasnet", "", &tensorflow.ModelInfo{
TFVersion: "1.12.0", TFVersion: "1.12.0",
Tags: []string{"photoprism"}, Tags: []string{"photoprism"},
Input: &tensorflow.PhotoInput{ Input: &tensorflow.PhotoInput{
@@ -194,7 +194,7 @@ func (m *Model) loadModel() (err error) {
return nil return nil
} }
modelPath := path.Join(m.assetsPath, m.modelPath) modelPath := path.Join(m.modelsPath, m.name)
if len(m.meta.Tags) == 0 { if len(m.meta.Tags) == 0 {
infos, modelErr := tensorflow.GetModelInfo(modelPath) infos, modelErr := tensorflow.GetModelInfo(modelPath)

View File

@@ -259,7 +259,7 @@ func assertContainsAny(t *testing.T, s string, substrings []string) {
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) { func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
testName := func(name string) string { testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name) return fmt.Sprintf("%s/%s", tensorFlow.name, name)
} }
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) { t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
@@ -364,7 +364,7 @@ func testModel_Run(t *testing.T, tensorFlow *Model) {
} }
testName := func(name string) string { testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name) return fmt.Sprintf("%s/%s", tensorFlow.name, name)
} }
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) { t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package classify
import ( import (
"os" "os"
"path/filepath"
"sync" "sync"
"testing" "testing"
@@ -12,14 +13,15 @@ import (
) )
var assetsPath = fs.Abs("../../../assets") var assetsPath = fs.Abs("../../../assets")
var modelPath = assetsPath + "/nasnet" var examplesPath = filepath.Join(assetsPath, "examples")
var examplesPath = assetsPath + "/examples" var modelsPath = filepath.Join(assetsPath, "models")
var modelPath = modelsPath + "/nasnet"
var once sync.Once var once sync.Once
var testInstance *Model var testInstance *Model
func NewModelTest(t *testing.T) *Model { func NewModelTest(t *testing.T) *Model {
once.Do(func() { once.Do(func() {
testInstance = NewNasnet(assetsPath, false) testInstance = NewNasnet(modelsPath, false)
if err := testInstance.loadModel(); err != nil { if err := testInstance.loadModel(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -29,7 +31,7 @@ func NewModelTest(t *testing.T) *Model {
} }
func TestModel_CenterCrop(t *testing.T) { func TestModel_CenterCrop(t *testing.T) {
model := NewNasnet(assetsPath, false) model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil { if err := model.loadModel(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -42,7 +44,7 @@ func TestModel_CenterCrop(t *testing.T) {
} }
func TestModel_Padding(t *testing.T) { func TestModel_Padding(t *testing.T) {
model := NewNasnet(assetsPath, false) model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil { if err := model.loadModel(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -55,7 +57,7 @@ func TestModel_Padding(t *testing.T) {
} }
func TestModel_ResizeBreakAspectRatio(t *testing.T) { func TestModel_ResizeBreakAspectRatio(t *testing.T) {
model := NewNasnet(assetsPath, false) model := NewNasnet(modelsPath, false)
if err := model.loadModel(); err != nil { if err := model.loadModel(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -154,7 +156,7 @@ func TestModel_LabelsFromFile(t *testing.T) {
assert.Empty(t, result) assert.Empty(t, result)
}) })
t.Run("Disabled", func(t *testing.T) { t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, true) tensorFlow := NewNasnet(modelsPath, true)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10) result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err) assert.Nil(t, err)
@@ -253,7 +255,7 @@ func TestModel_Run(t *testing.T) {
} }
}) })
t.Run("Disabled", func(t *testing.T) { t.Run("Disabled", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, true) tensorFlow := NewNasnet(modelsPath, true)
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil { if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err) t.Error(err)
@@ -277,7 +279,7 @@ func TestModel_LoadModel(t *testing.T) {
assert.True(t, tf.ModelLoaded()) assert.True(t, tf.ModelLoaded())
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath+"foo", false) tensorFlow := NewNasnet(modelsPath+"foo", false)
err := tensorFlow.loadModel() err := tensorFlow.loadModel()
if err != nil { if err != nil {
@@ -290,7 +292,7 @@ func TestModel_LoadModel(t *testing.T) {
func TestModel_BestLabels(t *testing.T) { func TestModel_BestLabels(t *testing.T) {
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, false) tensorFlow := NewNasnet(modelsPath, false)
if err := tensorFlow.loadLabels(modelPath); err != nil { if err := tensorFlow.loadLabels(modelPath); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -308,7 +310,7 @@ func TestModel_BestLabels(t *testing.T) {
t.Log(result) t.Log(result)
}) })
t.Run("NotLoaded", func(t *testing.T) { t.Run("NotLoaded", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath, false) tensorFlow := NewNasnet(modelsPath, false)
p := make([]float32, 1000) p := make([]float32, 1000)

View File

@@ -19,7 +19,6 @@ import (
// Model is a wrapper for the TensorFlow Facenet model. // Model is a wrapper for the TensorFlow Facenet model.
type Model struct { type Model struct {
model *tf.SavedModel model *tf.SavedModel
modelName string
modelPath string modelPath string
cachePath string cachePath string
resolution int resolution int
@@ -41,6 +40,7 @@ func NewModel(modelPath, cachePath string, resolution int, meta *tensorflow.Mode
if len(meta.Tags) == 0 { if len(meta.Tags) == 0 {
meta.Tags = []string{"serve"} meta.Tags = []string{"serve"}
} }
return &Model{ return &Model{
modelPath: modelPath, modelPath: modelPath,
cachePath: cachePath, cachePath: cachePath,

View File

@@ -10,7 +10,7 @@ import (
"github.com/photoprism/photoprism/pkg/fs/fastwalk" "github.com/photoprism/photoprism/pkg/fs/fastwalk"
) )
var modelPath, _ = filepath.Abs("../../../assets/facenet") var modelPath, _ = filepath.Abs("../../../assets/models/facenet")
func TestNet(t *testing.T) { func TestNet(t *testing.T) {
expected := map[string]int{ expected := map[string]int{

View File

@@ -11,7 +11,7 @@ import (
"github.com/photoprism/photoprism/pkg/fs/fastwalk" "github.com/photoprism/photoprism/pkg/fs/fastwalk"
) )
var modelPath, _ = filepath.Abs("../../../assets/nsfw") var modelPath, _ = filepath.Abs("../../../assets/models/nsfw")
var detector = NewModel(modelPath, nil, false) var detector = NewModel(modelPath, nil, false)

View File

@@ -2,6 +2,7 @@ package tensorflow
import ( import (
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -15,6 +16,9 @@ var defaultImageInput = &PhotoInput{
Width: 224, Width: 224,
} }
var assetsPath = fs.Abs("../../../assets")
var examplesPath = filepath.Join(assetsPath, "examples")
func TestConvertValue(t *testing.T) { func TestConvertValue(t *testing.T) {
result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1}) result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
assert.Equal(t, float32(3024.8982), result) assert.Equal(t, float32(3024.8982), result)
@@ -29,9 +33,6 @@ func TestConvertStdMean(t *testing.T) {
} }
func TestImageFromBytes(t *testing.T) { func TestImageFromBytes(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("CatJpeg", func(t *testing.T) { t.Run("CatJpeg", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg") imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")

View File

@@ -5,14 +5,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme" "github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
func TestNewApiRequest(t *testing.T) { func TestNewApiRequest(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("Data", func(t *testing.T) { t.Run("Data", func(t *testing.T) {
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"} thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
result, err := NewApiRequestImages(thumbnails, scheme.Data) result, err := NewApiRequestImages(thumbnails, scheme.Data)

View File

@@ -2,17 +2,17 @@ package vision
import ( import (
"net/http" "net/http"
"path/filepath"
"time" "time"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/scheme" "github.com/photoprism/photoprism/pkg/media/http/scheme"
) )
var ( var (
AssetsPath = fs.Abs("../../../assets") CachePath = ""
FaceNetModelPath = fs.Abs("../../../assets/facenet") ModelsPath = ""
NsfwModelPath = fs.Abs("../../../assets/nsfw")
CachePath = fs.Abs("../../../storage/cache")
DownloadUrl = "" DownloadUrl = ""
ServiceUri = "" ServiceUri = ""
ServiceKey = "" ServiceKey = ""
@@ -23,3 +23,65 @@ var (
ServiceResponseFormat = ApiFormatVision ServiceResponseFormat = ApiFormatVision
DefaultResolution = 224 DefaultResolution = 224
) )
// SetCachePath updates the cache path.
func SetCachePath(dir string) {
if dir = fs.Abs(dir); dir == "" {
return
}
CachePath = dir
}
// GetCachePath returns the cache path.
func GetCachePath() string {
if CachePath != "" {
return CachePath
}
CachePath = fs.Abs("../../../storage/cache")
return CachePath
}
// SetModelsPath updates the model assets path.
func SetModelsPath(dir string) {
if dir = fs.Abs(dir); dir == "" {
return
}
ModelsPath = dir
}
// GetModelsPath returns the model assets path, or an empty string if not configured or found.
func GetModelsPath() string {
if ModelsPath != "" {
return ModelsPath
}
assetsPath := fs.Abs("../../../assets")
if dir := filepath.Join(assetsPath, "models"); fs.PathExists(dir) {
ModelsPath = dir
} else if fs.PathExists(assetsPath) {
ModelsPath = assetsPath
}
return ModelsPath
}
func GetModelPath(name string) string {
return filepath.Join(GetModelsPath(), clean.Path(clean.TypeLowerUnderscore(name)))
}
func GetNasnetModelPath() string {
return GetModelPath(NasnetModel.Name)
}
func GetFacenetModelPath() string {
return GetModelPath(FacenetModel.Name)
}
func GetNsfwModelPath() string {
return GetModelPath(NsfwModel.Name)
}

View File

@@ -7,14 +7,10 @@ import (
"github.com/photoprism/photoprism/internal/ai/classify" "github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media" "github.com/photoprism/photoprism/pkg/media"
) )
func TestLabels(t *testing.T) { func TestLabels(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal, entity.SrcAuto) result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal, entity.SrcAuto)

View File

@@ -2,7 +2,6 @@ package vision
import ( import (
"fmt" "fmt"
"path/filepath"
"strings" "strings"
"sync" "sync"
@@ -143,7 +142,7 @@ func (m *Model) ClassifyModel() *classify.Model {
return nil return nil
case NasnetModel.Name, "nasnet": case NasnetModel.Name, "nasnet":
// Load and initialize the Nasnet image classification model. // Load and initialize the Nasnet image classification model.
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil { if model := classify.NewNasnet(GetModelsPath(), m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init nasnet model)", err) log.Errorf("vision: %s (init nasnet model)", err)
@@ -154,7 +153,7 @@ func (m *Model) ClassifyModel() *classify.Model {
default: default:
// Set model path from model name if no path is configured. // Set model path from model name if no path is configured.
if m.Path == "" { if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name) m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
} }
if m.Meta == nil { if m.Meta == nil {
@@ -173,8 +172,7 @@ func (m *Model) ClassifyModel() *classify.Model {
m.Meta.Input.SetResolution(m.Resolution) m.Meta.Input.SetResolution(m.Resolution)
// Try to load custom model based on the configuration values. // Try to load custom model based on the configuration values.
defaultPath := filepath.Join(AssetsPath, "nasnet") if model := classify.NewModel(GetModelsPath(), m.Path, GetNasnetModelPath(), m.Meta, m.Disabled); model == nil {
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -205,7 +203,7 @@ func (m *Model) FaceModel() *face.Model {
return nil return nil
case FacenetModel.Name, "facenet": case FacenetModel.Name, "facenet":
// Load and initialize the Nasnet image classification model. // Load and initialize the Nasnet image classification model.
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta, m.Disabled); model == nil { if model := face.NewModel(GetFacenetModelPath(), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -216,7 +214,7 @@ func (m *Model) FaceModel() *face.Model {
default: default:
// Set model path from model name if no path is configured. // Set model path from model name if no path is configured.
if m.Path == "" { if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name) m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
} }
// Set default thumbnail resolution if no tags are configured. // Set default thumbnail resolution if no tags are configured.
@@ -229,7 +227,7 @@ func (m *Model) FaceModel() *face.Model {
} }
// Try to load custom model based on the configuration values. // Try to load custom model based on the configuration values.
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta, m.Disabled); model == nil { if model := face.NewModel(GetModelPath(m.Path), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -260,7 +258,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
return nil return nil
case NsfwModel.Name, "nsfw": case NsfwModel.Name, "nsfw":
// Load and initialize the Nasnet image classification model. // Load and initialize the Nasnet image classification model.
if model := nsfw.NewModel(NsfwModelPath, NsfwModel.Meta, m.Disabled); model == nil { if model := nsfw.NewModel(GetNsfwModelPath(), NsfwModel.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)
@@ -271,7 +269,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
default: default:
// Set model path from model name if no path is configured. // Set model path from model name if no path is configured.
if m.Path == "" { if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name) m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
} }
// Set default thumbnail resolution if no tags are configured. // Set default thumbnail resolution if no tags are configured.
@@ -290,7 +288,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
} }
// Try to load custom model based on the configuration values. // Try to load custom model based on the configuration values.
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Meta, m.Disabled); model == nil { if model := nsfw.NewModel(GetModelPath(m.Path), m.Meta, m.Disabled); model == nil {
return nil return nil
} else if err := model.Init(); err != nil { } else if err := model.Init(); err != nil {
log.Errorf("vision: %s (init %s)", err, m.Path) log.Errorf("vision: %s (init %s)", err, m.Path)

View File

@@ -2,20 +2,25 @@ package vision
import ( import (
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/photoprism/photoprism/internal/api/download" "github.com/photoprism/photoprism/internal/api/download"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/pkg/fs"
) )
var assetsPath = fs.Abs("../../../assets")
var examplesPath = filepath.Join(assetsPath, "examples")
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Init test logger. // Init test logger.
log = logrus.StandardLogger() log = logrus.StandardLogger()
log.SetLevel(logrus.TraceLevel) log.SetLevel(logrus.TraceLevel)
event.AuditLog = log event.AuditLog = log
download.AllowedPaths = append(download.AllowedPaths, AssetsPath) download.AllowedPaths = append(download.AllowedPaths, assetsPath)
// Set test config values. // Set test config values.
DownloadUrl = "https://app.localssl.dev/api/v1/dl" DownloadUrl = "https://app.localssl.dev/api/v1/dl"

View File

@@ -291,10 +291,8 @@ func (c *Config) Propagate() {
dl.FFprobeBin = c.FFprobeBin() dl.FFprobeBin = c.FFprobeBin()
// Configure computer vision package. // Configure computer vision package.
vision.AssetsPath = c.AssetsPath() vision.SetCachePath(c.CachePath())
vision.FaceNetModelPath = c.FaceNetModelPath() vision.SetModelsPath(c.ModelsPath())
vision.NsfwModelPath = c.NSFWModelPath()
vision.CachePath = c.CachePath()
vision.ServiceUri = c.VisionUri() vision.ServiceUri = c.VisionUri()
vision.ServiceKey = c.VisionKey() vision.ServiceKey = c.VisionKey()
vision.DownloadUrl = c.DownloadUrl() vision.DownloadUrl = c.DownloadUrl()

View File

@@ -195,9 +195,9 @@ func (c *Config) CreateDirectories() error {
return createError(dir, err) return createError(dir, err)
} }
// Create TensorFlow model path if it doesn't exist yet. // Create computer vision models path if it doesn't exist yet.
if dir := c.NasnetModelPath(); dir == "" { if dir := c.ModelsPath(); dir == "" {
return notFoundError("tensorflow model") return notFoundError("models")
} else if err := fs.MkdirAll(dir); err != nil { } else if err := fs.MkdirAll(dir); err != nil {
return createError(dir, err) return createError(dir, err)
} }

View File

@@ -43,19 +43,35 @@ func (c *Config) VisionKey() string {
} }
} }
// ModelsPath returns the path where the machine learning models are located.
func (c *Config) ModelsPath() string {
if c.options.ModelsPath != "" {
return fs.Abs(c.options.ModelsPath)
}
if dir := filepath.Join(c.AssetsPath(), "models"); fs.PathExists(dir) {
c.options.ModelsPath = dir
return c.options.ModelsPath
}
c.options.ModelsPath = fs.FindDir(fs.ModelsPaths)
return c.options.ModelsPath
}
// NasnetModelPath returns the TensorFlow model path. // NasnetModelPath returns the TensorFlow model path.
func (c *Config) NasnetModelPath() string { func (c *Config) NasnetModelPath() string {
return filepath.Join(c.AssetsPath(), "nasnet") return filepath.Join(c.ModelsPath(), "nasnet")
} }
// FaceNetModelPath returns the FaceNet model path. // FacenetModelPath returns the FaceNet model path.
func (c *Config) FaceNetModelPath() string { func (c *Config) FacenetModelPath() string {
return filepath.Join(c.AssetsPath(), "facenet") return filepath.Join(c.ModelsPath(), "facenet")
} }
// NSFWModelPath returns the "not safe for work" TensorFlow model path. // NsfwModelPath returns the "not safe for work" TensorFlow model path.
func (c *Config) NSFWModelPath() string { func (c *Config) NsfwModelPath() string {
return filepath.Join(c.AssetsPath(), "nsfw") return filepath.Join(c.ModelsPath(), "nsfw")
} }
// DetectNSFW checks if NSFW photos should be detected and flagged. // DetectNSFW checks if NSFW photos should be detected and flagged.

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -34,11 +35,12 @@ func TestConfig_VisionKey(t *testing.T) {
assert.Equal(t, "", c.VisionKey()) assert.Equal(t, "", c.VisionKey())
} }
func TestConfig_TensorFlowModelPath(t *testing.T) { func TestConfig_ModelsPath(t *testing.T) {
c := NewConfig(CliTestContext()) c := NewConfig(CliTestContext())
path := c.NasnetModelPath() path := c.NasnetModelPath()
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/nasnet", path) assert.True(t, strings.HasPrefix(path, c.ModelsPath()))
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/models/nasnet", path)
} }
func TestConfig_TensorFlowDisabled(t *testing.T) { func TestConfig_TensorFlowDisabled(t *testing.T) {
@@ -51,13 +53,13 @@ func TestConfig_TensorFlowDisabled(t *testing.T) {
func TestConfig_NSFWModelPath(t *testing.T) { func TestConfig_NSFWModelPath(t *testing.T) {
c := NewConfig(CliTestContext()) c := NewConfig(CliTestContext())
assert.Contains(t, c.NSFWModelPath(), "/assets/nsfw") assert.Contains(t, c.NsfwModelPath(), "/assets/models/nsfw")
} }
func TestConfig_FaceNetModelPath(t *testing.T) { func TestConfig_FaceNetModelPath(t *testing.T) {
c := NewConfig(CliTestContext()) c := NewConfig(CliTestContext())
assert.Contains(t, c.FaceNetModelPath(), "/assets/facenet") assert.Contains(t, c.FacenetModelPath(), "/assets/models/facenet")
} }
func TestConfig_DetectNSFW(t *testing.T) { func TestConfig_DetectNSFW(t *testing.T) {

View File

@@ -298,6 +298,12 @@ var Flags = CliFlags{
EnvVars: EnvVars("ASSETS_PATH"), EnvVars: EnvVars("ASSETS_PATH"),
TakesFile: true, TakesFile: true,
}}, { }}, {
Flag: &cli.PathFlag{
Name: "models-path",
Usage: "custom model assets `PATH` where computer vision models are located",
EnvVars: EnvVars("MODELS_PATH"),
TakesFile: true,
}}, {
Flag: &cli.PathFlag{ Flag: &cli.PathFlag{
Name: "sidecar-path", Name: "sidecar-path",
Aliases: []string{"sc"}, Aliases: []string{"sc"},

View File

@@ -77,6 +77,7 @@ type Options struct {
TempPath string `yaml:"TempPath" json:"-" flag:"temp-path"` TempPath string `yaml:"TempPath" json:"-" flag:"temp-path"`
AssetsPath string `yaml:"AssetsPath" json:"-" flag:"assets-path"` AssetsPath string `yaml:"AssetsPath" json:"-" flag:"assets-path"`
CustomAssetsPath string `yaml:"-" json:"-" flag:"custom-assets-path" tags:"plus,pro"` CustomAssetsPath string `yaml:"-" json:"-" flag:"custom-assets-path" tags:"plus,pro"`
ModelsPath string `yaml:"ModelsPath" json:"-" flag:"models-path"`
SidecarPath string `yaml:"SidecarPath" json:"-" flag:"sidecar-path"` SidecarPath string `yaml:"SidecarPath" json:"-" flag:"sidecar-path"`
SidecarYaml bool `yaml:"SidecarYaml" json:"SidecarYaml" flag:"sidecar-yaml" default:"true"` SidecarYaml bool `yaml:"SidecarYaml" json:"SidecarYaml" flag:"sidecar-yaml" default:"true"`
UsageInfo bool `yaml:"UsageInfo" json:"UsageInfo" flag:"usage-info"` UsageInfo bool `yaml:"UsageInfo" json:"UsageInfo" flag:"usage-info"`

View File

@@ -76,6 +76,7 @@ func (c *Config) Report() (rows [][]string, cols []string) {
{"thumb-cache-path", c.ThumbCachePath()}, {"thumb-cache-path", c.ThumbCachePath()},
{"temp-path", c.TempPath()}, {"temp-path", c.TempPath()},
{"assets-path", c.AssetsPath()}, {"assets-path", c.AssetsPath()},
{"models-path", c.ModelsPath()},
{"static-path", c.StaticPath()}, {"static-path", c.StaticPath()},
{"build-path", c.BuildPath()}, {"build-path", c.BuildPath()},
{"img-path", c.ImgPath()}, {"img-path", c.ImgPath()},
@@ -256,8 +257,8 @@ func (c *Config) Report() (rows [][]string, cols []string) {
{"vision-uri", c.VisionUri()}, {"vision-uri", c.VisionUri()},
{"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))}, {"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))},
{"nasnet-model-path", c.NasnetModelPath()}, {"nasnet-model-path", c.NasnetModelPath()},
{"facenet-model-path", c.FaceNetModelPath()}, {"facenet-model-path", c.FacenetModelPath()},
{"nsfw-model-path", c.NSFWModelPath()}, {"nsfw-model-path", c.NsfwModelPath()},
{"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())}, {"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())},
// Facial Recognition. // Facial Recognition.

View File

@@ -120,6 +120,16 @@ var AssetPaths = []string{
"/var/lib/photoprism/assets", "/var/lib/photoprism/assets",
} }
var ModelsPaths = []string{
"/opt/photoprism/assets/models",
"/photoprism/assets/models",
"~/.photoprism/assets/models",
"~/photoprism/assets/models",
"photoprism/assets/models",
"assets/models",
"/var/lib/photoprism/assets/models",
}
// Dirs returns a slice of directories in a path, optional recursively and with symlinks. // Dirs returns a slice of directories in a path, optional recursively and with symlinks.
// //
// Warning: Following symlinks can make the result non-deterministic and hard to test! // Warning: Following symlinks can make the result non-deterministic and hard to test!

View File

@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
MODEL_NAME="Facenet" MODEL_NAME="Facenet"
MODEL_URL="https://dl.photoprism.app/tensorflow/facenet.zip?$TODAY" MODEL_URL="https://dl.photoprism.app/tensorflow/facenet.zip?$TODAY"
MODEL_PATH="assets/facenet" MODELS_PATH="assets/models"
MODEL_PATH="$MODELS_PATH/facenet"
MODEL_ZIP="/tmp/photoprism/facenet.zip" MODEL_ZIP="/tmp/photoprism/facenet.zip"
MODEL_HASH="0492eb1d67789108b7eefb274e26633504b059be $MODEL_ZIP" MODEL_HASH="0492eb1d67789108b7eefb274e26633504b059be $MODEL_ZIP"
MODEL_VERSION="$MODEL_PATH/version.txt" MODEL_VERSION="$MODEL_PATH/version.txt"
@@ -17,7 +18,7 @@ mkdir -p /tmp/photoprism
mkdir -p storage/backup mkdir -p storage/backup
# Check for update # Check for update
if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == ${MODEL_HASH} ]]; then if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == "${MODEL_HASH}" ]]; then
if [[ -f ${MODEL_VERSION} ]]; then if [[ -f ${MODEL_VERSION} ]]; then
echo "Already up to date." echo "Already up to date."
exit exit
@@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then
fi fi
# Unzip model # Unzip model
unzip ${MODEL_ZIP} -d assets unzip ${MODEL_ZIP} -d "$MODELS_PATH"
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
echo "Latest $MODEL_NAME installed." echo "Latest $MODEL_NAME installed."

View File

@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
MODEL_NAME="NASNet Mobile" MODEL_NAME="NASNet Mobile"
MODEL_URL="https://dl.photoprism.app/tensorflow/nasnet.zip?$TODAY" MODEL_URL="https://dl.photoprism.app/tensorflow/nasnet.zip?$TODAY"
MODEL_PATH="assets/nasnet" MODELS_PATH="assets/models"
MODEL_PATH="$MODELS_PATH/nasnet"
MODEL_ZIP="/tmp/photoprism/nasnet.zip" MODEL_ZIP="/tmp/photoprism/nasnet.zip"
MODEL_HASH="f18b801354e95cade497b4f12e8d2537d04c04f6 $MODEL_ZIP" MODEL_HASH="f18b801354e95cade497b4f12e8d2537d04c04f6 $MODEL_ZIP"
MODEL_VERSION="$MODEL_PATH/version.txt" MODEL_VERSION="$MODEL_PATH/version.txt"
@@ -41,7 +42,7 @@ if [[ -e ${MODEL_PATH} ]]; then
fi fi
# Unzip model # Unzip model
unzip ${MODEL_ZIP} -d assets unzip ${MODEL_ZIP} -d "$MODELS_PATH"
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
wget --inet4-only -c "${MODEL_21K_LABELS_URL}" -O ${MODEL_PATH}/labels21k.txt wget --inet4-only -c "${MODEL_21K_LABELS_URL}" -O ${MODEL_PATH}/labels21k.txt

View File

@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
MODEL_NAME="NSFW" MODEL_NAME="NSFW"
MODEL_URL="https://dl.photoprism.app/tensorflow/nsfw.zip?$TODAY" MODEL_URL="https://dl.photoprism.app/tensorflow/nsfw.zip?$TODAY"
MODEL_PATH="assets/nsfw" MODELS_PATH="assets/models"
MODEL_PATH="$MODELS_PATH/nsfw"
MODEL_ZIP="/tmp/photoprism/nsfw.zip" MODEL_ZIP="/tmp/photoprism/nsfw.zip"
MODEL_HASH="2e03ad3c6aec27c270c650d0574ff2a6291d992b $MODEL_ZIP" MODEL_HASH="2e03ad3c6aec27c270c650d0574ff2a6291d992b $MODEL_ZIP"
MODEL_VERSION="$MODEL_PATH/version.txt" MODEL_VERSION="$MODEL_PATH/version.txt"
@@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then
fi fi
# Unzip model # Unzip model
unzip ${MODEL_ZIP} -d assets unzip ${MODEL_ZIP} -d "$MODELS_PATH"
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
echo "Latest $MODEL_NAME installed." echo "Latest $MODEL_NAME installed."