Add many parameters to API endpoints

Support conversational HuggingFace providers
Fix streaming in PollinationsAI provider
This commit is contained in:
hlohaus
2025-03-11 22:16:03 +01:00
parent 3e7af90949
commit 713ad2c83c
7 changed files with 211 additions and 158 deletions

View File

@@ -131,7 +131,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
prompt: str = None,
negative_prompt: str = "blurry, deformed hands, ugly",
images_num: int = 1,
n: int = 1,
guidance_scale: int = 7,
num_inference_steps: int = 30,
aspect_ratio: str = "1:1",
@@ -149,7 +149,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
"prompt": prompt,
"negative_prompt": negative_prompt,
"style": model,
"images_num": str(images_num),
"images_num": str(n),
"cfg_scale": str(guidance_scale),
"steps": str(num_inference_steps),
"aspect_ratio": aspect_ratio,
@@ -181,7 +181,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
return
elif status in ("IN_QUEUE", "IN_PROGRESS"):
yield Reasoning(status=("Waiting" if status == "IN_QUEUE" else "Generating") + "." * counter)
await asyncio.sleep(5) # Poll every 5 seconds
await asyncio.sleep(2) # Poll every 5 seconds
counter += 1
if counter > 3:
counter = 0

View File

@@ -49,7 +49,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
# Models configuration
default_model = "openai"
default_image_model = "flux"
default_vision_model = "gpt-4o"
default_vision_model = default_model
text_models = [default_model]
image_models = [default_image_model]
extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
@@ -141,6 +141,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
stream: bool = False,
proxy: str = None,
cache: bool = False,
# Image generation parameters
prompt: str = None,
width: int = 1024,
height: int = 1024,
@@ -149,19 +151,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private: bool = False,
enhance: bool = False,
safe: bool = False,
# Text generation parameters
images: ImagesType = None,
temperature: float = None,
presence_penalty: float = None,
top_p: float = 1,
frequency_penalty: float = None,
response_format: Optional[dict] = None,
cache: bool = False,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias", "voice"],
**kwargs
) -> AsyncResult:
# Load model list
cls.get_models()
if images is not None and not model:
model = cls.default_vision_model
try:
model = cls.get_model(model)
except ModelNotFoundError:
@@ -231,7 +232,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
}
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items() if v is not None)
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
yield ImagePreview(url, prompt)
#yield ImagePreview(url, prompt)
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.get(url, allow_redirects=True) as response:
@@ -276,7 +277,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
if model in cls.audio_models or stream:
if model in cls.audio_models:
#data["voice"] = random.choice(cls.audio_models[model])
url = cls.text_api_endpoint
stream = False
@@ -328,12 +329,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
if content is not None:
if "</think>" in content and "<think>" not in content:
yield "<think>"
if content:
yield content.replace("\\(", "(").replace("\\)", ")")
yield content
if "usage" in result:
yield Usage(**result["usage"])

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
import requests
from ...providers.types import Messages
from ...typing import ImagesType
from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ...providers.response import ProviderInfo
from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model, llama_models
from .HuggingChat import HuggingChat
from .models import model_aliases, vision_models, default_vision_model, llama_models, text_models
from ... import debug
class HuggingFaceAPI(OpenaiTemplate):
@@ -22,32 +24,47 @@ class HuggingFaceAPI(OpenaiTemplate):
default_vision_model = default_vision_model
vision_models = vision_models
model_aliases = model_aliases
fallback_models = text_models + vision_models
pipeline_tags: dict[str, str] = {}
provider_mapping: dict[str, dict] = {}
@classmethod
def get_models(cls, **kwargs):
def get_model(cls, model: str, **kwargs) -> str:
try:
return super().get_model(model, **kwargs)
except ModelNotSupportedError:
return model
@classmethod
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
HuggingChat.get_models()
cls.models = HuggingChat.text_models.copy()
for model in cls.vision_models:
if model not in cls.models:
cls.models.append(model)
url = "https://huggingface.co/api/models?inference=warm&&expand[]=inferenceProviderMapping"
response = requests.get(url)
if response.ok:
cls.models = [
model["id"]
for model in response.json()
if [
provider
for provider in model.get("inferenceProviderMapping")
if provider.get("task") == "conversational"]]
else:
cls.models = cls.fallback_models
return cls.models
@classmethod
async def get_pipline_tag(cls, model: str, api_key: str = None):
if model in cls.pipeline_tags:
return cls.pipeline_tags[model]
async def get_mapping(cls, model: str, api_key: str = None):
if model in cls.provider_mapping:
return cls.provider_mapping[model]
async with StreamSession(
timeout=30,
headers=cls.get_headers(False, api_key),
) as session:
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
async with session.get(f"https://huggingface.co/api/models/{model}?expand[]=inferenceProviderMapping") as response:
await raise_for_status(response)
model_data = await response.json()
cls.pipeline_tags[model] = model_data.get("pipeline_tag")
return cls.pipeline_tags[model]
cls.provider_mapping[model] = model_data.get("inferenceProviderMapping")
return cls.provider_mapping[model]
@classmethod
async def create_async_generator(
@@ -65,12 +82,16 @@ class HuggingFaceAPI(OpenaiTemplate):
model = llama_models["text"] if images is None else llama_models["vision"]
if model in cls.model_aliases:
model = cls.model_aliases[model]
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
pipeline_tag = await cls.get_pipline_tag(model, api_key)
if pipeline_tag not in ("text-generation", "image-text-to-text"):
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
elif images and pipeline_tag != "image-text-to-text":
raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
provider_mapping = await cls.get_mapping(model, api_key)
for provider_key in provider_mapping:
api_path = provider_key if provider_key == "novita" else f"{provider_key}/v1"
api_base = f"https://router.huggingface.co/{api_path}"
task = provider_mapping[provider_key]["task"]
if task != "conversational":
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}")
model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
break
start = calculate_lenght(messages)
if start > max_inputs_lenght:
if len(messages) > 6:

