mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 21:01:58 +08:00
* AI: Added support for non BHWC models Tensorflow models use BHWC by default, however, if we are using converted models, we can find that the expected input is BCHW. Now the input is configurable (although the restriction of being dimesion 4 is still there) via Shape parameter on the input definition. Also, the model instrospection will try to deduce the input shape from the model signature. * AI: Added more tests for enum parsing ShapeComponent was missing from the tests * AI: Modified external tests to the new url The path has been moved from tensorflow/vision to tensorflow/models * AI: Moved the builder to the model to reuse it It should reduce the amount of allocations done * AI: fixed errors after merge Mainly incorrect paths and duplicated variables
This commit is contained in:
@@ -30,6 +30,7 @@ type Model struct {
|
||||
labels []string
|
||||
disabled bool
|
||||
meta *tensorflow.ModelInfo
|
||||
builder *tensorflow.ImageTensorBuilder
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -59,6 +60,7 @@ func NewNasnet(modelsPath string, disabled bool) *Model {
|
||||
Width: 224,
|
||||
ResizeOperation: tensorflow.CenterCrop,
|
||||
ColorChannelOrder: tensorflow.RGB,
|
||||
Shape: tensorflow.DefaultPhotoInputShape(),
|
||||
Intervals: []tensorflow.Interval{
|
||||
{
|
||||
Start: -1,
|
||||
@@ -176,7 +178,10 @@ func (m *Model) loadLabels(modelPath string) (err error) {
|
||||
log.Infof("vision: model does not seem to have tags at %s, trying %s", clean.Log(modelPath), clean.Log(m.defaultLabelsPath))
|
||||
m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels)
|
||||
}
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: could not load tags: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModelLoaded tests if the TensorFlow model is loaded.
|
||||
@@ -197,7 +202,7 @@ func (m *Model) loadModel() (err error) {
|
||||
modelPath := path.Join(m.modelsPath, m.name)
|
||||
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, modelErr := tensorflow.GetModelInfo(modelPath)
|
||||
infos, modelErr := tensorflow.GetModelTagsInfo(modelPath)
|
||||
if modelErr != nil {
|
||||
log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr))
|
||||
} else if len(infos) == 1 {
|
||||
@@ -209,9 +214,8 @@ func (m *Model) loadModel() (err error) {
|
||||
}
|
||||
|
||||
m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("classify: %s. Path: %s", clean.Error(err), modelPath)
|
||||
}
|
||||
|
||||
if !m.meta.IsComplete() {
|
||||
@@ -237,6 +241,11 @@ func (m *Model) loadModel() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
m.builder, err = tensorflow.NewImageTensorBuilder(m.meta.Input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("classify: could not create the tensor builder (%s)", clean.Error(err))
|
||||
}
|
||||
|
||||
return m.loadLabels(modelPath)
|
||||
}
|
||||
|
||||
@@ -310,5 +319,5 @@ func (m *Model) createTensor(image []byte) (*tf.Tensor, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return tensorflow.Image(img, m.meta.Input)
|
||||
return tensorflow.Image(img, m.meta.Input, m.builder)
|
||||
}
|
||||
|
@@ -22,9 +22,9 @@ const (
|
||||
ExternalModelsTestLabel = "PHOTOPRISM_TEST_EXTERNAL_MODELS"
|
||||
)
|
||||
|
||||
var baseUrl = "https://dl.photoprism.app/tensorflow/vision"
|
||||
var baseUrl = "https://dl.photoprism.app/tensorflow/models"
|
||||
|
||||
//To avoid downloading everything again and again...
|
||||
// To avoid downloading everything again and again...
|
||||
//var baseUrl = "http://host.docker.internal:8000"
|
||||
|
||||
type ModelTestCase struct {
|
||||
@@ -100,6 +100,15 @@ var modelsInfo = map[string]*ModelTestCase{
|
||||
},
|
||||
},
|
||||
},
|
||||
/* Not correctly uploaded
|
||||
"vit-base-patch16-google-250811.tar.gz": {
|
||||
Info: &tensorflow.ModelInfo{
|
||||
Output: &tensorflow.ModelOutput{
|
||||
OutputsLogits: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
*/
|
||||
}
|
||||
|
||||
func isSafePath(target, baseDir string) bool {
|
||||
|
@@ -320,3 +320,44 @@ func TestModel_BestLabels(t *testing.T) {
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkModel_BestLabelWithOptimization(b *testing.B) {
|
||||
model := NewNasnet(assetsPath, false)
|
||||
err := model.loadModel()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
_, err := model.Run(imageBuffer, 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkModel_BestLabelsNoOptimization(b *testing.B) {
|
||||
model := NewNasnet(assetsPath, false)
|
||||
err := model.loadModel()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
model.builder = nil
|
||||
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/dog_orange.jpg")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
_, err := model.Run(imageBuffer, 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -132,7 +132,7 @@ func (m *Model) loadModel() error {
|
||||
log.Infof("nsfw: loading %s", clean.Log(filepath.Base(m.modelPath)))
|
||||
|
||||
if len(m.meta.Tags) == 0 {
|
||||
infos, err := tensorflow.GetModelInfo(m.modelPath)
|
||||
infos, err := tensorflow.GetModelTagsInfo(m.modelPath)
|
||||
if err != nil {
|
||||
log.Errorf("nsfw: could not get the model info at %s: %v", clean.Log(m.modelPath))
|
||||
} else if len(infos) == 1 {
|
||||
@@ -150,10 +150,10 @@ func (m *Model) loadModel() error {
|
||||
}
|
||||
|
||||
if !m.meta.IsComplete() {
|
||||
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(m.model)
|
||||
input, output, err := tensorflow.GetInputAndOutputFromSavedModel(model)
|
||||
if err != nil {
|
||||
log.Errorf("nsfw: could not get info from signatures: %v", err)
|
||||
input, output, err = tensorflow.GuessInputAndOutput(m.model)
|
||||
input, output, err = tensorflow.GuessInputAndOutput(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nsfw: %w", err)
|
||||
}
|
||||
|
@@ -24,7 +24,7 @@ func ImageFromFile(fileName string, input *PhotoInput) (*tf.Tensor, error) {
|
||||
if img, err := OpenImage(fileName); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return Image(img, input)
|
||||
return Image(img, input, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,17 +39,17 @@ func OpenImage(fileName string) (image.Image, error) {
|
||||
return img, err
|
||||
}
|
||||
|
||||
func ImageFromBytes(b []byte, input *PhotoInput) (*tf.Tensor, error) {
|
||||
func ImageFromBytes(b []byte, input *PhotoInput, builder *ImageTensorBuilder) (*tf.Tensor, error) {
|
||||
img, _, imgErr := image.Decode(bytes.NewReader(b))
|
||||
|
||||
if imgErr != nil {
|
||||
return nil, imgErr
|
||||
}
|
||||
|
||||
return Image(img, input)
|
||||
return Image(img, input, builder)
|
||||
}
|
||||
|
||||
func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error) {
|
||||
func Image(img image.Image, input *PhotoInput, builder *ImageTensorBuilder) (tfTensor *tf.Tensor, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("tensorflow: %s (panic)\nstack: %s", r, debug.Stack())
|
||||
@@ -57,14 +57,14 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
|
||||
}()
|
||||
|
||||
if input.Resolution() <= 0 {
|
||||
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger 0")
|
||||
return tfTensor, fmt.Errorf("tensorflow: resolution must be larger than 0")
|
||||
}
|
||||
|
||||
var tfImage [1][][][3]float32
|
||||
rIndex, gIndex, bIndex := input.ColorChannelOrder.Indices()
|
||||
|
||||
for j := 0; j < input.Resolution(); j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, input.Resolution()))
|
||||
if builder == nil {
|
||||
builder, err = NewImageTensorBuilder(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < input.Resolution(); i++ {
|
||||
@@ -72,13 +72,14 @@ func Image(img image.Image, input *PhotoInput) (tfTensor *tf.Tensor, err error)
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
//Although RGB can be disordered, we assume the input intervals are
|
||||
//given in RGB order.
|
||||
tfImage[0][j][i][rIndex] = convertValue(r, input.GetInterval(0))
|
||||
tfImage[0][j][i][gIndex] = convertValue(g, input.GetInterval(1))
|
||||
tfImage[0][j][i][bIndex] = convertValue(b, input.GetInterval(2))
|
||||
builder.Set(i, j,
|
||||
convertValue(r, input.GetInterval(0)),
|
||||
convertValue(g, input.GetInterval(1)),
|
||||
convertValue(b, input.GetInterval(2)))
|
||||
}
|
||||
}
|
||||
|
||||
return tf.NewTensor(tfImage)
|
||||
return builder.BuildTensor()
|
||||
}
|
||||
|
||||
// ImageTransform transforms the given image into a *tf.Tensor and returns it.
|
||||
|
@@ -7,16 +7,14 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/wamuir/graft/tensorflow"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var defaultImageInput = &PhotoInput{
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
Shape: DefaultPhotoInputShape(),
|
||||
}
|
||||
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var examplesPath = filepath.Join(assetsPath, "examples")
|
||||
|
||||
func TestConvertValue(t *testing.T) {
|
||||
@@ -40,7 +38,11 @@ func TestImageFromBytes(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tensorflow.DataType(0x1), result.DataType())
|
||||
assert.Equal(t, int64(1), result.Shape()[0])
|
||||
assert.Equal(t, int64(224), result.Shape()[2])
|
||||
@@ -48,7 +50,7 @@ func TestImageFromBytes(t *testing.T) {
|
||||
t.Run("Document", func(t *testing.T) {
|
||||
imageBuffer, err := os.ReadFile(examplesPath + "/Random.docx")
|
||||
assert.Nil(t, err)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput)
|
||||
result, err := ImageFromBytes(imageBuffer, defaultImageInput, nil)
|
||||
|
||||
assert.Empty(t, result)
|
||||
assert.EqualError(t, err, "image: unknown format")
|
||||
|
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pb "github.com/wamuir/graft/tensorflow/core/protobuf/for_core_protos_go_proto"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -263,6 +261,26 @@ func (o *ColorChannelOrder) UnmarshalYAML(unmarshal func(interface{}) error) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// The expected shape for the input layer of a mode. Usually this shape is
|
||||
// (batch, resolution, resolution, channels) but sometimes it is not.
|
||||
type ShapeComponent string
|
||||
|
||||
const (
|
||||
ShapeBatch ShapeComponent = "Batch"
|
||||
ShapeWidth = "Width"
|
||||
ShapeHeight = "Height"
|
||||
ShapeColor = "Color"
|
||||
)
|
||||
|
||||
func DefaultPhotoInputShape() []ShapeComponent {
|
||||
return []ShapeComponent{
|
||||
ShapeBatch,
|
||||
ShapeHeight,
|
||||
ShapeWidth,
|
||||
ShapeColor,
|
||||
}
|
||||
}
|
||||
|
||||
// PhotoInput represents an input description for a photo input for a model.
|
||||
type PhotoInput struct {
|
||||
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
|
||||
@@ -272,6 +290,7 @@ type PhotoInput struct {
|
||||
OutputIndex int `yaml:"Index,omitempty" json:"index,omitempty"`
|
||||
Height int64 `yaml:"Height,omitempty" json:"height,omitempty"`
|
||||
Width int64 `yaml:"Width,omitempty" json:"width,omitempty"`
|
||||
Shape []ShapeComponent `yaml:"Shape,omitempty" json:"shape,omitempty"`
|
||||
}
|
||||
|
||||
// IsDynamic checks if image dimensions are not defined, so the model accepts any size.
|
||||
@@ -331,6 +350,10 @@ func (p *PhotoInput) Merge(other *PhotoInput) {
|
||||
p.Width = other.Width
|
||||
}
|
||||
|
||||
if p.Shape == nil && other.Shape != nil {
|
||||
p.Shape = other.Shape
|
||||
}
|
||||
|
||||
if p.ResizeOperation == UndefinedResizeOperation {
|
||||
p.ResizeOperation = other.ResizeOperation
|
||||
}
|
||||
@@ -401,83 +424,10 @@ func (m *ModelInfo) Merge(other *ModelInfo) {
|
||||
|
||||
// IsComplete checks if the model input and output are defined.
|
||||
func (m ModelInfo) IsComplete() bool {
|
||||
return m.Input != nil && m.Output != nil
|
||||
return m.Input != nil && m.Output != nil && m.Input.Shape != nil
|
||||
}
|
||||
|
||||
// GetInputAndOutputFromMetaSignature returns the signatures from a MetaGraphDef
|
||||
// and uses them to build PhotoInput and ModelOutput structs, that will complete
|
||||
// ModelInfo struct.
|
||||
func GetInputAndOutputFromMetaSignature(meta *pb.MetaGraphDef) (*PhotoInput, *ModelOutput, error) {
|
||||
if meta == nil {
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: nil input")
|
||||
}
|
||||
|
||||
sig := meta.GetSignatureDef()
|
||||
for k, v := range sig {
|
||||
inputs := v.GetInputs()
|
||||
outputs := v.GetOutputs()
|
||||
|
||||
if len(inputs) == 1 && len(outputs) == 1 {
|
||||
_, inputTensor := GetOne(inputs)
|
||||
outputVarName, outputTensor := GetOne(outputs)
|
||||
|
||||
if inputTensor != nil && (*inputTensor).GetTensorShape() != nil &&
|
||||
outputTensor != nil && (*outputTensor).GetTensorShape() != nil {
|
||||
inputDims := (*inputTensor).GetTensorShape().Dim
|
||||
outputDims := (*outputTensor).GetTensorShape().Dim
|
||||
|
||||
if inputDims[3].GetSize() != ExpectedChannels {
|
||||
log.Warnf("tensorflow: skipping signature %v because channels are expected to be %d, have %d",
|
||||
k, ExpectedChannels, inputDims[3].GetSize())
|
||||
}
|
||||
|
||||
if len(inputDims) == 4 &&
|
||||
inputDims[3].GetSize() == ExpectedChannels &&
|
||||
len(outputDims) == 2 {
|
||||
var err error
|
||||
var inputIdx, outputIdx = 0, 0
|
||||
|
||||
inputName, inputIndex, found := strings.Cut((*inputTensor).GetName(), ":")
|
||||
if found {
|
||||
|
||||
inputIdx, err = strconv.Atoi(inputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index %s (%s)", inputIndex, clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
outputName, outputIndex, found := strings.Cut((*outputTensor).GetName(), ":")
|
||||
if found {
|
||||
|
||||
outputIdx, err = strconv.Atoi(outputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index: %s (%s)", outputIndex, clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return &PhotoInput{
|
||||
Name: inputName,
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputDims[1].GetSize(),
|
||||
Width: inputDims[2].GetSize(),
|
||||
}, &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
NumOutputs: outputDims[1].GetSize(),
|
||||
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
|
||||
}, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromMetaSignature: Could not find a valid signature")
|
||||
}
|
||||
|
||||
func GetModelInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
func GetModelTagsInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
savedModel := filepath.Join(savedModelPath, "saved_model.pb")
|
||||
|
||||
data, err := os.ReadFile(savedModel)
|
||||
@@ -499,20 +449,10 @@ func GetModelInfo(savedModelPath string) ([]ModelInfo, error) {
|
||||
|
||||
for i := range metas {
|
||||
def := metas[i].GetMetaInfoDef()
|
||||
input, output, modelErr := GetInputAndOutputFromMetaSignature(metas[i])
|
||||
|
||||
newModel := ModelInfo{
|
||||
models = append(models, ModelInfo{
|
||||
TFVersion: def.GetTensorflowVersion(),
|
||||
Tags: def.GetTags(),
|
||||
Input: input,
|
||||
Output: output,
|
||||
}
|
||||
|
||||
if modelErr != nil {
|
||||
log.Errorf("vision: could not determine model inputs and outputs from TensorFlow %s signatures (%s)", newModel.TFVersion, clean.Error(modelErr))
|
||||
}
|
||||
|
||||
models = append(models, newModel)
|
||||
})
|
||||
}
|
||||
|
||||
return models, nil
|
||||
|
@@ -2,6 +2,7 @@ package tensorflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -16,6 +17,22 @@ var allOperations = []ResizeOperation{
|
||||
Padding,
|
||||
}
|
||||
|
||||
func TestGetModelTagsInfo(t *testing.T) {
|
||||
info, err := GetModelTagsInfo(
|
||||
filepath.Join(assetsPath, "models", "nasnet"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(info) != 1 {
|
||||
t.Fatalf("Expected 1 info but got %d", len(info))
|
||||
} else if len(info[0].Tags) != 1 {
|
||||
t.Fatalf("Expected 1 tag, but got %d", len(info[0].Tags))
|
||||
} else if info[0].Tags[0] != "photoprism" {
|
||||
t.Fatalf("Expected tag photoprism, but have %s", info[0].Tags[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResizeOperations(t *testing.T) {
|
||||
for i := range allOperations {
|
||||
text := allOperations[i].String()
|
||||
@@ -119,7 +136,7 @@ func TestColorChannelOrderJSON(t *testing.T) {
|
||||
[]byte(exampleOrderJSON), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("could not unmarshal the example operation")
|
||||
t.Fatal("could not unmarshal the example color order")
|
||||
}
|
||||
|
||||
for i := range allColorChannelOrders {
|
||||
@@ -148,7 +165,7 @@ func TestColorChannelOrderYAML(t *testing.T) {
|
||||
[]byte(exampleOrderYAML), &order)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("could not unmarshal the example operation")
|
||||
t.Fatal("could not unmarshal the example color order")
|
||||
}
|
||||
|
||||
for i := range allColorChannelOrders {
|
||||
@@ -193,3 +210,68 @@ func TestOrderIndices(t *testing.T) {
|
||||
assert.Equal(t, powerFx(r)+2*powerFx(g)+3*powerFx(b), int(allColorChannelOrders[i]))
|
||||
}
|
||||
}
|
||||
|
||||
var allShapeComponents = []ShapeComponent{
|
||||
ShapeBatch,
|
||||
ShapeWidth,
|
||||
ShapeHeight,
|
||||
ShapeColor,
|
||||
}
|
||||
|
||||
const exampleShapeComponentJSON = `"Batch"`
|
||||
|
||||
func TestShapeComponentJSON(t *testing.T) {
|
||||
var comp ShapeComponent
|
||||
|
||||
err := json.Unmarshal(
|
||||
[]byte(exampleShapeComponentJSON), &comp)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("could not unmarshal the example shape component")
|
||||
}
|
||||
|
||||
for i := range allShapeComponents {
|
||||
serialized, err := json.Marshal(allShapeComponents[i])
|
||||
if err != nil {
|
||||
t.Fatalf("could not marshal %v: %v",
|
||||
allShapeComponents[i], err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(serialized, &comp)
|
||||
if err != nil {
|
||||
t.Fatalf("could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, comp, allShapeComponents[i])
|
||||
}
|
||||
}
|
||||
|
||||
const exampleShapeComponentYAML = "Batch"
|
||||
|
||||
func TestShapeComponentYAML(t *testing.T) {
|
||||
var comp ShapeComponent
|
||||
|
||||
err := yaml.Unmarshal(
|
||||
[]byte(exampleShapeComponentYAML), &comp)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("could not unmarshal the example operation")
|
||||
}
|
||||
|
||||
for i := range allShapeComponents {
|
||||
serialized, err := yaml.Marshal(allShapeComponents[i])
|
||||
if err != nil {
|
||||
t.Fatalf("could not marshal %v: %v",
|
||||
allShapeComponents[i], err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(serialized, &comp)
|
||||
if err != nil {
|
||||
t.Fatalf("could not unmarshal %s: %v",
|
||||
string(serialized), err)
|
||||
}
|
||||
|
||||
assert.Equal(t, comp, allShapeComponents[i])
|
||||
}
|
||||
}
|
||||
|
@@ -25,18 +25,33 @@ func SavedModel(modelPath string, tags []string) (model *tf.SavedModel, err erro
|
||||
// GuessInputAndOutput tries to inspect a loaded saved model to build the
|
||||
// ModelInfo struct
|
||||
func GuessInputAndOutput(model *tf.SavedModel) (input *PhotoInput, output *ModelOutput, err error) {
|
||||
if model == nil {
|
||||
return nil, nil, fmt.Errorf("tensorflow: GuessInputAndOutput received a nil input")
|
||||
}
|
||||
|
||||
modelOps := model.Graph.Operations()
|
||||
|
||||
for i := range modelOps {
|
||||
if strings.HasPrefix(modelOps[i].Type(), "Placeholder") &&
|
||||
modelOps[i].NumOutputs() == 1 &&
|
||||
modelOps[i].Output(0).Shape().NumDimensions() == 4 &&
|
||||
modelOps[i].Output(0).Shape().Size(3) == ExpectedChannels { // check the channels are 3
|
||||
modelOps[i].Output(0).Shape().NumDimensions() == 4 {
|
||||
|
||||
shape := modelOps[i].Output(0).Shape()
|
||||
input = &PhotoInput{
|
||||
Name: modelOps[i].Name(),
|
||||
Height: shape.Size(1),
|
||||
Width: shape.Size(2),
|
||||
|
||||
var comps []ShapeComponent
|
||||
if shape.Size(3) == ExpectedChannels {
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
} else if shape.Size(1) == ExpectedChannels { // check the channels are 3
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
}
|
||||
|
||||
if comps != nil {
|
||||
input = &PhotoInput{
|
||||
Name: modelOps[i].Name(),
|
||||
Height: shape.Size(1),
|
||||
Width: shape.Size(2),
|
||||
Shape: comps,
|
||||
}
|
||||
}
|
||||
} else if (modelOps[i].Type() == "Softmax" || strings.HasPrefix(modelOps[i].Type(), "StatefulPartitionedCall")) &&
|
||||
modelOps[i].NumOutputs() == 1 && modelOps[i].Output(0).Shape().NumDimensions() == 2 {
|
||||
@@ -59,34 +74,57 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSavedModel: nil input")
|
||||
}
|
||||
|
||||
log.Debugf("tensorflow: found %d signatures", len(model.Signatures))
|
||||
for k, v := range model.Signatures {
|
||||
var photoInput *PhotoInput
|
||||
var modelOutput *ModelOutput
|
||||
|
||||
inputs := v.Inputs
|
||||
outputs := v.Outputs
|
||||
|
||||
if len(inputs) == 1 && len(outputs) == 1 {
|
||||
_, inputTensor := GetOne(inputs)
|
||||
outputVarName, outputTensor := GetOne(outputs)
|
||||
if len(inputs) >= 1 && len(outputs) >= 1 {
|
||||
for _, inputTensor := range inputs {
|
||||
if inputTensor.Shape.NumDimensions() == 4 {
|
||||
var comps []ShapeComponent
|
||||
if inputTensor.Shape.Size(3) == ExpectedChannels {
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeHeight, ShapeWidth, ShapeColor}
|
||||
} else if inputTensor.Shape.Size(1) == ExpectedChannels { // check the channels are 3
|
||||
comps = []ShapeComponent{ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth}
|
||||
} else {
|
||||
log.Debugf("tensorflow: shape %d", inputTensor.Shape.Size(1))
|
||||
}
|
||||
|
||||
if inputTensor != nil && outputTensor != nil {
|
||||
if inputTensor.Shape.Size(3) != ExpectedChannels {
|
||||
log.Warnf("tensorflow: skipping signature %v because channels are expected to be %d, have %d",
|
||||
k, ExpectedChannels, inputTensor.Shape.Size(3))
|
||||
}
|
||||
if comps == nil {
|
||||
log.Warnf("tensorflow: skipping signature %v because we could not find the color component", k)
|
||||
} else {
|
||||
var inputIdx = 0
|
||||
var err error
|
||||
|
||||
if inputTensor.Shape.NumDimensions() == 4 &&
|
||||
inputTensor.Shape.Size(3) == ExpectedChannels &&
|
||||
outputTensor.Shape.NumDimensions() == 2 {
|
||||
var inputIdx, outputIdx = 0, 0
|
||||
var err error
|
||||
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
|
||||
if found {
|
||||
inputIdx, err = strconv.Atoi(inputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index %s (%s)", inputIndex, clean.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
inputName, inputIndex, found := strings.Cut(inputTensor.Name, ":")
|
||||
if found {
|
||||
inputIdx, err = strconv.Atoi(inputIndex)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("could not parse index %s (%s)", inputIndex, clean.Error(err))
|
||||
photoInput = &PhotoInput{
|
||||
Name: inputName,
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputTensor.Shape.Size(1),
|
||||
Width: inputTensor.Shape.Size(2),
|
||||
Shape: comps,
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for outputVarName, outputTensor := range outputs {
|
||||
var err error
|
||||
var outputIdx int
|
||||
if outputTensor.Shape.NumDimensions() == 2 {
|
||||
outputName, outputIndex, found := strings.Cut(outputTensor.Name, ":")
|
||||
if found {
|
||||
outputIdx, err = strconv.Atoi(outputIndex)
|
||||
@@ -95,23 +133,20 @@ func GetInputAndOutputFromSavedModel(model *tf.SavedModel) (*PhotoInput, *ModelO
|
||||
}
|
||||
}
|
||||
|
||||
return &PhotoInput{
|
||||
Name: inputName,
|
||||
OutputIndex: inputIdx,
|
||||
Height: inputTensor.Shape.Size(1),
|
||||
Width: inputTensor.Shape.Size(2),
|
||||
}, &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
NumOutputs: outputTensor.Shape.Size(1),
|
||||
OutputsLogits: strings.Contains(Deref(outputVarName, ""), "logits"),
|
||||
}, nil
|
||||
|
||||
modelOutput = &ModelOutput{
|
||||
Name: outputName,
|
||||
OutputIndex: outputIdx,
|
||||
NumOutputs: outputTensor.Shape.Size(1),
|
||||
OutputsLogits: strings.Contains(outputVarName, "logits"),
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if photoInput != nil && modelOutput != nil {
|
||||
return photoInput, modelOutput, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("GetInputAndOutputFromSignature: could not find valid signatures")
|
||||
}
|
||||
|
96
internal/ai/tensorflow/model_test.go
Normal file
96
internal/ai/tensorflow/model_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var assetsPath = fs.Abs("../../../assets")
|
||||
var testDataPath = fs.Abs("testdata")
|
||||
|
||||
func TestTF1ModelLoad(t *testing.T) {
|
||||
model, err := SavedModel(
|
||||
filepath.Join(assetsPath, "models", "nasnet"),
|
||||
[]string{"photoprism"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input, output, err := GetInputAndOutputFromSavedModel(model)
|
||||
if err == nil {
|
||||
t.Fatalf("TF1 does not have signatures, but GetInput worked")
|
||||
}
|
||||
|
||||
input, output, err = GuessInputAndOutput(model)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
t.Fatal("Could not get the shape")
|
||||
} else {
|
||||
t.Logf("Shape: %v", input.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTF2ModelLoad(t *testing.T) {
|
||||
model, err := SavedModel(
|
||||
filepath.Join(testDataPath, "tf2"),
|
||||
[]string{"serve"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input, output, err := GetInputAndOutputFromSavedModel(model)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
t.Fatal("Could not get the shape")
|
||||
} else if !slices.Equal(input.Shape, DefaultPhotoInputShape()) {
|
||||
t.Fatalf("Invalid shape calculated. Expected BHWC, got %v",
|
||||
input.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTF2ModelBCHWLoad(t *testing.T) {
|
||||
model, err := SavedModel(
|
||||
filepath.Join(testDataPath, "tf2_bchw"),
|
||||
[]string{"serve"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
input, output, err := GetInputAndOutputFromSavedModel(model)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if input == nil {
|
||||
t.Fatal("Could not get the input")
|
||||
} else if output == nil {
|
||||
t.Fatal("Could not get the output")
|
||||
} else if input.Shape == nil {
|
||||
t.Fatal("Could not get the shape")
|
||||
} else if !slices.Equal(input.Shape, []ShapeComponent{
|
||||
ShapeBatch, ShapeColor, ShapeHeight, ShapeWidth,
|
||||
}) {
|
||||
t.Fatalf("Invalid shape calculated. Expected BCHW, got %v",
|
||||
input.Shape)
|
||||
}
|
||||
}
|
112
internal/ai/tensorflow/tensor_builder.go
Normal file
112
internal/ai/tensorflow/tensor_builder.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
tf "github.com/wamuir/graft/tensorflow"
|
||||
)
|
||||
|
||||
type ImageTensorBuilder struct {
|
||||
data []float32
|
||||
shape []ShapeComponent
|
||||
resolution int
|
||||
rIndex int
|
||||
gIndex int
|
||||
bIndex int
|
||||
}
|
||||
|
||||
func shapeLen(c ShapeComponent, res int) int {
|
||||
switch c {
|
||||
case ShapeBatch:
|
||||
return 1
|
||||
case ShapeHeight, ShapeWidth:
|
||||
return res
|
||||
case ShapeColor:
|
||||
return 3
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func NewImageTensorBuilder(input *PhotoInput) (*ImageTensorBuilder, error) {
|
||||
|
||||
if len(input.Shape) != 4 {
|
||||
return nil, fmt.Errorf("tensorflow: the shape length is %d and should be 4", len(input.Shape))
|
||||
}
|
||||
|
||||
if input.Shape[0] != ShapeBatch {
|
||||
return nil, errors.New("tensorflow: the first shape component must be Batch")
|
||||
}
|
||||
|
||||
if input.Shape[1] != ShapeColor && input.Shape[3] != ShapeColor {
|
||||
return nil, fmt.Errorf("tensorflow: unsupported shape %v", input.Shape)
|
||||
}
|
||||
|
||||
totalSize := 1
|
||||
for i := range input.Shape {
|
||||
totalSize *= shapeLen(input.Shape[i], input.Resolution())
|
||||
}
|
||||
|
||||
// Allocate just one big chunk
|
||||
flatBuffer := make([]float32, totalSize)
|
||||
|
||||
rIndex, gIndex, bIndex := input.ColorChannelOrder.Indices()
|
||||
return &ImageTensorBuilder{
|
||||
data: flatBuffer,
|
||||
shape: input.Shape,
|
||||
resolution: input.Resolution(),
|
||||
rIndex: rIndex,
|
||||
gIndex: gIndex,
|
||||
bIndex: bIndex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *ImageTensorBuilder) Set(x, y int, r, g, b float32) {
|
||||
t.data[t.flatIndex(x, y, t.rIndex)] = r
|
||||
t.data[t.flatIndex(x, y, t.gIndex)] = g
|
||||
t.data[t.flatIndex(x, y, t.bIndex)] = b
|
||||
}
|
||||
|
||||
func (t *ImageTensorBuilder) flatIndex(x, y, c int) int {
|
||||
|
||||
shapeVal := func(s ShapeComponent) int {
|
||||
switch s {
|
||||
case ShapeBatch:
|
||||
return 0
|
||||
case ShapeColor:
|
||||
return c
|
||||
case ShapeWidth:
|
||||
return x
|
||||
case ShapeHeight:
|
||||
return y
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
idx := 0
|
||||
for _, s := range t.shape {
|
||||
idx = idx*shapeLen(s, t.resolution) + shapeVal(s)
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
func (t *ImageTensorBuilder) BuildTensor() (*tf.Tensor, error) {
|
||||
|
||||
arr := make([][][][]float32, shapeLen(t.shape[0], t.resolution))
|
||||
offset := 0
|
||||
for i := 0; i < shapeLen(t.shape[0], t.resolution); i++ {
|
||||
arr[i] = make([][][]float32, shapeLen(t.shape[1], t.resolution))
|
||||
for j := 0; j < shapeLen(t.shape[1], t.resolution); j++ {
|
||||
arr[i][j] = make([][]float32, shapeLen(t.shape[2], t.resolution))
|
||||
for k := 0; k < shapeLen(t.shape[2], t.resolution); k++ {
|
||||
arr[i][j][k] = t.data[offset : offset+shapeLen(t.shape[3], t.resolution)]
|
||||
offset += shapeLen(t.shape[3], t.resolution)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tf.NewTensor(arr)
|
||||
}
|
BIN
internal/ai/tensorflow/testdata/tf2/saved_model.pb
vendored
Executable file
BIN
internal/ai/tensorflow/testdata/tf2/saved_model.pb
vendored
Executable file
Binary file not shown.
BIN
internal/ai/tensorflow/testdata/tf2_bchw/saved_model.pb
vendored
Normal file
BIN
internal/ai/tensorflow/testdata/tf2_bchw/saved_model.pb
vendored
Normal file
Binary file not shown.
@@ -22,6 +22,7 @@ var (
|
||||
Width: 224,
|
||||
ResizeOperation: tensorflow.CenterCrop,
|
||||
ColorChannelOrder: tensorflow.RGB,
|
||||
Shape: tensorflow.DefaultPhotoInputShape(),
|
||||
Intervals: []tensorflow.Interval{
|
||||
{
|
||||
Start: -1.0,
|
||||
@@ -52,6 +53,7 @@ var (
|
||||
Height: 224,
|
||||
Width: 224,
|
||||
OutputIndex: 0,
|
||||
Shape: tensorflow.DefaultPhotoInputShape(),
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
Name: "nsfw_cls_model/final_prediction",
|
||||
@@ -74,6 +76,7 @@ var (
|
||||
Name: "input",
|
||||
Height: 160,
|
||||
Width: 160,
|
||||
Shape: tensorflow.DefaultPhotoInputShape(),
|
||||
OutputIndex: 0,
|
||||
},
|
||||
Output: &tensorflow.ModelOutput{
|
||||
|
15
internal/ai/vision/testdata/vision.yml
vendored
15
internal/ai/vision/testdata/vision.yml
vendored
@@ -16,6 +16,11 @@ Models:
|
||||
ColorChannelOrder: RGB
|
||||
Height: 224
|
||||
Width: 224
|
||||
Shape:
|
||||
- Batch
|
||||
- Height
|
||||
- Width
|
||||
- Color
|
||||
Output:
|
||||
Name: predictions/Softmax
|
||||
Outputs: 1000
|
||||
@@ -31,6 +36,11 @@ Models:
|
||||
Name: input_tensor
|
||||
Height: 224
|
||||
Width: 224
|
||||
Shape:
|
||||
- Batch
|
||||
- Height
|
||||
- Width
|
||||
- Color
|
||||
Output:
|
||||
Name: nsfw_cls_model/final_prediction
|
||||
Outputs: 5
|
||||
@@ -46,6 +56,11 @@ Models:
|
||||
Name: input
|
||||
Height: 160
|
||||
Width: 160
|
||||
Shape:
|
||||
- Batch
|
||||
- Height
|
||||
- Width
|
||||
- Color
|
||||
Output:
|
||||
Name: embeddings
|
||||
Outputs: 512
|
||||
|
Reference in New Issue
Block a user