Fix load share conversation

This commit is contained in:
hlohaus
2025-06-15 00:06:16 +02:00
parent 28810e4773
commit 74b3137107
5 changed files with 40 additions and 25 deletions

View File

@@ -9,6 +9,7 @@ from inspect import signature
from ...errors import VersionNotFoundError, MissingAuthError from ...errors import VersionNotFoundError, MissingAuthError
from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir
from ...image import get_width_height
from ...tools.run_tools import iter_run_tools from ...tools.run_tools import iter_run_tools
from ... import Provider from ... import Provider
from ...providers.base_provider import ProviderModelMixin from ...providers.base_provider import ProviderModelMixin
@@ -196,8 +197,9 @@ class Api:
media = chunk media = chunk
if download_media or chunk.get("cookies"): if download_media or chunk.get("cookies"):
chunk.alt = format_media_prompt(kwargs.get("messages"), chunk.alt) chunk.alt = format_media_prompt(kwargs.get("messages"), chunk.alt)
tags = [model, kwargs.get("aspect_ratio"), kwargs.get("resolution"), kwargs.get("width"), kwargs.get("height")] width, height = get_width_height(chunk.get("width"), chunk.get("height"))
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags)) tags = [model, kwargs.get("aspect_ratio"), kwargs.get("resolution")]
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags, add_url=f"width={width}&height={height}&"))
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt) media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
yield self._format_json("content", str(media), urls=media.urls, alt=media.alt) yield self._format_json("content", str(media), urls=media.urls, alt=media.alt)
elif isinstance(chunk, SynthesizeData): elif isinstance(chunk, SynthesizeData):

View File

@@ -442,14 +442,14 @@ class Backend_Api(Api):
@self.app.route('/backend-api/v2/chat/<share_id>', methods=['GET']) @self.app.route('/backend-api/v2/chat/<share_id>', methods=['GET'])
def get_chat(share_id: str) -> str: def get_chat(share_id: str) -> str:
share_id = secure_filename(share_id) share_id = secure_filename(share_id)
if self.chat_cache.get(share_id, 0) == int(request.headers.get("if-none-match", 0)): if self.chat_cache.get(share_id, 0) == int(request.headers.get("if-none-match", -1)):
return jsonify({"error": {"message": "Not modified"}}), 304 return jsonify({"error": {"message": "Not modified"}}), 304
file = get_bucket_dir(share_id, "chat.json") file = get_bucket_dir(share_id, "chat.json")
if not os.path.isfile(file): if not os.path.isfile(file):
return jsonify({"error": {"message": "Not found"}}), 404 return jsonify({"error": {"message": "Not found"}}), 404
with open(file, 'r') as f: with open(file, 'r') as f:
chat_data = json.load(f) chat_data = json.load(f)
if chat_data.get("updated", 0) == int(request.headers.get("if-none-match", 0)): if chat_data.get("updated", 0) == int(request.headers.get("if-none-match", -1)):
return jsonify({"error": {"message": "Not modified"}}), 304 return jsonify({"error": {"message": "Not modified"}}), 304
self.chat_cache[share_id] = chat_data.get("updated", 0) self.chat_cache[share_id] = chat_data.get("updated", 0)
return jsonify(chat_data), 200 return jsonify(chat_data), 200

View File

@@ -310,26 +310,31 @@ def to_input_audio(audio: ImageType, filename: str = None) -> str:
def use_aspect_ratio(extra_body: dict, aspect_ratio: str) -> Image: def use_aspect_ratio(extra_body: dict, aspect_ratio: str) -> Image:
extra_body = {key: value for key, value in extra_body.items() if value is not None} extra_body = {key: value for key, value in extra_body.items() if value is not None}
if extra_body.get("width") is None or extra_body.get("height") is None: if extra_body.get("width") is None or extra_body.get("height") is None:
if aspect_ratio == "1:1": width, height = get_width_height(
aspect_ratio,
extra_body.get("width"),
extra_body.get("height")
)
extra_body = { extra_body = {
"width": extra_body.get("width", 1024), "width": width,
"height": extra_body.get("height", 1024), "height": height,
**extra_body
}
elif aspect_ratio == "16:9":
extra_body = {
"width": extra_body.get("width", 832),
"height": extra_body.get("height", 480),
**extra_body
}
elif aspect_ratio == "9:16":
extra_body = {
"width": extra_body.get("width", 480),
"height": extra_body.get("height", 832),
**extra_body **extra_body
} }
return extra_body return extra_body
def get_width_height(
aspect_ratio: str,
width: Optional[int] = None,
height: Optional[int] = None
) -> tuple[int, int]:
if aspect_ratio == "1:1":
return width or 1024, height or 1024
elif aspect_ratio == "16:9":
return width or 832, height or 480
elif aspect_ratio == "9:16":
return width or 480, height or 832,
return width, height
class ImageRequest: class ImageRequest:
def __init__( def __init__(
self, self,

View File

@@ -10,7 +10,7 @@ from urllib.parse import quote, unquote
from aiohttp import ClientSession, ClientError from aiohttp import ClientSession, ClientError
from urllib.parse import urlparse from urllib.parse import urlparse
from ..typing import Optional, Cookies from ..typing import Optional, Cookies, Union
from ..requests.aiohttp import get_connector from ..requests.aiohttp import get_connector
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
from ..tools.files import secure_filename from ..tools.files import secure_filename
@@ -108,7 +108,7 @@ async def copy_media(
proxy: Optional[str] = None, proxy: Optional[str] = None,
alt: str = None, alt: str = None,
tags: list[str] = None, tags: list[str] = None,
add_url: bool = True, add_url: Union[bool, str] = True,
target: str = None, target: str = None,
ssl: bool = None ssl: bool = None
) -> list[str]: ) -> list[str]:
@@ -178,7 +178,7 @@ async def copy_media(
pass pass
# Build URL with safe encoding # Build URL with safe encoding
url_filename = quote(os.path.basename(target_path)) url_filename = quote(os.path.basename(target_path))
return f"/media/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '') return f"/media/{url_filename}" + (('?' + add_url if isinstance(add_url, str) else '' + 'url=' + quote(image)) if add_url and not image.startswith('data:') else '')
except (ClientError, IOError, OSError, ValueError) as e: except (ClientError, IOError, OSError, ValueError) as e:
debug.error(f"Image copying failed: {type(e).__name__}: {e}") debug.error(f"Image copying failed: {type(e).__name__}: {e}")

View File

@@ -70,7 +70,15 @@ def merge_media(media: list, messages: list) -> Iterator:
yield from media yield from media
def render_messages(messages: Messages, media: list = None) -> Iterator: def render_messages(messages: Messages, media: list = None) -> Iterator:
last_is_assistant = False
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
# Remove duplicate assistant messages
if message.get("role") == "assistant":
if last_is_assistant:
continue
last_is_assistant = True
else:
last_is_assistant = False
if isinstance(message["content"], list): if isinstance(message["content"], list):
parts = [render_part(part) for part in message["content"] if part] parts = [render_part(part) for part in message["content"] if part]
yield { yield {