Files
gpt4free/g4f/Provider/hf/HuggingFaceInference.py
hlohaus 93986d15f6 fix: resolve model duplication and improve provider handling
- Fixed duplicate model entries in Blackbox provider model_aliases
- Added meta-llama- to llama- name cleaning in Cloudflare provider
- Enhanced PollinationsAI provider with improved vision model detection
- Added reasoning support to PollinationsAI provider
- Fixed HuggingChat authentication to include headers and impersonate
- Removed unused max_inputs_length parameter from HuggingFaceAPI
- Renamed extra_data to extra_body for consistency across providers
- Added Puter provider with grouped model support
- Enhanced AnyProvider with grouped model display and better model organization
- Fixed model cleaning in AnyProvider to handle more model name variations
- Added api_key handling for HuggingFace providers in AnyProvider
- Added see_stream helper function to parse event streams
- Updated GUI server to handle JsonConversation properly
- Fixed aspect ratio handling in image generation functions
- Added ResponsesConfig and ClientResponse for new API endpoint
- Updated requirements to include markitdown
2025-05-16 00:18:12 +02:00

254 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import base64
import random
import requests
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
from ...errors import ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse
from ...image.copy_images import save_response_media
from ...image import use_aspect_ratio
from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
from ... import debug
provider_together_urls = {
"black-forest-labs/FLUX.1-dev": "https://router.huggingface.co/together/v1/images/generations",
"black-forest-labs/FLUX.1-schnell": "https://router.huggingface.co/together/v1/images/generations",
}
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
parent = "HuggingFace"
working = True
default_model = default_model
default_image_model = default_image_model
model_aliases = model_aliases
image_models = image_models
model_data: dict[str, dict] = {}
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
models = text_models.copy()
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
response = requests.get(url)
if response.ok:
extra_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 10]
models = extra_models + vision_models + [model for model in models if model not in extra_models]
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
response = requests.get(url)
cls.image_models = image_models.copy()
if response.ok:
extra_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
cls.image_models.extend([model for model in extra_models if model not in cls.image_models])
models.extend([model for model in cls.image_models if model not in models])
cls.models = models
return cls.models
@classmethod
async def get_model_data(cls, session: StreamSession, model: str) -> str:
if model in cls.model_data:
return cls.model_data[model]
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
await raise_for_status(response)
cls.model_data[model] = await response.json()
return cls.model_data[model]
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
proxy: str = None,
timeout: int = 600,
api_base: str = "https://api-inference.huggingface.co",
api_key: str = None,
max_tokens: int = 1024,
temperature: float = None,
prompt: str = None,
action: str = None,
extra_body: dict = {},
seed: int = None,
aspect_ratio: str = None,
width: int = None,
height: int = None,
**kwargs
) -> AsyncResult:
try:
model = cls.get_model(model)
except ModelNotSupportedError:
pass
headers = {
'Accept-Encoding': 'gzip, deflate',
'Content-Type': 'application/json',
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
image_extra_body = use_aspect_ratio({
"width": width,
"height": height,
**extra_body
}, aspect_ratio)
async with StreamSession(
headers=headers,
proxy=proxy,
timeout=timeout
) as session:
try:
if model in provider_together_urls:
data = {
"response_format": "url",
"prompt": format_image_prompt(messages, prompt),
"model": model,
**image_extra_body
}
async with session.post(provider_together_urls[model], json=data) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
result = await response.json()
yield ImageResponse([item["url"] for item in result["data"]], data["prompt"])
return
except ModelNotSupportedError:
pass
payload = None
params = {
"return_full_text": False,
"max_new_tokens": max_tokens,
"temperature": temperature,
**extra_body
}
do_continue = action == "continue"
if payload is None:
model_data = await cls.get_model_data(session, model)
pipeline_tag = model_data.get("pipeline_tag")
if pipeline_tag == "text-to-image":
stream = False
inputs = format_image_prompt(messages, prompt)
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **image_extra_body}}
elif pipeline_tag in ("text-generation", "image-text-to-text"):
model_type = None
if "config" in model_data and "model_type" in model_data["config"]:
model_type = model_data["config"]["model_type"]
debug.log(f"Model type: {model_type}")
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"Inputs len: {len(inputs)}")
if len(inputs) > 4096:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
else:
messages = [m for m in messages if m["role"] == "system"] + [{"role": "user", "content": get_last_user_message(messages)}]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512
if seed is not None:
params["seed"] = seed
payload = {"inputs": inputs, "parameters": params, "stream": stream}
else:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response)
if stream:
first = True
is_special = False
async for line in response.iter_lines():
if line.startswith(b"data:"):
data = json.loads(line[5:])
if "error" in data:
raise ResponseError(data["error"])
if not data["token"]["special"]:
chunk = data["token"]["text"]
if first and not do_continue:
first = False
chunk = chunk.lstrip()
if chunk:
yield chunk
else:
is_special = True
debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length")
else:
async for chunk in save_response_media(response, inputs, [aspect_ratio, model]):
yield chunk
return
yield (await response.json())[0]["generated_text"].strip()
def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:
system_messages = [message["content"] for message in messages if message["role"] == "system"]
question = " ".join([messages[-1]["content"], *system_messages])
history = "\n".join([
f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
for idx, message in enumerate(messages)
if message["role"] == "assistant"
])
if do_continue:
return history[:-len('</s>')]
return f"{history}\n<s>[INST] {question} [/INST]"
def format_prompt_qwen(messages: Messages, do_continue: bool = False) -> str:
prompt = "".join([
f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages
]) + ("" if do_continue else "<|im_start|>assistant\n")
if do_continue:
return prompt[:-len("\n<|im_end|>\n")]
return prompt
def format_prompt_qwen2(messages: Messages, do_continue: bool = False) -> str:
prompt = "".join([
f"\u003C{message['role'].capitalize()}\u003E{message['content']}\u003Cend▁of▁sentence\u003E" for message in messages
]) + ("" if do_continue else "\u003CAssistant\u003E")
if do_continue:
return prompt[:-len("\u003CAssistant\u003E")]
return prompt
def format_prompt_llama(messages: Messages, do_continue: bool = False) -> str:
prompt = "<|begin_of_text|>" + "".join([
f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
]) + ("" if do_continue else "<|start_header_id|>assistant<|end_header_id|>\n\n")
if do_continue:
return prompt[:-len("\n<|eot_id|>\n")]
return prompt
def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continue: bool = False) -> str:
prompt = "".join([
f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages
]) + ("" if do_continue else "<|assistant|>\n")
if do_continue:
return prompt[:-len(end_token + "\n")]
return prompt
def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
inputs = format_prompt(messages, do_continue=do_continue)
elif model_type == "mistral" and model_data.get("author") == "mistralai":
inputs = format_prompt_mistral(messages, do_continue)
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
inputs = format_prompt_custom(messages, eos_token, do_continue)
elif eos_token == "<|im_end|>":
inputs = format_prompt_qwen(messages, do_continue)
elif "content" in eos_token and eos_token["content"] == "\u003Cend▁of▁sentence\u003E":
inputs = format_prompt_qwen2(messages, do_continue)
elif eos_token == "<|eot_id|>":
inputs = format_prompt_llama(messages, do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
return inputs