mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-20 15:06:00 +08:00
Show only free providers by default
This commit is contained in:
@@ -4,6 +4,7 @@ from ...providers.types import Messages
|
||||
from ...typing import ImagesType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import ModelNotSupportedError
|
||||
from ...providers.helper import get_last_user_message
|
||||
from ..template.OpenaiTemplate import OpenaiTemplate
|
||||
from .models import model_aliases, vision_models, default_vision_model
|
||||
from .HuggingChat import HuggingChat
|
||||
@@ -22,7 +23,7 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
vision_models = vision_models
|
||||
model_aliases = model_aliases
|
||||
|
||||
pipeline_tag: dict[str, str] = {}
|
||||
pipeline_tags: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs):
|
||||
@@ -36,8 +37,8 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
|
||||
@classmethod
|
||||
async def get_pipline_tag(cls, model: str, api_key: str = None):
|
||||
if model in cls.pipeline_tag:
|
||||
return cls.pipeline_tag[model]
|
||||
if model in cls.pipeline_tags:
|
||||
return cls.pipeline_tags[model]
|
||||
async with StreamSession(
|
||||
timeout=30,
|
||||
headers=cls.get_headers(False, api_key),
|
||||
@@ -45,8 +46,8 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
|
||||
await raise_for_status(response)
|
||||
model_data = await response.json()
|
||||
cls.pipeline_tag[model] = model_data.get("pipeline_tag")
|
||||
return cls.pipeline_tag[model]
|
||||
cls.pipeline_tags[model] = model_data.get("pipeline_tag")
|
||||
return cls.pipeline_tags[model]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@@ -73,10 +74,11 @@ class HuggingFaceAPI(OpenaiTemplate):
|
||||
if len(messages) > 6:
|
||||
messages = messages[:3] + messages[-3:]
|
||||
if calculate_lenght(messages) > max_inputs_lenght:
|
||||
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
|
||||
if len(messages) > 2:
|
||||
messages = [m for m in messages if m["role"] == "system"] + messages[-1:]
|
||||
messages = [m for m in messages if m["role"] == "system"] + last_user_message
|
||||
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
||||
messages = [messages[-1]]
|
||||
messages = last_user_message
|
||||
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
||||
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
|
||||
yield chunk
|
||||
|
Reference in New Issue
Block a user