mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-09-26 20:31:14 +08:00
feat: improve media handling, file conversion, and error management
- Added UUID-based "x-xai-request-id" header and 403 error handling in Grok.py - Updated backend_api.py to handle empty media results and unsupported file types with error raising and file cleanup - Simplified render logic in website.py by removing is_live flag and related code - Changed "audio/wav" MIME type to "audio/x-wav" in image/__init__.py - Added is_valid_media and is_valid_audio functions to image/__init__.py for stricter media validation - Enhanced MarkItDown integration in markitdown/__init__.py with convert_stream method supporting non-seekable streams - Modified _transcribe_audio.py to use recognize_faster_whisper if available, fallback to recognize_google - Updated providers/helper.py to prioritize "text" key in to_string function - Improved stream_read_files in files.py to skip DOWNLOADS_FILE and adjust code block formatting - Added get_filename_from_url utility in files.py for consistent filename generation from URLs - Enhanced download_urls in files.py to use MarkItDown for URL conversion and improved error logging - Improved render_part and related functions in media.py to use new media validation logic and handle more cases - Adjusted merge_media and render_messages in media.py for stricter part filtering and validation
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Dict, Any, AsyncIterator
|
||||
|
||||
try:
|
||||
@@ -13,8 +14,9 @@ except ImportError:
|
||||
|
||||
from ...typing import Messages, AsyncResult
|
||||
from ...providers.response import JsonConversation, Reasoning, ImagePreview, ImageResponse, TitleGeneration, AuthResult, RequestLogin
|
||||
from ...requests import StreamSession, get_nodriver, DEFAULT_HEADERS
|
||||
from ...requests import StreamSession, get_nodriver, DEFAULT_HEADERS, merge_cookies
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...errors import MissingAuthError
|
||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, get_last_user_message
|
||||
|
||||
@@ -112,7 +114,10 @@ class Grok(AsyncAuthedProvider, ProviderModelMixin):
|
||||
url = f"{cls.conversation_url}/new"
|
||||
else:
|
||||
url = f"{cls.conversation_url}/{conversation_id}/responses"
|
||||
async with session.post(url, json=payload) as response:
|
||||
async with session.post(url, json=payload, headers={"x-xai-request-id": str(uuid.uuid4())}) as response:
|
||||
if response.status == 403:
|
||||
raise MissingAuthError("Invalid secrets")
|
||||
auth_result.cookies = merge_cookies(auth_result.cookies, response)
|
||||
await raise_for_status(response)
|
||||
thinking_duration = None
|
||||
async for line in response.iter_lines():
|
||||
|
@@ -19,7 +19,6 @@ try:
|
||||
from ...integration.markitdown import MarkItDown, StreamInfo
|
||||
has_markitdown = True
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
has_markitdown = False
|
||||
|
||||
from ...client.service import convert_to_provider
|
||||
@@ -367,10 +366,16 @@ class Backend_Api(Api):
|
||||
if is_media:
|
||||
os.makedirs(media_dir, exist_ok=True)
|
||||
newfile = os.path.join(media_dir, filename)
|
||||
media.append({"name": filename, "text": result})
|
||||
elif not result and is_supported:
|
||||
if result:
|
||||
media.append({"name": filename, "text": result})
|
||||
else:
|
||||
media.append({"name": filename})
|
||||
elif is_supported:
|
||||
newfile = os.path.join(bucket_dir, filename)
|
||||
filenames.append(filename)
|
||||
else:
|
||||
os.remove(copyfile)
|
||||
raise ValueError(f"Unsupported file type: {filename}")
|
||||
try:
|
||||
os.rename(copyfile, newfile)
|
||||
except OSError:
|
||||
|
@@ -14,9 +14,7 @@ def redirect_home():
|
||||
return redirect('/chat')
|
||||
|
||||
def render(filename = "chat"):
|
||||
is_live = True
|
||||
if os.path.exists(DIST_DIR):
|
||||
is_live = False
|
||||
path = os.path.abspath(os.path.join(os.path.dirname(DIST_DIR), (filename + ("" if "." in filename else ".html"))))
|
||||
return send_from_directory(os.path.dirname(path), os.path.basename(path))
|
||||
try:
|
||||
@@ -25,13 +23,12 @@ def render(filename = "chat"):
|
||||
latest_version = version.utils.current_version
|
||||
today = datetime.today().strftime('%Y-%m-%d')
|
||||
cache_dir = os.path.join(get_cookies_dir(), ".gui_cache")
|
||||
cache_file = os.path.join(cache_dir, f"{today}.{secure_filename(f'{filename}.{version.utils.current_version}-{latest_version}')}{'.live' if is_live else ''}.html")
|
||||
cache_file = os.path.join(cache_dir, f"{today}.{secure_filename(f'{filename}.{version.utils.current_version}-{latest_version}')}.html")
|
||||
if not os.path.exists(cache_file):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
html = requests.get(f"{STATIC_URL}{filename}.html").text
|
||||
if is_live:
|
||||
html = html.replace("../dist/", f"dist/")
|
||||
html = html.replace("\"dist/", f"\"{STATIC_URL}dist/")
|
||||
html = html.replace("../dist/", f"dist/")
|
||||
html = html.replace("\"dist/", f"\"{STATIC_URL}dist/")
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html)
|
||||
return send_from_directory(os.path.abspath(cache_dir), os.path.basename(cache_file))
|
||||
|
@@ -25,7 +25,7 @@ EXTENSIONS_MAP: dict[str, str] = {
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
# Audio
|
||||
"wav": "audio/wav",
|
||||
"wav": "audio/x-wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"flac": "audio/flac",
|
||||
"opus": "audio/opus",
|
||||
@@ -107,18 +107,39 @@ def is_data_an_media(data, filename: str = None) -> str:
|
||||
return is_accepted_format(data)
|
||||
return is_data_uri_an_image(data)
|
||||
|
||||
def is_valid_media(data, filename: str = None) -> str:
|
||||
if is_valid_audio(data, filename):
|
||||
return True
|
||||
if filename:
|
||||
extension = get_extension(filename)
|
||||
if extension is not None:
|
||||
media_type = EXTENSIONS_MAP[extension]
|
||||
if media_type.startswith("image/"):
|
||||
return media_type
|
||||
if isinstance(data, bytes):
|
||||
return is_accepted_format(data)
|
||||
return is_data_uri_an_image(data)
|
||||
|
||||
def is_data_an_audio(data_uri: str = None, filename: str = None) -> str:
|
||||
if filename:
|
||||
extension = get_extension(filename)
|
||||
if extension is not None:
|
||||
media_type = EXTENSIONS_MAP[extension]
|
||||
if media_type.startswith("audio/") or media_type == "video/webm":
|
||||
if media_type.startswith("audio/"):
|
||||
return media_type
|
||||
if isinstance(data_uri, str):
|
||||
audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
|
||||
if audio_format:
|
||||
return audio_format.group(1)
|
||||
|
||||
def is_valid_audio(data_uri: str = None, filename: str = None) -> bool:
|
||||
mimetype = is_data_an_audio(data_uri, filename)
|
||||
if mimetype is None:
|
||||
return False
|
||||
if MEDIA_TYPE_MAP.get(mimetype) not in ("wav", "mp3"):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
"""
|
||||
Checks if the given data URI represents an image.
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Union, BinaryIO
|
||||
import io
|
||||
from typing import List, Union, BinaryIO, Optional, Any
|
||||
from markitdown import MarkItDown as BaseMarkItDown
|
||||
from markitdown._stream_info import StreamInfo
|
||||
from markitdown._base_converter import DocumentConverterResult
|
||||
@@ -117,4 +118,51 @@ class MarkItDown(BaseMarkItDown):
|
||||
# Nothing can handle it!
|
||||
raise UnsupportedFormatException(
|
||||
f"Could not convert stream to Markdown. No converter attempted a conversion, suggesting that the filetype is simply not supported."
|
||||
)
|
||||
)
|
||||
|
||||
def convert_stream(
|
||||
self,
|
||||
stream: BinaryIO,
|
||||
*,
|
||||
stream_info: Optional[StreamInfo] = None,
|
||||
file_extension: Optional[str] = None, # Deprecated -- use stream_info
|
||||
url: Optional[str] = None, # Deprecated -- use stream_info
|
||||
**kwargs: Any,
|
||||
) -> DocumentConverterResult:
|
||||
guesses: List[StreamInfo] = []
|
||||
|
||||
# Do we have anything on which to base a guess?
|
||||
base_guess = None
|
||||
if stream_info is not None or file_extension is not None or url is not None:
|
||||
# Start with a non-Null base guess
|
||||
if stream_info is None:
|
||||
base_guess = StreamInfo()
|
||||
else:
|
||||
base_guess = stream_info
|
||||
|
||||
if file_extension is not None:
|
||||
# Deprecated -- use stream_info
|
||||
assert base_guess is not None # for mypy
|
||||
base_guess = base_guess.copy_and_update(extension=file_extension)
|
||||
|
||||
if url is not None:
|
||||
# Deprecated -- use stream_info
|
||||
assert base_guess is not None # for mypy
|
||||
base_guess = base_guess.copy_and_update(url=url)
|
||||
|
||||
# Check if we have a seekable stream. If not, load the entire stream into memory.
|
||||
if not hasattr(stream, "seekable") or not stream.seekable():
|
||||
buffer = io.BytesIO()
|
||||
while True:
|
||||
chunk = stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buffer.write(chunk)
|
||||
buffer.seek(0)
|
||||
stream = buffer
|
||||
|
||||
# Add guesses based on stream content
|
||||
guesses = self._get_stream_info_guesses(
|
||||
file_stream=stream, base_guess=base_guess or StreamInfo()
|
||||
)
|
||||
return self._convert(file_stream=stream, stream_info_guesses=guesses, **kwargs)
|
||||
|
@@ -47,5 +47,8 @@ def transcribe_audio(file_stream: BinaryIO, *, audio_format: str = "wav", langua
|
||||
audio = recognizer.record(source)
|
||||
if language is None:
|
||||
language = "en-US"
|
||||
transcript = recognizer.recognize_google(audio, language=language).strip()
|
||||
try:
|
||||
transcript = recognizer.recognize_faster_whisper(audio, language=language.split("-")[0]).strip()
|
||||
except ImportError:
|
||||
transcript = recognizer.recognize_google(audio, language=language).strip()
|
||||
return "[No speech detected]" if transcript == "" else transcript.strip()
|
||||
|
@@ -12,13 +12,13 @@ def to_string(value) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
if "name" in value:
|
||||
if "text" in value:
|
||||
return value["text"]
|
||||
elif "name" in value:
|
||||
return ""
|
||||
elif "bucket_id" in value:
|
||||
bucket_dir = Path(get_bucket_dir(value.get("bucket_id")))
|
||||
return "".join(read_bucket(bucket_dir))
|
||||
elif value.get("type") == "text":
|
||||
return value.get("text")
|
||||
return ""
|
||||
elif isinstance(value, list):
|
||||
return "".join([to_string(v) for v in value if v.get("type", "text") == "text"])
|
||||
|
@@ -69,6 +69,11 @@ try:
|
||||
has_beautifulsoup4 = True
|
||||
except ImportError:
|
||||
has_beautifulsoup4 = False
|
||||
try:
|
||||
from markitdown import MarkItDown
|
||||
has_markitdown = True
|
||||
except ImportError:
|
||||
has_markitdown = False
|
||||
|
||||
from .web_search import scrape_text
|
||||
from ..cookies import get_cookies_dir
|
||||
@@ -169,8 +174,10 @@ def get_filenames(bucket_dir: Path):
|
||||
return [filename.strip() for filename in f.readlines()]
|
||||
return []
|
||||
|
||||
def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = False) -> Iterator[str]:
|
||||
def stream_read_files(bucket_dir: Path, filenames: list[str], delete_files: bool = False) -> Iterator[str]:
|
||||
for filename in filenames:
|
||||
if filename.startswith(DOWNLOADS_FILE):
|
||||
continue
|
||||
file_path: Path = bucket_dir / filename
|
||||
if not file_path.exists() or file_path.lstat().st_size <= 0:
|
||||
continue
|
||||
@@ -192,7 +199,7 @@ def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = Fa
|
||||
else:
|
||||
os.unlink(filepath)
|
||||
continue
|
||||
yield f"```{filename.replace('.md', '')}\n"
|
||||
yield f"```{filename}\n"
|
||||
if has_pypdf2 and filename.endswith(".pdf"):
|
||||
try:
|
||||
reader = PyPDF2.PdfReader(file_path)
|
||||
@@ -339,6 +346,13 @@ def split_file_by_size_and_newline(input_filename, output_dir, chunk_size_bytes=
|
||||
with open(output_filename, 'w', encoding='utf-8') as outfile:
|
||||
outfile.write(current_chunk)
|
||||
|
||||
def get_filename_from_url(url: str) -> str:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
sha256_hash = hashlib.sha256(url.encode()).digest()
|
||||
base32_encoded = base64.b32encode(sha256_hash).decode()
|
||||
url_hash = base32_encoded[:24].lower()
|
||||
return f"{parsed_url.netloc}+{parsed_url.path[1:].replace('/', '_')}+{url_hash}.md"
|
||||
|
||||
async def get_filename(response: ClientResponse) -> str:
|
||||
"""
|
||||
Attempts to extract a filename from an aiohttp response. Prioritizes Content-Disposition, then URL.
|
||||
@@ -364,11 +378,7 @@ async def get_filename(response: ClientResponse) -> str:
|
||||
if content_type and url:
|
||||
extension = await get_file_extension(response)
|
||||
if extension:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
sha256_hash = hashlib.sha256(url.encode()).digest()
|
||||
base32_encoded = base64.b32encode(sha256_hash).decode()
|
||||
url_hash = base32_encoded[:24].lower()
|
||||
return f"{parsed_url.netloc}+{parsed_url.path[1:].replace('/', '_')}+{url_hash}{extension}"
|
||||
return get_filename_from_url(url)
|
||||
|
||||
return None
|
||||
|
||||
@@ -442,17 +452,29 @@ async def download_urls(
|
||||
) -> AsyncIterator[str]:
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
md = MarkItDown()
|
||||
async with ClientSession(
|
||||
connector=get_connector(proxy=proxy),
|
||||
timeout=ClientTimeout(timeout)
|
||||
) as session:
|
||||
async def download_url(url: str, max_depth: int) -> str:
|
||||
text_content = None
|
||||
if has_markitdown:
|
||||
try:
|
||||
text_content = md.convert(url).text_content
|
||||
if text_content:
|
||||
filename = get_filename_from_url(url)
|
||||
target = bucket_dir / filename
|
||||
target.write_text(text_content, errors="replace")
|
||||
return filename
|
||||
except Exception as e:
|
||||
debug.log(f"Failed to convert URL to text: {type(e).__name__}: {e}")
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status()
|
||||
filename = await get_filename(response)
|
||||
if not filename:
|
||||
print(f"Failed to get filename for {url}")
|
||||
debug.log(f"Failed to get filename for {url}")
|
||||
return None
|
||||
if not is_allowed_extension(filename) and not supports_filename(filename) or filename == DOWNLOADS_FILE:
|
||||
return None
|
||||
|
@@ -6,7 +6,7 @@ from typing import Iterator, Union
|
||||
from pathlib import Path
|
||||
|
||||
from ..typing import Messages
|
||||
from ..image import is_data_an_media, is_data_an_audio, to_input_audio, to_data_uri
|
||||
from ..image import is_data_an_media, to_input_audio, is_valid_media, is_valid_audio, to_data_uri
|
||||
from .files import get_bucket_dir, read_bucket
|
||||
|
||||
def render_media(bucket_id: str, name: str, url: str, as_path: bool = False, as_base64: bool = False) -> Union[str, Path]:
|
||||
@@ -37,7 +37,7 @@ def render_part(part: dict) -> dict:
|
||||
"type": "text",
|
||||
"text": "".join(read_bucket(bucket_dir))
|
||||
}
|
||||
if is_data_an_audio(filename=filename):
|
||||
if is_valid_audio(filename=filename):
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
@@ -45,10 +45,11 @@ def render_part(part: dict) -> dict:
|
||||
"format": os.path.splitext(filename)[1][1:]
|
||||
}
|
||||
}
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": render_media(**part)}
|
||||
}
|
||||
if is_valid_media(filename=filename):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": render_media(**part)}
|
||||
}
|
||||
|
||||
def merge_media(media: list, messages: list) -> Iterator:
|
||||
buffer = []
|
||||
@@ -57,7 +58,7 @@ def merge_media(media: list, messages: list) -> Iterator:
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if "type" not in part and "name" in part:
|
||||
if "type" not in part and "name" in part and "text" not in part:
|
||||
path = render_media(**part, as_path=True)
|
||||
buffer.append((path, os.path.basename(path)))
|
||||
elif part.get("type") == "image_url":
|
||||
@@ -71,9 +72,10 @@ def merge_media(media: list, messages: list) -> Iterator:
|
||||
def render_messages(messages: Messages, media: list = None) -> Iterator:
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message["content"], list):
|
||||
parts = [render_part(part) for part in message["content"] if part]
|
||||
yield {
|
||||
**message,
|
||||
"content": [render_part(part) for part in message["content"] if part]
|
||||
"content": [part for part in parts if part]
|
||||
}
|
||||
else:
|
||||
if media is not None and idx == len(messages) - 1:
|
||||
@@ -84,11 +86,12 @@ def render_messages(messages: Messages, media: list = None) -> Iterator:
|
||||
"type": "input_audio",
|
||||
"input_audio": to_input_audio(media_data, filename)
|
||||
}
|
||||
if is_data_an_audio(media_data, filename) else {
|
||||
if is_valid_audio(media_data, filename) else {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(media_data)}
|
||||
}
|
||||
for media_data, filename in media
|
||||
if is_valid_media(media_data, filename)
|
||||
] + ([{"type": "text", "text": message["content"]}] if isinstance(message["content"], str) else message["content"])
|
||||
}
|
||||
else:
|
||||
|
Reference in New Issue
Block a user