feat: improve media handling, file conversion, and error management

- Added UUID-based "x-xai-request-id" header and 403 error handling in Grok.py
- Updated backend_api.py to handle empty media results and unsupported file types with error raising and file cleanup
- Simplified render logic in website.py by removing is_live flag and related code
- Changed "audio/wav" MIME type to "audio/x-wav" in image/__init__.py
- Added is_valid_media and is_valid_audio functions to image/__init__.py for stricter media validation
- Enhanced MarkItDown integration in markitdown/__init__.py with convert_stream method supporting non-seekable streams
- Modified _transcribe_audio.py to use recognize_faster_whisper if available, fallback to recognize_google
- Updated providers/helper.py to prioritize "text" key in to_string function
- Improved stream_read_files in files.py to skip DOWNLOADS_FILE and adjust code block formatting
- Added get_filename_from_url utility in files.py for consistent filename generation from URLs
- Enhanced download_urls in files.py to use MarkItDown for URL conversion and improved error logging
- Improved render_part and related functions in media.py to use new media validation logic and handle more cases
- Adjusted merge_media and render_messages in media.py for stricter part filtering and validation
This commit is contained in:
hlohaus
2025-05-20 22:40:12 +02:00
parent 9936c56644
commit 9461949542
9 changed files with 140 additions and 36 deletions

View File

@@ -4,6 +4,7 @@ import os
import json
import time
import asyncio
import uuid
from typing import Dict, Any, AsyncIterator
try:
@@ -13,8 +14,9 @@ except ImportError:
from ...typing import Messages, AsyncResult
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
from ...requests import StreamSession, get_nodriver, DEFAULT_HEADERS
from ...requests import StreamSession, get_nodriver, DEFAULT_HEADERS, merge_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ..helper import format_prompt, get_last_user_message
@@ -112,7 +114,10 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
url = f"{cls.conversation_url}/new"
else:
url = f"{cls.conversation_url}/{conversation_id}/responses"
async with session.post(url, json=payload) as response:
async with session.post(url, json=payload, headers={"x-xai-request-id": str(uuid.uuid4())}) as response:
if response.status == 403:
raise MissingAuthError("Invalid secrets")
auth_result.cookies = merge_cookies(auth_result.cookies, response)
await raise_for_status(response)
thinking_duration = None
async for line in response.iter_lines():

View File

@@ -19,7 +19,6 @@ try:
from ...integration.markitdown import MarkItDown, StreamInfo
has_markitdown = True
except ImportError as e:
print(e)
has_markitdown = False
from ...client.service import convert_to_provider
@@ -367,10 +366,16 @@ class Backend_Api(Api):
if is_media:
os.makedirs(media_dir, exist_ok=True)
newfile = os.path.join(media_dir, filename)
media.append({"name": filename, "text": result})
elif not result and is_supported:
if result:
media.append({"name": filename, "text": result})
else:
media.append({"name": filename})
elif is_supported:
newfile = os.path.join(bucket_dir, filename)
filenames.append(filename)
else:
os.remove(copyfile)
raise ValueError(f"Unsupported file type: {filename}")
try:
os.rename(copyfile, newfile)
except OSError:

View File

@@ -14,9 +14,7 @@ def redirect_home():
return redirect('/chat')
def render(filename = "chat"):
is_live = True
if os.path.exists(DIST_DIR):
is_live = False
path = os.path.abspath(os.path.join(os.path.dirname(DIST_DIR), (filename + ("" if "." in filename else ".html"))))
return send_from_directory(os.path.dirname(path), os.path.basename(path))
try:
@@ -25,13 +23,12 @@ def render(filename = "chat"):
latest_version = version.utils.current_version
today = datetime.today().strftime('%Y-%m-%d')
cache_dir = os.path.join(get_cookies_dir(), ".gui_cache")
cache_file = os.path.join(cache_dir, f"{today}.{secure_filename(f'{filename}.{version.utils.current_version}-{latest_version}')}{'.live' if is_live else ''}.html")
cache_file = os.path.join(cache_dir, f"{today}.{secure_filename(f'{filename}.{version.utils.current_version}-{latest_version}')}.html")
if not os.path.exists(cache_file):
os.makedirs(cache_dir, exist_ok=True)
html = requests.get(f"{STATIC_URL}{filename}.html").text
if is_live:
html = html.replace("../dist/", f"dist/")
html = html.replace("\"dist/", f"\"{STATIC_URL}dist/")
html = html.replace("../dist/", f"dist/")
html = html.replace("\"dist/", f"\"{STATIC_URL}dist/")
with open(cache_file, 'w', encoding='utf-8') as f:
f.write(html)
return send_from_directory(os.path.abspath(cache_dir), os.path.basename(cache_file))

View File

