AI: Add TensorFlow utility package and improve model loading #127 #1090

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-04-07 05:26:45 +02:00
parent 35e9294d87
commit bfdb839d01
17 changed files with 421 additions and 261 deletions

1
.gitignore vendored
View File

@@ -50,6 +50,7 @@ frontend/coverage/
/assets/nsfw
/assets/static/build/
/assets/*net
/assets/vision
/pro
/plus

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

BIN
assets/examples/green.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -1,14 +1,11 @@
package classify
import (
"bufio"
"bytes"
"fmt"
"image"
"math"
"os"
"path"
"path/filepath"
"runtime/debug"
"sort"
"strings"
@@ -17,7 +14,7 @@ import (
"github.com/disintegration/imaging"
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
)
// Model represents a TensorFlow classification model.
@@ -82,7 +79,7 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
return nil, loadErr
}
// Create tensor from image.
// Create input tensor from image.
tensor, err := m.createTensor(img)
if err != nil {
@@ -112,37 +109,16 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
if len(result) > 0 {
log.Tracef("classify: image classified as %+v", result)
} else {
result = Labels{}
}
return result, nil
}
func (m *Model) loadLabels(path string) error {
modelLabels := path + "/labels.txt"
log.Infof("classify: loading labels from labels.txt")
// Load labels
f, err := os.Open(modelLabels)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
m.labels = append(m.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil
func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath)
return err
}
// ModelLoaded tests if the TensorFlow model is loaded.
@@ -150,7 +126,9 @@ func (m *Model) ModelLoaded() bool {
return m.model != nil
}
func (m *Model) loadModel() error {
func (m *Model) loadModel() (err error) {
// Use mutex to prevent the model from being loaded and
// initialized twice by different indexing workers.
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -160,16 +138,7 @@ func (m *Model) loadModel() error {
modelPath := path.Join(m.assetsPath, m.modelPath)
log.Infof("classify: loading %s", clean.Log(filepath.Base(modelPath)))
// Load model
model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
if err != nil {
return err
}
m.model = model
m.model, err = tensorflow.SavedModel(modelPath, m.modelTags)
return m.loadLabels(modelPath)
}
@@ -184,8 +153,10 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
break
}
confidence := int(math.Round(float64(p * 100)))
// discard labels with low probabilities
if p < 0.1 {
if confidence < confidenceThreshold {
continue
}
@@ -204,12 +175,7 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
}
labelText = strings.TrimSpace(labelText)
confidence := int(math.Round(float64(p * 100)))
if confidence >= confidenceThreshold {
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories})
}
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories})
}
// Sort by probability
@@ -231,42 +197,7 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
return nil, err
}
width, height := m.resolution, m.resolution
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
img = imaging.Fill(img, width, height, imaging.Center, imaging.Lanczos)
return imageToTensor(img, width, height)
}
func imageToTensor(img image.Image, imageHeight, imageWidth int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("classify: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if imageHeight <= 0 || imageWidth <= 0 {
return tfTensor, fmt.Errorf("classify: image width and height must be > 0")
}
var tfImage [1][][][3]float32
for j := 0; j < imageHeight; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
}
for i := 0; i < imageWidth; i++ {
for j := 0; j < imageHeight; j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r)
tfImage[0][j][i][1] = convertValue(g)
tfImage[0][j][i][2] = convertValue(b)
}
}
return tf.NewTensor(tfImage)
}
func convertValue(value uint32) float32 {
return (float32(value>>8) - float32(127.5)) / float32(127.5)
return tensorflow.Image(img, m.resolution)
}

View File

@@ -6,7 +6,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/fs"
)
@@ -31,24 +30,66 @@ func NewModelTest(t *testing.T) *Model {
func TestModel_LabelsFromFile(t *testing.T) {
t.Run("chameleon_lime.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
t.Log(result)
if len(result) > 0 {
t.Logf("result: %#v", result[0])
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty)
}
})
t.Run("cat_224.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
assert.Equal(t, 7, result[0].Uncertainty)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 59, result[0].Uncertainty)
}
})
t.Run("cat_720.jpeg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 3, len(result))
// t.Logf("labels: %#v", result)
if len(result) > 0 {
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
}
})
t.Run("green.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
t.Logf("labels: %#v", result)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 1, len(result))
if len(result) > 0 {
assert.Equal(t, "outdoor", result[0].Name)
assert.Equal(t, 70, result[0].Uncertainty)
}
})
t.Run("not existing file", func(t *testing.T) {
tensorFlow := NewModelTest(t)
@@ -180,11 +221,13 @@ func TestModel_LoadModel(t *testing.T) {
})
t.Run("model path does not exist", func(t *testing.T) {
tensorFlow := NewNasnet(assetsPath+"foo", false)
if err := tensorFlow.loadModel(); err != nil {
assert.Contains(t, err.Error(), "Could not find SavedModel")
} else {
t.Fatal("err should NOT be nil")
err := tensorFlow.loadModel()
if err != nil {
assert.Contains(t, err.Error(), "no such file or directory")
}
assert.Error(t, err)
})
}
@@ -218,35 +261,3 @@ func TestModel_BestLabels(t *testing.T) {
t.Log(result)
})
}
func TestModel_MakeTensor(t *testing.T) {
t.Run("cat_brown.jpg", func(t *testing.T) {
tensorFlow := NewModelTest(t)
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
if err != nil {
t.Fatal(err)
}
result, err := tensorFlow.createTensor(imageBuffer)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2])
})
t.Run("Random.docx", func(t *testing.T) {
tensorFlow := NewModelTest(t)
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err)
result, err := tensorFlow.createTensor(imageBuffer)
assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format")
})
}
func Test_convertValue(t *testing.T) {
result := convertValue(uint32(98765432))
assert.Equal(t, float32(3024.898), result)
}

View File

@@ -32,7 +32,7 @@ func NewModel(modelPath, cachePath string, disabled bool) *Model {
}
// Detect runs the detection and facenet algorithms over the provided source image.
func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) {
func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) {
faces, err = Detect(fileName, false, minSize)
if err != nil {
@@ -40,13 +40,13 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
}
// Skip FaceNet?
if t.disabled {
if m.disabled {
return faces, nil
} else if c := len(faces); c == 0 || expected > 0 && c == expected {
return faces, nil
}
err = t.loadModel()
err = m.loadModel()
if err != nil {
return faces, err
@@ -59,7 +59,7 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
log.Errorf("faces: failed to decode image: %s", imgErr)
} else if embeddings := t.getEmbeddings(img); !embeddings.Empty() {
} else if embeddings := m.getEmbeddings(img); !embeddings.Empty() {
faces[i].Embeddings = embeddings
}
}
@@ -68,38 +68,40 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
}
// ModelLoaded tests if the TensorFlow model is loaded.
func (t *Model) ModelLoaded() bool {
return t.model != nil
func (m *Model) ModelLoaded() bool {
return m.model != nil
}
// loadModel loads the TensorFlow model.
func (t *Model) loadModel() error {
t.mutex.Lock()
defer t.mutex.Unlock()
func (m *Model) loadModel() error {
// Use mutex to prevent the model from being loaded and
// initialized twice by different indexing workers.
m.mutex.Lock()
defer m.mutex.Unlock()
if t.ModelLoaded() {
if m.ModelLoaded() {
return nil
}
modelPath := path.Join(t.modelPath)
modelPath := path.Join(m.modelPath)
log.Infof("faces: loading %s", clean.Log(filepath.Base(modelPath)))
// Load model
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil)
model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
if err != nil {
return err
}
t.model = model
m.model = model
return nil
}
// getEmbeddings returns the face embeddings for an image.
func (t *Model) getEmbeddings(img image.Image) Embeddings {
tensor, err := imageToTensor(img, t.resolution)
func (m *Model) getEmbeddings(img image.Image) Embeddings {
tensor, err := imageToTensor(img, m.resolution)
if err != nil {
log.Errorf("faces: failed to convert image to tensor: %s", err)
@@ -109,13 +111,13 @@ func (t *Model) getEmbeddings(img image.Image) Embeddings {
trainPhaseBoolTensor, err := tf.NewTensor(false)
output, err := t.model.Session.Run(
output, err := m.model.Session.Run(
map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input").Output(0): tensor,
t.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
m.model.Graph.Operation("input").Output(0): tensor,
m.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
},
[]tf.Output{
t.model.Graph.Operation("embeddings").Output(0),
m.model.Graph.Operation("embeddings").Output(0),
},
nil)

View File

@@ -1,25 +1,19 @@
package nsfw
import (
"bufio"
"fmt"
"os"
"path/filepath"
"sync"
tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/media/http/header"
)
const (
Mean = float32(117)
Scale = float32(1)
)
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
type Model struct {
model *tf.SavedModel
@@ -36,7 +30,7 @@ func NewModel(modelPath string) *Model {
}
// File returns matching labels for a jpeg media file.
func (t *Model) File(filename string) (result Labels, err error) {
func (m *Model) File(filename string) (result Labels, err error) {
if fs.MimeType(filename) != header.ContentTypeJpeg {
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename)))
}
@@ -47,29 +41,29 @@ func (t *Model) File(filename string) (result Labels, err error) {
return result, err
}
return t.Labels(imageBuffer)
return m.Labels(imageBuffer)
}
// Labels returns matching labels for a jpeg media string.
func (t *Model) Labels(img []byte) (result Labels, err error) {
if loadErr := t.loadModel(); loadErr != nil {
func (m *Model) Labels(img []byte) (result Labels, err error) {
if loadErr := m.loadModel(); loadErr != nil {
return result, loadErr
}
// Make tensor
tensor, err := createTensorFromImage(img, "jpeg", t.resolution)
// Create input tensor from image.
input, err := tensorflow.ImageTransform(img, fs.ImageJpeg, m.resolution)
if err != nil {
return result, fmt.Errorf("nsfw: %s", err)
}
// Run inference
output, err := t.model.Session.Run(
// Run inference.
output, err := m.model.Session.Run(
map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input_tensor").Output(0): tensor,
m.model.Graph.Operation("input_tensor").Output(0): input,
},
[]tf.Output{
t.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0),
m.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0),
},
nil)
@@ -81,66 +75,45 @@ func (t *Model) Labels(img []byte) (result Labels, err error) {
return result, fmt.Errorf("nsfw: inference failed, no output")
}
// Return best labels
result = t.getLabels(output[0].Value().([][]float32)[0])
// Return best labels.
result = m.getLabels(output[0].Value().([][]float32)[0])
log.Tracef("nsfw: image classified as %+v", result)
return result, nil
}
func (t *Model) loadLabels(path string) error {
modelLabels := path + "/labels.txt"
log.Infof("nsfw: loading labels from labels.txt")
// Load labels
f, err := os.Open(modelLabels)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
t.labels = append(t.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
func (m *Model) loadLabels(modelPath string) (err error) {
m.labels, err = tensorflow.LoadLabels(modelPath)
return nil
}
func (t *Model) loadModel() error {
t.mutex.Lock()
defer t.mutex.Unlock()
func (m *Model) loadModel() error {
// Use mutex to prevent the model from being loaded and
// initialized twice by different indexing workers.
m.mutex.Lock()
defer m.mutex.Unlock()
if t.model != nil {
if m.model != nil {
// Already loaded
return nil
}
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(t.modelPath)))
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(m.modelPath)))
// Load model
model, err := tf.LoadSavedModel(t.modelPath, t.modelTags, nil)
// Load saved TensorFlow model from the specified path.
model, err := tensorflow.SavedModel(m.modelPath, m.modelTags)
if err != nil {
return err
}
t.model = model
m.model = model
return t.loadLabels(t.modelPath)
return m.loadLabels(m.modelPath)
}
func (t *Model) getLabels(p []float32) Labels {
func (m *Model) getLabels(p []float32) Labels {
return Labels{
Drawing: p[0],
Hentai: p[1],
@@ -149,56 +122,3 @@ func (t *Model) getLabels(p []float32) Labels {
Sexy: p[4],
}
}
func transformImageGraph(imageFormat string, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
var H, W = int32(resolution), int32(resolution)
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Decode PNG or JPEG
var decode tf.Output
if imageFormat == "png" {
decode = op.DecodePng(s, input, op.DecodePngChannels(3))
} else {
decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
// Div and Sub perform (value-Mean)/Scale for each pixel
output = op.Div(s,
op.Sub(s,
// Resize to 224x224 with bilinear interpolation
op.ResizeBilinear(s,
// Create a batch containing a single image
op.ExpandDims(s,
// Use decoded pixel values
op.Cast(s, decode, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
func createTensorFromImage(image []byte, imageFormat string, resolution int) (*tf.Tensor, error) {
tensor, err := tf.NewTensor(string(image))
if err != nil {
return nil, err
}
graph, input, output, err := transformImageGraph(imageFormat, resolution)
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}

View File

@@ -0,0 +1,145 @@
package tensorflow
import (
"bytes"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"runtime/debug"
tf "github.com/wamuir/graft/tensorflow"
"github.com/wamuir/graft/tensorflow/op"
"github.com/photoprism/photoprism/pkg/fs"
)
const (
Mean = float32(117)
Scale = float32(1)
)
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
if img, err := OpenImage(fileName); err != nil {
return nil, err
} else {
return Image(img, resolution)
}
}
func OpenImage(fileName string) (image.Image, error) {
f, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer f.Close()
img, _, err := image.Decode(f)
return img, err
}
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
img, _, imgErr := image.Decode(bytes.NewReader(b))
if imgErr != nil {
return nil, imgErr
}
return Image(img, resolution)
}
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if resolution <= 0 {
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
}
var tfImage [1][][][3]float32
for j := 0; j < resolution; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
}
for i := 0; i < resolution; i++ {
for j := 0; j < resolution; j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r, 127.5)
tfImage[0][j][i][1] = convertValue(g, 127.5)
tfImage[0][j][i][2] = convertValue(b, 127.5)
}
}
return tf.NewTensor(tfImage)
}
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
func ImageTransform(image []byte, imageFormat fs.Type, resolution int) (*tf.Tensor, error) {
tensor, err := tf.NewTensor(string(image))
if err != nil {
return nil, err
}
graph, input, output, err := transformImageGraph(imageFormat, resolution)
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}
func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Assume the image is a JPEG, or a PNG if explicitly specified.
var decodedImage tf.Output
switch imageFormat {
case fs.ImagePng:
decodedImage = op.DecodePng(s, input, op.DecodePngChannels(3))
default:
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
output = op.Div(s,
op.Sub(s,
op.ResizeBilinear(s,
op.ExpandDims(s,
op.Cast(s, decodedImage, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
func convertValue(value uint32, mean float32) float32 {
if mean == 0 {
mean = 127.5
}
return (float32(value>>8) - mean) / mean
}

View File

@@ -0,0 +1,42 @@
package tensorflow
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/fs"
)
func TestConvertValue(t *testing.T) {
result := convertValue(uint32(98765432), 127.5)
assert.Equal(t, float32(3024.898), result)
}
func TestImageFromBytes(t *testing.T) {
var assetsPath = fs.Abs("../../../assets")
var examplesPath = assetsPath + "/examples"
t.Run("CatJpeg", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
if err != nil {
t.Fatal(err)
}
result, err := ImageFromBytes(imageBuffer, 224)
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
assert.Equal(t, int64(1), result.Shape()[0])
assert.Equal(t, int64(224), result.Shape()[2])
})
t.Run("Document", func(t *testing.T) {
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
assert.Nil(t, err)
result, err := ImageFromBytes(imageBuffer, 224)
assert.Empty(t, result)
assert.EqualError(t, err, "image: unknown format")
})
}

View File

@@ -0,0 +1,32 @@
package tensorflow
import (
"bufio"
"os"
)
// LoadLabels loads the labels of classification models from the specified path and returns them.
func LoadLabels(modelPath string) (labels []string, err error) {
modelLabels := modelPath + "/labels.txt"
log.Infof("tensorflow: loading model labels from labels.txt")
f, err := os.Open(modelLabels)
if err != nil {
return labels, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
labels = append(labels, scanner.Text())
}
err = scanner.Err()
return labels, err
}

View File

@@ -0,0 +1,20 @@
package tensorflow
import (
"path/filepath"
tf "github.com/wamuir/graft/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
)
// SavedModel loads a saved TensorFlow model from the specified path.
func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err error) {
log.Infof("tensorflow: loading %s", clean.Log(filepath.Base(modelPath)))
if len(tags) == 0 {
tags = []string{"serve"}
}
return tf.LoadSavedModel(modelPath, tags, nil)
}

View File

@@ -0,0 +1,31 @@
/*
Package tensorflow provides TensorFlow utility functions.
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
This program is free software: you can redistribute it and/or modify
it under Version 3 of the GNU Affero General Public License (the "AGPL"):
<https://docs.photoprism.app/license/agpl>
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
The AGPL is supplemented by our Trademark and Brand Guidelines,
which describe how our Brand Assets may be used:
<https://www.photoprism.app/trademark>
Feel free to send an email to hello@photoprism.app if you have questions,
want to support our work, or just want to say hello.
Additional information can be found in our Developer Guide:
<https://docs.photoprism.app/developer-guide/>
*/
package tensorflow
import (
"github.com/photoprism/photoprism/internal/event"
)
var log = event.Log

