mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 12:51:31 +08:00
Config: Change default vision model assets path to assets/models/ #127
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
/assets/imagenet
|
||||
/assets/resnet
|
||||
/assets/vision
|
||||
/assets/models
|
||||
/storage
|
||||
/build
|
||||
/photoprism
|
||||
|
4
Makefile
4
Makefile
@@ -73,6 +73,7 @@ pull: docker-pull
|
||||
test: test-js test-go
|
||||
test-go: reset-sqlite run-test-go
|
||||
test-pkg: reset-sqlite run-test-pkg
|
||||
test-ai: reset-sqlite run-test-ai
|
||||
test-api: reset-sqlite run-test-api
|
||||
test-video: reset-sqlite run-test-video
|
||||
test-entity: reset-sqlite run-test-entity
|
||||
@@ -395,6 +396,9 @@ run-test-mariadb:
|
||||
run-test-pkg:
|
||||
$(info Running all Go tests in "/pkg"...)
|
||||
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./pkg/...
|
||||
run-test-ai:
|
||||
$(info Running all AI tests...)
|
||||
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/ai/...
|
||||
run-test-api:
|
||||
$(info Running all API tests...)
|
||||
$(GOTEST) -parallel 2 -count 1 -cpu 2 -tags="slow,develop" -timeout 20m ./internal/api/...
|
||||
|
@@ -1,8 +1,8 @@
|
||||
examples
|
||||
efficientnet
|
||||
imagenet
|
||||
resnet
|
||||
vision
|
||||
models/_*
|
||||
models/dev*
|
||||
models/local*
|
||||
models/test*
|
||||
README.md
|
||||
docs
|
||||
.*
|
2
assets/models/.gitignore
vendored
Normal file
2
assets/models/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
@@ -24,8 +24,8 @@ import (
|
||||
// Model represents a TensorFlow classification model.
|
||||
type Model struct {
|
||||
model *tf.SavedModel
|
||||
modelPath string
|
||||
assetsPath string
|
||||
name string
|
||||
modelsPath string
|
||||
defaultLabelsPath string
|
||||
labels []string
|
||||
disabled bool
|
||||
@@ -34,14 +34,14 @@ type Model struct {
|
||||
}
|
||||
|
||||
// NewModel returns new TensorFlow classification model instance.
|
||||
func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
||||
func NewModel(modelsPath, name, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model {
|
||||
if meta == nil {
|
||||
meta = new(tensorflow.ModelInfo)
|
||||
}
|
||||
|
||||
return &Model{
|
||||
modelPath: modelPath,
|
||||
assetsPath: assetsPath,
|
||||
name: name,
|
||||
modelsPath: modelsPath,
|
||||
defaultLabelsPath: defaultLabelsPath,
|
||||
meta: meta,
|
||||
disabled: disabled,
|
||||
@@ -49,8 +49,8 @@ func NewModel(assetsPath, modelPath, defaultLabelsPath string, meta *tensorflow.
|
||||
}
|
||||
|
||||
// NewNasnet returns new Nasnet TensorFlow classification model instance.
|
||||
func NewNasnet(assetsPath string, disabled bool) *Model {
|
||||
return NewModel(assetsPath, "nasnet", "", &tensorflow.ModelInfo{
|
||||
func NewNasnet(modelsPath string, disabled bool) *Model {
|
||||
return NewModel(modelsPath, "nasnet", "", &tensorflow.ModelInfo{
|
||||
TFVersion: "1.12.0",
|
||||
Tags: []string{"photoprism"},
|
||||
Input: &tensorflow.PhotoInput{
|
||||
@@ -194,7 +194,7 @@ func (m *Model) loadModel() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelPath := path.Join(m.assetsPath, m.modelPath)
|
||||
modelPath := path.Join(m.modelsPath, m.name)
|
||||
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, modelErr := tensorflow.GetModelInfo(modelPath)
|
||||
|
@@ -259,7 +259,7 @@ func assertContainsAny(t *testing.T, s string, substrings []string) {
|
||||
|
||||
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
||||
testName := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
||||
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
|
||||
}
|
||||
|
||||
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
||||
@@ -364,7 +364,7 @@ func testModel_Run(t *testing.T, tensorFlow *Model) {
|
||||
}
|
||||
|
||||
testName := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", tensorFlow.modelPath, name)
|
||||
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
|
||||
}
|
||||
|
||||
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
||||
|
@@ -2,6 +2,7 @@ package classify
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -12,14 +13,15 @@ import (
|
||||
)
|
||||
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var modelPath = assetsPath + "/nasnet"
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
var examplesPath = filepath.Join(assetsPath, "examples")
|
||||
var modelsPath = filepath.Join(assetsPath, "models")
|
||||
var modelPath = modelsPath + "/nasnet"
|
||||
var once sync.Once
|
||||
var testInstance *Model
|
||||
|
||||
func NewModelTest(t *testing.T) *Model {
|
||||
once.Do(func() {
|
||||
testInstance = NewNasnet(assetsPath, false)
|
||||
testInstance = NewNasnet(modelsPath, false)
|
||||
if err := testInstance.loadModel(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -29,7 +31,7 @@ func NewModelTest(t *testing.T) *Model {
|
||||
}
|
||||
|
||||
func TestModel_CenterCrop(t *testing.T) {
|
||||
model := NewNasnet(assetsPath, false)
|
||||
model := NewNasnet(modelsPath, false)
|
||||
if err := model.loadModel(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -42,7 +44,7 @@ func TestModel_CenterCrop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_Padding(t *testing.T) {
|
||||
model := NewNasnet(assetsPath, false)
|
||||
model := NewNasnet(modelsPath, false)
|
||||
if err := model.loadModel(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,7 +57,7 @@ func TestModel_Padding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_ResizeBreakAspectRatio(t *testing.T) {
|
||||
model := NewNasnet(assetsPath, false)
|
||||
model := NewNasnet(modelsPath, false)
|
||||
if err := model.loadModel(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func TestModel_LabelsFromFile(t *testing.T) {
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, true)
|
||||
tensorFlow := NewNasnet(modelsPath, true)
|
||||
|
||||
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
||||
assert.Nil(t, err)
|
||||
@@ -253,7 +255,7 @@ func TestModel_Run(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, true)
|
||||
tensorFlow := NewNasnet(modelsPath, true)
|
||||
|
||||
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
@@ -277,7 +279,7 @@ func TestModel_LoadModel(t *testing.T) {
|
||||
assert.True(t, tf.ModelLoaded())
|
||||
})
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath+"foo", false)
|
||||
tensorFlow := NewNasnet(modelsPath+"foo", false)
|
||||
err := tensorFlow.loadModel()
|
||||
|
||||
if err != nil {
|
||||
@@ -290,7 +292,7 @@ func TestModel_LoadModel(t *testing.T) {
|
||||
|
||||
func TestModel_BestLabels(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, false)
|
||||
tensorFlow := NewNasnet(modelsPath, false)
|
||||
|
||||
if err := tensorFlow.loadLabels(modelPath); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -308,7 +310,7 @@ func TestModel_BestLabels(t *testing.T) {
|
||||
t.Log(result)
|
||||
})
|
||||
t.Run("NotLoaded", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath, false)
|
||||
tensorFlow := NewNasnet(modelsPath, false)
|
||||
|
||||
p := make([]float32, 1000)
|
||||
|
||||
|
@@ -19,7 +19,6 @@ import (
|
||||
// Model is a wrapper for the TensorFlow Facenet model.
|
||||
type Model struct {
|
||||
model *tf.SavedModel
|
||||
modelName string
|
||||
modelPath string
|
||||
cachePath string
|
||||
resolution int
|
||||
@@ -41,6 +40,7 @@ func NewModel(modelPath, cachePath string, resolution int, meta *tensorflow.Mode
|
||||
if len(meta.Tags) == 0 {
|
||||
meta.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
return &Model{
|
||||
modelPath: modelPath,
|
||||
cachePath: cachePath,
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/fs/fastwalk"
|
||||
)
|
||||
|
||||
var modelPath, _ = filepath.Abs("../../../assets/facenet")
|
||||
var modelPath, _ = filepath.Abs("../../../assets/models/facenet")
|
||||
|
||||
func TestNet(t *testing.T) {
|
||||
expected := map[string]int{
|
||||
|
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/fs/fastwalk"
|
||||
)
|
||||
|
||||
var modelPath, _ = filepath.Abs("../../../assets/nsfw")
|
||||
var modelPath, _ = filepath.Abs("../../../assets/models/nsfw")
|
||||
|
||||
var detector = NewModel(modelPath, nil, false)
|
||||
|
||||
|
@@ -2,6 +2,7 @@ package tensorflow
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -15,6 +16,9 @@ var defaultImageInput = &PhotoInput{
|
||||
Width: 224,
|
||||
}
|
||||
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = filepath.Join(assetsPath, "examples")
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
result := convertValue(uint32(98765432), &Interval{Start: -1, End: 1})
|
||||
assert.Equal(t, float32(3024.8982), result)
|
||||
@@ -29,9 +33,6 @@ func TestConvertStdMean(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestImageFromBytes(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("CatJpeg", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
|
||||
|
||||
|
@@ -5,14 +5,10 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
func TestNewApiRequest(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("Data", func(t *testing.T) {
|
||||
thumbnails := Files{examplesPath + "/chameleon_lime.jpg"}
|
||||
result, err := NewApiRequestImages(thumbnails, scheme.Data)
|
||||
|
@@ -2,17 +2,17 @@ package vision
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/scheme"
|
||||
)
|
||||
|
||||
var (
|
||||
AssetsPath = fs.Abs("../../../assets")
|
||||
FaceNetModelPath = fs.Abs("../../../assets/facenet")
|
||||
NsfwModelPath = fs.Abs("../../../assets/nsfw")
|
||||
CachePath = fs.Abs("../../../storage/cache")
|
||||
CachePath = ""
|
||||
ModelsPath = ""
|
||||
DownloadUrl = ""
|
||||
ServiceUri = ""
|
||||
ServiceKey = ""
|
||||
@@ -23,3 +23,65 @@ var (
|
||||
ServiceResponseFormat = ApiFormatVision
|
||||
DefaultResolution = 224
|
||||
)
|
||||
|
||||
// SetCachePath updates the cache path.
|
||||
func SetCachePath(dir string) {
|
||||
if dir = fs.Abs(dir); dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
CachePath = dir
|
||||
}
|
||||
|
||||
// GetCachePath returns the cache path.
|
||||
func GetCachePath() string {
|
||||
if CachePath != "" {
|
||||
return CachePath
|
||||
}
|
||||
|
||||
CachePath = fs.Abs("../../../storage/cache")
|
||||
|
||||
return CachePath
|
||||
}
|
||||
|
||||
// SetModelsPath updates the model assets path.
|
||||
func SetModelsPath(dir string) {
|
||||
if dir = fs.Abs(dir); dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ModelsPath = dir
|
||||
}
|
||||
|
||||
// GetModelsPath returns the model assets path, or an empty string if not configured or found.
|
||||
func GetModelsPath() string {
|
||||
if ModelsPath != "" {
|
||||
return ModelsPath
|
||||
}
|
||||
|
||||
assetsPath := fs.Abs("../../../assets")
|
||||
|
||||
if dir := filepath.Join(assetsPath, "models"); fs.PathExists(dir) {
|
||||
ModelsPath = dir
|
||||
} else if fs.PathExists(assetsPath) {
|
||||
ModelsPath = assetsPath
|
||||
}
|
||||
|
||||
return ModelsPath
|
||||
}
|
||||
|
||||
func GetModelPath(name string) string {
|
||||
return filepath.Join(GetModelsPath(), clean.Path(clean.TypeLowerUnderscore(name)))
|
||||
}
|
||||
|
||||
func GetNasnetModelPath() string {
|
||||
return GetModelPath(NasnetModel.Name)
|
||||
}
|
||||
|
||||
func GetFacenetModelPath() string {
|
||||
return GetModelPath(FacenetModel.Name)
|
||||
}
|
||||
|
||||
func GetNsfwModelPath() string {
|
||||
return GetModelPath(NsfwModel.Name)
|
||||
}
|
||||
|
@@ -7,14 +7,10 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media"
|
||||
)
|
||||
|
||||
func TestLabels(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
result, err := Labels(Files{examplesPath + "/chameleon_lime.jpg"}, media.SrcLocal, entity.SrcAuto)
|
||||
|
||||
|
@@ -2,7 +2,6 @@ package vision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -143,7 +142,7 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
return nil
|
||||
case NasnetModel.Name, "nasnet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
|
||||
if model := classify.NewNasnet(GetModelsPath(), m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init nasnet model)", err)
|
||||
@@ -154,7 +153,7 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
|
||||
}
|
||||
|
||||
if m.Meta == nil {
|
||||
@@ -173,8 +172,7 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
m.Meta.Input.SetResolution(m.Resolution)
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
defaultPath := filepath.Join(AssetsPath, "nasnet")
|
||||
if model := classify.NewModel(AssetsPath, m.Path, defaultPath, m.Meta, m.Disabled); model == nil {
|
||||
if model := classify.NewModel(GetModelsPath(), m.Path, GetNasnetModelPath(), m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -205,7 +203,7 @@ func (m *Model) FaceModel() *face.Model {
|
||||
return nil
|
||||
case FacenetModel.Name, "facenet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := face.NewModel(FaceNetModelPath, CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
if model := face.NewModel(GetFacenetModelPath(), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -216,7 +214,7 @@ func (m *Model) FaceModel() *face.Model {
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
@@ -229,7 +227,7 @@ func (m *Model) FaceModel() *face.Model {
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := face.NewModel(filepath.Join(AssetsPath, m.Path), CachePath, m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
if model := face.NewModel(GetModelPath(m.Path), GetCachePath(), m.Resolution, m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -260,7 +258,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
return nil
|
||||
case NsfwModel.Name, "nsfw":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := nsfw.NewModel(NsfwModelPath, NsfwModel.Meta, m.Disabled); model == nil {
|
||||
if model := nsfw.NewModel(GetNsfwModelPath(), NsfwModel.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
@@ -271,7 +269,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
m.Path = clean.Path(clean.TypeLowerUnderscore(m.Name))
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
@@ -290,7 +288,7 @@ func (m *Model) NsfwModel() *nsfw.Model {
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := nsfw.NewModel(filepath.Join(AssetsPath, m.Path), m.Meta, m.Disabled); model == nil {
|
||||
if model := nsfw.NewModel(GetModelPath(m.Path), m.Meta, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
log.Errorf("vision: %s (init %s)", err, m.Path)
|
||||
|
@@ -2,20 +2,25 @@ package vision
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/api/download"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = filepath.Join(assetsPath, "examples")
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Init test logger.
|
||||
log = logrus.StandardLogger()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
download.AllowedPaths = append(download.AllowedPaths, AssetsPath)
|
||||
download.AllowedPaths = append(download.AllowedPaths, assetsPath)
|
||||
|
||||
// Set test config values.
|
||||
DownloadUrl = "https://app.localssl.dev/api/v1/dl"
|
||||
|
@@ -291,10 +291,8 @@ func (c *Config) Propagate() {
|
||||
dl.FFprobeBin = c.FFprobeBin()
|
||||
|
||||
// Configure computer vision package.
|
||||
vision.AssetsPath = c.AssetsPath()
|
||||
vision.FaceNetModelPath = c.FaceNetModelPath()
|
||||
vision.NsfwModelPath = c.NSFWModelPath()
|
||||
vision.CachePath = c.CachePath()
|
||||
vision.SetCachePath(c.CachePath())
|
||||
vision.SetModelsPath(c.ModelsPath())
|
||||
vision.ServiceUri = c.VisionUri()
|
||||
vision.ServiceKey = c.VisionKey()
|
||||
vision.DownloadUrl = c.DownloadUrl()
|
||||
|
@@ -195,9 +195,9 @@ func (c *Config) CreateDirectories() error {
|
||||
return createError(dir, err)
|
||||
}
|
||||
|
||||
// Create TensorFlow model path if it doesn't exist yet.
|
||||
if dir := c.NasnetModelPath(); dir == "" {
|
||||
return notFoundError("tensorflow model")
|
||||
// Create computer vision models path if it doesn't exist yet.
|
||||
if dir := c.ModelsPath(); dir == "" {
|
||||
return notFoundError("models")
|
||||
} else if err := fs.MkdirAll(dir); err != nil {
|
||||
return createError(dir, err)
|
||||
}
|
||||
|
@@ -43,19 +43,35 @@ func (c *Config) VisionKey() string {
|
||||
}
|
||||
}
|
||||
|
||||
// ModelsPath returns the path where the machine learning models are located.
|
||||
func (c *Config) ModelsPath() string {
|
||||
if c.options.ModelsPath != "" {
|
||||
return fs.Abs(c.options.ModelsPath)
|
||||
}
|
||||
|
||||
if dir := filepath.Join(c.AssetsPath(), "models"); fs.PathExists(dir) {
|
||||
c.options.ModelsPath = dir
|
||||
return c.options.ModelsPath
|
||||
}
|
||||
|
||||
c.options.ModelsPath = fs.FindDir(fs.ModelsPaths)
|
||||
|
||||
return c.options.ModelsPath
|
||||
}
|
||||
|
||||
// NasnetModelPath returns the TensorFlow model path.
|
||||
func (c *Config) NasnetModelPath() string {
|
||||
return filepath.Join(c.AssetsPath(), "nasnet")
|
||||
return filepath.Join(c.ModelsPath(), "nasnet")
|
||||
}
|
||||
|
||||
// FaceNetModelPath returns the FaceNet model path.
|
||||
func (c *Config) FaceNetModelPath() string {
|
||||
return filepath.Join(c.AssetsPath(), "facenet")
|
||||
// FacenetModelPath returns the FaceNet model path.
|
||||
func (c *Config) FacenetModelPath() string {
|
||||
return filepath.Join(c.ModelsPath(), "facenet")
|
||||
}
|
||||
|
||||
// NSFWModelPath returns the "not safe for work" TensorFlow model path.
|
||||
func (c *Config) NSFWModelPath() string {
|
||||
return filepath.Join(c.AssetsPath(), "nsfw")
|
||||
// NsfwModelPath returns the "not safe for work" TensorFlow model path.
|
||||
func (c *Config) NsfwModelPath() string {
|
||||
return filepath.Join(c.ModelsPath(), "nsfw")
|
||||
}
|
||||
|
||||
// DetectNSFW checks if NSFW photos should be detected and flagged.
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -34,11 +35,12 @@ func TestConfig_VisionKey(t *testing.T) {
|
||||
assert.Equal(t, "", c.VisionKey())
|
||||
}
|
||||
|
||||
func TestConfig_TensorFlowModelPath(t *testing.T) {
|
||||
func TestConfig_ModelsPath(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
||||
path := c.NasnetModelPath()
|
||||
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/nasnet", path)
|
||||
assert.True(t, strings.HasPrefix(path, c.ModelsPath()))
|
||||
assert.Equal(t, "/go/src/github.com/photoprism/photoprism/assets/models/nasnet", path)
|
||||
}
|
||||
|
||||
func TestConfig_TensorFlowDisabled(t *testing.T) {
|
||||
@@ -51,13 +53,13 @@ func TestConfig_TensorFlowDisabled(t *testing.T) {
|
||||
func TestConfig_NSFWModelPath(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
||||
assert.Contains(t, c.NSFWModelPath(), "/assets/nsfw")
|
||||
assert.Contains(t, c.NsfwModelPath(), "/assets/models/nsfw")
|
||||
}
|
||||
|
||||
func TestConfig_FaceNetModelPath(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
||||
assert.Contains(t, c.FaceNetModelPath(), "/assets/facenet")
|
||||
assert.Contains(t, c.FacenetModelPath(), "/assets/models/facenet")
|
||||
}
|
||||
|
||||
func TestConfig_DetectNSFW(t *testing.T) {
|
||||
|
@@ -298,6 +298,12 @@ var Flags = CliFlags{
|
||||
EnvVars: EnvVars("ASSETS_PATH"),
|
||||
TakesFile: true,
|
||||
}}, {
|
||||
Flag: &cli.PathFlag{
|
||||
Name: "models-path",
|
||||
Usage: "custom model assets `PATH` where computer vision models are located",
|
||||
EnvVars: EnvVars("MODELS_PATH"),
|
||||
TakesFile: true,
|
||||
}}, {
|
||||
Flag: &cli.PathFlag{
|
||||
Name: "sidecar-path",
|
||||
Aliases: []string{"sc"},
|
||||
|
@@ -77,6 +77,7 @@ type Options struct {
|
||||
TempPath string `yaml:"TempPath" json:"-" flag:"temp-path"`
|
||||
AssetsPath string `yaml:"AssetsPath" json:"-" flag:"assets-path"`
|
||||
CustomAssetsPath string `yaml:"-" json:"-" flag:"custom-assets-path" tags:"plus,pro"`
|
||||
ModelsPath string `yaml:"ModelsPath" json:"-" flag:"models-path"`
|
||||
SidecarPath string `yaml:"SidecarPath" json:"-" flag:"sidecar-path"`
|
||||
SidecarYaml bool `yaml:"SidecarYaml" json:"SidecarYaml" flag:"sidecar-yaml" default:"true"`
|
||||
UsageInfo bool `yaml:"UsageInfo" json:"UsageInfo" flag:"usage-info"`
|
||||
|
@@ -76,6 +76,7 @@ func (c *Config) Report() (rows [][]string, cols []string) {
|
||||
{"thumb-cache-path", c.ThumbCachePath()},
|
||||
{"temp-path", c.TempPath()},
|
||||
{"assets-path", c.AssetsPath()},
|
||||
{"models-path", c.ModelsPath()},
|
||||
{"static-path", c.StaticPath()},
|
||||
{"build-path", c.BuildPath()},
|
||||
{"img-path", c.ImgPath()},
|
||||
@@ -256,8 +257,8 @@ func (c *Config) Report() (rows [][]string, cols []string) {
|
||||
{"vision-uri", c.VisionUri()},
|
||||
{"vision-key", strings.Repeat("*", utf8.RuneCountInString(c.VisionKey()))},
|
||||
{"nasnet-model-path", c.NasnetModelPath()},
|
||||
{"facenet-model-path", c.FaceNetModelPath()},
|
||||
{"nsfw-model-path", c.NSFWModelPath()},
|
||||
{"facenet-model-path", c.FacenetModelPath()},
|
||||
{"nsfw-model-path", c.NsfwModelPath()},
|
||||
{"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())},
|
||||
|
||||
// Facial Recognition.
|
||||
|
@@ -120,6 +120,16 @@ var AssetPaths = []string{
|
||||
"/var/lib/photoprism/assets",
|
||||
}
|
||||
|
||||
var ModelsPaths = []string{
|
||||
"/opt/photoprism/assets/models",
|
||||
"/photoprism/assets/models",
|
||||
"~/.photoprism/assets/models",
|
||||
"~/photoprism/assets/models",
|
||||
"photoprism/assets/models",
|
||||
"assets/models",
|
||||
"/var/lib/photoprism/assets/models",
|
||||
}
|
||||
|
||||
// Dirs returns a slice of directories in a path, optional recursively and with symlinks.
|
||||
//
|
||||
// Warning: Following symlinks can make the result non-deterministic and hard to test!
|
||||
|
@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
|
||||
|
||||
MODEL_NAME="Facenet"
|
||||
MODEL_URL="https://dl.photoprism.app/tensorflow/facenet.zip?$TODAY"
|
||||
MODEL_PATH="assets/facenet"
|
||||
MODELS_PATH="assets/models"
|
||||
MODEL_PATH="$MODELS_PATH/facenet"
|
||||
MODEL_ZIP="/tmp/photoprism/facenet.zip"
|
||||
MODEL_HASH="0492eb1d67789108b7eefb274e26633504b059be $MODEL_ZIP"
|
||||
MODEL_VERSION="$MODEL_PATH/version.txt"
|
||||
@@ -17,7 +18,7 @@ mkdir -p /tmp/photoprism
|
||||
mkdir -p storage/backup
|
||||
|
||||
# Check for update
|
||||
if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == ${MODEL_HASH} ]]; then
|
||||
if [[ -f ${MODEL_ZIP} ]] && [[ $(sha1sum ${MODEL_ZIP}) == "${MODEL_HASH}" ]]; then
|
||||
if [[ -f ${MODEL_VERSION} ]]; then
|
||||
echo "Already up to date."
|
||||
exit
|
||||
@@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then
|
||||
fi
|
||||
|
||||
# Unzip model
|
||||
unzip ${MODEL_ZIP} -d assets
|
||||
unzip ${MODEL_ZIP} -d "$MODELS_PATH"
|
||||
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
|
||||
|
||||
echo "Latest $MODEL_NAME installed."
|
||||
|
@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
|
||||
|
||||
MODEL_NAME="NASNet Mobile"
|
||||
MODEL_URL="https://dl.photoprism.app/tensorflow/nasnet.zip?$TODAY"
|
||||
MODEL_PATH="assets/nasnet"
|
||||
MODELS_PATH="assets/models"
|
||||
MODEL_PATH="$MODELS_PATH/nasnet"
|
||||
MODEL_ZIP="/tmp/photoprism/nasnet.zip"
|
||||
MODEL_HASH="f18b801354e95cade497b4f12e8d2537d04c04f6 $MODEL_ZIP"
|
||||
MODEL_VERSION="$MODEL_PATH/version.txt"
|
||||
@@ -41,7 +42,7 @@ if [[ -e ${MODEL_PATH} ]]; then
|
||||
fi
|
||||
|
||||
# Unzip model
|
||||
unzip ${MODEL_ZIP} -d assets
|
||||
unzip ${MODEL_ZIP} -d "$MODELS_PATH"
|
||||
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
|
||||
wget --inet4-only -c "${MODEL_21K_LABELS_URL}" -O ${MODEL_PATH}/labels21k.txt
|
||||
|
||||
|
@@ -4,7 +4,8 @@ TODAY=$(date -u +%Y%m%d)
|
||||
|
||||
MODEL_NAME="NSFW"
|
||||
MODEL_URL="https://dl.photoprism.app/tensorflow/nsfw.zip?$TODAY"
|
||||
MODEL_PATH="assets/nsfw"
|
||||
MODELS_PATH="assets/models"
|
||||
MODEL_PATH="$MODELS_PATH/nsfw"
|
||||
MODEL_ZIP="/tmp/photoprism/nsfw.zip"
|
||||
MODEL_HASH="2e03ad3c6aec27c270c650d0574ff2a6291d992b $MODEL_ZIP"
|
||||
MODEL_VERSION="$MODEL_PATH/version.txt"
|
||||
@@ -40,7 +41,7 @@ if [[ -e ${MODEL_PATH} ]]; then
|
||||
fi
|
||||
|
||||
# Unzip model
|
||||
unzip ${MODEL_ZIP} -d assets
|
||||
unzip ${MODEL_ZIP} -d "$MODELS_PATH"
|
||||
echo "$MODEL_NAME $TODAY $MODEL_HASH" > ${MODEL_VERSION}
|
||||
|
||||
echo "Latest $MODEL_NAME installed."
|
||||
|
Reference in New Issue
Block a user