AI: Generate captions using the Ollama API #5011 #5123

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-07-21 10:43:49 +02:00
parent f67ba0e634
commit ae42af54d8
16 changed files with 313 additions and 41 deletions

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media/http/header"
)
@@ -47,10 +48,11 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
return apiResponse, clientErr
}
apiResponse = &ApiResponse{}
// Parse and return response, or an error if the request failed.
switch apiRequest.GetResponseFormat() {
case ApiFormatVision:
apiResponse = &ApiResponse{}
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
return apiResponse, apiErr
} else if apiErr = json.Unmarshal(apiJson, apiResponse); apiErr != nil {
@@ -58,6 +60,27 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
} else if clientResp.StatusCode >= 300 {
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
}
case ApiFormatOllama:
ollamaResponse := &ApiResponseOllama{}
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
return apiResponse, apiErr
} else if apiErr = json.Unmarshal(apiJson, ollamaResponse); apiErr != nil {
return apiResponse, apiErr
} else if clientResp.StatusCode >= 300 {
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
}
apiResponse.Id = apiRequest.Id
apiResponse.Code = clientResp.StatusCode
apiResponse.Model = &Model{
Name: ollamaResponse.Model,
}
apiResponse.Result.Caption = &CaptionResult{
Text: ollamaResponse.Response,
Source: entity.SrcImage,
}
default:
return apiResponse, fmt.Errorf("unsupported response format %s", clean.Log(apiRequest.responseFormat))
}

View File

@@ -6,4 +6,5 @@ const (
ApiFormatUrl ApiFormat = "url"
ApiFormatImages ApiFormat = "images"
ApiFormatVision ApiFormat = "vision"
ApiFormatOllama ApiFormat = "ollama"
)

View File