View File

@@ -51,7 +51,7 @@ func Labels(thumbnails []string) (result classify.Labels, err error) {
}
if !found {
result = append(result, labels...)
result = append(result, labels[j])
}
}
}

View File

@@ -25,6 +25,19 @@ func TestLabels(t *testing.T) {
assert.Equal(t, "chameleon", result[0].Name)
assert.Equal(t, 7, result[0].Uncertainty)
})
t.Run("Cats", func(t *testing.T) {
result, err := Labels([]string{examplesPath + "/cat_720.jpeg"})
assert.NoError(t, err)
assert.IsType(t, classify.Labels{}, result)
assert.Equal(t, 1, len(result))
t.Log(result)
assert.Equal(t, "cat", result[0].Name)
assert.Equal(t, 60, result[0].Uncertainty)
assert.Equal(t, 40, result[0].Confidence())
})
t.Run("InvalidFile", func(t *testing.T) {
_, err := Labels([]string{examplesPath + "/notexisting.jpg"})
assert.Error(t, err)

View File

@@ -28,9 +28,12 @@ type Models []*Model
// ClassifyModel returns the matching classify model instance, if any.
func (m *Model) ClassifyModel() *classify.Model {
// Use mutex to prevent models from being loaded and
// initialized twice by different indexing workers.
modelMutex.Lock()
defer modelMutex.Unlock()
// Return the existing model instance if it has already been created.
if m.classifyModel != nil {
return m.classifyModel
}
@@ -40,6 +43,7 @@ func (m *Model) ClassifyModel() *classify.Model {
log.Warnf("vision: missing name, model instance cannot be created")
return nil
case NasnetModel.Name, "nasnet":
// Load and initialize the Nasnet image classification model.
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {
@@ -49,14 +53,22 @@ func (m *Model) ClassifyModel() *classify.Model {
m.classifyModel = model
}
default:
// Set model path from model name if no path is configured.
if m.Path == "" {
m.Path = clean.TypeLowerUnderscore(m.Name)
}
// Set default thumbnail resolution if no tags are configured.
if m.Resolution <= 0 {
m.Resolution = DefaultResolution
}
// Set default tag if no tags are configured.
if len(m.Tags) == 0 {
m.Tags = []string{"serve"}
}
// Try to load custom model based on the configuration values.
if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil {
return nil
} else if err := model.Init(); err != nil {

View File

@@ -12,5 +12,5 @@ var DefaultResolution = 224
// NasnetModel is a standard TensorFlow model used for label generation.
var (
NasnetModel = &Model{Name: "Nasnet", Resolution: 224, Tags: []string{"photoprism"}}
NasnetModel = &Model{Name: "Nasnet", Version: "Mobile", Resolution: 224, Tags: []string{"photoprism"}}
)