diff --git a/.dockerignore b/.dockerignore index 04f90c210..561592e57 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,6 +11,7 @@ /assets/imagenet /assets/resnet /assets/vision +/assets/models /storage /build /photoprism diff --git a/Makefile b/Makefile index ee35950e5..9ddd431cb 100644 --- a/Makefile +++ b/Makefile @@ -73,6 +73,7 @@ pull: docker-pull test: test-js test-go test-go: reset-sqlite run-test-go test-pkg: reset-sqlite run-test-pkg +test-ai: reset-sqlite run-test-ai test-api: reset-sqlite run-test-api test-video: reset-sqlite run-test-video test-entity: reset-sqlite run-test-entity @@ -395,6 +396,9 @@ run-test-mariadb: run-test-pkg: $(info Running all Go tests in "/pkg"...) $(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./pkg/... +run-test-ai: + $(info Running all AI tests...) + $(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/ai/... run-test-api: $(info Running all API tests...) $(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/api/... diff --git a/assets/.buildignore b/assets/.buildignore index ce0f694df..4ba02a2bf 100644 --- a/assets/.buildignore +++ b/assets/.buildignore @@ -1,8 +1,8 @@ examples -efficientnet -imagenet -resnet -vision +models/_* +models/dev* +models/local* +models/test* README.md docs .* \ No newline at end of file diff --git a/assets/models/.gitignore b/assets/models/.gitignore new file mode 100644 index 000000000..c96a04f00 --- /dev/null +++ b/assets/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/internal/ai/classify/model.go b/internal/ai/classify/model.go index 0adc01eaf..37b949bca 100644 --- a/internal/ai/classify/model.go +++ b/internal/ai/classify/model.go @@ -24,8 +24,8 @@ import ( // Model represents a TensorFlow classification model. type Model struct { model *tf.SavedModel - modelPath string - assetsPath string + name string + modelsPath string defaultLabelsPath string labels []string disabled bool @@ -34,14 +34,14 @@ type Model struct { } // NewModel returns new TensorFlow classification model instance. -func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model { +func NewModel(modelsPath, name, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model { if meta == nil { meta = new(tensorflow.ModelInfo) } return &Model{ - modelPath: modelPath, - assetsPath: assetsPath, + name: name, + modelsPath: modelsPath, defaultLabelsPath: defaultLabelsPath, meta: meta, disabled: disabled, @@ -49,8 +49,8 @@ func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow. } // NewNasnet returns new Nasnet TensorFlow classification model instance. -func NewNasnet(assetsPath string, disabled bool) *Model { - return NewModel(assetsPath, "nasnet", "", &tensorflow.ModelInfo{ +func NewNasnet(modelsPath string, disabled bool) *Model { + return NewModel(modelsPath, "nasnet", "", &tensorflow.ModelInfo{ TFVersion: "1.12.0", Tags: []string{"photoprism"}, Input: &tensorflow.PhotoInput{ @@ -194,7 +194,7 @@ func (m *Model) loadModel() (err error) { return nil } - modelPath := path.Join(m.assetsPath, m.modelPath) + modelPath := path.Join(m.modelsPath, m.name) if len(m.meta.Tags) == 0 { infos, modelErr := tensorflow.GetModelInfo(modelPath) diff --git a/internal/ai/classify/model_external_test.go b/internal/ai/classify/model_external_test.go index 8c203116a..e9b75b9d8 100644 --- a/internal/ai/classify/model_external_test.go +++ b/internal/ai/classify/model_external_test.go @@ -259,7 +259,7 @@ func assertContainsAny(t *testing.T, s string, substrings []string) { func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) { testName := func(name string) string { - return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name) + return fmt.Sprintf("%s/%s", tensorFlow.name, name) } t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) { @@ -364,7 +364,7 @@ func testModel_Run(t *testing.T, tensorFlow *Model) { } testName := func(name string) string { - return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name) + return fmt.Sprintf("%s/%s", tensorFlow.name, name) } t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) { diff --git a/internal/ai/classify/model_test.go b/internal/ai/classify/model_test.go index 092a225e3..93f86520b 100644 --- a/internal/ai/classify/model_test.go +++ b/internal/ai/classify/model_test.go @@ -2,6 +2,7 @@ package classify import ( "os" + "path/filepath" "sync" "testing" @@ -12,14 +13,15 @@ import ( ) var assetsPath = fs.Abs("../../../assets") -var modelPath = assetsPath + "/nasnet" -var examplesPath = assetsPath + "/examples" +var examplesPath = filepath.Join(assetsPath, "examples") +var modelsPath = filepath.Join(assetsPath, "models") +var modelPath = modelsPath + "/nasnet" var once sync.Once var testInstance *Model func NewModelTest(t *testing.T) *Model { once.Do(func() { - testInstance = NewNasnet(assetsPath, false) + testInstance = NewNasnet(modelsPath, false) if err := testInstance.loadModel(); err != nil { t.Fatal(err) } @@ -29,7 +31,7 @@ func NewModelTest(t *testing.T) *Model { } func TestModel_CenterCrop(t *testing.T) { - model := NewNasnet(assetsPath, false) + model := NewNasnet(modelsPath, false) if err := model.loadModel(); err != nil { t.Fatal(err) } @@ -42,7 +44,7 @@ func TestModel_CenterCrop(t *testing.T) { } func TestModel_Padding(t *testing.T) { - model := NewNasnet(assetsPath, false) + model := NewNasnet(modelsPath, false) if err := model.loadModel(); err != nil { t.Fatal(err) } @@ -55,7 +57,7 @@ func TestModel_Padding(t *testing.T) { } func TestModel_ResizeBreakAspectRatio(t *testing.T) { - model := NewNasnet(assetsPath, false) + model := NewNasnet(modelsPath, false) if err := model.loadModel(); err != nil { t.Fatal(err) } @@ -154,7 +156,7 @@ func TestModel_LabelsFromFile(t *testing.T) { assert.Empty(t, result) }) t.Run("Disabled", func(t *testing.T) { - tensorFlow := NewNasnet(assetsPath, true) + tensorFlow := NewNasnet(modelsPath, true) result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10) assert.Nil(t, err) @@ -253,7 +255,7 @@ func TestModel_Run(t *testing.T) { } }) t.Run("Disabled", func(t *testing.T) { - tensorFlow := NewNasnet(assetsPath, true) + tensorFlow := NewNasnet(modelsPath, true) if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil { t.Error(err) @@ -277,7 +279,7 @@ func TestModel_LoadModel(t *testing.T) { assert.True(t, tf.ModelLoaded()) }) t.Run("NotFound", func(t *testing.T) { - tensorFlow := NewNasnet(assetsPath+"foo", false) + tensorFlow := NewNasnet(modelsPath+"foo", false) err := tensorFlow.loadModel() if err != nil { @@ -290,7 +292,7 @@ func TestModel_LoadModel(t *testing.T) { func TestModel_BestLabels(t *testing.T) { t.Run("Success", func(t *testing.T) { - tensorFlow := NewNasnet(assetsPath, false) + tensorFlow := NewNasnet(modelsPath, false) if err := tensorFlow.loadLabels(modelPath); err != nil { t.Fatal(err) @@ -308,7 +310,7 @@ func TestModel_BestLabels(t *testing.T) { t.Log(result) }) t.Run("NotLoaded", func(t *testing.T) { - tensorFlow := NewNasnet(assetsPath, false) + tensorFlow := NewNasnet(modelsPath, false) p := make([]float32, 1000) diff --git a/internal/ai/face/model.go b/internal/ai/face/model.go index fe8b0d8d3..951c5ecc2 100644 --- a/internal/ai/face/model.go +++ b/internal/ai/face/model.go @@ -19,7 +19,6 @@ import ( // Model is a wrapper for the TensorFlow Facenet model. type Model struct { model *tf.SavedModel - modelName string modelPath string cachePath string resolution int @@ -41,6 +40,7 @@ func NewModel(modelPath, cachePath string, resolution int, meta *tensorflow.Mode if len(meta.Tags) == 0 { meta.Tags = []string{"serve"} } + return &Model{ modelPath: modelPath, cachePath: cachePath, diff --git a/internal/ai/face/model_test.go b/internal/ai/face/model_test.go index dce5ca6cf..d0aab58be 100644 --- a/internal/ai/face/model_test.go +++ b/internal/ai/face/model_test.go @@ -10,7 +10,7 @@ import ( "github.com/photoprism/photoprism/pkg/fs/fastwalk" ) -var modelPath, _ = filepath.Abs("../../../assets/facenet") +var modelPath, _ = filepath.Abs("../../../assets/models/facenet") func TestNet(t *testing.T) { expected := map[string]int{ diff --git a/internal/ai/nsfw/nsfw_test.go b/internal/ai/nsfw/nsfw_test.go index 41959a050..75f61f2f8 100644 --- a/internal/ai/nsfw/nsfw_test.go +++ b/internal/ai/nsfw/nsfw_test.go @@ -11,7 +11,7 @@ import ( "github.com/photoprism/photoprism/pkg/fs/fastwalk" ) -var modelPath, _ = filepath.Abs("../../../assets/nsfw") +var modelPath, _ = filepath.Abs("../../../assets/models/nsfw") var detector = NewModel(modelPath, nil, false) diff --git a/internal/ai/tensorflow/image_test.go b/internal/ai/tensorflow/image_test.go index 20ebff977..026137cd6 100644 --- a/internal/ai/tensorflow/image_test.go +++ b/internal/ai/tensorflow/image_test.go @@ -2,6 +2,7 @@ package tensorflow import ( "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -15,6 +16,9 @@ var defaultImageInput = &PhotoInput{ Width: 224, } +var assetsPath = fs.Abs("../../../assets") +var examplesPath = filepath.Join(assetsPath, "examples") + func TestConvertValue(t *testing.T) { result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1}) assert.Equal(t, float32(3024.8982), result) @@ -29,9 +33,6 @@ func TestConvertStdMean(t *testing.T) { } func TestImageFromBytes(t *testing.T) { - var assetsPath = fs.Abs("../../../assets") - var examplesPath = assetsPath + "/examples" - t.Run("CatJpeg", func(t *testing.T) { imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg") diff --git a/internal/ai/vision/api_client_test.go b/internal/ai/vision/api_client_test.go index f41a7e716..770b8b779 100644 --- a/internal/ai/vision/api_client_test.go +++ b/internal/ai/vision/api_client_test.go @@ -5,14 +5,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/media/http/scheme" ) func TestNewApiRequest(t *testing.T) { - var assetsPath = fs.Abs("../../../assets") - var examplesPath = assetsPath + "/examples" - t.Run("Data", func(t *testing.T) { thumbnails := Files{examplesPath + "/chameleon_lime.jpg"} result, err := NewApiRequestImages(thumbnails, scheme.Data) diff --git a/internal/ai/vision/config.go b/internal/ai/vision/config.go index ed312ed77..ba25c1d5b 100644 --- a/internal/ai/vision/config.go +++ b/internal/ai/vision/config.go @@ -2,17 +2,17 @@ package vision import ( "net/http" + "path/filepath" "time" + "github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/media/http/scheme" ) var ( - AssetsPath = fs.Abs("../../../assets") - FaceNetModelPath = fs.Abs("../../../assets/facenet") - NsfwModelPath = fs.Abs("../../../assets/nsfw") - CachePath = fs.Abs("../../../storage/cache") + CachePath = "" + ModelsPath = "" DownloadUrl = "" ServiceUri = "" ServiceKey = "" @@ -23,3 +23,65 @@ var ( ServiceResponseFormat = ApiFormatVision DefaultResolution = 224 ) + +// SetCachePath updates the cache path. +func SetCachePath(dir string) { + if dir = fs.Abs(dir); dir == "" { + return + } + + CachePath = dir +} + +// GetCachePath returns the cache path. +func GetCachePath() string { + if CachePath != "" { + return CachePath + } + + CachePath = fs.Abs("../../../storage/cache") + + return CachePath +} + +// SetModelsPath updates the model assets path. +func SetModelsPath(dir string) { + if dir = fs.Abs(dir); dir == "" { + return + } + + ModelsPath = dir +} + +// GetModelsPath returns the model assets path, or an empty string if not configured or found. +func GetModelsPath() string { + if ModelsPath != "" { + return ModelsPath + } + + assetsPath := fs.Abs("../../../assets") + + if dir := filepath.Join(assetsPath, "models"); fs.PathExists(dir) { + ModelsPath = dir + } else if fs.PathExists(assetsPath) { + ModelsPath = assetsPath + } + + return ModelsPath +} + +func GetModelPath(name string) string { + return filepath.Join(GetModelsPath(), clean.Path(clean.TypeLowerUnderscore(name))) +} + +func GetNasnetModelPath() string { + return GetModelPath(NasnetModel.Name) +} + +func GetFacenetModelPath() string { + return GetModelPath(FacenetModel.Name) +} + +func GetNsfwModelPath() string { + return GetModelPath(NsfwModel.Name) +} diff --git a/internal/ai/vision/labels_test.go b/internal/ai/vision/labels_test.go index d68418a0b..540018607 100644 --- a/internal/ai/vision/labels_test.go +++ b/internal/ai/vision/labels_test.go @@ -7,14 +7,10 @@ import ( "github.com/photoprism/photoprism/internal/ai/classify" "github.com/photoprism/photoprism/internal/entity" - "github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/media" ) func TestLabels(t *testing.T) { - var assetsPath = fs.Abs("../../../assets") - var examplesPath = assetsPath + "/examples" - t.Run("Success", func(t *testing.T) { result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal, entity.SrcAuto) diff --git a/internal/ai/vision/model.go b/internal/ai/vision/model.go index ee11e9159..9f6a963bc 100644 --- a/internal/ai/vision/model.go +++ b/internal/ai/vision/model.go @@ -2,7 +2,6 @@ package vision import ( "fmt" - "path/filepath" "strings" "sync" @@ -143,7 +142,7 @@ func (m *Model) ClassifyModel() *classify.Model { return nil case NasnetModel.Name, "nasnet": // Load and initialize the Nasnet image classification model. - if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil { + if model := classify.NewNasnet(GetModelsPath(), m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init nasnet model)", err) @@ -154,7 +153,7 @@ func (m *Model) ClassifyModel() *classify.Model { default: // Set model path from model name if no path is configured. if m.Path == "" { - m.Path = clean.TypeLowerUnderscore(m.Name) + m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name)) } if m.Meta == nil { @@ -173,8 +172,7 @@ func (m *Model) ClassifyModel() *classify.Model { m.Meta.Input.SetResolution(m.Resolution) // Try to load custom model based on the configuration values. - defaultPath := filepath.Join(AssetsPath, "nasnet") - if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil { + if model := classify.NewModel(GetModelsPath(), m.Path, GetNasnetModelPath(), m.Meta, m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init %s)", err, m.Path) @@ -205,7 +203,7 @@ func (m *Model) FaceModel() *face.Model { return nil case FacenetModel.Name, "facenet": // Load and initialize the Nasnet image classification model. - if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta, m.Disabled); model == nil { + if model := face.NewModel(GetFacenetModelPath(), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init %s)", err, m.Path) @@ -216,7 +214,7 @@ func (m *Model) FaceModel() *face.Model { default: // Set model path from model name if no path is configured. if m.Path == "" { - m.Path = clean.TypeLowerUnderscore(m.Name) + m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name)) } // Set default thumbnail resolution if no tags are configured. @@ -229,7 +227,7 @@ func (m *Model) FaceModel() *face.Model { } // Try to load custom model based on the configuration values. - if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta, m.Disabled); model == nil { + if model := face.NewModel(GetModelPath(m.Path), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init %s)", err, m.Path) @@ -260,7 +258,7 @@ func (m *Model) NsfwModel() *nsfw.Model { return nil case NsfwModel.Name, "nsfw": // Load and initialize the Nasnet image classification model. - if model := nsfw.NewModel(NsfwModelPath, NsfwModel.Meta, m.Disabled); model == nil { + if model := nsfw.NewModel(GetNsfwModelPath(), NsfwModel.Meta, m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init %s)", err, m.Path) @@ -271,7 +269,7 @@ func (m *Model) NsfwModel() *nsfw.Model { default: // Set model path from model name if no path is configured. if m.Path == "" { - m.Path = clean.TypeLowerUnderscore(m.Name) + m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name)) } // Set default thumbnail resolution if no tags are configured. @@ -290,7 +288,7 @@ func (m *Model) NsfwModel() *nsfw.Model { } // Try to load custom model based on the configuration values. - if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Meta, m.Disabled); model == nil { + if model := nsfw.NewModel(GetModelPath(m.Path), m.Meta, m.Disabled); model == nil { return nil } else if err := model.Init(); err != nil { log.Errorf("vision: %s (init %s)", err, m.Path) diff --git a/internal/ai/vision/vision_test.go b/internal/ai/vision/vision_test.go index 717e98e6e..8e84c3b03 100644 --- a/internal/ai/vision/vision_test.go +++ b/internal/ai/vision/vision_test.go @@ -2,20 +2,25 @@ package vision import ( "os" + "path/filepath" "testing" "github.com/sirupsen/logrus" "github.com/photoprism/photoprism/internal/api/download" "github.com/photoprism/photoprism/internal/event" + "github.com/photoprism/photoprism/pkg/fs" ) +var assetsPath = fs.Abs("../../../assets") +var examplesPath = filepath.Join(assetsPath, "examples") + func TestMain(m *testing.M) { // Init test logger. log = logrus.StandardLogger() log.SetLevel(logrus.TraceLevel) event.AuditLog = log - download.AllowedPaths = append(download.AllowedPaths, AssetsPath) + download.AllowedPaths = append(download.AllowedPaths, assetsPath) // Set test config values. DownloadUrl = "https://app.localssl.dev/api/v1/dl" diff --git a/internal/config/config.go b/internal/config/config.go index e37edbd78..2a891797a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -291,10 +291,8 @@ func (c *Config) Propagate() { dl.FFprobeBin = c.FFprobeBin() // Configure computer vision package. - vision.AssetsPath = c.AssetsPath() - vision.FaceNetModelPath = c.FaceNetModelPath() - vision.NsfwModelPath = c.NSFWModelPath() - vision.CachePath = c.CachePath() + vision.SetCachePath(c.CachePath()) + vision.SetModelsPath(c.ModelsPath()) vision.ServiceUri = c.VisionUri() vision.ServiceKey = c.VisionKey() vision.DownloadUrl = c.DownloadUrl() diff --git a/internal/config/config_storage.go b/internal/config/config_storage.go index 183e26abb..b5b588976 100644 --- a/internal/config/config_storage.go +++ b/internal/config/config_storage.go @@ -195,9 +195,9 @@ func (c *Config) CreateDirectories() error { return createError(dir, err) } - // Create TensorFlow model path if it doesn't exist yet. - if dir := c.NasnetModelPath(); dir == "" { - return notFoundError("tensorflow model") + // Create computer vision models path if it doesn't exist yet. + if dir := c.ModelsPath(); dir == "" { + return notFoundError("models") } else if err := fs.MkdirAll(dir); err != nil { return createError(dir, err) } diff --git a/internal/config/config_vision.go b/internal/config/config_vision.go index fb37b19c8..76f5bf9a6 100644 --- a/internal/config/config_vision.go +++ b/internal/config/config_vision.go @@ -43,19 +43,35 @@ func (c *Config) VisionKey() string { } } +// ModelsPath returns the path where the machine learning models are located. +func (c *Config) ModelsPath() string { + if c.options.ModelsPath != "" { + return fs.Abs(c.options.ModelsPath) + } + + if dir := filepath.Join(c.AssetsPath(), "models"); fs.PathExists(dir) { + c.options.ModelsPath = dir + return c.options.ModelsPath + } + + c.options.ModelsPath = fs.FindDir(fs.ModelsPaths) + + return c.options.ModelsPath +} + // NasnetModelPath returns the TensorFlow model path. func (c *Config) NasnetModelPath() string { - return filepath.Join(c.AssetsPath(), "nasnet") + return filepath.Join(c.ModelsPath(), "nasnet") } -// FaceNetModelPath returns the FaceNet model path. -func (c *Config) FaceNetModelPath() string { - return filepath.Join(c.AssetsPath(), "facenet") +// FacenetModelPath returns the FaceNet model path. +func (c *Config) FacenetModelPath() string { + return filepath.Join(c.ModelsPath(), "facenet") } -// NSFWModelPath returns the "not safe for work" TensorFlow model path. -func (c *Config) NSFWModelPath() string { - return filepath.Join(c.AssetsPath(), "nsfw") +// NsfwModelPath returns the "not safe for work" TensorFlow model path. +func (c *Config) NsfwModelPath() string { + return filepath.Join(c.ModelsPath(), "nsfw") } // DetectNSFW checks if NSFW photos should be detected and flagged. diff --git a/internal/config/config_vision_test.go b/internal/config/config_vision_test.go index e7c70209e..36e10fcdb 100644 --- a/internal/config/config_vision_test.go +++ b/internal/config/config_vision_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -34,11 +35,12 @@ func TestConfig_VisionKey(t *testing.T) { assert.Equal(t, "", c.VisionKey()) } -func TestConfig_TensorFlowModelPath(t *testing.T) { +func TestConfig_ModelsPath(t *testing.T) { c := NewConfig(CliTestContext()) path := c.NasnetModelPath() - assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/nasnet", path) + assert.True(t, strings.HasPrefix(path, c.ModelsPath())) + assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/models/nasnet", path) } func TestConfig_TensorFlowDisabled(t *testing.T) { @@ -51,13 +53,13 @@ func TestConfig_TensorFlowDisabled(t *testing.T) { func TestConfig_NSFWModelPath(t *testing.T) { c := NewConfig(CliTestContext()) - assert.Contains(t, c.NSFWModelPath(), "/assets/nsfw") + assert.Contains(t, c.NsfwModelPath(), "/assets/models/nsfw") } func TestConfig_FaceNetModelPath(t *testing.T) { c := NewConfig(CliTestContext()) - assert.Contains(t, c.FaceNetModelPath(), "/assets/facenet") + assert.Contains(t, c.FacenetModelPath(), "/assets/models/facenet") } func TestConfig_DetectNSFW(t *testing.T) { diff --git a/internal/config/flags.go b/internal/config/flags.go index 2ce30ba64..d8ed81c51 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -298,6 +298,12 @@ var Flags = CliFlags{ EnvVars: EnvVars("ASSETS_PATH"), TakesFile: true, }}, { + Flag: &cli.PathFlag{ + Name: "models-path", + Usage: "custom model assets `PATH` where computer vision models are located", + EnvVars: EnvVars("MODELS_PATH"), + TakesFile: true, + }}, { Flag: &cli.PathFlag{ Name: "sidecar-path", Aliases: []string{"sc"}, diff --git a/internal/config/options.go b/internal/config/options.go index a6793284b..b61c09359 100644 --- a/internal/config/options.go +++ b/internal/config/options.go @@ -77,6 +77,7 @@ type Options struct { TempPath string `yaml:"TempPath" json:"-" flag:"temp-path"` AssetsPath string `yaml:"AssetsPath" json:"-" flag:"assets-path"` CustomAssetsPath string `yaml:"-" json:"-" flag:"custom-assets-path" tags:"plus,pro"` + ModelsPath string `yaml:"ModelsPath" json:"-" flag:"models-path"` SidecarPath string `yaml:"SidecarPath" json:"-" flag:"sidecar-path"` SidecarYaml bool `yaml:"SidecarYaml" json:"SidecarYaml" flag:"sidecar-yaml" default:"true"` UsageInfo bool `yaml:"UsageInfo" json:"UsageInfo" flag:"usage-info"` diff --git a/internal/config/report.go b/internal/config/report.go index a53891834..055d9cbc2 100644 --- a/internal/config/report.go +++ b/internal/config/report.go @@ -76,6 +76,7 @@ func (c *Config) Report() (rows [][]string, cols []string) { {"thumb-cache-path", c.ThumbCachePath()}, {"temp-path", c.TempPath()}, {"assets-path", c.AssetsPath()}, + {"models-path", c.ModelsPath()}, {"static-path", c.StaticPath()}, {"build-path", c.BuildPath()}, {"img-path", c.ImgPath()}, @@ -256,8 +257,8 @@ func (c *Config) Report() (rows [][]string, cols []string) { {"vision-uri", c.VisionUri()}, {"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))}, {"nasnet-model-path", c.NasnetModelPath()}, - {"facenet-model-path", c.FaceNetModelPath()}, - {"nsfw-model-path", c.NSFWModelPath()}, + {"facenet-model-path", c.FacenetModelPath()}, + {"nsfw-model-path", c.NsfwModelPath()}, {"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())}, // Facial Recognition. diff --git a/pkg/fs/directories.go b/pkg/fs/directories.go index d96c83833..d3ac7993f 100644 --- a/pkg/fs/directories.go +++ b/pkg/fs/directories.go @@ -120,6 +120,16 @@ var AssetPaths = []string{ "/var/lib/photoprism/assets", } +var ModelsPaths = []string{ + "/opt/photoprism/assets/models", + "/photoprism/assets/models", + "~/.photoprism/assets/models", + "~/photoprism/assets/models", + "photoprism/assets/models", + "assets/models", + "/var/lib/photoprism/assets/models", +} + // Dirs returns a slice of directories in a path, optional recursively and with symlinks. // // Warning: Following symlinks can make the result non-deterministic and hard to test! diff --git a/scripts/download-facenet.sh b/scripts/download-facenet.sh index cde93b402..93c975226 100755 --- a/scripts/download-facenet.sh +++ b/scripts/download-facenet.sh @@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d) MODEL_NAME="Facenet" MODEL_URL="https://dl.photoprism.app/tensorflow/facenet.zip?$TODAY" -MODEL_PATH="assets/facenet" +MODELS_PATH="assets/models" +MODEL_PATH="$MODELS_PATH/facenet" MODEL_ZIP="/tmp/photoprism/facenet.zip" MODEL_HASH="0492eb1d67789108b7eefb274e26633504b059be $MODEL_ZIP" MODEL_VERSION="$MODEL_PATH/version.txt" @@ -17,7 +18,7 @@ mkdir -p /tmp/photoprism mkdir -p storage/backup # Check for update -if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == ${MODEL_HASH} ]]; then +if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == "${MODEL_HASH}" ]]; then if [[ -f ${MODEL_VERSION} ]]; then echo "Already up to date." exit @@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then fi # Unzip model -unzip ${MODEL_ZIP} -d assets +unzip ${MODEL_ZIP} -d "$MODELS_PATH" echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} echo "Latest $MODEL_NAME installed." diff --git a/scripts/download-nasnet.sh b/scripts/download-nasnet.sh index ce423cf73..2d8569964 100755 --- a/scripts/download-nasnet.sh +++ b/scripts/download-nasnet.sh @@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d) MODEL_NAME="NASNet Mobile" MODEL_URL="https://dl.photoprism.app/tensorflow/nasnet.zip?$TODAY" -MODEL_PATH="assets/nasnet" +MODELS_PATH="assets/models" +MODEL_PATH="$MODELS_PATH/nasnet" MODEL_ZIP="/tmp/photoprism/nasnet.zip" MODEL_HASH="f18b801354e95cade497b4f12e8d2537d04c04f6 $MODEL_ZIP" MODEL_VERSION="$MODEL_PATH/version.txt" @@ -41,7 +42,7 @@ if [[ -e ${MODEL_PATH} ]]; then fi # Unzip model -unzip ${MODEL_ZIP} -d assets +unzip ${MODEL_ZIP} -d "$MODELS_PATH" echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} wget --inet4-only -c "${MODEL_21K_LABELS_URL}" -O ${MODEL_PATH}/labels21k.txt diff --git a/scripts/download-nsfw.sh b/scripts/download-nsfw.sh index 9875b4b02..c0f61a2bd 100755 --- a/scripts/download-nsfw.sh +++ b/scripts/download-nsfw.sh @@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d) MODEL_NAME="NSFW" MODEL_URL="https://dl.photoprism.app/tensorflow/nsfw.zip?$TODAY" -MODEL_PATH="assets/nsfw" +MODELS_PATH="assets/models" +MODEL_PATH="$MODELS_PATH/nsfw" MODEL_ZIP="/tmp/photoprism/nsfw.zip" MODEL_HASH="2e03ad3c6aec27c270c650d0574ff2a6291d992b $MODEL_ZIP" MODEL_VERSION="$MODEL_PATH/version.txt" @@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then fi # Unzip model -unzip ${MODEL_ZIP} -d assets +unzip ${MODEL_ZIP} -d "$MODELS_PATH" echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION} echo "Latest $MODEL_NAME installed."