Support TitleGeneration, Reasoning in HuggingChat

Improve model list in HuggingSpace, PollinationsAI
Fix Image Generation in PollinationsAI
Add Image Upload in PollinationsAI
Support Usage, FinishReason,  jsonMode in PollinationsAI
Add Reasoning to Web UI
Fix using provider api_keys in Web UI
This commit is contained in:
hlohaus
2025-01-23 23:16:12 +01:00
parent 78fa745698
commit cad308108c
15 changed files with 303 additions and 181 deletions

View File

@@ -3,14 +3,23 @@ from __future__ import annotations
import json
import random
import requests
from urllib.parse import quote
from urllib.parse import quote_plus
from typing import Optional
from aiohttp import ClientSession
from .helper import filter_none
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri
from ..requests.raise_for_status import raise_for_status
from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..requests.aiohttp import get_connector
from ..providers.response import ImageResponse, FinishReason, Usage
DEFAULT_HEADERS = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Pollinations AI"
@@ -21,24 +30,18 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
# API endpoints base
api_base = "https://text.pollinations.ai/openai"
# API endpoints
text_api_endpoint = "https://text.pollinations.ai/"
text_api_endpoint = "https://text.pollinations.ai/openai"
image_api_endpoint = "https://image.pollinations.ai/"
# Models configuration
default_model = "openai"
default_image_model = "flux"
image_models = []
models = []
additional_models_image = ["midjourney", "dall-e-3"]
additional_models_text = ["claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
default_vision_model = "gpt-4o"
extra_image_models = ["midjourney", "dall-e-3"]
vision_models = [default_vision_model, "gpt-4o-mini"]
extra_text_models = [*vision_models, "claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
model_aliases = {
"gpt-4o": default_model,
"qwen-2-72b": "qwen",
"qwen-2.5-coder-32b": "qwen-coder",
"llama-3.3-70b": "llama",
@@ -50,22 +53,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"deepseek-chat": "deepseek",
"llama-3.2-3b": "llamalight",
}
text_models = []
@classmethod
def get_models(cls, **kwargs):
# Initialize model lists if not exists
if not hasattr(cls, 'image_models'):
cls.image_models = []
if not hasattr(cls, 'text_models'):
cls.text_models = []
# Fetch image models if not cached
if not cls.image_models:
url = "https://image.pollinations.ai/models"
response = requests.get(url)
raise_for_status(response)
cls.image_models = response.json()
cls.image_models.extend(cls.additional_models_image)
cls.image_models.extend(cls.extra_image_models)
# Fetch text models if not cached
if not cls.text_models:
@@ -73,7 +71,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
response = requests.get(url)
raise_for_status(response)
cls.text_models = [model.get("name") for model in response.json()]
cls.text_models.extend(cls.additional_models_text)
cls.text_models.extend(cls.extra_text_models)
# Return combined models
return cls.text_models + cls.image_models
@@ -94,22 +92,27 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool = False,
safe: bool = False,
# Text specific parameters
temperature: float = 0.5,
presence_penalty: float = 0,
images: ImagesType = None,
temperature: float = None,
presence_penalty: float = None,
top_p: float = 1,
frequency_penalty: float = 0,
stream: bool = False,
frequency_penalty: float = None,
response_format: Optional[dict] = None,
cache: bool = False,
**kwargs
) -> AsyncResult:
if images is not None and not model:
model = cls.default_vision_model
model = cls.get_model(model)
if not cache and seed is None:
seed = random.randint(0, 100000)
# Check if models
# Image generation
if model in cls.image_models:
async for result in cls._generate_image(
yield await cls._generate_image(
model=model,
messages=messages,
prompt=prompt,
prompt=messages[-1]["content"] if prompt is None else prompt,
proxy=proxy,
width=width,
height=height,
@@ -118,19 +121,21 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private=private,
enhance=enhance,
safe=safe
):
yield result
)
else:
# Text generation
async for result in cls._generate_text(
model=model,
messages=messages,
images=images,
proxy=proxy,
temperature=temperature,
presence_penalty=presence_penalty,
top_p=top_p,
frequency_penalty=frequency_penalty,
stream=stream
response_format=response_format,
seed=seed,
cache=cache,
):
yield result
@@ -138,7 +143,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async def _generate_image(
cls,
model: str,
messages: Messages,
prompt: str,
proxy: str,
width: int,
@@ -148,16 +152,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private: bool,
enhance: bool,
safe: bool
) -> AsyncResult:
if seed is None:
seed = random.randint(0, 10000)
headers = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
) -> ImageResponse:
params = {
"seed": seed,
"width": width,
@@ -168,42 +163,47 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"enhance": enhance,
"safe": safe
}
params = {k: v for k, v in params.items() if v is not None}
async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt
param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{quote(prompt)}?{param_string}"
async with session.head(url, proxy=proxy) as response:
if response.status == 200:
image_response = ImageResponse(images=url, alt=prompt)
yield image_response
params = {k: json.dumps(v) if isinstance(v, bool) else v for k, v in params.items() if v is not None}
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with session.head(f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}", params=params) as response:
await raise_for_status(response)
return ImageResponse(str(response.url), prompt)
@classmethod
async def _generate_text(
cls,
model: str,
messages: Messages,
images: Optional[ImagesType],
proxy: str,
temperature: float,
presence_penalty: float,
top_p: float,
frequency_penalty: float,
stream: bool,
seed: Optional[int] = None
response_format: Optional[dict],
seed: Optional[int],
cache: bool
) -> AsyncResult:
headers = {
"accept": "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
jsonMode = False
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
jsonMode = True
if images is not None and messages:
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
{
"type": "text",
"text": messages[-1]["content"]
}
]
messages[-1] = last_message
if seed is None:
seed = random.randint(0, 10000)
async with ClientSession(headers=headers) as session:
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
data = {
"messages": messages,
"model": model,
@@ -211,42 +211,33 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"presence_penalty": presence_penalty,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"jsonMode": False,
"stream": stream,
"jsonMode": jsonMode,
"stream": False, # To get more informations like Usage and FinishReason
"seed": seed,
"cache": False
"cache": cache
}
async with session.post(cls.text_api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
async for chunk in response.content:
if chunk:
decoded_chunk = chunk.decode()
# Skip [DONE].
async with session.post(cls.text_api_endpoint, json=filter_none(**data)) as response:
await raise_for_status(response)
async for line in response.content:
decoded_chunk = line.decode(errors="replace")
# If [DONE].
if "data: [DONE]" in decoded_chunk:
continue
# Processing plain text
if not decoded_chunk.startswith("data:"):
clean_text = decoded_chunk.strip()
if clean_text:
yield clean_text
continue
break
# Processing JSON format
try:
# Remove the prefix “data: “ and parse JSON
json_str = decoded_chunk.replace("data:", "").strip()
json_response = json.loads(json_str)
if "choices" in json_response and json_response["choices"]:
if "delta" in json_response["choices"][0]:
content = json_response["choices"][0]["delta"].get("content")
if content:
# Remove escaped slashes before parentheses
clean_content = content.replace("\\(", "(").replace("\\)", ")")
yield clean_content
data = json.loads(json_str)
choice = data["choices"][0]
if "usage" in data:
yield Usage(**data["usage"])
if "message" in choice and "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].replace("\\(", "(").replace("\\)", ")")
elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
yield choice["delta"]["content"].replace("\\(", "(").replace("\\)", ")")
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
except json.JSONDecodeError:
# If JSON could not be parsed, skip
yield decoded_chunk.strip()
continue

View File

@@ -18,6 +18,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
default_model = "qwen-qvq-72b-preview"
models = [default_model]
vision_models = models
model_aliases = {"qwq-32b": default_model}
@classmethod

View File

@@ -33,12 +33,18 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls, **kwargs) -> list[str]:
if not cls.models:
models = []
image_models = []
vision_models = []
for provider in cls.providers:
models.extend(provider.get_models(**kwargs))
models.extend(provider.model_aliases.keys())
image_models.extend(provider.image_models)
vision_models.extend(provider.vision_models)
models = list(set(models))
models.sort()
cls.models = models
cls.image_models = list(set(image_models))
cls.vision_models = list(set(vision_models))
return cls.models
@classmethod

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import json
import re
import requests
try:
from curl_cffi.requests import Session, CurlMime
@@ -13,14 +15,13 @@ from ..helper import format_prompt
from ...typing import CreateResult, Messages, Cookies
from ...errors import MissingRequirementsError
from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning
from ...cookies import get_cookies
from ... import debug
class Conversation(JsonConversation):
def __init__(self, conversation_id: str, message_id: str):
self.conversation_id: str = conversation_id
self.message_id: str = message_id
def __init__(self, models: dict):
self.models: dict = models
class HuggingChat(AbstractProvider, ProviderModelMixin):
url = "https://huggingface.co/chat"
@@ -32,11 +33,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [
"black-forest-labs/FLUX.1-dev",
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
models = [
'Qwen/Qwen2.5-Coder-32B-Instruct',
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'Qwen/QwQ-32B-Preview',
@@ -63,12 +64,33 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
}
@classmethod
def get_models(cls):
if not cls.models:
try:
text = requests.get(cls.url).text
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = text.replace('void 0', 'null')
def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":'
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
models = json.loads(text)
cls.text_models = [model["id"] for model in models]
cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
cls.models = [*cls.fallback_models]
return cls.models
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
stream: bool,
prompt: str = None,
return_conversation: bool = False,
conversation: Conversation = None,
web_search: bool = False,
@@ -99,22 +121,26 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
}
if conversation is None:
if conversation is None or not hasattr(conversation, "models"):
conversation = Conversation({})
if model not in conversation.models:
conversationId = cls.create_conversation(session, model)
messageId = cls.fetch_message_id(session, conversationId)
conversation = Conversation(conversationId, messageId)
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation:
yield conversation
inputs = format_prompt(messages)
else:
conversation.message_id = cls.fetch_message_id(session, conversation.conversation_id)
conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["message_id"] = cls.fetch_message_id(session, conversationId)
inputs = messages[-1]["content"]
debug.log(f"Use conversation: {conversation.conversation_id} Use message: {conversation.message_id}")
debug.log(f"Use model {model}: {json.dumps(conversation.models[model])}")
settings = {
"inputs": inputs,
"id": conversation.message_id,
"id": conversation.models[model]["message_id"],
"is_retry": False,
"is_continue": False,
"web_search": web_search,
@@ -128,7 +154,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'origin': 'https://huggingface.co',
'pragma': 'no-cache',
'priority': 'u=1, i',
'referer': f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
'referer': f'https://huggingface.co/chat/conversation/{conversationId}',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
@@ -142,7 +168,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
response = session.post(
f'https://huggingface.co/chat/conversation/{conversation.conversation_id}',
f'https://huggingface.co/chat/conversation/{conversationId}',
cookies=session.cookies,
headers=headers,
multipart=data,
@@ -170,10 +196,17 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
elif line["type"] == "finalAnswer":
break
elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
prompt = messages[-1]["content"] if prompt is None else prompt
yield ImageResponse(url, alt=prompt, options={"cookies": cookies})
elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"])
elif line["type"] == "title":
yield TitleGeneration(line["title"])
elif line["type"] == "reasoning":
yield Reasoning(line.get("token"), line.get("status"))
else:
pass #print(line)
full_response = full_response.replace('<|im_end|', '').strip()
if not stream:

View File

@@ -143,7 +143,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else:
is_special = True
debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
yield FinishReason("stop" if is_special else "length")
else:
if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from .OpenaiAPI import OpenaiAPI
from .HuggingChat import HuggingChat
from ...providers.types import Messages
class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)"
@@ -11,6 +12,23 @@ class HuggingFaceAPI(OpenaiAPI):
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model
models = [
*HuggingChat.models
]
@classmethod
def get_models(cls, **kwargs):
HuggingChat.get_models()
cls.models = HuggingChat.text_models
cls.vision_models = HuggingChat.vision_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = None,
**kwargs
):
if api_base is None:
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
async for chunk in super().create_async_generator(model, messages, api_base=api_base, **kwargs):
yield chunk

