Files
photoprism/internal/ai/classify/model_external_test.go
2025-08-08 19:06:56 +02:00

464 lines
11 KiB
Go

package classify
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
)
const (
DefaultResolution = 224
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
)
var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
//To avoid downloading everything again and again...
//var baseUrl = "http://host.docker.internal:8000"
type ModelTestCase struct {
Info *tensorflow.ModelInfo
Labels string
}
var modelsInfo = map[string]*ModelTestCase{
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Height: 480,
Width: 480,
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
Labels: "labels-imagenet21k.txt",
},
"inception-v3-tensorflow2-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Height: 299,
Width: 299,
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Intervals: []tensorflow.Interval{
{
Start: -1.0,
End: 1.0,
},
},
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
}
func isSafePath(target, baseDir string) bool {
// Resolve the absolute path of the target
absTarget := filepath.Join(baseDir, target)
absBase, err := filepath.Abs(baseDir)
if err != nil {
return false
}
return strings.HasPrefix(absTarget, absBase)
}
func TestExternalModel_AllModels(t *testing.T) {
if os.Getenv(ExternalModelsTestLabel) == "" {
t.Skipf("Skipping external model tests. To test them add set env var %s=true",
ExternalModelsTestLabel)
}
tmpPath, err := os.MkdirTemp("", "*-photoprism")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpPath)
for k, v := range modelsInfo {
t.Run(k, func(*testing.T) {
log.Infof("vision: testing model %s", k)
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
log.Infof("vision: model downloaded to %s", downloadedModel)
if v.Labels != "" {
modelPath := filepath.Join(tmpPath, downloadedModel)
t.Logf("vision: model path is %s", modelPath)
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
}
model := NewModel(tmpPath, downloadedModel, modelPath, v.Info, false)
if err := model.loadModel(); err != nil {
t.Fatal(err)
}
if model.meta.Input.IsDynamic() {
model.meta.Input.SetResolution(DefaultResolution)
}
testModel_LabelsFromFile(t, model)
testModel_Run(t, model)
})
}
}
func downloadLabels(t *testing.T, url, dst string) {
resp, err := http.Get(url)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
output, err := os.Create(filepath.Join(dst, "labels.txt"))
if err != nil {
t.Fatal(err)
}
defer output.Close()
_, err = io.Copy(output, resp.Body)
if err != nil {
t.Fatal(err)
}
}
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
t.Logf("Downloading %s to %s", url, tmpPath)
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
tmpPath = filepath.Join(tmpPath, modelPath)
os.MkdirAll(tmpPath, 0755)
resp, err := http.Get(url)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
t.Fatalf("Invalid status code for url %s: %d", url, resp.StatusCode)
}
uncompressedBody, err := gzip.NewReader(resp.Body)
if err != nil {
t.Fatal(err)
}
tarReader := tar.NewReader(uncompressedBody)
for true {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("could not extract the file: %v", err)
}
target := filepath.Join(tmpPath, header.Name)
if !isSafePath(target, tmpPath) {
t.Fatalf("The model file contains an invalid path: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.Mkdir(target, 0755); err != nil {
t.Fatalf("could not make the dir %s: %v", header.Name, err)
}
case tar.TypeReg:
outFile, err := os.Create(target)
if err != nil {
t.Fatalf("could not create file %s: %v", header.Name, err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
t.Fatalf("could not copy file %s: %v", header.Name, err)
}
rootPath, fileName := filepath.Split(header.Name)
if fileName == "saved_model.pb" {
model = filepath.Join(modelPath, rootPath)
}
outFile.Close()
default:
t.Fatalf("could not extract file. Unknown type %v in %s",
header.Typeflag,
header.Name)
}
}
return
}
func containsAny(s string, substrings []string) bool {
for i := range substrings {
if strings.Contains(s, substrings[i]) {
return true
}
}
return false
}
func assertContainsAny(t *testing.T, s string, substrings []string) {
assert.Truef(t, containsAny(s, substrings),
"The result [%s] does not contain any of %v",
s, substrings)
}
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
}
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Contains(t, result[0].Name, "chameleon")
//assert.Equal(t, 7, result[0].Uncertainty)
}
})
t.Run(testName("cat_224.jpeg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
}
})
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
//assert.Equal(t, 3, len(result))
assert.GreaterOrEqual(t, len(result), 1)
// t.Logf("labels: %#v", result)
if len(result) != 3 {
t.Logf("Expected 3 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
}
})
t.Run(testName("green.jpg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
t.Logf("labels: %#v", result)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Equal(t, "outdoor", result[0].Name)
}
})
t.Run(testName("not existing file"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
assert.Contains(t, err.Error(), "no such file or directory")
assert.Empty(t, result)
})
t.Run(testName("disabled true"), func(t *testing.T) {
tensorFlow.disabled = true
defer func() { tensorFlow.disabled = false }()
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
t.Log(result)
})
}
func testModel_Run(t *testing.T, tensorFlow *Model) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
}
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Contains(t, result[0].Name, "chameleon")
}
}
})
t.Run(testName("dog_orange.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
}
}
})
t.Run(testName("Random.docx"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
assert.Empty(t, result)
assert.Error(t, err)
}
})
t.Run(testName("6720px_white.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
if err != nil {
t.Fatal(err)
}
assert.Empty(t, result)
}
})
t.Run(testName("disabled true"), func(t *testing.T) {
tensorFlow.disabled = true
defer func() { tensorFlow.disabled = false }()
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.Nil(t, result)
assert.Nil(t, err)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
}
})
}