mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 21:01:58 +08:00
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -50,6 +50,7 @@ frontend/coverage/
|
||||
/assets/nsfw
|
||||
/assets/static/build/
|
||||
/assets/*net
|
||||
/assets/vision
|
||||
/pro
|
||||
/plus
|
||||
|
||||
|
BIN
assets/examples/cat_224.jpeg
Normal file
BIN
assets/examples/cat_224.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
BIN
assets/examples/cat_720.jpeg
Normal file
BIN
assets/examples/cat_720.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 98 KiB |
BIN
assets/examples/green.jpg
Normal file
BIN
assets/examples/green.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
@@ -1,14 +1,11 @@
|
||||
package classify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -17,7 +14,7 @@ import (
|
||||
"github.com/disintegration/imaging"
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
)
|
||||
|
||||
// Model represents a TensorFlow classification model.
|
||||
@@ -82,7 +79,7 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
|
||||
return nil, loadErr
|
||||
}
|
||||
|
||||
// Create tensor from image.
|
||||
// Create input tensor from image.
|
||||
tensor, err := m.createTensor(img)
|
||||
|
||||
if err != nil {
|
||||
@@ -112,37 +109,16 @@ func (m *Model) Labels(img []byte, confidenceThreshold int) (result Labels, err
|
||||
|
||||
if len(result) > 0 {
|
||||
log.Tracef("classify: image classified as %+v", result)
|
||||
} else {
|
||||
result = Labels{}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) loadLabels(path string) error {
|
||||
modelLabels := path + "/labels.txt"
|
||||
|
||||
log.Infof("classify: loading labels from labels.txt")
|
||||
|
||||
// Load labels
|
||||
f, err := os.Open(modelLabels)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
|
||||
// Labels are separated by newlines
|
||||
for scanner.Scan() {
|
||||
m.labels = append(m.labels, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
||||
return err
|
||||
}
|
||||
|
||||
// ModelLoaded tests if the TensorFlow model is loaded.
|
||||
@@ -150,7 +126,9 @@ func (m *Model) ModelLoaded() bool {
|
||||
return m.model != nil
|
||||
}
|
||||
|
||||
func (m *Model) loadModel() error {
|
||||
func (m *Model) loadModel() (err error) {
|
||||
// Use mutex to prevent the model from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -160,16 +138,7 @@ func (m *Model) loadModel() error {
|
||||
|
||||
modelPath := path.Join(m.assetsPath, m.modelPath)
|
||||
|
||||
log.Infof("classify: loading %s", clean.Log(filepath.Base(modelPath)))
|
||||
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.model = model
|
||||
m.model, err = tensorflow.SavedModel(modelPath, m.modelTags)
|
||||
|
||||
return m.loadLabels(modelPath)
|
||||
}
|
||||
@@ -184,8 +153,10 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
|
||||
break
|
||||
}
|
||||
|
||||
confidence := int(math.Round(float64(p * 100)))
|
||||
|
||||
// discard labels with low probabilities
|
||||
if p < 0.1 {
|
||||
if confidence < confidenceThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -204,12 +175,7 @@ func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Lab
|
||||
}
|
||||
|
||||
labelText = strings.TrimSpace(labelText)
|
||||
|
||||
confidence := int(math.Round(float64(p * 100)))
|
||||
|
||||
if confidence >= confidenceThreshold {
|
||||
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories})
|
||||
}
|
||||
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories})
|
||||
}
|
||||
|
||||
// Sort by probability
|
||||
@@ -231,42 +197,7 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
width, height := m.resolution, m.resolution
|
||||
img = imaging.Fill(img, m.resolution, m.resolution, imaging.Center, imaging.Lanczos)
|
||||
|
||||
img = imaging.Fill(img, width, height, imaging.Center, imaging.Lanczos)
|
||||
|
||||
return imageToTensor(img, width, height)
|
||||
}
|
||||
|
||||
func imageToTensor(img image.Image, imageHeight, imageWidth int) (tfTensor *tf.Tensor, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("classify: %s (panic)\nstack: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
if imageHeight <= 0 || imageWidth <= 0 {
|
||||
return tfTensor, fmt.Errorf("classify: image width and height must be > 0")
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
|
||||
for j := 0; j < imageHeight; j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
|
||||
}
|
||||
|
||||
for i := 0; i < imageWidth; i++ {
|
||||
for j := 0; j < imageHeight; j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
tfImage[0][j][i][0] = convertValue(r)
|
||||
tfImage[0][j][i][1] = convertValue(g)
|
||||
tfImage[0][j][i][2] = convertValue(b)
|
||||
}
|
||||
}
|
||||
|
||||
return tf.NewTensor(tfImage)
|
||||
}
|
||||
|
||||
func convertValue(value uint32) float32 {
|
||||
return (float32(value>>8) - float32(127.5)) / float32(127.5)
|
||||
return tensorflow.Image(img, m.resolution)
|
||||
}
|
||||
|
@@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
@@ -31,24 +30,66 @@ func NewModelTest(t *testing.T) *Model {
|
||||
func TestModel_LabelsFromFile(t *testing.T) {
|
||||
t.Run("chameleon_lime.jpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
t.Log(result)
|
||||
if len(result) > 0 {
|
||||
t.Logf("result: %#v", result[0])
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
assert.Equal(t, 7, result[0].Uncertainty)
|
||||
}
|
||||
})
|
||||
t.Run("cat_224.jpeg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
|
||||
|
||||
assert.Equal(t, 7, result[0].Uncertainty)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, "cat", result[0].Name)
|
||||
|
||||
assert.Equal(t, 59, result[0].Uncertainty)
|
||||
}
|
||||
})
|
||||
t.Run("cat_720.jpeg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 3, len(result))
|
||||
|
||||
// t.Logf("labels: %#v", result)
|
||||
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, "cat", result[0].Name)
|
||||
assert.Equal(t, 60, result[0].Uncertainty)
|
||||
}
|
||||
})
|
||||
t.Run("green.jpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
|
||||
|
||||
t.Logf("labels: %#v", result)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
if len(result) > 0 {
|
||||
assert.Equal(t, "outdoor", result[0].Name)
|
||||
|
||||
assert.Equal(t, 70, result[0].Uncertainty)
|
||||
}
|
||||
})
|
||||
t.Run("not existing file", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
@@ -180,11 +221,13 @@ func TestModel_LoadModel(t *testing.T) {
|
||||
})
|
||||
t.Run("model path does not exist", func(t *testing.T) {
|
||||
tensorFlow := NewNasnet(assetsPath+"foo", false)
|
||||
if err := tensorFlow.loadModel(); err != nil {
|
||||
assert.Contains(t, err.Error(), "Could not find SavedModel")
|
||||
} else {
|
||||
t.Fatal("err should NOT be nil")
|
||||
err := tensorFlow.loadModel()
|
||||
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "no such file or directory")
|
||||
}
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -218,35 +261,3 @@ func TestModel_BestLabels(t *testing.T) {
|
||||
t.Log(result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModel_MakeTensor(t *testing.T) {
|
||||
t.Run("cat_brown.jpg", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := tensorFlow.createTensor(imageBuffer)
|
||||
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
||||
assert.Equal(t, int64(1), result.Shape()[0])
|
||||
assert.Equal(t, int64(224), result.Shape()[2])
|
||||
})
|
||||
t.Run("Random.docx", func(t *testing.T) {
|
||||
tensorFlow := NewModelTest(t)
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||
assert.Nil(t, err)
|
||||
result, err := tensorFlow.createTensor(imageBuffer)
|
||||
|
||||
assert.Empty(t, result)
|
||||
assert.EqualError(t, err, "image: unknown format")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_convertValue(t *testing.T) {
|
||||
result := convertValue(uint32(98765432))
|
||||
assert.Equal(t, float32(3024.898), result)
|
||||
}
|
||||
|
@@ -32,7 +32,7 @@ func NewModel(modelPath, cachePath string, disabled bool) *Model {
|
||||
}
|
||||
|
||||
// Detect runs the detection and facenet algorithms over the provided source image.
|
||||
func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) {
|
||||
func (m *Model) Detect(fileName string, minSize int, cacheCrop bool, expected int) (faces Faces, err error) {
|
||||
faces, err = Detect(fileName, false, minSize)
|
||||
|
||||
if err != nil {
|
||||
@@ -40,13 +40,13 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
|
||||
}
|
||||
|
||||
// Skip FaceNet?
|
||||
if t.disabled {
|
||||
if m.disabled {
|
||||
return faces, nil
|
||||
} else if c := len(faces); c == 0 || expected > 0 && c == expected {
|
||||
return faces, nil
|
||||
}
|
||||
|
||||
err = t.loadModel()
|
||||
err = m.loadModel()
|
||||
|
||||
if err != nil {
|
||||
return faces, err
|
||||
@@ -59,7 +59,7 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
|
||||
|
||||
if img, imgErr := crop.ImageFromThumb(fileName, f.CropArea(), CropSize, cacheCrop); imgErr != nil {
|
||||
log.Errorf("faces: failed to decode image: %s", imgErr)
|
||||
} else if embeddings := t.getEmbeddings(img); !embeddings.Empty() {
|
||||
} else if embeddings := m.getEmbeddings(img); !embeddings.Empty() {
|
||||
faces[i].Embeddings = embeddings
|
||||
}
|
||||
}
|
||||
@@ -68,38 +68,40 @@ func (t *Model) Detect(fileName string, minSize int, cacheCrop bool, expected in
|
||||
}
|
||||
|
||||
// ModelLoaded tests if the TensorFlow model is loaded.
|
||||
func (t *Model) ModelLoaded() bool {
|
||||
return t.model != nil
|
||||
func (m *Model) ModelLoaded() bool {
|
||||
return m.model != nil
|
||||
}
|
||||
|
||||
// loadModel loads the TensorFlow model.
|
||||
func (t *Model) loadModel() error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
func (m *Model) loadModel() error {
|
||||
// Use mutex to prevent the model from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if t.ModelLoaded() {
|
||||
if m.ModelLoaded() {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelPath := path.Join(t.modelPath)
|
||||
modelPath := path.Join(m.modelPath)
|
||||
|
||||
log.Infof("faces: loading %s", clean.Log(filepath.Base(modelPath)))
|
||||
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil)
|
||||
model, err := tf.LoadSavedModel(modelPath, m.modelTags, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.model = model
|
||||
m.model = model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEmbeddings returns the face embeddings for an image.
|
||||
func (t *Model) getEmbeddings(img image.Image) Embeddings {
|
||||
tensor, err := imageToTensor(img, t.resolution)
|
||||
func (m *Model) getEmbeddings(img image.Image) Embeddings {
|
||||
tensor, err := imageToTensor(img, m.resolution)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("faces: failed to convert image to tensor: %s", err)
|
||||
@@ -109,13 +111,13 @@ func (t *Model) getEmbeddings(img image.Image) Embeddings {
|
||||
|
||||
trainPhaseBoolTensor, err := tf.NewTensor(false)
|
||||
|
||||
output, err := t.model.Session.Run(
|
||||
output, err := m.model.Session.Run(
|
||||
map[tf.Output]*tf.Tensor{
|
||||
t.model.Graph.Operation("input").Output(0): tensor,
|
||||
t.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
|
||||
m.model.Graph.Operation("input").Output(0): tensor,
|
||||
m.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
|
||||
},
|
||||
[]tf.Output{
|
||||
t.model.Graph.Operation("embeddings").Output(0),
|
||||
m.model.Graph.Operation("embeddings").Output(0),
|
||||
},
|
||||
nil)
|
||||
|
||||
|
@@ -1,25 +1,19 @@
|
||||
package nsfw
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
"github.com/wamuir/graft/tensorflow/op"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/media/http/header"
|
||||
)
|
||||
|
||||
const (
|
||||
Mean = float32(117)
|
||||
Scale = float32(1)
|
||||
)
|
||||
|
||||
// Model uses TensorFlow to label drawing, hentai, neutral, porn and sexy images.
|
||||
type Model struct {
|
||||
model *tf.SavedModel
|
||||
@@ -36,7 +30,7 @@ func NewModel(modelPath string) *Model {
|
||||
}
|
||||
|
||||
// File returns matching labels for a jpeg media file.
|
||||
func (t *Model) File(filename string) (result Labels, err error) {
|
||||
func (m *Model) File(filename string) (result Labels, err error) {
|
||||
if fs.MimeType(filename) != header.ContentTypeJpeg {
|
||||
return result, fmt.Errorf("nsfw: %s is not a jpeg file", clean.Log(filepath.Base(filename)))
|
||||
}
|
||||
@@ -47,29 +41,29 @@ func (t *Model) File(filename string) (result Labels, err error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return t.Labels(imageBuffer)
|
||||
return m.Labels(imageBuffer)
|
||||
}
|
||||
|
||||
// Labels returns matching labels for a jpeg media string.
|
||||
func (t *Model) Labels(img []byte) (result Labels, err error) {
|
||||
if loadErr := t.loadModel(); loadErr != nil {
|
||||
func (m *Model) Labels(img []byte) (result Labels, err error) {
|
||||
if loadErr := m.loadModel(); loadErr != nil {
|
||||
return result, loadErr
|
||||
}
|
||||
|
||||
// Make tensor
|
||||
tensor, err := createTensorFromImage(img, "jpeg", t.resolution)
|
||||
// Create input tensor from image.
|
||||
input, err := tensorflow.ImageTransform(img, fs.ImageJpeg, m.resolution)
|
||||
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("nsfw: %s", err)
|
||||
}
|
||||
|
||||
// Run inference
|
||||
output, err := t.model.Session.Run(
|
||||
// Run inference.
|
||||
output, err := m.model.Session.Run(
|
||||
map[tf.Output]*tf.Tensor{
|
||||
t.model.Graph.Operation("input_tensor").Output(0): tensor,
|
||||
m.model.Graph.Operation("input_tensor").Output(0): input,
|
||||
},
|
||||
[]tf.Output{
|
||||
t.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0),
|
||||
m.model.Graph.Operation("nsfw_cls_model/final_prediction").Output(0),
|
||||
},
|
||||
nil)
|
||||
|
||||
@@ -81,66 +75,45 @@ func (t *Model) Labels(img []byte) (result Labels, err error) {
|
||||
return result, fmt.Errorf("nsfw: inference failed, no output")
|
||||
}
|
||||
|
||||
// Return best labels
|
||||
result = t.getLabels(output[0].Value().([][]float32)[0])
|
||||
// Return best labels.
|
||||
result = m.getLabels(output[0].Value().([][]float32)[0])
|
||||
|
||||
log.Tracef("nsfw: image classified as %+v", result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *Model) loadLabels(path string) error {
|
||||
modelLabels := path + "/labels.txt"
|
||||
|
||||
log.Infof("nsfw: loading labels from labels.txt")
|
||||
|
||||
// Load labels
|
||||
f, err := os.Open(modelLabels)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
|
||||
// Labels are separated by newlines
|
||||
for scanner.Scan() {
|
||||
t.labels = append(t.labels, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
m.labels, err = tensorflow.LoadLabels(modelPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Model) loadModel() error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
func (m *Model) loadModel() error {
|
||||
// Use mutex to prevent the model from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if t.model != nil {
|
||||
if m.model != nil {
|
||||
// Already loaded
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(t.modelPath)))
|
||||
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(m.modelPath)))
|
||||
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(t.modelPath, t.modelTags, nil)
|
||||
// Load saved TensorFlow model from the specified path.
|
||||
model, err := tensorflow.SavedModel(m.modelPath, m.modelTags)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.model = model
|
||||
m.model = model
|
||||
|
||||
return t.loadLabels(t.modelPath)
|
||||
return m.loadLabels(m.modelPath)
|
||||
}
|
||||
|
||||
func (t *Model) getLabels(p []float32) Labels {
|
||||
func (m *Model) getLabels(p []float32) Labels {
|
||||
return Labels{
|
||||
Drawing: p[0],
|
||||
Hentai: p[1],
|
||||
@@ -149,56 +122,3 @@ func (t *Model) getLabels(p []float32) Labels {
|
||||
Sexy: p[4],
|
||||
}
|
||||
}
|
||||
|
||||
func transformImageGraph(imageFormat string, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
|
||||
var H, W = int32(resolution), int32(resolution)
|
||||
|
||||
s := op.NewScope()
|
||||
input = op.Placeholder(s, tf.String)
|
||||
// Decode PNG or JPEG
|
||||
var decode tf.Output
|
||||
if imageFormat == "png" {
|
||||
decode = op.DecodePng(s, input, op.DecodePngChannels(3))
|
||||
} else {
|
||||
decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
|
||||
}
|
||||
// Div and Sub perform (value-Mean)/Scale for each pixel
|
||||
output = op.Div(s,
|
||||
op.Sub(s,
|
||||
// Resize to 224x224 with bilinear interpolation
|
||||
op.ResizeBilinear(s,
|
||||
// Create a batch containing a single image
|
||||
op.ExpandDims(s,
|
||||
// Use decoded pixel values
|
||||
op.Cast(s, decode, tf.Float),
|
||||
op.Const(s.SubScope("make_batch"), int32(0))),
|
||||
op.Const(s.SubScope("size"), []int32{H, W})),
|
||||
op.Const(s.SubScope("mean"), Mean)),
|
||||
op.Const(s.SubScope("scale"), Scale))
|
||||
graph, err = s.Finalize()
|
||||
return graph, input, output, err
|
||||
}
|
||||
|
||||
func createTensorFromImage(image []byte, imageFormat string, resolution int) (*tf.Tensor, error) {
|
||||
tensor, err := tf.NewTensor(string(image))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
graph, input, output, err := transformImageGraph(imageFormat, resolution)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
normalized, err := session.Run(
|
||||
map[tf.Output]*tf.Tensor{input: tensor},
|
||||
[]tf.Output{output},
|
||||
nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return normalized[0], nil
|
||||
}
|
||||
|
145
internal/ai/tensorflow/image.go
Normal file
145
internal/ai/tensorflow/image.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
"github.com/wamuir/graft/tensorflow/op"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
Mean = float32(117)
|
||||
Scale = float32(1)
|
||||
)
|
||||
|
||||
func ImageFromFile(fileName string, resolution int) (*tf.Tensor, error) {
|
||||
if img, err := OpenImage(fileName); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return Image(img, resolution)
|
||||
}
|
||||
}
|
||||
|
||||
func OpenImage(fileName string) (image.Image, error) {
|
||||
f, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
img, _, err := image.Decode(f)
|
||||
|
||||
return img, err
|
||||
}
|
||||
|
||||
func ImageFromBytes(b []byte, resolution int) (*tf.Tensor, error) {
|
||||
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
||||
|
||||
if imgErr != nil {
|
||||
return nil, imgErr
|
||||
}
|
||||
|
||||
return Image(img, resolution)
|
||||
}
|
||||
|
||||
func Image(img image.Image, resolution int) (tfTensor *tf.Tensor, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
if resolution <= 0 {
|
||||
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
|
||||
for j := 0; j < resolution; j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, resolution))
|
||||
}
|
||||
|
||||
for i := 0; i < resolution; i++ {
|
||||
for j := 0; j < resolution; j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
tfImage[0][j][i][0] = convertValue(r, 127.5)
|
||||
tfImage[0][j][i][1] = convertValue(g, 127.5)
|
||||
tfImage[0][j][i][2] = convertValue(b, 127.5)
|
||||
}
|
||||
}
|
||||
|
||||
return tf.NewTensor(tfImage)
|
||||
}
|
||||
|
||||
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
|
||||
func ImageTransform(image []byte, imageFormat fs.Type, resolution int) (*tf.Tensor, error) {
|
||||
tensor, err := tf.NewTensor(string(image))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph, input, output, err := transformImageGraph(imageFormat, resolution)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
normalized, err := session.Run(
|
||||
map[tf.Output]*tf.Tensor{input: tensor},
|
||||
[]tf.Output{output},
|
||||
nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return normalized[0], nil
|
||||
}
|
||||
|
||||
func transformImageGraph(imageFormat fs.Type, resolution int) (graph *tf.Graph, input, output tf.Output, err error) {
|
||||
s := op.NewScope()
|
||||
input = op.Placeholder(s, tf.String)
|
||||
|
||||
// Assume the image is a JPEG, or a PNG if explicitly specified.
|
||||
var decodedImage tf.Output
|
||||
switch imageFormat {
|
||||
case fs.ImagePng:
|
||||
decodedImage = op.DecodePng(s, input, op.DecodePngChannels(3))
|
||||
default:
|
||||
decodedImage = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
|
||||
}
|
||||
|
||||
output = op.Div(s,
|
||||
op.Sub(s,
|
||||
op.ResizeBilinear(s,
|
||||
op.ExpandDims(s,
|
||||
op.Cast(s, decodedImage, tf.Float),
|
||||
op.Const(s.SubScope("make_batch"), int32(0))),
|
||||
op.Const(s.SubScope("size"), []int32{int32(resolution), int32(resolution)})),
|
||||
op.Const(s.SubScope("mean"), Mean)),
|
||||
op.Const(s.SubScope("scale"), Scale))
|
||||
|
||||
graph, err = s.Finalize()
|
||||
|
||||
return graph, input, output, err
|
||||
}
|
||||
|
||||
func convertValue(value uint32, mean float32) float32 {
|
||||
if mean == 0 {
|
||||
mean = 127.5
|
||||
}
|
||||
|
||||
return (float32(value>>8) - mean) / mean
|
||||
}
|
42
internal/ai/tensorflow/image_test.go
Normal file
42
internal/ai/tensorflow/image_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
result := convertValue(uint32(98765432), 127.5)
|
||||
assert.Equal(t, float32(3024.898), result)
|
||||
}
|
||||
|
||||
func TestImageFromBytes(t *testing.T) {
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = assetsPath + "/examples"
|
||||
|
||||
t.Run("CatJpeg", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/cat_brown.jpg")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := ImageFromBytes(imageBuffer, 224)
|
||||
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
||||
assert.Equal(t, int64(1), result.Shape()[0])
|
||||
assert.Equal(t, int64(224), result.Shape()[2])
|
||||
})
|
||||
t.Run("Document", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||
assert.Nil(t, err)
|
||||
result, err := ImageFromBytes(imageBuffer, 224)
|
||||
|
||||
assert.Empty(t, result)
|
||||
assert.EqualError(t, err, "image: unknown format")
|
||||
})
|
||||
}
|
32
internal/ai/tensorflow/labels.go
Normal file
32
internal/ai/tensorflow/labels.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
)
|
||||
|
||||
// LoadLabels loads the labels of classification models from the specified path and returns them.
|
||||
func LoadLabels(modelPath string) (labels []string, err error) {
|
||||
modelLabels := modelPath + "/labels.txt"
|
||||
|
||||
log.Infof("tensorflow: loading model labels from labels.txt")
|
||||
|
||||
f, err := os.Open(modelLabels)
|
||||
|
||||
if err != nil {
|
||||
return labels, err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
|
||||
// Labels are separated by newlines
|
||||
for scanner.Scan() {
|
||||
labels = append(labels, scanner.Text())
|
||||
}
|
||||
|
||||
err = scanner.Err()
|
||||
|
||||
return labels, err
|
||||
}
|
20
internal/ai/tensorflow/model.go
Normal file
20
internal/ai/tensorflow/model.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
// SavedModel loads a saved TensorFlow model from the specified path.
|
||||
func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err error) {
|
||||
log.Infof("tensorflow: loading %s", clean.Log(filepath.Base(modelPath)))
|
||||
|
||||
if len(tags) == 0 {
|
||||
tags = []string{"serve"}
|
||||
}
|
||||
|
||||
return tf.LoadSavedModel(modelPath, tags, nil)
|
||||
}
|
31
internal/ai/tensorflow/tensorflow.go
Normal file
31
internal/ai/tensorflow/tensorflow.go
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
Package tensorflow provides TensorFlow utility functions.
|
||||
|
||||
Copyright (c) 2018 - 2025 PhotoPrism UG. All rights reserved.
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under Version 3 of the GNU Affero General Public License (the "AGPL"):
|
||||
<https://docs.photoprism.app/license/agpl>
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
The AGPL is supplemented by our Trademark and Brand Guidelines,
|
||||
which describe how our Brand Assets may be used:
|
||||
<https://www.photoprism.app/trademark>
|
||||
|
||||
Feel free to send an email to hello@photoprism.app if you have questions,
|
||||
want to support our work, or just want to say hello.
|
||||
|
||||
Additional information can be found in our Developer Guide:
|
||||
<https://docs.photoprism.app/developer-guide/>
|
||||
*/
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
)
|
||||
|
||||
var log = event.Log
|
@@ -51,7 +51,7 @@ func Labels(thumbnails []string) (result classify.Labels, err error) {
|
||||
}
|
||||
|
||||
if !found {
|
||||
result = append(result, labels...)
|
||||
result = append(result, labels[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -25,6 +25,19 @@ func TestLabels(t *testing.T) {
|
||||
assert.Equal(t, "chameleon", result[0].Name)
|
||||
assert.Equal(t, 7, result[0].Uncertainty)
|
||||
})
|
||||
t.Run("Cats", func(t *testing.T) {
|
||||
result, err := Labels([]string{examplesPath + "/cat_720.jpeg"})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, classify.Labels{}, result)
|
||||
assert.Equal(t, 1, len(result))
|
||||
|
||||
t.Log(result)
|
||||
|
||||
assert.Equal(t, "cat", result[0].Name)
|
||||
assert.Equal(t, 60, result[0].Uncertainty)
|
||||
assert.Equal(t, 40, result[0].Confidence())
|
||||
})
|
||||
t.Run("InvalidFile", func(t *testing.T) {
|
||||
_, err := Labels([]string{examplesPath + "/notexisting.jpg"})
|
||||
assert.Error(t, err)
|
||||
|
@@ -28,9 +28,12 @@ type Models []*Model
|
||||
|
||||
// ClassifyModel returns the matching classify model instance, if any.
|
||||
func (m *Model) ClassifyModel() *classify.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
// initialized twice by different indexing workers.
|
||||
modelMutex.Lock()
|
||||
defer modelMutex.Unlock()
|
||||
|
||||
// Return the existing model instance if it has already been created.
|
||||
if m.classifyModel != nil {
|
||||
return m.classifyModel
|
||||
}
|
||||
@@ -40,6 +43,7 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
log.Warnf("vision: missing name, model instance cannot be created")
|
||||
return nil
|
||||
case NasnetModel.Name, "nasnet":
|
||||
// Load and initialize the Nasnet image classification model.
|
||||
if model := classify.NewNasnet(AssetsPath, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
@@ -49,14 +53,22 @@ func (m *Model) ClassifyModel() *classify.Model {
|
||||
m.classifyModel = model
|
||||
}
|
||||
default:
|
||||
// Set model path from model name if no path is configured.
|
||||
if m.Path == "" {
|
||||
m.Path = clean.TypeLowerUnderscore(m.Name)
|
||||
}
|
||||
|
||||
// Set default thumbnail resolution if no tags are configured.
|
||||
if m.Resolution <= 0 {
|
||||
m.Resolution = DefaultResolution
|
||||
}
|
||||
|
||||
// Set default tag if no tags are configured.
|
||||
if len(m.Tags) == 0 {
|
||||
m.Tags = []string{"serve"}
|
||||
}
|
||||
|
||||
// Try to load custom model based on the configuration values.
|
||||
if model := classify.NewModel(AssetsPath, m.Path, m.Resolution, m.Tags, m.Disabled); model == nil {
|
||||
return nil
|
||||
} else if err := model.Init(); err != nil {
|
||||
|
@@ -12,5 +12,5 @@ var DefaultResolution = 224
|
||||
|
||||
// NasnetModel is a standard TensorFlow model used for label generation.
|
||||
var (
|
||||
NasnetModel = &Model{Name: "Nasnet", Resolution: 224, Tags: []string{"photoprism"}}
|
||||
NasnetModel = &Model{Name: "Nasnet", Version: "Mobile", Resolution: 224, Tags: []string{"photoprism"}}
|
||||
)
|
||||
|
Reference in New Issue
Block a user