fix: improve error handling and add type checks in various providers and API

- Updated error handling in g4f/Provider/DDG.py to raise ResponseError instead of yield error strings
- Replaced yield statements with raises in g4f/Provider/DDG.py for HTTP and response errors
- Added response raising in g4f/Provider/DeepInfraChat.py for image upload responses
- Included model alias validation and error raising in g4f/Provider/hf/HuggingFaceMedia.py
- Corrected model alias dictionary key in g4f/Provider/hf_space/StabilityAI_SD35Large.py
- Ensured referrer parameter default value in g4f/Provider/PollinationsImage.py
- Removed duplicate imports and adjusted get_models method in g4f/Provider/har/__init__.py
- Modified g4f/gui/server/api.py to remove unused conversation parameter in _create_response_stream
- Fixed logic to handle single exception in g4f/providers/retry_provider.py
- Added missing import of JsonConversation in g4f/providers/retry_provider.py
- Corrected stream_read_files to replace extension in return string in g4f/tools/files.py
This commit is contained in:
hlohaus
2025-05-17 10:02:13 +02:00
parent 33b3e1d431
commit 3775c1e06d
17 changed files with 239 additions and 131 deletions

View File

@@ -6,15 +6,14 @@ import time
import random
import hashlib
import asyncio
from datetime import datetime
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from ..errors import ResponseError
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt, get_last_user_message
from ..providers.response import FinishReason, JsonConversation
class Conversation(JsonConversation):
message_history: Messages = []
@@ -22,7 +21,6 @@ class Conversation(JsonConversation):
self.model = model
self.message_history = []
class DDG(AsyncGeneratorProvider, ProviderModelMixin):
label = "DuckDuckGo AI Chat"
url = "https://duckduckgo.com"
@@ -161,8 +159,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
if response.status != 200:
error_text = await response.text()
if "ERR_BN_LIMIT" in error_text:
yield "Blocked by DuckDuckGo: Bot limit exceeded (ERR_BN_LIMIT)."
return
raise ResponseError("Blocked by DuckDuckGo: Bot limit exceeded (ERR_BN_LIMIT).")
if "ERR_INVALID_VQD" in error_text and retry_count < 3:
await asyncio.sleep(random.uniform(2.5, 5.5))
async for chunk in cls.create_async_generator(
@@ -170,9 +167,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
):
yield chunk
return
yield f"Error: HTTP {response.status} - {error_text}"
return
raise ResponseError(f"HTTP {response.status} - {error_text}")
full_message = ""
async for line in response.content:
line_text = line.decode("utf-8").strip()
@@ -188,8 +183,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
try:
msg = json.loads(payload)
if msg.get("action") == "error":
yield f"Error: {msg.get('type', 'unknown')}"
break
raise ResponseError(f"Error: {msg.get('type', 'unknown')}")
if "message" in msg:
content = msg["message"]
yield content
@@ -204,4 +198,4 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
):
yield chunk
else:
yield f"Error: {str(e)}"
raise ResponseError(f"Error: {str(e)}")

View File

@@ -45,10 +45,10 @@ class DeepInfraChat(OpenaiTemplate):
] + vision_models
model_aliases = {
"deepseek-prover-v2-671b": "deepseek-ai/DeepSeek-Prover-V2-671B",
"qwen-3-235b": "Qwen/Qwen3-235B-A22B",
"qwen-3-30b": "Qwen/Qwen3-30B-A3B",
"qwen-3-32b": "Qwen/Qwen3-32B",
"qwen-3-14b": "Qwen/Qwen3-14B",
"qwen3-235b": "Qwen/Qwen3-235B-A22B",
"qwen3-30b": "Qwen/Qwen3-30B-A3B",
"qwen3-32b": "Qwen/Qwen3-32B",
"qwen3-14b": "Qwen/Qwen3-14B",
"llama-4-maverick": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"llama-4-maverick-17b": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"llama-4-scout": "meta-llama/Llama-4-Scout-17B-16E-Instruct",

View File

@@ -254,8 +254,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance=enhance,
safe=safe,
n=n,
referrer=referrer,
extra_body=extra_body
referrer=referrer
):
yield chunk
else:
@@ -305,10 +304,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool,
safe: bool,
n: int,
referrer: str,
extra_body: dict
referrer: str
) -> AsyncResult:
extra_body = use_aspect_ratio({
params = use_aspect_ratio({
"width": width,
"height": height,
"model": model,
@@ -316,9 +314,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"private": str(private).lower(),
"enhance": str(enhance).lower(),
"safe": str(safe).lower(),
**extra_body
}, aspect_ratio)
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in extra_body.items() if v is not None)
query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None)
prompt = quote_plus(prompt)[:2048-len(cls.image_api_endpoint)-len(query)-8]
url = f"{cls.image_api_endpoint}prompt/{prompt}?{query}"
def get_image_url(i: int, seed: Optional[int] = None):

