mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-11-02 22:24:03 +08:00
Update demo model list Disable upload cookies in demo Track usage in demo mode Add messages without asking the ai Add hint for browser usage in provider list Add qwen2 prompt template to HuggingFace provider Trim automatic messages in HuggingFaceAPI
57 lines
2.2 KiB
Python
57 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
from .OpenaiTemplate import OpenaiTemplate
|
|
from .HuggingChat import HuggingChat
|
|
from ...providers.types import Messages
|
|
from ... import debug
|
|
|
|
class HuggingFaceAPI(OpenaiTemplate):
|
|
label = "HuggingFace (Inference API)"
|
|
parent = "HuggingFace"
|
|
url = "https://api-inference.huggingface.com"
|
|
api_base = "https://api-inference.huggingface.co/v1"
|
|
working = True
|
|
needs_auth = True
|
|
|
|
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
default_vision_model = default_model
|
|
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
|
|
model_aliases = HuggingChat.model_aliases
|
|
|
|
@classmethod
|
|
def get_models(cls, **kwargs):
|
|
if not cls.models:
|
|
HuggingChat.get_models()
|
|
cls.models = list(set(HuggingChat.text_models + cls.vision_models))
|
|
return cls.models
|
|
|
|
@classmethod
|
|
async def create_async_generator(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
api_base: str = None,
|
|
max_tokens: int = 2048,
|
|
max_inputs_lenght: int = 10000,
|
|
**kwargs
|
|
):
|
|
if api_base is None:
|
|
model_name = model
|
|
if model in cls.model_aliases:
|
|
model_name = cls.model_aliases[model]
|
|
api_base = f"https://api-inference.huggingface.co/models/{model_name}/v1"
|
|
start = calculate_lenght(messages)
|
|
if start > max_inputs_lenght:
|
|
if len(messages) > 6:
|
|
messages = messages[:3] + messages[-3:]
|
|
if calculate_lenght(messages) > max_inputs_lenght:
|
|
if len(messages) > 2:
|
|
messages = [m for m in messages if m["role"] == "system"] + messages[-1:]
|
|
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
|
|
messages = [messages[-1]]
|
|
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
|
|
async for chunk in super().create_async_generator(model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs):
|
|
yield chunk
|
|
|
|
def calculate_lenght(messages: Messages) -> int:
|
|
return sum([len(message["content"]) + 16 for message in messages]) |