Enhance LMArena provider to handle image responses and raise MissingRequirementsError for missing auth files

This commit is contained in:
hlohaus
2025-09-06 19:02:11 +02:00
parent 933fff985b
commit 2bb58a18be
2 changed files with 16 additions and 11 deletions

View File

@@ -22,7 +22,7 @@ except ImportError:
from ...typing import AsyncResult, Messages, MediaListType from ...typing import AsyncResult, Messages, MediaListType
from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning
from ...tools.media import merge_media from ...tools.media import merge_media
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin,AuthFileMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin,AuthFileMixin
@@ -553,6 +553,9 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
"provider": cls.__name__ "provider": cls.__name__
}) })
raise_for_status(response) raise_for_status(response)
if response.headers.get("Content-Type", "").startswith("image/"):
yield ImageResponse(str(response.url), prompt)
else:
text, *args = response.text.split("\n" * 10 + "<!--", 1) text, *args = response.text.split("\n" * 10 + "<!--", 1)
if args: if args:
debug.log("Save args to cache file:", str(cache_file)) debug.log("Save args to cache file:", str(cache_file))
@@ -562,6 +565,8 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
finally: finally:
cls.looked = False cls.looked = False
return return
else:
raise MissingRequirementsError("No auth file found and nodriver is not available.")
if not cls._models_loaded: if not cls._models_loaded:
cls.get_models() cls.get_models()

View File

@@ -37,8 +37,8 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
if message is None or is_html: if message is None or is_html:
if response.status == 520: if response.status == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"
elif response.status in (429, 402): if response.status in (429, 402):
message = "Rate limit" raise RateLimitError(f"Response {response.status}: {message}")
if response.status == 401: if response.status == 401:
raise MissingAuthError(f"Response {response.status}: {message}") raise MissingAuthError(f"Response {response.status}: {message}")
if response.status == 403 and is_cloudflare(message): if response.status == 403 and is_cloudflare(message):
@@ -64,8 +64,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
if message is None or is_html: if message is None or is_html:
if response.status_code == 520: if response.status_code == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"
elif response.status_code in (429, 402): if response.status_code in (429, 402):
raise RateLimitError(f"Response {response.status_code}: Rate Limit") raise RateLimitError(f"Response {response.status_code}: {message}")
if response.status_code == 401: if response.status_code == 401:
raise MissingAuthError(f"Response {response.status_code}: {message}") raise MissingAuthError(f"Response {response.status_code}: {message}")
if response.status_code == 403 and is_cloudflare(response.text): if response.status_code == 403 and is_cloudflare(response.text):