View File

@@ -37,6 +37,7 @@ class PollinationsImage(PollinationsAI):
model: str,
messages: Messages,
proxy: str = None,
referrer: str = "https://gpt4free.github.io/",
prompt: str = None,
aspect_ratio: str = "1:1",
width: int = None,
@@ -65,6 +66,7 @@ class PollinationsImage(PollinationsAI):
private=private,
enhance=enhance,
safe=safe,
n=n
n=n,
referrer=referrer
):
yield chunk

View File

@@ -5,8 +5,11 @@ import json
import uuid
from urllib.parse import urlparse
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...typing import AsyncResult, Messages, MediaListType
from ...requests import StreamSession, StreamResponse, FormData, raise_for_status
from ...providers.response import JsonConversation
from ...tools.media import merge_media
from ...image import to_bytes, is_accepted_format
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_last_user_message
from ..openai.har_file import get_headers
@@ -14,8 +17,54 @@ from ..openai.har_file import get_headers
class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
label = "LM Arena"
url = "https://lmarena.ai"
api_endpoint = "/queue/join?"
working = True
default_model = "chatgpt-4o-latest-20250326"
model_aliases = {
"claude-3.7-sonnet": "claude-3-7-sonnet-20250219",
}
vision_models = [
"o3-2025-04-16",
"o4-mini-2025-04-16",
"gpt-4.1-2025-04-14",
"gemini-2.5-pro-exp-03-25",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking-32k",
"llama-4-maverick-17b-128e-instruct",
"gpt-4.1-mini-2025-04-14",
"gpt-4.1-nano-2025-04-14",
"gemini-2.0-flash-thinking-exp-01-21",
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-preview-02-05",
"claude-3-5-sonnet-20241022",
"gpt-4o-mini-2024-07-18",
"gpt-4o-2024-11-20",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"claude-3-5-sonnet-20240620",
"doubao-1.5-vision-pro-32k-250115",
"amazon-nova-pro-v1.0",
"amazon-nova-lite-v1.0",
"qwen2.5-vl-32b-instruct",
"qwen2.5-vl-72b-instruct",
"gemini-1.5-pro-002",
"gemini-1.5-flash-002",
"gemini-1.5-flash-8b-001",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001",
"hunyuan-standard-vision-2024-12-31",
"pixtral-large-2411",
"step-1o-vision-32k-highres",
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"qwen-vl-max-1119",
"qwen-vl-max-0809",
"reka-core-20240904",
"reka-flash-20240904",
"c4ai-aya-vision-32b",
"pixtral-12b-2409"
]
@classmethod
def get_models(cls):
@@ -33,35 +82,58 @@ class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
break
return cls.models
@classmethod
def _build_second_payloads(cls, model_id: str, session_hash: str, text: str, max_tokens: int, temperature: float, top_p: float):
first_payload = {
"data":[None,model_id,text,{
"text_models":[model_id],
"all_text_models":[model_id],
"vision_models":[],
"image_gen_models":[],
"all_image_gen_models":[],
"search_models":[],
"all_search_models":[],
"models":[model_id],
"all_models":[model_id],
"arena_type":"text-arena"}],
"event_data": None,
"fn_index": 122,
"trigger_id": 157,
"session_hash": session_hash
}
second_payload = {
"data": [],
"event_data": None,
"fn_index": 123,
"trigger_id": 157,
"session_hash": session_hash
}
third_payload = {
"data": [None, temperature, top_p, max_tokens],
"event_data": None,
"fn_index": 124,
"trigger_id": 157,
"session_hash": session_hash
}
return first_payload, second_payload, third_payload
@classmethod
async def create_async_generator(
cls, model: str, messages: Messages,
cls,
model: str,
messages: Messages,
proxy: str = None,
media: MediaListType = None,
max_tokens: int = 2048,
temperature: float = 0.7,
top_p: float = 1,
conversation: JsonConversation = None,
**kwargs
) -> AsyncResult:
if model in cls.model_aliases:
model = cls.model_aliases[model]
session_hash = str(uuid.uuid4()).replace("-", "")
prompt = get_last_user_message(messages)
for domain, harFile in read_har_files():
async with StreamSession(impersonate="chrome") as session:
for v in harFile['log']['entries']:
request_url = v['request']['url']
if domain not in request_url or "." in urlparse(request_url).path or "heartbeat" in request_url:
continue
postData = None
if "postData" in v['request']:
postData = v['request']['postData']['text']
postData = postData.replace('"hello"', json.dumps(prompt))
postData = postData.replace("__SESSION__", session_hash)
if model:
postData = postData.replace("__MODEL__", model)
request_url = request_url.replace("__SESSION__", session_hash)
method = v['request']['method'].lower()
async with getattr(session, method)(request_url, data=postData, headers=get_headers(v), proxy=proxy) as response:
await raise_for_status(response)
async def read_response(response: StreamResponse):
returned_data = ""
async for line in response.iter_lines():
if not line.startswith(b"data: "):
@@ -78,6 +150,73 @@ class HarProvider(AsyncGeneratorProvider, ProviderModelMixin):
continue
returned_data += new_content
yield new_content
if model in cls.model_aliases:
model = cls.model_aliases[model]
prompt = get_last_user_message(messages)
async with StreamSession(impersonate="chrome") as session:
if conversation is None:
conversation = JsonConversation(session_hash=str(uuid.uuid4()).replace("-", ""))
media = list(merge_media(media, messages))
if media:
data = FormData()
for i in range(len(media)):
media[i] = (to_bytes(media[i][0]), media[i][1])
for image, image_name in media:
data.add_field(f"files", image, filename=image_name)
async with session.post(f"{cls.url}/upload", params={"upload_id": conversation.session_hash}, data=data) as response:
await raise_for_status(response)
image_files = await response.json()
media = [{
"path": image_file,
"url": f"{cls.url}/file={image_file}",
"orig_name": media[i][1],
"size": len(media[i][0]),
"mime_type": is_accepted_format(media[i][0]),
"meta": {
"_type": "gradio.FileData"
}
} for i, image_file in enumerate(image_files)]
for domain, harFile in read_har_files():
for v in harFile['log']['entries']:
request_url = v['request']['url']
if domain not in request_url or "." in urlparse(request_url).path or "heartbeat" in request_url:
continue
postData = None
if "postData" in v['request']:
postData = v['request']['postData']['text']
postData = postData.replace('"hello"', json.dumps(prompt))
postData = postData.replace('[null,0.7,1,2048]', json.dumps([None, temperature, top_p, max_tokens]))
postData = postData.replace('"files":[]', f'"files":{json.dumps(media)}')
postData = postData.replace("__SESSION__", conversation.session_hash)
if model:
postData = postData.replace("__MODEL__", model)
request_url = request_url.replace("__SESSION__", conversation.session_hash)
method = v['request']['method'].lower()
async with getattr(session, method)(request_url, data=postData, headers=get_headers(v), proxy=proxy) as response:
await raise_for_status(response)
async for chunk in read_response(response):
yield chunk
yield conversation
else:
first_payload, second_payload, third_payload = cls._build_second_payloads(model, conversation.session_hash, prompt, max_tokens, temperature, top_p)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
# POST 1
async with session.post(f"{cls.url}{cls.api_endpoint}", json=first_payload, proxy=proxy, headers=headers) as response:
await raise_for_status(response)
# POST 2
async with session.post(f"{cls.url}{cls.api_endpoint}", json=second_payload, proxy=proxy, headers=headers) as response:
await raise_for_status(response)
# POST 3
async with session.post(f"{cls.url}{cls.api_endpoint}", json=third_payload, proxy=proxy, headers=headers) as response:
await raise_for_status(response)
stream_url = f"{cls.url}/queue/data?session_hash={conversation.session_hash}"
async with session.get(stream_url, headers={"Accept": "text/event-stream"}, proxy=proxy) as response:
await raise_for_status(response)
async for chunk in read_response(response):
yield chunk
def read_har_files():
for root, _, files in os.walk(os.path.dirname(__file__)):

