Show only free providers by default

This commit is contained in:
hlohaus
2025-02-21 06:52:04 +01:00
parent e53483d85b
commit 470b795418
14 changed files with 84 additions and 59 deletions

View File

@@ -4,6 +4,7 @@ from ...providers.types import Messages
from ...typing import ImagesType
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model
from .HuggingChat import HuggingChat
@@ -22,7 +23,7 @@ class HuggingFaceAPI(OpenaiTemplate):
vision_models = vision_models
model_aliases = model_aliases
pipeline_tag: dict[str, str] = {}
pipeline_tags: dict[str, str] = {}
@classmethod
def get_models(cls, **kwargs):
@@ -36,8 +37,8 @@ class HuggingFaceAPI(OpenaiTemplate):
@classmethod
async def get_pipline_tag(cls, model: str, api_key: str = None):
if model in cls.pipeline_tag:
return cls.pipeline_tag[model]
if model in cls.pipeline_tags:
return cls.pipeline_tags[model]
async with StreamSession(
timeout=30,
headers=cls.get_headers(False, api_key),
@@ -45,8 +46,8 @@ class HuggingFaceAPI(OpenaiTemplate):
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
await raise_for_status(response)
model_data = await response.json()
cls.pipeline_tag[model] = model_data.get("pipeline_tag")
return cls.pipeline_tag[model]
cls.pipeline_tags[model] = model_data.get("pipeline_tag")
return cls.pipeline_tags[model]
@classmethod
async def create_async_generator(
@@ -73,10 +74,11 @@ class HuggingFaceAPI(OpenaiTemplate):
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + messages[-1:]
messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = [messages[-1]]
messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
yield chunk