Files
gpt4free/g4f/Provider/needs_auth/HuggingFace.py
Heiner Lohaus 12c413fd2e Add Edge as Browser for nodriver
Fix for RetryProviders doesn't retry
Add retry and continue for DuckDuckGo provider
Add cache for Cloudflare provider
Add cache for prompts on gui home
Add scroll to bottom checkbox in gui
Improve prompts on home gui
Fix response content type in api for files
2025-01-05 17:02:15 +01:00

202 lines
9.3 KiB
Python

from __future__ import annotations
import json
import base64
import random
import requests
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason
from ...image import ImageResponse
from ... import debug
from .HuggingChat import HuggingChat
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
working = True
supports_message_history = True
default_model = HuggingChat.default_model
default_image_model = HuggingChat.default_image_model
model_aliases = HuggingChat.model_aliases
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
cls.models = [model["id"] for model in requests.get(url).json()]
cls.models.append("meta-llama/Llama-3.2-11B-Vision-Instruct")
cls.models.append("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
cls.models.sort()
if not cls.image_models:
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
cls.image_models.sort()
cls.models.extend(cls.image_models)
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
proxy: str = None,
api_base: str = "https://api-inference.huggingface.co",
api_key: str = None,
max_tokens: int = 1024,
temperature: float = None,
prompt: str = None,
action: str = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
try:
model = cls.get_model(model)
except ModelNotSupportedError:
pass
headers = {
'accept': '*/*',
'accept-language': 'en',
'cache-control': 'no-cache',
'origin': 'https://huggingface.co',
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': 'https://huggingface.co/chat/',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
payload = None
if cls.get_models() and model in cls.image_models:
stream = False
prompt = messages[-1]["content"] if prompt is None else prompt
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
else:
params = {
"return_full_text": False,
"max_new_tokens": max_tokens,
"temperature": temperature,
**extra_data
}
do_continue = action == "continue"
async with StreamSession(
headers=headers,
proxy=proxy,
timeout=600
) as session:
if payload is None:
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
await raise_for_status(response)
model_data = await response.json()
model_type = None
if "config" in model_data and "model_type" in model_data["config"]:
model_type = model_data["config"]["model_type"]
debug.log(f"Model type: {model_type}")
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"Inputs len: {len(inputs)}")
if len(inputs) > 4096:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
else:
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512
payload = {"inputs": inputs, "parameters": params, "stream": stream}
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
if response.status == 404:
raise ModelNotFoundError(f"Model is not supported: {model}")
await raise_for_status(response)
if stream:
first = True
is_special = False
async for line in response.iter_lines():
if line.startswith(b"data:"):
data = json.loads(line[5:])
if "error" in data:
raise ResponseError(data["error"])
if not data["token"]["special"]:
chunk = data["token"]["text"]
if first and not do_continue:
first = False
chunk = chunk.lstrip()
if chunk:
yield chunk
else:
is_special = True
debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
else:
if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
yield ImageResponse(url, prompt)
else:
yield (await response.json())[0]["generated_text"].strip()
def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:
system_messages = [message["content"] for message in messages if message["role"] == "system"]
question = " ".join([messages[-1]["content"], *system_messages])
history = "".join([
f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
for idx, message in enumerate(messages)
if message["role"] == "assistant"
])
if do_continue:
return history[:-len('</s>')]
return f"{history}<s>[INST] {question} [/INST]"
def format_prompt_qwen(messages: Messages, do_continue: bool = False) -> str:
prompt = "".join([
f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages
]) + ("" if do_continue else "<|im_start|>assistant\n")
if do_continue:
return prompt[:-len("\n<|im_end|>\n")]
return prompt
def format_prompt_llama(messages: Messages, do_continue: bool = False) -> str:
prompt = "<|begin_of_text|>" + "".join([
f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
]) + ("" if do_continue else "<|start_header_id|>assistant<|end_header_id|>\n\n")
if do_continue:
return prompt[:-len("\n<|eot_id|>\n")]
return prompt
def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continue: bool = False) -> str:
prompt = "".join([
f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages
]) + ("" if do_continue else "<|assistant|>\n")
if do_continue:
return prompt[:-len(end_token + "\n")]
return prompt
def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
inputs = format_prompt(messages, do_continue=do_continue)
elif model_type in ("mistral"):
inputs = format_prompt_mistral(messages, do_continue)
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
inputs = format_prompt_custom(messages, eos_token, do_continue)
elif eos_token == "<|im_end|>":
inputs = format_prompt_qwen(messages, do_continue)
elif eos_token == "<|eot_id|>":
inputs = format_prompt_llama(messages, do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
else:
inputs = format_prompt(messages, do_continue=do_continue)
return inputs