View File

@@ -73,10 +73,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
raise MissingAuthError('Add a "api_key"')
if api_base is None:
api_base = cls.api_base
if images is not None:
if images is not None and messages:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
messages[-1]["content"] = [
last_message = messages[-1].copy()
last_message["content"] = [
*[{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
@@ -86,6 +87,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
"text": messages[-1]["content"]
}
]
messages[-1] = last_message
async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
@@ -117,9 +119,9 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data:
yield Usage(**data["usage"])
finish = cls.read_finish_reason(choice)
if finish is not None:
yield finish
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
return
else:
first = True
async for line in response.iter_lines():
@@ -137,15 +139,9 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if delta:
first = False
yield delta
finish = cls.read_finish_reason(choice)
if finish is not None:
yield finish
break
@staticmethod
def read_finish_reason(choice: dict) -> Optional[FinishReason]:
if "finish_reason" in choice and choice["finish_reason"] is not None:
return FinishReason(choice["finish_reason"])
yield FinishReason(choice["finish_reason"])
break
@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:

View File

@@ -495,8 +495,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"headers": cls._headers,
"web_search": web_search,
})
actions = ["variant", "continue"] if conversation.finish_reason == "max_tokens" else ["variant"]
yield FinishReason(conversation.finish_reason, actions=actions)
yield FinishReason(conversation.finish_reason)
@classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:

