mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 21:01:58 +08:00
464 lines
11 KiB
Go
464 lines
11 KiB
Go
package classify
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
|
)
|
|
|
|
const (
|
|
DefaultResolution = 224
|
|
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
|
|
)
|
|
|
|
var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
|
|
|
|
//To avoid downloading everything again and again...
|
|
//var baseUrl = "http://host.docker.internal:8000"
|
|
|
|
type ModelTestCase struct {
|
|
Info *tensorflow.ModelInfo
|
|
Labels string
|
|
}
|
|
|
|
var modelsInfo = map[string]*ModelTestCase{
|
|
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 480,
|
|
Width: 480,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
Labels: "labels-imagenet21k.txt",
|
|
},
|
|
"inception-v3-tensorflow2-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Input: &tensorflow.PhotoInput{
|
|
Height: 299,
|
|
Width: 299,
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
|
|
Info: &tensorflow.ModelInfo{
|
|
Input: &tensorflow.PhotoInput{
|
|
Intervals: []tensorflow.Interval{
|
|
{
|
|
Start: -1.0,
|
|
End: 1.0,
|
|
},
|
|
},
|
|
},
|
|
Output: &tensorflow.ModelOutput{
|
|
OutputsLogits: true,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
func isSafePath(target, baseDir string) bool {
|
|
|
|
// Resolve the absolute path of the target
|
|
absTarget := filepath.Join(baseDir, target)
|
|
absBase, err := filepath.Abs(baseDir)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return strings.HasPrefix(absTarget, absBase)
|
|
}
|
|
|
|
func TestExternalModel_AllModels(t *testing.T) {
|
|
|
|
if os.Getenv(ExternalModelsTestLabel) == "" {
|
|
t.Skipf("Skipping external model tests. To test them add set env var %s=true",
|
|
ExternalModelsTestLabel)
|
|
}
|
|
|
|
tmpPath, err := os.MkdirTemp("", "*-photoprism")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.RemoveAll(tmpPath)
|
|
|
|
for k, v := range modelsInfo {
|
|
t.Run(k, func(*testing.T) {
|
|
log.Infof("vision: testing model %s", k)
|
|
|
|
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
|
|
log.Infof("vision: model downloaded to %s", downloadedModel)
|
|
|
|
if v.Labels != "" {
|
|
modelPath := filepath.Join(tmpPath, downloadedModel)
|
|
|
|
t.Logf("vision: model path is %s", modelPath)
|
|
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
|
|
}
|
|
|
|
model := NewModel(tmpPath, downloadedModel, modelPath, v.Info, false)
|
|
if err := model.loadModel(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if model.meta.Input.IsDynamic() {
|
|
model.meta.Input.SetResolution(DefaultResolution)
|
|
}
|
|
|
|
testModel_LabelsFromFile(t, model)
|
|
testModel_Run(t, model)
|
|
})
|
|
}
|
|
}
|
|
|
|
func downloadLabels(t *testing.T, url, dst string) {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
output, err := os.Create(filepath.Join(dst, "labels.txt"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer output.Close()
|
|
|
|
_, err = io.Copy(output, resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
|
|
t.Logf("Downloading %s to %s", url, tmpPath)
|
|
|
|
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
|
|
tmpPath = filepath.Join(tmpPath, modelPath)
|
|
os.MkdirAll(tmpPath, 0755)
|
|
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
t.Fatalf("Invalid status code for url %s: %d", url, resp.StatusCode)
|
|
}
|
|
|
|
uncompressedBody, err := gzip.NewReader(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tarReader := tar.NewReader(uncompressedBody)
|
|
for true {
|
|
header, err := tarReader.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("could not extract the file: %v", err)
|
|
}
|
|
|
|
target := filepath.Join(tmpPath, header.Name)
|
|
if !isSafePath(target, tmpPath) {
|
|
t.Fatalf("The model file contains an invalid path: %s", header.Name)
|
|
}
|
|
|
|
switch header.Typeflag {
|
|
case tar.TypeDir:
|
|
if err := os.Mkdir(target, 0755); err != nil {
|
|
t.Fatalf("could not make the dir %s: %v", header.Name, err)
|
|
}
|
|
case tar.TypeReg:
|
|
outFile, err := os.Create(target)
|
|
if err != nil {
|
|
t.Fatalf("could not create file %s: %v", header.Name, err)
|
|
}
|
|
if _, err := io.Copy(outFile, tarReader); err != nil {
|
|
t.Fatalf("could not copy file %s: %v", header.Name, err)
|
|
}
|
|
|
|
rootPath, fileName := filepath.Split(header.Name)
|
|
if fileName == "saved_model.pb" {
|
|
model = filepath.Join(modelPath, rootPath)
|
|
}
|
|
outFile.Close()
|
|
default:
|
|
t.Fatalf("could not extract file. Unknown type %v in %s",
|
|
header.Typeflag,
|
|
header.Name)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func containsAny(s string, substrings []string) bool {
|
|
for i := range substrings {
|
|
if strings.Contains(s, substrings[i]) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func assertContainsAny(t *testing.T, s string, substrings []string) {
|
|
assert.Truef(t, containsAny(s, substrings),
|
|
"The result [%s] does not contain any of %v",
|
|
s, substrings)
|
|
}
|
|
|
|
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
//assert.Equal(t, 7, result[0].Uncertainty)
|
|
}
|
|
})
|
|
t.Run(testName("cat_224.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
}
|
|
})
|
|
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
//assert.Equal(t, 3, len(result))
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
// t.Logf("labels: %#v", result)
|
|
if len(result) != 3 {
|
|
t.Logf("Expected 3 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
|
|
}
|
|
})
|
|
t.Run(testName("green.jpg"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
|
|
|
|
t.Logf("labels: %#v", result)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Equal(t, "outdoor", result[0].Name)
|
|
}
|
|
})
|
|
t.Run(testName("not existing file"), func(t *testing.T) {
|
|
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
|
|
assert.Contains(t, err.Error(), "no such file or directory")
|
|
assert.Empty(t, result)
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
|
|
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
|
|
assert.Nil(t, err)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Nil(t, result)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
|
|
t.Log(result)
|
|
})
|
|
}
|
|
|
|
func testModel_Run(t *testing.T, tensorFlow *Model) {
|
|
if testing.Short() {
|
|
t.Skip("skipping test in short mode.")
|
|
}
|
|
|
|
testName := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
|
|
}
|
|
|
|
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assert.Contains(t, result[0].Name, "chameleon")
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("dog_orange.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.NotNil(t, result)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
|
|
if len(result) != 1 {
|
|
t.Logf("Expected 1 result, but found %d", len(result))
|
|
t.Logf("Results: %#v", result)
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
|
|
}
|
|
}
|
|
})
|
|
t.Run(testName("Random.docx"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
assert.Empty(t, result)
|
|
assert.Error(t, err)
|
|
}
|
|
})
|
|
t.Run(testName("6720px_white.jpg"), func(t *testing.T) {
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Empty(t, result)
|
|
}
|
|
})
|
|
t.Run(testName("disabled true"), func(t *testing.T) {
|
|
tensorFlow.disabled = true
|
|
defer func() { tensorFlow.disabled = false }()
|
|
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
|
|
t.Error(err)
|
|
} else {
|
|
result, err := tensorFlow.Run(imageBuffer, 10)
|
|
|
|
t.Log(result)
|
|
|
|
assert.Nil(t, result)
|
|
|
|
assert.Nil(t, err)
|
|
assert.IsType(t, Labels{}, result)
|
|
assert.Equal(t, 0, len(result))
|
|
}
|
|
})
|
|
}
|