View File

@@ -50,7 +50,7 @@ class ChatCompletion:
result = provider.get_create_function()(model, messages, stream=stream, **kwargs)
return result if stream else concat_chunks(result)
return result if stream or ignore_stream else concat_chunks(result)
@staticmethod
def create_async(model : Union[Model, str],
@@ -74,7 +74,7 @@ class ChatCompletion:
result = provider.get_async_create_function()(model, messages, stream=stream, **kwargs)
if not stream:
if not stream and not ignore_stream:
if hasattr(result, "__aiter__"):
result = async_concat_chunks(result)

View File

@@ -18,6 +18,9 @@ class ChatCompletionsConfig(BaseModel):
image_name: Optional[str] = None
images: Optional[list[tuple[str, str]]] = None
temperature: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
stop: Union[list[str], str, None] = None
api_key: Optional[str] = None
@@ -27,7 +30,6 @@ class ChatCompletionsConfig(BaseModel):
conversation_id: Optional[str] = None
conversation: Optional[dict] = None
history_disabled: Optional[bool] = None
auto_continue: Optional[bool] = None
timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[
{
@@ -48,6 +50,14 @@ class ImageGenerationConfig(BaseModel):
response_format: Optional[str] = None
api_key: Optional[str] = None
proxy: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
num_inference_steps: Optional[int] = None
seed: Optional[int] = None
guidance_scale: Optional[int] = None
aspect_ratio: Optional[str] = None
n: Optional[int] = None
negative_prompt: Optional[str] = None
class ProviderResponseModel(BaseModel):
id: str

View File

@@ -40,12 +40,12 @@ let parameters_storage = {};
let finish_storage = {};
let usage_storage = {};
let reasoning_storage = {};
let generate_storage = {};
let title_ids_storage = {};
let image_storage = {};
let is_demo = false;
let wakeLock = null;
let countTokensEnabled = true;
let reloadConversation = true;
messageInput.addEventListener("blur", () => {
document.documentElement.scrollTop = 0;
@@ -203,10 +203,12 @@ const highlight = (container) => {
const get_message_el = (el) => {
let message_el = el;
while(!("index" in message_el.dataset) && message_el.parentElement) {
while(!(message_el.classList.contains('message')) && message_el.parentElement) {
message_el = message_el.parentElement;
}
if (message_el.classList.contains('message')) {
return message_el;
}
}
function register_message_images() {
@@ -220,7 +222,7 @@ function register_message_images() {
el.onerror = () => {
let indexCommand;
if ((indexCommand = el.src.indexOf("/generate/")) >= 0) {
generate_storage[window.conversation_id] = true;
reloadConversation = false;
indexCommand = indexCommand + "/generate/".length + 1;
let newPath = el.src.substring(indexCommand)
let filename = newPath.replace(/(?:\?.+?|$)/, "");
@@ -282,22 +284,30 @@ const register_message_buttons = async () => {
});
});
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const message_el = get_message_el(el);
if (message_el) {
if ("index" in message_el.dataset) {
await remove_message(window.conversation_id, message_el.dataset.index);
}
message_el.remove();
}
reloadConversation = true;
await safe_load_conversation(window.conversation_id, false);
}));
});
});
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
let message_el = get_message_el(el);
let response = await fetch(message_el.dataset.object_url);
let copyText = await response.text();
@@ -313,13 +323,15 @@ const register_message_buttons = async () => {
}
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
}))
});
})
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const elem = window.document.createElement('a');
let filename = `chat ${new Date().toLocaleString()}.txt`.replaceAll(":", "-");
const conversation = await get_conversation(window.conversation_id);
@@ -336,13 +348,15 @@ const register_message_buttons = async () => {
download.click();
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
}))
});
})
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const message_el = get_message_el(el);
let audio;
if (message_el.dataset.synthesize_url) {
@@ -361,47 +375,55 @@ const register_message_buttons = async () => {
audio.play();
return;
}
}));
});
});
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const message_el = get_message_el(el);
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
await ask_gpt(get_message_id(), message_el.dataset.index);
}));
});
});
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
if (!el.disabled) {
el.disabled = true;
const message_el = get_message_el(el);
el.classList.add("clicked");
setTimeout(() => {el.classList.remove("clicked"); el.disabled = false}, 1000);
await ask_gpt(get_message_id(), message_el.dataset.index, false, null, null, "continue");
}}
));
}
});
});
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const text = get_message_el(el).innerText;
window.open(`https://wa.me/?text=${encodeURIComponent(text)}`, '_blank');
}));
});
});
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
const message_el = get_message_el(el);
el.classList.add("clicked");
message_box.scrollTop = 0;
@@ -411,18 +433,21 @@ const register_message_buttons = async () => {
message_el.classList.remove("print");
}, 1000);
window.print()
}));
});
});
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => el.addEventListener("click", async () => {
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => {
if (el.dataset.click) {
return
}
el.dataset.click = true;
el.addEventListener("click", async () => {
let text_el = el.parentElement.querySelector(".reasoning_text");
if (text_el) {
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
text_el.classList.toogle("hidden");
}
}));
});
});
}
const delete_conversations = async () => {
@@ -842,7 +867,7 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
audio.controls = true;
content_map.inner.appendChild(audio);
audio.play();
generate_storage[window.conversation_id] = true;
reloadConversation = false;
} else if (message.type == "content") {
message_storage[message_id] += message.content;
update_message(content_map, message_id, null, scroll);
@@ -866,10 +891,11 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
} else if (message.type == "reasoning") {
if (!reasoning_storage[message_id]) {
reasoning_storage[message_id] = message;
reasoning_storage[message_id].text = message.token || "";
reasoning_storage[message_id].text = message_storage[message_id];
message_storage[message_id] = "";
} else if (message.status) {
reasoning_storage[message_id].status = message.status;
} else if (message.token) {
} if (message.token) {
reasoning_storage[message_id].text += message.token;
}
update_message(content_map, message_id, render_reasoning(reasoning_storage[message_id]), scroll);
@@ -1039,7 +1065,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
delete controller_storage[message_id];
}
// Reload conversation if no error
if (!error_storage[message_id] && !generate_storage[window.conversation_id]) {
if (!error_storage[message_id] && reloadConversation) {
await safe_load_conversation(window.conversation_id, scroll);
}
let cursorDiv = message_el.querySelector(".cursor");
@@ -1077,18 +1103,17 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
}
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
let extra_parameters = {};
document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea").forEach(async (el) => {
let extra_parameters = [];
for (el of document.getElementById(`${provider}-form`)?.querySelectorAll(".saved input, .saved textarea") || []) {
let value = el.type == "checkbox" ? el.checked : el.value;
extra_parameters[el.name] = value;
if (el.type == "textarea") {
try {
extra_parameters[el.name] = await JSON.parse(value);
value = await JSON.parse(value);
} catch (e) {
}
}
});
console.log(extra_parameters);
extra_parameters[el.name] = value;
};
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,

View File

@@ -5,7 +5,7 @@ import json
import asyncio
import time
from pathlib import Path
from typing import Optional, Callable, AsyncIterator, Dict, Any, Tuple, List, Union
from typing import Optional, Callable, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
from ..typing import Messages
from ..providers.helper import filter_none
@@ -154,7 +154,7 @@ class ThinkingProcessor:
results = []
# Handle non-thinking chunk
if not start_time and "<think>" not in chunk:
if not start_time and "<think>" not in chunk and "</think>" not in chunk:
return 0, [chunk]
# Handle thinking start
@@ -255,7 +255,7 @@ def iter_run_tools(
provider: Optional[str] = None,
tool_calls: Optional[List[dict]] = None,
**kwargs
) -> AsyncIterator:
) -> Iterator:
"""Run tools synchronously and yield results"""
# Process web search
web_search = kwargs.get('web_search')