Enhance LMArena provider to handle image responses and raise MissingRequirementsError for missing auth files

This commit is contained in:
hlohaus
2025-09-06 19:02:11 +02:00
parent 933fff985b
commit 2bb58a18be
2 changed files with 16 additions and 11 deletions

View File

@@ -22,7 +22,7 @@ except ImportError:
from ...typing import AsyncResult, Messages, MediaListType
from ...requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError
from ...errors import ModelNotFoundError, CloudflareError, MissingAuthError, MissingRequirementsError
from ...providers.response import FinishReason, Usage, JsonConversation, ImageResponse, Reasoning
from ...tools.media import merge_media
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin,AuthFileMixin
@@ -553,15 +553,20 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
"provider": cls.__name__
})
raise_for_status(response)
text, *args = response.text.split("\n" * 10 + "<!--", 1)
if args:
debug.log("Save args to cache file:", str(cache_file))
with cache_file.open("w") as f:
f.write(args[0].strip())
yield text
if response.headers.get("Content-Type", "").startswith("image/"):
yield ImageResponse(str(response.url), prompt)
else:
text, *args = response.text.split("\n" * 10 + "<!--", 1)
if args:
debug.log("Save args to cache file:", str(cache_file))
with cache_file.open("w") as f:
f.write(args[0].strip())
yield text
finally:
cls.looked = False
return
else:
raise MissingRequirementsError("No auth file found and nodriver is not available.")
if not cls._models_loaded:
cls.get_models()

View File

@@ -37,8 +37,8 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
if message is None or is_html:
if response.status == 520:
message = "Unknown error (Cloudflare)"
elif response.status in (429, 402):
message = "Rate limit"
if response.status in (429, 402):
raise RateLimitError(f"Response {response.status}: {message}")
if response.status == 401:
raise MissingAuthError(f"Response {response.status}: {message}")
if response.status == 403 and is_cloudflare(message):
@@ -64,8 +64,8 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
if message is None or is_html:
if response.status_code == 520:
message = "Unknown error (Cloudflare)"
elif response.status_code in (429, 402):
raise RateLimitError(f"Response {response.status_code}: Rate Limit")
if response.status_code in (429, 402):
raise RateLimitError(f"Response {response.status_code}: {message}")
if response.status_code == 401:
raise MissingAuthError(f"Response {response.status_code}: {message}")
if response.status_code == 403 and is_cloudflare(response.text):