Files
gpt4free/g4f/Provider/hf/HuggingFaceAPI.py
hlohaus c97ba0c88e Add audio transcribing example and support
Add Grok Chat provider
Rename images parameter to media
Update demo homepage
2025-03-21 03:17:45 +01:00

110 lines
4.6 KiB
Python

from __future__ import annotations
import requests
from ...providers.types import Messages
from ...typing import MediaListType
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ...providers.response import ProviderInfo
from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model, llama_models, text_models
from ... import debug
class HuggingFaceAPI(OpenaiTemplate):
label = "HuggingFace (Inference API)"
parent = "HuggingFace"
url = "https://api-inference.huggingface.com"
api_base = "https://api-inference.huggingface.co/v1"
working = True
needs_auth = True
default_model = default_vision_model
default_vision_model = default_vision_model
vision_models = vision_models
model_aliases = model_aliases
fallback_models = text_models + vision_models
provider_mapping: dict[str, dict] = {}
@classmethod
def get_model(cls, model: str, **kwargs) -> str:
try:
return super().get_model(model, **kwargs)
except ModelNotSupportedError:
return model
@classmethod
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
url = "https://huggingface.co/api/models?inference=warm&&expand[]=inferenceProviderMapping"
response = requests.get(url)
if response.ok:
cls.models = [
model["id"]
for model in response.json()
if [
provider
for provider in model.get("inferenceProviderMapping")
if provider.get("task") == "conversational"]]
else:
cls.models = cls.fallback_models
return cls.models
@classmethod
async def get_mapping(cls, model: str, api_key: str = None):
if model in cls.provider_mapping:
return cls.provider_mapping[model]
async with StreamSession(
timeout=30,
headers=cls.get_headers(False, api_key),
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}?expand[]=inferenceProviderMapping") as response:
await raise_for_status(response)
model_data = await response.json()
cls.provider_mapping[model] = model_data.get("inferenceProviderMapping")
return cls.provider_mapping[model]
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = None,
api_key: str = None,
max_tokens: int = 2048,
max_inputs_lenght: int = 10000,
media: MediaListType = None,
**kwargs
):
if model == llama_models["name"]:
model = llama_models["text"] if media is None else llama_models["vision"]
if model in cls.model_aliases:
model = cls.model_aliases[model]
provider_mapping = await cls.get_mapping(model, api_key)
for provider_key in provider_mapping:
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
api_base = f"https://router.huggingface.co/{api_path}"
task = provider_mapping[provider_key]["task"]
if task != "conversational":
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
break
start = calculate_lenght(messages)
if start > max_inputs_lenght:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk
def calculate_lenght(messages: Messages) -> int:
return sum([len(message["content"]) + 16 for message in messages])