View File

@@ -376,6 +376,29 @@ body:not(.white) a:visited{
display: flex;
}
.message .reasoning_text.final:not(.hidden), .message .reasoning_title {
margin-bottom: var(--inner-gap);
padding-bottom: var(--inner-gap);
border-bottom: 1px solid var(--colour-3);
overflow: hidden;
}
.message .reasoning_text.final {
max-height: 1000px;
transition: max-height 0.25s ease-in;
}
.message .reasoning_text.final.hidden {
transition: max-height 0.15s ease-out;
max-height: 0;
display: block;
overflow: hidden;
}
.message .reasoning_title {
cursor: pointer;
}
.message .user i {
position: absolute;
bottom: -6px;

View File

@@ -35,6 +35,7 @@ let title_storage = {};
let parameters_storage = {};
let finish_storage = {};
let usage_storage = {};
let reasoning_storage = {}
messageInput.addEventListener("blur", () => {
window.scrollTo(0, 0);
@@ -70,6 +71,17 @@ if (window.markdownit) {
}
}
function render_reasoning(reasoning, final = false) {
return `<div class="reasoning_body">
<div class="reasoning_title">
<strong>Reasoning <i class="fa-solid fa-brain"></i>:</strong> ${escapeHtml(reasoning.status)}
</div>
<div class="reasoning_text${final ? " final hidden" : ""}">
${markdown_render(reasoning.text)}
</div>
</div>`;
}
function filter_message(text) {
return text.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
@@ -169,7 +181,7 @@ const get_message_el = (el) => {
}
const register_message_buttons = async () => {
document.querySelectorAll(".message .content .provider").forEach(async (el) => {
message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
const provider_forms = document.querySelector(".provider_forms");
@@ -192,7 +204,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -203,7 +215,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -226,7 +238,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -244,7 +256,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -270,7 +282,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -282,7 +294,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .continue_button").forEach(async (el) => {
message_box.querySelectorAll(".message .continue_button").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -297,7 +309,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -307,7 +319,7 @@ const register_message_buttons = async () => {
}
});
document.querySelectorAll(".message .fa-print").forEach(async (el) => {
message_box.querySelectorAll(".message .fa-print").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
@@ -323,6 +335,16 @@ const register_message_buttons = async () => {
})
}
});
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
let text_el = el.parentElement.querySelector(".reasoning_text");
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
})
}
});
}
const delete_conversations = async () => {
@@ -469,7 +491,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
messages.forEach((message) => {
message_copy = { ...message };
if (last_message) {
if (last_message["role"] == message["role"]) {
if (last_message["role"] == message["role"] && message["role"] == "assistant") {
message_copy["content"] = last_message["content"] + message_copy["content"];
new_messages.pop();
}
@@ -515,6 +537,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
delete new_message.synthesize;
delete new_message.finish;
delete new_message.usage;
delete new_message.reasoning;
delete new_message.conversation;
delete new_message.continue;
// Append message to new messages
@@ -711,11 +734,21 @@ async function add_message_chunk(message, message_id, provider, scroll) {
} else if (message.type == "title") {
title_storage[message_id] = message.title;
} else if (message.type == "login") {
update_message(content_map, message_id, message.login, scroll);
update_message(content_map, message_id, markdown_render(message.login), scroll);
} else if (message.type == "finish") {
finish_storage[message_id] = message.finish;
} else if (message.type == "usage") {
usage_storage[message_id] = message.usage;
} else if (message.type == "reasoning") {
if (!reasoning_storage[message_id]) {
reasoning_storage[message_id] = message;
reasoning_storage[message_id].text = "";
} else if (message.status) {
reasoning_storage[message_id].status = message.status;
} else if (message.token) {
reasoning_storage[message_id].text += message.token;
}
update_message(content_map, message_id, render_reasoning(reasoning_storage[message_id]), scroll);
} else if (message.type == "parameters") {
if (!parameters_storage[provider]) {
parameters_storage[provider] = {};
@@ -846,6 +879,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
title_storage[message_id],
finish_storage[message_id],
usage_storage[message_id],
reasoning_storage[message_id],
action=="continue"
);
delete controller_storage[message_id];
@@ -1042,6 +1076,7 @@ function merge_messages(message1, message2) {
const load_conversation = async (conversation_id, scroll=true) => {
let conversation = await get_conversation(conversation_id);
let messages = conversation?.items || [];
console.debug("Conversation:", conversation)
if (!conversation) {
return;
@@ -1098,11 +1133,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
let add_buttons = [];
// Find buttons to add
actions = ["variant"]
if (item.finish && item.finish.actions) {
actions = item.finish.actions
}
// Add continue button if possible
if (item.role == "assistant" && !actions.includes("continue")) {
if (item.role == "assistant") {
let reason = "stop";
// Read finish reason from conversation
if (item.finish && item.finish.reason) {
@@ -1167,7 +1199,10 @@ const load_conversation = async (conversation_id, scroll=true) => {
</div>
<div class="content">
${provider}
<div class="content_inner">${markdown_render(buffer)}</div>
<div class="content_inner">
${item.reasoning ? render_reasoning(item.reasoning, true): ""}
${markdown_render(buffer)}
</div>
<div class="count">
${count_words_and_tokens(buffer, next_provider?.model, completion_tokens, prompt_tokens)}
${add_buttons.join("")}
@@ -1298,6 +1333,7 @@ const add_message = async (
title = null,
finish = null,
usage = null,
reasoning = null,
do_continue = false
) => {
const conversation = await get_conversation(conversation_id);
@@ -1329,6 +1365,9 @@ const add_message = async (
if (usage) {
new_message.usage = usage;
}
if (reasoning) {
new_message.reasoning = reasoning;
}
if (do_continue) {
new_message.continue = true;
}
@@ -1604,23 +1643,24 @@ function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
function update_message(content_map, message_id, content = null, scroll = true) {
content_map.update_timeouts.push(setTimeout(() => {
if (!content) content = message_storage[message_id];
html = markdown_render(content);
if (!content) {
content = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element)
const index = content.lastIndexOf(element)
if (index - element.length > lastIndex) {
lastElement = element;
lastIndex = index;
}
}
if (lastIndex) {
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
content = content.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
}
content_map.inner.innerHTML = content;
if (error_storage[message_id]) {
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${error_storage[message_id]}`);
}
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
if (scroll) {
@@ -2132,9 +2172,9 @@ async function read_response(response, message_id, provider, scroll) {
function get_api_key_by_provider(provider) {
let api_key = null;
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
if (api_key == null) {
api_key = document.querySelector(`.${provider}-api_key`)?.id || null;
if (api_key == null) {
api_key = document.getElementById(`${provider}-api_key`)?.id || null;
}
if (api_key) {
api_key = appStorage.getItem(api_key);

View File

@@ -13,7 +13,7 @@ from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import IterListProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
from ... import version, models
from ... import ChatCompletion, get_model_and_provider
@@ -207,6 +207,8 @@ class Api:
yield self._format_json("finish", chunk.get_dict())
elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
else:
yield self._format_json("content", str(chunk))
if debug.logs:
@@ -219,10 +221,15 @@ class Api:
if first:
yield self.handle_provider(provider_handler, model)
def _format_json(self, response_type: str, content):
def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None:
return {
'type': response_type,
response_type: content
response_type: content,
}
return {
'type': response_type,
**kwargs
}
def handle_provider(self, provider_handler, model):

View File

@@ -309,7 +309,7 @@ class Backend_Api(Api):
return "Provider not found", 404
return models
def _format_json(self, response_type: str, content) -> str:
def _format_json(self, response_type: str, content = None, **kwargs) -> str:
"""
Formats and returns a JSON response.
@@ -320,4 +320,4 @@ class Backend_Api(Api):
Returns:
str: A JSON formatted string.
"""
return json.dumps(super()._format_json(response_type, content)) + "\n"
return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"

View File

@@ -340,7 +340,8 @@ class ProviderModelMixin:
default_model: str = None
models: list[str] = []
model_aliases: dict[str, str] = {}
image_models: list = None
image_models: list = []
vision_models: list = []
last_model: str = None
@classmethod

View File

@@ -89,9 +89,8 @@ class JsonMixin:
self.__dict__ = {}
class FinishReason(ResponseType, JsonMixin):
def __init__(self, reason: str, actions: list[str] = None) -> None:
def __init__(self, reason: str) -> None:
self.reason = reason
self.actions = actions
def __str__(self) -> str:
return ""
@@ -121,6 +120,14 @@ class TitleGeneration(ResponseType):
def __str__(self) -> str:
return ""
class Reasoning(ResponseType):
def __init__(self, token: str = None, status: str = None) -> None:
self.token = token
self.status = status
def __str__(self) -> str:
return "" if self.token is None else self.token
class Sources(ResponseType):
def __init__(self, sources: list[dict[str, str]]) -> None:
self.list = []

View File

@@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return
text = await response.text()
if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else text
if message == "HTML content":
if response.status == 520:
message = "Unknown error (Cloudflare)"
@@ -46,7 +46,7 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
if response.ok:
return
if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else response.text
message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else response.text
if message == "HTML content":
if response.status_code == 520:
message = "Unknown error (Cloudflare)"