mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-05 08:16:58 +08:00
87 lines
3.7 KiB
Python
87 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from ...providers.types import Messages
|
|
from ...typing import ImagesType
|
|
from ...requests import StreamSession, raise_for_status
|
|
from ...errors import ModelNotSupportedError
|
|
from ...providers.helper import get_last_user_message
|
|
from ..template.OpenaiTemplate import OpenaiTemplate
|
|
from .models import model_aliases, vision_models, default_vision_model
|
|
from .HuggingChat import HuggingChat
|
|
from ... import debug
|
|
|
|
class HuggingFaceAPI(OpenaiTemplate):
|
|
label = "HuggingFace (Inference API)"
|
|
parent = "HuggingFace"
|
|
url = "https://api-inference.huggingface.com"
|
|
api_base = "https://api-inference.huggingface.co/v1"
|
|
working = True
|
|
needs_auth = True
|
|
|
|
default_model = default_vision_model
|
|
default_vision_model = default_vision_model
|
|
vision_models = vision_models
|
|
model_aliases = model_aliases
|
|
|
|
pipeline_tags: dict[str, str] = {}
|
|
|
|
@classmethod
|
|
def get_models(cls, **kwargs):
|
|
if not cls.models:
|
|
HuggingChat.get_models()
|
|
cls.models = HuggingChat.text_models.copy()
|
|
for model in cls.vision_models:
|
|
if model not in cls.models:
|
|
cls.models.append(model)
|
|
return cls.models
|
|
|
|
@classmethod
|
|
async def get_pipline_tag(cls, model: str, api_key: str = None):
|
|
if model in cls.pipeline_tags:
|
|
return cls.pipeline_tags[model]
|
|
async with StreamSession(
|
|
timeout=30,
|
|
headers=cls.get_headers(False, api_key),
|
|
) as session:
|
|
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
|
|
await raise_for_status(response)
|
|
model_data = await response.json()
|
|
cls.pipeline_tags[model] = model_data.get("pipeline_tag")
|
|
return cls.pipeline_tags[model]
|
|
|
|
@classmethod
|
|
async def create_async_generator(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
api_base: str = None,
|
|
api_key: str = None,
|
|
max_tokens: int = 2048,
|
|
max_inputs_lenght: int = 10000,
|
|
images: ImagesType = None,
|
|
**kwargs
|
|
):
|
|
if model in cls.model_aliases:
|
|
model = cls.model_aliases[model]
|
|
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
|
|
pipeline_tag = await cls.get_pipline_tag(model, api_key)
|
|
if pipeline_tag not in ("text-generation", "image-text-to-text"):
|
|
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
|
elif images and pipeline_tag != "image-text-to-text":
|
|
raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
|
start = calculate_lenght(messages)
|
|
if start > max_inputs_lenght:
|
|
if len(messages) > 6:
|
|
messages = messages[:3] + messages[-3:]
|
|
if calculate_lenght(messages) > max_inputs_lenght:
|
|
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
|
|
if len(messages) > 2:
|
|
messages = [m for m in messages if m["role"] == "system"] + last_user_message
|
|
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
|
messages = last_user_message
|
|
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
|
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
|
|
yield chunk
|
|
|
|
def calculate_lenght(messages: Messages) -> int:
|
|
return sum([len(message["content"]) + 16 for message in messages]) |