Files
photoprism/internal/ai/classify/model_external_test.go
raystlin 519a6ab34a AI: Add TensorFlow model shape detection #127 #5164
* AI: Added support for non BHWC models

Tensorflow models use BHWC by default, however, if we are using
converted models, we can find that the expected input is BCHW. Now the
input is configurable (although the restriction of being dimesion 4 is
still there) via Shape parameter on the input definition. Also, the
model instrospection will try to deduce the input shape from the model
signature.

* AI: Added more tests for enum parsing

ShapeComponent was missing from the tests

* AI: Modified external tests to the new url

The path has been moved from tensorflow/vision to tensorflow/models

* AI: Moved the builder to the model to reuse it

It should reduce the amount of allocations done

* AI: fixed errors after merge

Mainly incorrect paths and duplicated variables
2025-08-16 15:55:59 +02:00

473 lines
11 KiB
Go

package classify
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
)
const (
DefaultResolution = 224
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
)
var baseUrl = "https://dl.photoprism.app/tensorflow/models"
// To avoid downloading everything again and again...
//var baseUrl = "http://host.docker.internal:8000"
type ModelTestCase struct {
Info *tensorflow.ModelInfo
Labels string
}
var modelsInfo = map[string]*ModelTestCase{
"efficientnet-v2-tensorflow2-imagenet1k-b0-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"efficientnet-v2-tensorflow2-imagenet1k-m-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Height: 480,
Width: 480,
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"efficientnet-v2-tensorflow2-imagenet21k-b0-classification-v1.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
Labels: "labels-imagenet21k.txt",
},
"inception-v3-tensorflow2-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Height: 299,
Width: 299,
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"resnet-v2-tensorflow2-101-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"resnet-v2-tensorflow2-152-classification-v2.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
"vision-transformer-tensorflow2-vit-b16-classification-v1.tar.gz": {
Info: &tensorflow.ModelInfo{
Input: &tensorflow.PhotoInput{
Intervals: []tensorflow.Interval{
{
Start: -1.0,
End: 1.0,
},
},
},
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
/* Not correctly uploaded
"vit-base-patch16-google-250811.tar.gz": {
Info: &tensorflow.ModelInfo{
Output: &tensorflow.ModelOutput{
OutputsLogits: true,
},
},
},
*/
}
func isSafePath(target, baseDir string) bool {
// Resolve the absolute path of the target
absTarget := filepath.Join(baseDir, target)
absBase, err := filepath.Abs(baseDir)
if err != nil {
return false
}
return strings.HasPrefix(absTarget, absBase)
}
func TestExternalModel_AllModels(t *testing.T) {
if os.Getenv(ExternalModelsTestLabel) == "" {
t.Skipf("Skipping external model tests. To test them add set env var %s=true",
ExternalModelsTestLabel)
}
tmpPath, err := os.MkdirTemp("", "*-photoprism")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpPath)
for k, v := range modelsInfo {
t.Run(k, func(*testing.T) {
log.Infof("vision: testing model %s", k)
downloadedModel := downloadRemoteModel(t, fmt.Sprintf("%s/%s", baseUrl, k), tmpPath)
log.Infof("vision: model downloaded to %s", downloadedModel)
if v.Labels != "" {
modelPath := filepath.Join(tmpPath, downloadedModel)
t.Logf("vision: model path is %s", modelPath)
downloadLabels(t, fmt.Sprintf("%s/%s", baseUrl, v.Labels), modelPath)
}
model := NewModel(tmpPath, downloadedModel, modelPath, v.Info, false)
if err := model.loadModel(); err != nil {
t.Fatal(err)
}
if model.meta.Input.IsDynamic() {
model.meta.Input.SetResolution(DefaultResolution)
}
testModel_LabelsFromFile(t, model)
testModel_Run(t, model)
})
}
}
func downloadLabels(t *testing.T, url, dst string) {
resp, err := http.Get(url)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
output, err := os.Create(filepath.Join(dst, "labels.txt"))
if err != nil {
t.Fatal(err)
}
defer output.Close()
_, err = io.Copy(output, resp.Body)
if err != nil {
t.Fatal(err)
}
}
func downloadRemoteModel(t *testing.T, url, tmpPath string) (model string) {
t.Logf("Downloading %s to %s", url, tmpPath)
modelPath := strings.TrimSuffix(path.Base(url), ".tar.gz")
tmpPath = filepath.Join(tmpPath, modelPath)
os.MkdirAll(tmpPath, 0755)
resp, err := http.Get(url)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
t.Fatalf("Invalid status code for url %s: %d", url, resp.StatusCode)
}
uncompressedBody, err := gzip.NewReader(resp.Body)
if err != nil {
t.Fatal(err)
}
tarReader := tar.NewReader(uncompressedBody)
for true {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("could not extract the file: %v", err)
}
target := filepath.Join(tmpPath, header.Name)
if !isSafePath(target, tmpPath) {
t.Fatalf("The model file contains an invalid path: %s", header.Name)
}
switch header.Typeflag {
case tar.TypeDir:
if err := os.Mkdir(target, 0755); err != nil {
t.Fatalf("could not make the dir %s: %v", header.Name, err)
}
case tar.TypeReg:
outFile, err := os.Create(target)
if err != nil {
t.Fatalf("could not create file %s: %v", header.Name, err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
t.Fatalf("could not copy file %s: %v", header.Name, err)
}
rootPath, fileName := filepath.Split(header.Name)
if fileName == "saved_model.pb" {
model = filepath.Join(modelPath, rootPath)
}
outFile.Close()
default:
t.Fatalf("could not extract file. Unknown type %v in %s",
header.Typeflag,
header.Name)
}
}
return
}
func containsAny(s string, substrings []string) bool {
for i := range substrings {
if strings.Contains(s, substrings[i]) {
return true
}
}
return false
}
func assertContainsAny(t *testing.T, s string, substrings []string) {
assert.Truef(t, containsAny(s, substrings),
"The result [%s] does not contain any of %v",
s, substrings)
}
func testModel_LabelsFromFile(t *testing.T, tensorFlow *Model) {
testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
}
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Contains(t, result[0].Name, "chameleon")
//assert.Equal(t, 7, result[0].Uncertainty)
}
})
t.Run(testName("cat_224.jpeg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/cat_224.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
}
})
t.Run(testName("cat_720.jpeg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/cat_720.jpeg", 10)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
//assert.Equal(t, 3, len(result))
assert.GreaterOrEqual(t, len(result), 1)
// t.Logf("labels: %#v", result)
if len(result) != 3 {
t.Logf("Expected 3 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"cat", "kitty"})
}
})
t.Run(testName("green.jpg"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/green.jpg", 10)
t.Logf("labels: %#v", result)
assert.NoError(t, err)
assert.NotNil(t, result)
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Equal(t, "outdoor", result[0].Name)
}
})
t.Run(testName("not existing file"), func(t *testing.T) {
result, err := tensorFlow.File(examplesPath+"/notexisting.jpg", 10)
assert.Contains(t, err.Error(), "no such file or directory")
assert.Empty(t, result)
})
t.Run(testName("disabled true"), func(t *testing.T) {
tensorFlow.disabled = true
defer func() { tensorFlow.disabled = false }()
result, err := tensorFlow.File(examplesPath+"/chameleon_lime.jpg", 10)
assert.Nil(t, err)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, result)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
t.Log(result)
})
}
func testModel_Run(t *testing.T, tensorFlow *Model) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
testName := func(name string) string {
return fmt.Sprintf("%s/%s", tensorFlow.name, name)
}
t.Run(testName("chameleon_lime.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/chameleon_lime.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assert.Contains(t, result[0].Name, "chameleon")
}
}
})
t.Run(testName("dog_orange.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.NotNil(t, result)
if err != nil {
t.Fatal(err)
}
assert.IsType(t, Labels{}, result)
assert.GreaterOrEqual(t, len(result), 1)
if len(result) != 1 {
t.Logf("Expected 1 result, but found %d", len(result))
t.Logf("Results: %#v", result)
}
if len(result) > 0 {
assertContainsAny(t, result[0].Name, []string{"dog", "corgi"})
}
}
})
t.Run(testName("Random.docx"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
assert.Empty(t, result)
assert.Error(t, err)
}
})
t.Run(testName("6720px_white.jpg"), func(t *testing.T) {
if imageBuffer, err := os.ReadFile(examplesPath + "/6720px_white.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
if err != nil {
t.Fatal(err)
}
assert.Empty(t, result)
}
})
t.Run(testName("disabled true"), func(t *testing.T) {
tensorFlow.disabled = true
defer func() { tensorFlow.disabled = false }()
if imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg"); err != nil {
t.Error(err)
} else {
result, err := tensorFlow.Run(imageBuffer, 10)
t.Log(result)
assert.Nil(t, result)
assert.Nil(t, err)
assert.IsType(t, Labels{}, result)
assert.Equal(t, 0, len(result))
}
})
}