@@ -0,0 +1,85 @@
package vision
import (
"errors"
"fmt"
"os"
"time"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/media"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
"github.com/photoprism/photoprism/pkg/rnd"
)
// ApiResponseOllama represents a Ollama API service response.
type ApiResponseOllama struct {
Id string `yaml:"Id,omitempty" json:"id,omitempty"`
Code int `yaml:"Code,omitempty" json:"code,omitempty"`
Error string `yaml:"Error,omitempty" json:"error,omitempty"`
Model string `yaml:"Model,omitempty" json:"model,omitempty"`
CreatedAt time.Time `yaml:"CreatedAt,omitempty" json:"created_at,omitempty"`
Response string `yaml:"Response,omitempty" json:"response,omitempty"`
Done bool `yaml:"Done,omitempty" json:"done,omitempty"`
Context []int `yaml:"Context,omitempty" json:"context,omitempty"`
TotalDuration int64 `yaml:"TotalDuration,omitempty" json:"total_duration,omitempty"`
LoadDuration int `yaml:"LoadDuration,omitempty" json:"load_duration,omitempty"`
PromptEvalCount int `yaml:"PromptEvalCount,omitempty" json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `yaml:"PromptEvalDuration,omitempty" json:"prompt_eval_duration,omitempty"`
EvalCount int `yaml:"EvalCount,omitempty" json:"eval_count,omitempty"`
EvalDuration int64 `yaml:"EvalDuration,omitempty" json:"eval_duration,omitempty"`
Result ApiResult `yaml:"Result,omitempty" json:"result,omitempty"`
}
// Err returns an error if the request has failed.
func (r *ApiResponseOllama) Err() error {
if r == nil {
return errors.New("response is nil")
}
if r.Code >= 400 {
if r.Error != "" {
return errors.New(r.Error)
}
return fmt.Errorf("error %d", r.Code)
} else if r.Result.IsEmpty() {
return errors.New("no result")
}
return nil
}
// HasResult checks if there is at least one result in the response data.
func (r *ApiResponseOllama) HasResult() bool {
if r == nil {
return false
}
return !r.Result.IsEmpty()
}
// NewApiRequestOllama returns a new Ollama API request with the specified images as payload.
func NewApiRequestOllama(images Files, fileScheme scheme.Type) (*ApiRequest, error) {
imagesData := make(Files, len(images))
for i := range images {
switch fileScheme {
case scheme.Data, scheme.Base64:
if file, err := os.Open(images[i]); err != nil {
return nil, fmt.Errorf("%s (create data url)", err)
} else {
imagesData[i] = media.DataBase64(file)
}
default:
return nil, fmt.Errorf("unsupported file scheme %s", clean.Log(fileScheme))
}
}
return &ApiRequest{
Id: rnd.UUID(),
Model: "",
Images: imagesData,
responseFormat: ApiFormatOllama,
}, nil
}

View File

@@ -21,15 +21,56 @@ import (
type Files = []string
// ApiRequestOptions represents additional model parameters listed in the documentation.
type ApiRequestOptions struct {
NumKeep int `json:"num_keep,omitempty"`
Seed int `json:"seed,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MinP float64 `json:"min_p,omitempty"`
TfsZ float64 `json:"tfs_z,omitempty"`
TypicalP float64 `json:"typical_p,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
RepeatPenalty float64 `json:"repeat_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float64 `json:"mirostat_tau,omitempty"`
MirostatEta float64 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
Stop []string `json:"stop,omitempty"`
Numa bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGpu int `json:"num_gpu,omitempty"`
MainGpu int `json:"main_gpu,omitempty"`
LowVram bool `json:"low_vram,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMmap bool `json:"use_mmap,omitempty"`
UseMlock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
// ApiRequestContext represents a context parameter returned from a previous request.
type ApiRequestContext = []int
// ApiRequest represents a Vision API service request.
type ApiRequest struct {
Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"`
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
Version string `form:"version" yaml:"Version,omitempty" json:"version,omitempty"`
Prompt string `form:"prompt" yaml:"Prompt,omitempty" json:"prompt,omitempty"`
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
responseFormat ApiFormat `form:"-"`
Id string `form:"id" yaml:"Id,omitempty" json:"id,omitempty"`
Model string `form:"model" yaml:"Model,omitempty" json:"model,omitempty"`
Version string `form:"version" yaml:"Version,omitempty" json:"version,omitempty"`
System string `form:"system" yaml:"System,omitempty" json:"system,omitempty"`
Prompt string `form:"prompt" yaml:"Prompt,omitempty" json:"prompt,omitempty"`
Suffix string `form:"suffix" yaml:"Suffix,omitempty" json:"suffix"`
Format string `form:"format" yaml:"Format,omitempty" json:"format,omitempty"`
Url string `form:"url" yaml:"Url,omitempty" json:"url,omitempty"`
Options *ApiRequestOptions `form:"options" yaml:"Options,omitempty" json:"options,omitempty"`
Context *ApiRequestContext `form:"context" yaml:"Context,omitempty" json:"context,omitempty"`
Stream bool `form:"stream" yaml:"Stream,omitempty" json:"stream"`
Images Files `form:"images" yaml:"Images,omitempty" json:"images,omitempty"`
responseFormat ApiFormat `form:"-"`
}
// NewApiRequest returns a new service API request with the specified format and payload.
@@ -43,6 +84,8 @@ func NewApiRequest(requestFormat ApiFormat, files Files, fileScheme scheme.Type)
return NewApiRequestUrl(files[0], fileScheme)
case ApiFormatImages, ApiFormatVision:
return NewApiRequestImages(files, fileScheme)
case ApiFormatOllama:
return NewApiRequestOllama(files, fileScheme)
default:
return result, errors.New("invalid request format")
}

View File