View File

@@ -129,6 +129,8 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
if key in ["replicate", "together", "hf-inference"]
}
provider_mapping = {**new_mapping, **provider_mapping}
if not provider_mapping:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
async def generate(extra_body: dict, aspect_ratio: str = None):
last_response = None
for provider_key, provider in provider_mapping.items():

View File

@@ -39,7 +39,7 @@ model_aliases = {
"gemma-2-27b": "google/gemma-2-27b-it",
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
"qvq-72b": "Qwen/QVQ-72B-Preview",
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
"stable-diffusion-3.5-large": "stabilityai/stable-diffusion-3.5-large",
}
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",

View File

@@ -121,7 +121,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
}
} for i, image_file in enumerate(image_files)]
async with cls.run(method, session, prompt, conversation, None if media is None else media.pop(), seed) as response:
async with cls.run(method, session, prompt, conversation, None if not media else media.pop(), seed) as response:
await raise_for_status(response)
async with cls.run("get", session, prompt, conversation, None, seed) as response:

View File

@@ -32,15 +32,6 @@ class Qwen_Qwen_3(AsyncGeneratorProvider, ProviderModelMixin):
"qwen3-1.7b",
"qwen3-0.6b",
}
model_aliases = {
"qwen-3-235b": default_model,
"qwen-3-32b": "qwen3-32b",
"qwen-3-30b": "qwen3-30b-a3b",
"qwen-3-14b": "qwen3-14b",
"qwen-3-4b": "qwen3-4b",
"qwen-3-1.7b": "qwen3-1.7b",
"qwen-3-0.6b": "qwen3-0.6b",
}
@classmethod
async def create_async_generator(

View File

@@ -19,7 +19,7 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin):
default_model = 'stabilityai-stable-diffusion-3-5-large'
default_image_model = default_model
model_aliases = {"sd-3.5": default_model}
model_aliases = {"stable-diffusion-3.5-large": default_model}
image_models = list(model_aliases.keys())
models = image_models

