mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-09-27 12:42:11 +08:00
Enhance LMArena provider to handle image responses and raise MissingRequirementsError for missing auth files
This commit is contained in:
@@ -22,7 +22,7 @@ except ImportError:
|
|||||||
|
|
||||||
from ...typing import AsyncResult, Messages, MediaListType
|
from ...typing import AsyncResult, Messages, MediaListType
|
||||||
from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||||
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError
|
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
|
||||||
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning
|
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning
|
||||||
from ...tools.media import merge_media
|
from ...tools.media import merge_media
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin,AuthFileMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin,AuthFileMixin
|
||||||
@@ -553,6 +553,9 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||||||
"provider": cls.__name__
|
"provider": cls.__name__
|
||||||
})
|
})
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
|
if response.headers.get("Content-Type", "").startswith("image/"):
|
||||||
|
yield ImageResponse(str(response.url), prompt)
|
||||||
|
else:
|
||||||
text, *args = response.text.split("\n" * 10 + "<!--", 1)
|
text, *args = response.text.split("\n" * 10 + "<!--", 1)
|
||||||
if args:
|
if args:
|
||||||
debug.log("Save args to cache file:", str(cache_file))
|
debug.log("Save args to cache file:", str(cache_file))
|
||||||
@@ -562,6 +565,8 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||||||
finally:
|
finally:
|
||||||
cls.looked = False
|
cls.looked = False
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
raise MissingRequirementsError("No auth file found and nodriver is not available.")
|
||||||
|
|
||||||
if not cls._models_loaded:
|
if not cls._models_loaded:
|
||||||
cls.get_models()
|
cls.get_models()
|
||||||
|
@@ -37,8 +37,8 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
|
|||||||
if message is None or is_html:
|
if message is None or is_html:
|
||||||
if response.status == 520:
|
if response.status == 520:
|
||||||
message = "Unknown error (Cloudflare)"
|
message = "Unknown error (Cloudflare)"
|
||||||
elif response.status in (429, 402):
|
if response.status in (429, 402):
|
||||||
message = "Rate limit"
|
raise RateLimitError(f"Response {response.status}: {message}")
|
||||||
if response.status == 401:
|
if response.status == 401:
|
||||||
raise MissingAuthError(f"Response {response.status}: {message}")
|
raise MissingAuthError(f"Response {response.status}: {message}")
|
||||||
if response.status == 403 and is_cloudflare(message):
|
if response.status == 403 and is_cloudflare(message):
|
||||||
@@ -64,8 +64,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
|||||||
if message is None or is_html:
|
if message is None or is_html:
|
||||||
if response.status_code == 520:
|
if response.status_code == 520:
|
||||||
message = "Unknown error (Cloudflare)"
|
message = "Unknown error (Cloudflare)"
|
||||||
elif response.status_code in (429, 402):
|
if response.status_code in (429, 402):
|
||||||
raise RateLimitError(f"Response {response.status_code}: Rate Limit")
|
raise RateLimitError(f"Response {response.status_code}: {message}")
|
||||||
if response.status_code == 401:
|
if response.status_code == 401:
|
||||||
raise MissingAuthError(f"Response {response.status_code}: {message}")
|
raise MissingAuthError(f"Response {response.status_code}: {message}")
|
||||||
if response.status_code == 403 and is_cloudflare(response.text):
|
if response.status_code == 403 and is_cloudflare(response.text):
|
||||||
|
Reference in New Issue
Block a user