@@ -33,14 +33,15 @@ func Caption(images Files, src media.Src) (result *CaptionResult, model *Model,
return result, model, err
}
if model.Name != "" {
apiRequest.Model = model.Name
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
}
if model.Version != "" {
apiRequest.Version = model.Version
} else {
apiRequest.Version = "latest"
if model.System != "" {
apiRequest.System = model.System
}
if model.Prompt != "" {
@@ -52,6 +53,8 @@ func Caption(images Files, src media.Src) (result *CaptionResult, model *Model,
// Log JSON request data in trace mode.
apiRequest.WriteLog()
// Todo: Refactor response handling to support different API response formats,
// including those used by Ollama and OpenAI.
if apiResponse, err = PerformApiRequest(apiRequest, uri, method, model.EndpointKey()); err != nil {
return result, model, err
} else if apiResponse.Result.Caption == nil {

View File

@@ -53,12 +53,10 @@ func Faces(fileName string, minSize int, cacheCrop bool, expected int) (result f
return result, err
}
if model.Name != "" {
apiRequest.Model = model.Name
}
_, apiRequest.Model, apiRequest.Version = model.Model()
if model.Version != "" {
apiRequest.Version = model.Version
if model.System != "" {
apiRequest.System = model.System
}
if model.Prompt != "" {

View File

@@ -30,12 +30,15 @@ func Labels(images Files, src media.Src) (result classify.Labels, err error) {
return result, err
}
if model.Name != "" {
apiRequest.Model = model.Name
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
}
if model.Version != "" {
apiRequest.Version = model.Version
if model.System != "" {
apiRequest.System = model.System
}
if model.Prompt != "" {

View File

@@ -3,6 +3,7 @@ package vision
import (
"fmt"
"path/filepath"
"strings"
"sync"
"github.com/photoprism/photoprism/internal/ai/classify"
@@ -16,9 +17,8 @@ var modelMutex = sync.Mutex{}
// Default model version strings.
var (
ModelVersionNone = ""
ModelVersionLatest = "latest"
ModelVersionMobile = "Mobile"
VersionLatest = "latest"
VersionMobile = "mobile"
)
// Model represents a computer vision model configuration.
@@ -26,6 +26,7 @@ type Model struct {
Type ModelType `yaml:"Type,omitempty" json:"type,omitempty"`
Name string `yaml:"Name,omitempty" json:"name,omitempty"`
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
System string `yaml:"System,omitempty" json:"system,omitempty"`
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
Service Service `yaml:"Service,omitempty" json:"Service,omitempty"`
@@ -40,6 +41,39 @@ type Model struct {
// Models represents a set of computer vision models.
type Models []*Model
// Model returns the parsed and normalized model identifier, name, and version strings.
func (m *Model) Model() (model, name, version string) {
// Return empty identifier string if no name was set.
if m.Name == "" {
return "", "", clean.TypeLowerDash(m.Version)
}
// Normalize model name.
name = clean.TypeLowerDash(m.Name)
// Split name to check if it contains the version.
s := strings.SplitN(name, ":", 2)
// Return if name contains both model name and version.
if len(s) == 2 && s[0] != "" && s[1] != "" {
return name, s[0], s[1]
}
// Normalize model version.
version = clean.TypeLowerDash(m.Version)
// Default to "latest" if no specific version was set.
if version == "" {
version = VersionLatest
}
// Create model identifier from model name and version.
model = strings.Join([]string{s[0], version}, ":")
// Return normalized model identifier, name, and version.
return model, name, version
}
// Endpoint returns the remote service request method and endpoint URL, if any.
func (m *Model) Endpoint() (uri, method string) {
if uri, method = m.Service.Endpoint(); uri != "" && method != "" {

View File

@@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/media/http/scheme"
)
func TestModel(t *testing.T) {
@@ -14,10 +16,45 @@ func TestModel(t *testing.T) {
ServiceUri = ""
assert.Equal(t, "https://app.localssl.dev/api/v1/vision/labels", uri)
assert.Equal(t, http.MethodPost, method)
uri, method = NasnetModel.Endpoint()
assert.Equal(t, "", uri)
assert.Equal(t, "", method)
})
t.Run("Caption", func(t *testing.T) {
uri, method := CaptionModel.Endpoint()
assert.Equal(t, "", uri)
assert.Equal(t, "", method)
model, name, version := CaptionModel.Model()
assert.Equal(t, "qwen2.5vl:latest", model)
assert.Equal(t, "qwen2.5vl", name)
assert.Equal(t, "latest", version)
})
t.Run("ParseName", func(t *testing.T) {
m := &Model{
Type: ModelTypeCaption,
Name: "deepseek-r1:1.5b",
Version: "",
Resolution: 720,
Prompt: CaptionPromptDefault,
Service: Service{
Uri: "http://foo:bar@photoprism-vision:5000/api/v1/vision/caption",
FileScheme: scheme.Data,
RequestFormat: ApiFormatVision,
ResponseFormat: ApiFormatVision,
},
}
uri, method := m.Endpoint()
assert.Equal(t, "http://foo:bar@photoprism-vision:5000/api/v1/vision/caption", uri)
assert.Equal(t, "POST", method)
model, name, version := m.Model()
assert.Equal(t, "deepseek-r1:1.5b", model)
assert.Equal(t, "deepseek-r1", name)
assert.Equal(t, "1.5b", version)
})
}
func TestParseTypes(t *testing.T) {

View File

@@ -8,29 +8,29 @@ import (
var (
NasnetModel = &Model{
Type: ModelTypeLabels,
Name: "NASNet",
Version: ModelVersionMobile,
Name: "nasnet",
Version: VersionMobile,
Resolution: 224, // Cropped image tile with 224x224 pixels.
Tags: []string{"photoprism"},
}
NsfwModel = &Model{
Type: ModelTypeNsfw,
Name: "Nsfw",
Version: ModelVersionNone,
Name: "nsfw",
Version: VersionLatest,
Resolution: 224, // Cropped image tile with 224x224 pixels.
Tags: []string{"serve"},
}
FacenetModel = &Model{
Type: ModelTypeFace,
Name: "FaceNet",
Version: ModelVersionNone,
Name: "facenet",
Version: VersionLatest,
Resolution: 160, // Cropped image tile with 160x160 pixels.
Tags: []string{"serve"},
}
CaptionModel = &Model{
Type: ModelTypeCaption,
Name: CaptionModelDefault,
Version: ModelVersionLatest,
Version: VersionLatest,
Resolution: 720, // Original aspect ratio, with a max size of 720 x 720 pixels.
Prompt: CaptionPromptDefault,
Service: Service{

View File

@@ -31,12 +31,15 @@ func Nsfw(images Files, src media.Src) (result []nsfw.Result, err error) {
return result, err
}
if model.Name != "" {
apiRequest.Model = model.Name
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
default:
_, apiRequest.Model, apiRequest.Version = model.Model()
}
if model.Version != "" {
apiRequest.Version = model.Version
if model.System != "" {
apiRequest.System = model.System
}
if model.Prompt != "" {

View File

@@ -1,17 +1,19 @@
Models:
- Type: labels
Name: NASNet
Version: Mobile
Name: nasnet
Version: mobile
Resolution: 224
Tags:
- photoprism
- Type: nsfw
Name: Nsfw
Name: nsfw
Version: latest
Resolution: 224
Tags:
- serve
- Type: face
Name: FaceNet
Name: facenet
Version: latest
Resolution: 160
Tags:
- serve

View File

@@ -33,6 +33,15 @@ func TypeLowerUnderscore(s string) string {
return strings.ReplaceAll(TypeLower(s), " ", "_")
}
// TypeLowerDash converts a string to a lowercase type string and replaces spaces with dashes.
func TypeLowerDash(s string) string {
if s == "" {
return s
}
return strings.ReplaceAll(TypeLower(s), " ", "-")
}
// ShortType omits invalid runes, ensures a maximum length of 8 characters, and returns the result.
func ShortType(s string) string {
if s == "" {

View File

@@ -82,6 +82,24 @@ func TestTypeLowerUnderscore(t *testing.T) {
})
}
func TestTypeLowerDash(t *testing.T) {
t.Run("Undefined", func(t *testing.T) {
assert.Equal(t, "", TypeLowerDash(" \t "))
})
t.Run("ClientCredentials", func(t *testing.T) {
assert.Equal(t, "client-credentials", TypeLowerDash(" Client Credentials幸"))
})
t.Run("OllamaModel", func(t *testing.T) {
assert.Equal(
t,
"ollama-model:7b",
TypeLowerDash("Ollama Model:7b"))
})
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", TypeLowerDash(""))
})
}
func TestShortType(t *testing.T) {
t.Run("Clip", func(t *testing.T) {
result := ShortType(" 幸福 Hanzi are logograms developed for the writing of Chinese! Expressions in an index may not ...!")

View File

@@ -45,6 +45,18 @@ func DataUrl(r io.Reader) string {
return fmt.Sprintf("data:%s;base64,%s", mimeType, EncodeBase64String(data))
}
// DataBase64 generates a base64 encoded string of the binary data from the specified io.Reader.
func DataBase64(r io.Reader) string {
// Read binary data.
data, err := io.ReadAll(r)
if err != nil || len(data) == 0 {
return ""
}
return EncodeBase64String(data)
}
// ReadUrl reads binary data from a regular file path,
// fetches its data from a remote http or https URL,
// or decodes a base64 data URL as created by DataUrl.

View File

@@ -6,6 +6,7 @@ type Type = string
const (
File Type = "file"
Data Type = "data"
Base64 Type = "base64"
Http Type = "http"
Https Type = "https"
Websocket Type = "wss"