View File

@@ -143,11 +143,10 @@ class Api:
"messages": messages,
"stream": True,
"ignore_stream": True,
"return_conversation": True,
**kwargs
}
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_media: bool = True) -> Iterator:
def _create_response_stream(self, kwargs: dict, provider: str, download_media: bool = True) -> Iterator:
def decorated_log(text: str, file = None):
debug.logs.append(text)
if debug.logging:
@@ -178,12 +177,9 @@ class Api:
for chunk in result:
if isinstance(chunk, ProviderInfo):
yield self.handle_provider(chunk, model)
provider = chunk.name
elif isinstance(chunk, JsonConversation):
if provider is not None:
if hasattr(provider, "__name__"):
provider = provider.__name__
yield self._format_json("conversation", {
yield self._format_json("conversation", chunk.get_dict() if provider == "AnyProvider" else {
provider: chunk.get_dict()
})
elif isinstance(chunk, Exception):

View File

@@ -128,7 +128,6 @@ class Backend_Api(Api):
return self.app.response_class(
self._create_response_stream(
kwargs,
json_data.get("conversation_id"),
json_data.get("provider"),
json_data.get("download_media", True),
),

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import os
import requests
from datetime import datetime
from flask import send_from_directory, redirect, request
from ...image.copy_images import secure_filename, get_media_dir, ensure_media_dir
from flask import send_from_directory, redirect
from ...image.copy_images import secure_filename
from ...cookies import get_cookies_dir
from ...errors import VersionNotFoundError
from ... import version
@@ -15,27 +16,26 @@ def redirect_home():
return redirect('/chat')
def render(filename = "chat", add_origion = True):
if request.args.get("live"):
add_origion = False
if os.path.exists(DIST_DIR):
add_origion = False
path = os.path.abspath(os.path.join(os.path.dirname(DIST_DIR), (filename + ("" if "." in filename else ".html"))))
print( f"Debug mode: {path}")
return send_from_directory(os.path.dirname(path), os.path.basename(path))
try:
latest_version = version.utils.latest_version
except VersionNotFoundError:
latest_version = version.utils.current_version
today = datetime.today().strftime('%Y-%m-%d')
cache_file = os.path.join(get_media_dir(), f"{today}.{secure_filename(filename)}.{version.utils.current_version}-{latest_version}{'.live' if add_origion else ''}.html")
cache_dir = os.path.join(get_cookies_dir(), ".gui_cache")
cache_file = os.path.join(cache_dir, f"{today}.{secure_filename(filename)}.{version.utils.current_version}-{latest_version}{'.live' if add_origion else ''}.html")
if not os.path.exists(cache_file):
ensure_media_dir()
os.makedirs(cache_dir, exist_ok=True)
html = requests.get(f"{GPT4FREE_URL}/{filename}.html").text
if add_origion:
html = html.replace("../dist/", f"dist/")
html = html.replace("\"dist/", f"\"{GPT4FREE_URL}/dist/")
with open(cache_file, 'w', encoding='utf-8') as f:
f.write(html)
return send_from_directory(os.path.abspath(get_media_dir()), os.path.basename(cache_file))
return send_from_directory(os.path.abspath(cache_dir), os.path.basename(cache_file))
class Website:
def __init__(self, app) -> None:

View File

@@ -949,7 +949,7 @@ sdxl_turbo = ImageModel(
)
sd_3_5 = ImageModel(
name = 'sd-3.5',
name = 'stable-diffusion-3.5-large',
base_provider = 'Stability AI',
best_provider = HuggingSpace
)

View File

@@ -5,7 +5,6 @@ from ..errors import ModelNotFoundError
from ..image import is_data_an_audio
from ..providers.retry_provider import IterListProvider
from ..providers.types import ProviderType
from ..providers.response import JsonConversation, ProviderInfo
from ..Provider.needs_auth import OpenaiChat, CopilotAccount
from ..Provider.hf_space import HuggingSpace
from ..Provider import Cloudflare, Gemini, Grok, DeepSeekAPI, PerplexityLabs, LambdaChat, PollinationsAI, FreeRouter
@@ -14,9 +13,9 @@ from ..Provider import HarProvider, DDG, HuggingFace, HuggingFaceMedia
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .. import Provider
from .. import models
from .. import debug
LABELS = {
"default": "Default",
"openai": "OpenAI: ChatGPT",
"llama": "Meta: LLaMA",
"deepseek": "DeepSeek",
@@ -26,6 +25,7 @@ LABELS = {
"claude": "Anthropic: Claude",
"command": "Cohere: Command",
"phi": "Microsoft: Phi",
"mistral": "Mistral",
"PollinationsAI": "Pollinations AI",
"perplexity": "Perplexity Labs",
"video": "Video Generation",
@@ -45,7 +45,12 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
for model in unsorted_models:
added = False
for group in groups:
if group == "qwen":
if group == "mistral":
if model.split("-")[0] in ("mistral", "mixtral", "mistralai", "pixtral", "ministral", "codestral"):
groups[group].append(model)
added = True
break
elif group == "qwen":
if model.startswith("qwen") or model.startswith("qwq") or model.startswith("qvq"):
groups[group].append(model)
added = True
@@ -198,8 +203,6 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
stream: bool = True,
media: MediaListType = None,
ignored: list[str] = [],
conversation: JsonConversation = None,
api_key: str = None,
**kwargs
) -> AsyncResult:
cls.get_models(ignored=ignored)
@@ -246,33 +249,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
providers = list({provider.__name__: provider for provider in providers}.values())
if len(providers) == 0:
raise ModelNotFoundError(f"Model {model} not found in any provider.")
if len(providers) == 1:
provider = providers[0]
if conversation is not None:
child_conversation = getattr(conversation, provider.__name__, None)
if child_conversation is not None:
kwargs["conversation"] = JsonConversation(**child_conversation)
debug.log(f"Using {provider.__name__} provider" + f" and {model} model" if model else "")
yield ProviderInfo(**provider.get_dict(), model=model)
if provider in (HuggingFace, HuggingFaceMedia):
kwargs["api_key"] = api_key
async for chunk in provider.get_async_create_function()(
model,
messages,
stream=stream,
media=media,
**kwargs
):
if isinstance(chunk, JsonConversation):
if conversation is None:
conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
yield conversation
else:
yield chunk
return
kwargs["api_key"] = api_key
async for chunk in IterListProvider(providers).get_async_create_function()(
async for chunk in IterListProvider(providers).create_async_generator(
model,
messages,
stream=stream,

View File

@@ -4,7 +4,7 @@ import random
from ..typing import Type, List, CreateResult, Messages, AsyncResult
from .types import BaseProvider, BaseRetryProvider, ProviderType
from .response import MediaResponse, AudioResponse, ProviderInfo, Reasoning
from .response import MediaResponse, AudioResponse, ProviderInfo, Reasoning, JsonConversation
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
@@ -38,7 +38,6 @@ class IterListProvider(BaseRetryProvider):
stream: bool = False,
ignore_stream: bool = False,
ignored: list[str] = [],
api_key: str = None,
**kwargs,
) -> CreateResult:
"""
@@ -59,8 +58,6 @@ class IterListProvider(BaseRetryProvider):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
kwargs["api_key"] = api_key
try:
response = provider.get_create_function()(model, messages, stream=stream, **kwargs)
for chunk in response:
@@ -86,6 +83,8 @@ class IterListProvider(BaseRetryProvider):
stream: bool = True,
ignore_stream: bool = False,
ignored: list[str] = [],
api_key: str = None,
conversation: JsonConversation = None,
**kwargs
) -> AsyncResult:
exceptions = {}
@@ -93,13 +92,23 @@ class IterListProvider(BaseRetryProvider):
for provider in self.get_providers(stream and not ignore_stream, ignored):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
debug.log(f"Using {provider.__name__} provider and {model} model")
yield ProviderInfo(**provider.get_dict(), model=model if model else getattr(provider, "default_model"))
extra_body = kwargs.copy()
if self.add_api_key or provider.__name__ in ["HuggingFace", "HuggingFaceMedia"]:
extra_body["api_key"] = api_key
if conversation is not None and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
try:
response = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
response = provider.get_async_create_function()(model, messages, stream=stream, **extra_body)
if hasattr(response, "__aiter__"):
async for chunk in response:
if chunk:
if isinstance(chunk, JsonConversation):
if conversation is None:
conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
yield conversation
elif chunk:
yield chunk
if is_content(chunk):
started = True
@@ -246,6 +255,8 @@ def raise_exceptions(exceptions: dict) -> None:
for provider_name, e in exceptions.items():
if isinstance(e, (MissingAuthError, NoValidHarFileError)):
raise e
if len(exceptions) == 1:
raise list(exceptions.values())[0]
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {type(exception).__name__}: {exception}" for p, exception in exceptions.items()
])) from list(exceptions.values())[0]

View File

@@ -189,7 +189,7 @@ def stream_read_files(bucket_dir: Path, filenames: list, delete_files: bool = Fa
else:
os.unlink(filepath)
continue
yield f"```{filename}\n"
yield f"```{filename.replace('.md', '')}\n"
if has_pypdf2 and filename.endswith(".pdf"):
try:
reader = PyPDF2.PdfReader(file_path)