@@ -25,7 +25,7 @@ EXTENSIONS_MAP: dict[str, str] = {
"gif": "image/gif",
"webp": "image/webp",
# Audio
"wav": "audio/wav",
"wav": "audio/x-wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"opus": "audio/opus",
@@ -107,18 +107,39 @@ def is_data_an_media(data, filename: str = None) -> str:
return is_accepted_format(data)
return is_data_uri_an_image(data)
def is_valid_media(data, filename: str = None) -> str:
if is_valid_audio(data, filename):
return True
if filename:
extension = get_extension(filename)
if extension is not None:
media_type = EXTENSIONS_MAP[extension]
if media_type.startswith("image/"):
return media_type
if isinstance(data, bytes):
return is_accepted_format(data)
return is_data_uri_an_image(data)
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
if filename:
extension = get_extension(filename)
if extension is not None:
media_type = EXTENSIONS_MAP[extension]
if media_type.startswith("audio/") or media_type == "video/webm":
if media_type.startswith("audio/"):
return media_type
if isinstance(data_uri, str):
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
if audio_format:
return audio_format.group(1)
def is_valid_audio(data_uri: str = None, filename: str = None) -> bool:
mimetype = is_data_an_audio(data_uri, filename)
if mimetype is None:
return False
if MEDIA_TYPE_MAP.get(mimetype) not in ("wav", "mp3"):
return False
return True
def is_data_uri_an_image(data_uri: str) -> bool:
"""
Checks if the given data URI represents an image.

View File

@@ -1,6 +1,7 @@
import re
import sys
from typing import List, Union, BinaryIO
import io
from typing import List, Union, BinaryIO, Optional, Any
from markitdown import MarkItDown as BaseMarkItDown
from markitdown._stream_info import StreamInfo
from markitdown._base_converter import DocumentConverterResult
@@ -117,4 +118,51 @@ class MarkItDown(BaseMarkItDown):
# Nothing can handle it!
raise UnsupportedFormatException(
f"Could not convert stream to Markdown. No converter attempted a conversion, suggesting that the filetype is simply not supported."
)
)
def convert_stream(
self,
stream: BinaryIO,
*,
stream_info: Optional[StreamInfo] = None,
file_extension: Optional[str] = None, # Deprecated -- use stream_info
url: Optional[str] = None, # Deprecated -- use stream_info
**kwargs: Any,
) -> DocumentConverterResult:
guesses: List[StreamInfo] = []
# Do we have anything on which to base a guess?
base_guess = None
if stream_info is not None or file_extension is not None or url is not None:
# Start with a non-Null base guess
if stream_info is None:
base_guess = StreamInfo()
else:
base_guess = stream_info
if file_extension is not None:
# Deprecated -- use stream_info
assert base_guess is not None # for mypy
base_guess = base_guess.copy_and_update(extension=file_extension)
if url is not None:
# Deprecated -- use stream_info
assert base_guess is not None # for mypy
base_guess = base_guess.copy_and_update(url=url)
# Check if we have a seekable stream. If not, load the entire stream into memory.
if not hasattr(stream, "seekable") or not stream.seekable():
buffer = io.BytesIO()
while True:
chunk = stream.read(4096)
if not chunk:
break
buffer.write(chunk)
buffer.seek(0)
stream = buffer
# Add guesses based on stream content
guesses = self._get_stream_info_guesses(
file_stream=stream, base_guess=base_guess or StreamInfo()
)
return self._convert(file_stream=stream, stream_info_guesses=guesses, **kwargs)

View File

@@ -47,5 +47,8 @@ def transcribe_audio(file_stream: BinaryIO, *, audio_format: str = "wav", langua
audio = recognizer.record(source)
if language is None:
language = "en-US"
transcript = recognizer.recognize_google(audio, language=language).strip()
try:
transcript = recognizer.recognize_faster_whisper(audio, language=language.split("-")[0]).strip()
except ImportError:
transcript = recognizer.recognize_google(audio, language=language).strip()
return "[No speech detected]" if transcript == "" else transcript.strip()

View File

@@ -12,13 +12,13 @@ def to_string(value) -> str:
if isinstance(value, str):
return value
elif isinstance(value, dict):
if "name" in value:
if "text" in value:
return value["text"]
elif "name" in value:
return ""
elif "bucket_id" in value:
bucket_dir = Path(get_bucket_dir(value.get("bucket_id")))
return "".join(read_bucket(bucket_dir))
elif value.get("type") == "text":
return value.get("text")
return ""
elif isinstance(value, list):
return "".join([to_string(v) for v in value if v.get("type", "text") == "text"])

View File

@@ -69,6 +69,11 @@ try:
has_beautifulsoup4 = True
except ImportError:
has_beautifulsoup4 = False
try:
from markitdown import MarkItDown
has_markitdown = True
except ImportError:
has_markitdown = False
from .web_search import scrape_text
from ..cookies import get_cookies_dir
@@ -169,8 +174,10 @@ def get_filenames(bucket_dir: Path):
return [filename.strip() for filename in f.readlines()]
return []
def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = False) -> Iterator[str]:
def stream_read_files(bucket_dir: Path, filenames: list[str], delete_files: bool = False) -> Iterator[str]:
for filename in filenames:
if filename.startswith(DOWNLOADS_FILE):
continue
file_path: Path = bucket_dir / filename
if not file_path.exists() or file_path.lstat().st_size <= 0:
continue
@@ -192,7 +199,7 @@ def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = Fa
else:
os.unlink(filepath)
continue
yield f"```{filename.replace('.md', '')}\n"
yield f"```{filename}\n"
if has_pypdf2 and filename.endswith(".pdf"):
try:
reader = PyPDF2.PdfReader(file_path)
@@ -339,6 +346,13 @@ def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes=
with open(output_filename, 'w', encoding='utf-8') as outfile:
outfile.write(current_chunk)
def get_filename_from_url(url: str) -> str:
parsed_url = urllib.parse.urlparse(url)
sha256_hash = hashlib.sha256(url.encode()).digest()
base32_encoded = base64.b32encode(sha256_hash).decode()
url_hash = base32_encoded[:24].lower()
return f"{parsed_url.netloc}+{parsed_url.path[1:].replace('/', '_')}+{url_hash}.md"
async def get_filename(response: ClientResponse) -> str:
"""
Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL.
@@ -364,11 +378,7 @@ async def get_filename(response: ClientResponse) -> str:
if content_type and url:
extension = await get_file_extension(response)
if extension:
parsed_url = urllib.parse.urlparse(url)
sha256_hash = hashlib.sha256(url.encode()).digest()
base32_encoded = base64.b32encode(sha256_hash).decode()
url_hash = base32_encoded[:24].lower()
return f"{parsed_url.netloc}+{parsed_url.path[1:].replace('/', '_')}+{url_hash}{extension}"
return get_filename_from_url(url)
return None
@@ -442,17 +452,29 @@ async def download_urls(
) -> AsyncIterator[str]:
if lock is None:
lock = asyncio.Lock()
md = MarkItDown()
async with ClientSession(
connector=get_connector(proxy=proxy),
timeout=ClientTimeout(timeout)
) as session:
async def download_url(url: str, max_depth: int) -> str:
text_content = None
if has_markitdown:
try:
text_content = md.convert(url).text_content
if text_content:
filename = get_filename_from_url(url)
target = bucket_dir / filename
target.write_text(text_content, errors="replace")
return filename
except Exception as e:
debug.log(f"Failed to convert URL to text: {type(e).__name__}: {e}")
try:
async with session.get(url) as response:
response.raise_for_status()
filename = await get_filename(response)
if not filename:
print(f"Failed to get filename for {url}")
debug.log(f"Failed to get filename for {url}")
return None
if not is_allowed_extension(filename) and not supports_filename(filename) or filename == DOWNLOADS_FILE:
return None

View File

@@ -6,7 +6,7 @@ from typing import Iterator, Union
from pathlib import Path
from ..typing import Messages
from ..image import is_data_an_media, is_data_an_audio, to_input_audio, to_data_uri
from ..image import is_data_an_media, to_input_audio, is_valid_media, is_valid_audio, to_data_uri
from .files import get_bucket_dir, read_bucket
def render_media(bucket_id: str, name: str, url: str, as_path: bool = False, as_base64: bool = False) -> Union[str, Path]:
@@ -37,7 +37,7 @@ def render_part(part: dict) -> dict:
"type": "text",
"text": "".join(read_bucket(bucket_dir))
}
if is_data_an_audio(filename=filename):
if is_valid_audio(filename=filename):
return {
"type": "input_audio",
"input_audio": {
@@ -45,10 +45,11 @@ def render_part(part: dict) -> dict:
"format": os.path.splitext(filename)[1][1:]
}
}
return {
"type": "image_url",
"image_url": {"url": render_media(**part)}
}
if is_valid_media(filename=filename):
return {
"type": "image_url",
"image_url": {"url": render_media(**part)}
}
def merge_media(media: list, messages: list) -> Iterator:
buffer = []
@@ -57,7 +58,7 @@ def merge_media(media: list, messages: list) -> Iterator:
content = message.get("content")
if isinstance(content, list):
for part in content:
if "type" not in part and "name" in part:
if "type" not in part and "name" in part and "text" not in part:
path = render_media(**part, as_path=True)
buffer.append((path, os.path.basename(path)))
elif part.get("type") == "image_url":
@@ -71,9 +72,10 @@ def merge_media(media: list, messages: list) -> Iterator:
def render_messages(messages: Messages, media: list = None) -> Iterator:
for idx, message in enumerate(messages):
if isinstance(message["content"], list):
parts = [render_part(part) for part in message["content"] if part]
yield {
**message,
"content": [render_part(part) for part in message["content"] if part]
"content": [part for part in parts if part]
}
else:
if media is not None and idx == len(messages) - 1:
@@ -84,11 +86,12 @@ def render_messages(messages: Messages, media: list = None) -> Iterator:
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
if is_valid_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
if is_valid_media(media_data, filename)
] + ([{"type": "text", "text": message["content"]}] if isinstance(message["content"], str) else message["content"])
}
else: