Support TitleGeneration, Reasoning in HuggingChat

Improve model list in HuggingSpace, PollinationsAI
Fix Image Generation in PollinationsAI
Add Image Upload in PollinationsAI
Support Usage, FinishReason,  jsonMode in PollinationsAI
Add Reasoning to Web UI
Fix using provider api_keys in Web UI
This commit is contained in:
hlohaus
2025-01-23 23:16:12 +01:00
parent 78fa745698
commit cad308108c
15 changed files with 303 additions and 181 deletions

View File

@@ -3,42 +3,45 @@ from __future__ import annotations
import json import json
import random import random
import requests import requests
from urllib.parse import quote from urllib.parse import quote_plus
from typing import Optional from typing import Optional
from aiohttp import ClientSession from aiohttp import ClientSession
from .helper import filter_none
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..typing import AsyncResult, Messages from ..requests.aiohttp import get_connector
from ..image import ImageResponse from ..providers.response import ImageResponse, FinishReason, Usage
DEFAULT_HEADERS = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Pollinations AI" label = "Pollinations AI"
url = "https://pollinations.ai" url = "https://pollinations.ai"
working = True working = True
supports_stream = False supports_stream = False
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True
# API endpoints base
api_base = "https://text.pollinations.ai/openai"
# API endpoints # API endpoints
text_api_endpoint = "https://text.pollinations.ai/" text_api_endpoint = "https://text.pollinations.ai/openai"
image_api_endpoint = "https://image.pollinations.ai/" image_api_endpoint = "https://image.pollinations.ai/"
# Models configuration # Models configuration
default_model = "openai" default_model = "openai"
default_image_model = "flux" default_image_model = "flux"
default_vision_model = "gpt-4o"
image_models = [] extra_image_models = ["midjourney", "dall-e-3"]
models = [] vision_models = [default_vision_model, "gpt-4o-mini"]
extra_text_models = [*vision_models, "claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
additional_models_image = ["midjourney", "dall-e-3"]
additional_models_text = ["claude", "karma", "command-r", "llamalight", "mistral-large", "sur", "sur-mistral"]
model_aliases = { model_aliases = {
"gpt-4o": default_model,
"qwen-2-72b": "qwen", "qwen-2-72b": "qwen",
"qwen-2.5-coder-32b": "qwen-coder", "qwen-2.5-coder-32b": "qwen-coder",
"llama-3.3-70b": "llama", "llama-3.3-70b": "llama",
@@ -50,22 +53,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"deepseek-chat": "deepseek", "deepseek-chat": "deepseek",
"llama-3.2-3b": "llamalight", "llama-3.2-3b": "llamalight",
} }
text_models = []
@classmethod @classmethod
def get_models(cls, **kwargs): def get_models(cls, **kwargs):
# Initialize model lists if not exists
if not hasattr(cls, 'image_models'):
cls.image_models = []
if not hasattr(cls, 'text_models'):
cls.text_models = []
# Fetch image models if not cached # Fetch image models if not cached
if not cls.image_models: if not cls.image_models:
url = "https://image.pollinations.ai/models" url = "https://image.pollinations.ai/models"
response = requests.get(url) response = requests.get(url)
raise_for_status(response) raise_for_status(response)
cls.image_models = response.json() cls.image_models = response.json()
cls.image_models.extend(cls.additional_models_image) cls.image_models.extend(cls.extra_image_models)
# Fetch text models if not cached # Fetch text models if not cached
if not cls.text_models: if not cls.text_models:
@@ -73,7 +71,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
response = requests.get(url) response = requests.get(url)
raise_for_status(response) raise_for_status(response)
cls.text_models = [model.get("name") for model in response.json()] cls.text_models = [model.get("name") for model in response.json()]
cls.text_models.extend(cls.additional_models_text) cls.text_models.extend(cls.extra_text_models)
# Return combined models # Return combined models
return cls.text_models + cls.image_models return cls.text_models + cls.image_models
@@ -94,22 +92,27 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
enhance: bool = False, enhance: bool = False,
safe: bool = False, safe: bool = False,
# Text specific parameters # Text specific parameters
temperature: float = 0.5, images: ImagesType = None,
presence_penalty: float = 0, temperature: float = None,
presence_penalty: float = None,
top_p: float = 1, top_p: float = 1,
frequency_penalty: float = 0, frequency_penalty: float = None,
stream: bool = False, response_format: Optional[dict] = None,
cache: bool = False,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if images is not None and not model:
model = cls.default_vision_model
model = cls.get_model(model) model = cls.get_model(model)
if not cache and seed is None:
seed = random.randint(0, 100000)
# Check if models # Check if models
# Image generation # Image generation
if model in cls.image_models: if model in cls.image_models:
async for result in cls._generate_image( yield await cls._generate_image(
model=model, model=model,
messages=messages, prompt=messages[-1]["content"] if prompt is None else prompt,
prompt=prompt,
proxy=proxy, proxy=proxy,
width=width, width=width,
height=height, height=height,
@@ -118,19 +121,21 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private=private, private=private,
enhance=enhance, enhance=enhance,
safe=safe safe=safe
): )
yield result
else: else:
# Text generation # Text generation
async for result in cls._generate_text( async for result in cls._generate_text(
model=model, model=model,
messages=messages, messages=messages,
images=images,
proxy=proxy, proxy=proxy,
temperature=temperature, temperature=temperature,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
top_p=top_p, top_p=top_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
stream=stream response_format=response_format,
seed=seed,
cache=cache,
): ):
yield result yield result
@@ -138,7 +143,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async def _generate_image( async def _generate_image(
cls, cls,
model: str, model: str,
messages: Messages,
prompt: str, prompt: str,
proxy: str, proxy: str,
width: int, width: int,
@@ -148,16 +152,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
private: bool, private: bool,
enhance: bool, enhance: bool,
safe: bool safe: bool
) -> AsyncResult: ) -> ImageResponse:
if seed is None:
seed = random.randint(0, 10000)
headers = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
}
params = { params = {
"seed": seed, "seed": seed,
"width": width, "width": width,
@@ -168,42 +163,47 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"enhance": enhance, "enhance": enhance,
"safe": safe "safe": safe
} }
params = {k: v for k, v in params.items() if v is not None} params = {k: json.dumps(v) if isinstance(v, bool) else v for k, v in params.items() if v is not None}
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
async with ClientSession(headers=headers) as session: async with session.head(f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}", params=params) as response:
prompt = messages[-1]["content"] if prompt is None else prompt await raise_for_status(response)
param_string = "&".join(f"{k}={v}" for k, v in params.items()) return ImageResponse(str(response.url), prompt)
url = f"{cls.image_api_endpoint}/prompt/{quote(prompt)}?{param_string}"
async with session.head(url, proxy=proxy) as response:
if response.status == 200:
image_response = ImageResponse(images=url, alt=prompt)
yield image_response
@classmethod @classmethod
async def _generate_text( async def _generate_text(
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
images: Optional[ImagesType],
proxy: str, proxy: str,
temperature: float, temperature: float,
presence_penalty: float, presence_penalty: float,
top_p: float, top_p: float,
frequency_penalty: float, frequency_penalty: float,
stream: bool, response_format: Optional[dict],
seed: Optional[int] = None seed: Optional[int],
) -> AsyncResult: cache: bool
headers = { ) -> AsyncResult:
"accept": "*/*", jsonMode = False
"accept-language": "en-US,en;q=0.9", if response_format is not None and "type" in response_format:
"content-type": "application/json", if response_format["type"] == "json_object":
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" jsonMode = True
}
if images is not None and messages:
if seed is None: last_message = messages[-1].copy()
seed = random.randint(0, 10000) last_message["content"] = [
*[{
async with ClientSession(headers=headers) as session: "type": "image_url",
"image_url": {"url": to_data_uri(image)}
} for image, _ in images],
{
"type": "text",
"text": messages[-1]["content"]
}
]
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
data = { data = {
"messages": messages, "messages": messages,
"model": model, "model": model,
@@ -211,42 +211,33 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
"presence_penalty": presence_penalty, "presence_penalty": presence_penalty,
"top_p": top_p, "top_p": top_p,
"frequency_penalty": frequency_penalty, "frequency_penalty": frequency_penalty,
"jsonMode": False, "jsonMode": jsonMode,
"stream": stream, "stream": False, # To get more informations like Usage and FinishReason
"seed": seed, "seed": seed,
"cache": False "cache": cache
} }
async with session.post(cls.text_api_endpoint, json=filter_none(**data)) as response:
async with session.post(cls.text_api_endpoint, json=data, proxy=proxy) as response: await raise_for_status(response)
response.raise_for_status() async for line in response.content:
async for chunk in response.content: decoded_chunk = line.decode(errors="replace")
if chunk: # If [DONE].
decoded_chunk = chunk.decode() if "data: [DONE]" in decoded_chunk:
break
# Skip [DONE]. # Processing JSON format
if "data: [DONE]" in decoded_chunk: try:
continue # Remove the prefix “data: “ and parse JSON
json_str = decoded_chunk.replace("data:", "").strip()
# Processing plain text data = json.loads(json_str)
if not decoded_chunk.startswith("data:"): choice = data["choices"][0]
clean_text = decoded_chunk.strip() if "usage" in data:
if clean_text: yield Usage(**data["usage"])
yield clean_text if "message" in choice and "content" in choice["message"] and choice["message"]["content"]:
continue yield choice["message"]["content"].replace("\\(", "(").replace("\\)", ")")
elif "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
# Processing JSON format yield choice["delta"]["content"].replace("\\(", "(").replace("\\)", ")")
try: if "finish_reason" in choice and choice["finish_reason"] is not None:
# Remove the prefix “data: “ and parse JSON yield FinishReason(choice["finish_reason"])
json_str = decoded_chunk.replace("data:", "").strip() break
json_response = json.loads(json_str) except json.JSONDecodeError:
yield decoded_chunk.strip()
if "choices" in json_response and json_response["choices"]: continue
if "delta" in json_response["choices"][0]:
content = json_response["choices"][0]["delta"].get("content")
if content:
# Remove escaped slashes before parentheses
clean_content = content.replace("\\(", "(").replace("\\)", ")")
yield clean_content
except json.JSONDecodeError:
# If JSON could not be parsed, skip
continue

View File

@@ -18,6 +18,7 @@ class Qwen_QVQ_72B(AsyncGeneratorProvider, ProviderModelMixin):
default_model = "qwen-qvq-72b-preview" default_model = "qwen-qvq-72b-preview"
models = [default_model] models = [default_model]
vision_models = models
model_aliases = {"qwq-32b": default_model} model_aliases = {"qwq-32b": default_model}
@classmethod @classmethod

View File

@@ -33,12 +33,18 @@ class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls, **kwargs) -> list[str]: def get_models(cls, **kwargs) -> list[str]:
if not cls.models: if not cls.models:
models = [] models = []
image_models = []
vision_models = []
for provider in cls.providers: for provider in cls.providers:
models.extend(provider.get_models(**kwargs)) models.extend(provider.get_models(**kwargs))
models.extend(provider.model_aliases.keys()) models.extend(provider.model_aliases.keys())
image_models.extend(provider.image_models)
vision_models.extend(provider.vision_models)
models = list(set(models)) models = list(set(models))
models.sort() models.sort()
cls.models = models cls.models = models
cls.image_models = list(set(image_models))
cls.vision_models = list(set(vision_models))
return cls.models return cls.models
@classmethod @classmethod

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
import requests
try: try:
from curl_cffi.requests import Session, CurlMime from curl_cffi.requests import Session, CurlMime
@@ -13,14 +15,13 @@ from ..helper import format_prompt
from ...typing import CreateResult, Messages, Cookies from ...typing import CreateResult, Messages, Cookies
from ...errors import MissingRequirementsError from ...errors import MissingRequirementsError
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning
from ...cookies import get_cookies from ...cookies import get_cookies
from ... import debug from ... import debug
class Conversation(JsonConversation): class Conversation(JsonConversation):
def __init__(self, conversation_id: str, message_id: str): def __init__(self, models: dict):
self.conversation_id: str = conversation_id self.models: dict = models
self.message_id: str = message_id
class HuggingChat(AbstractProvider, ProviderModelMixin): class HuggingChat(AbstractProvider, ProviderModelMixin):
url = "https://huggingface.co/chat" url = "https://huggingface.co/chat"
@@ -32,11 +33,11 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
default_model = "Qwen/Qwen2.5-72B-Instruct" default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev" default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [ image_models = [
"black-forest-labs/FLUX.1-dev", default_image_model,
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
] ]
models = [ fallback_models = [
'Qwen/Qwen2.5-Coder-32B-Instruct', default_model,
'meta-llama/Llama-3.3-70B-Instruct', 'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024', 'CohereForAI/c4ai-command-r-plus-08-2024',
'Qwen/QwQ-32B-Preview', 'Qwen/QwQ-32B-Preview',
@@ -63,12 +64,33 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"flux-schnell": "black-forest-labs/FLUX.1-schnell", "flux-schnell": "black-forest-labs/FLUX.1-schnell",
} }
@classmethod
def get_models(cls):
if not cls.models:
try:
text = requests.get(cls.url).text
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = text.replace('void 0', 'null')
def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":'
text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
models = json.loads(text)
cls.text_models = [model["id"] for model in models]
cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
cls.models = [*cls.fallback_models]
return cls.models
@classmethod @classmethod
def create_completion( def create_completion(
cls, cls,
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool, stream: bool,
prompt: str = None,
return_conversation: bool = False, return_conversation: bool = False,
conversation: Conversation = None, conversation: Conversation = None,
web_search: bool = False, web_search: bool = False,
@@ -99,22 +121,26 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36', 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
} }
if conversation is None: if conversation is None or not hasattr(conversation, "models"):
conversation = Conversation({})
if model not in conversation.models:
conversationId = cls.create_conversation(session, model) conversationId = cls.create_conversation(session, model)
messageId = cls.fetch_message_id(session, conversationId) messageId = cls.fetch_message_id(session, conversationId)
conversation = Conversation(conversationId, messageId) conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation: if return_conversation:
yield conversation yield conversation
inputs = format_prompt(messages) inputs = format_prompt(messages)
else: else:
conversation.message_id = cls.fetch_message_id(session, conversation.conversation_id) conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["message_id"] = cls.fetch_message_id(session, conversationId)
inputs = messages[-1]["content"] inputs = messages[-1]["content"]
debug.log(f"Use conversation: {conversation.conversation_id} Use message: {conversation.message_id}") debug.log(f"Use model {model}: {json.dumps(conversation.models[model])}")
settings = { settings = {
"inputs": inputs, "inputs": inputs,
"id": conversation.message_id, "id": conversation.models[model]["message_id"],
"is_retry": False, "is_retry": False,
"is_continue": False, "is_continue": False,
"web_search": web_search, "web_search": web_search,
@@ -128,7 +154,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'origin': 'https://huggingface.co', 'origin': 'https://huggingface.co',
'pragma': 'no-cache', 'pragma': 'no-cache',
'priority': 'u=1, i', 'priority': 'u=1, i',
'referer': f'https://huggingface.co/chat/conversation/{conversation.conversation_id}', 'referer': f'https://huggingface.co/chat/conversation/{conversationId}',
'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"', 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
'sec-ch-ua-mobile': '?0', 'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"', 'sec-ch-ua-platform': '"macOS"',
@@ -142,7 +168,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
data.addpart('data', data=json.dumps(settings, separators=(',', ':'))) data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
response = session.post( response = session.post(
f'https://huggingface.co/chat/conversation/{conversation.conversation_id}', f'https://huggingface.co/chat/conversation/{conversationId}',
cookies=session.cookies, cookies=session.cookies,
headers=headers, headers=headers,
multipart=data, multipart=data,
@@ -170,10 +196,17 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
elif line["type"] == "finalAnswer": elif line["type"] == "finalAnswer":
break break
elif line["type"] == "file": elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}" url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies}) prompt = messages[-1]["content"] if prompt is None else prompt
yield ImageResponse(url, alt=prompt, options={"cookies": cookies})
elif line["type"] == "webSearch" and "sources" in line: elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"]) sources = Sources(line["sources"])
elif line["type"] == "title":
yield TitleGeneration(line["title"])
elif line["type"] == "reasoning":
yield Reasoning(line.get("token"), line.get("status"))
else:
pass #print(line)
full_response = full_response.replace('<|im_end|', '').strip() full_response = full_response.replace('<|im_end|', '').strip()
if not stream: if not stream:

View File

@@ -143,7 +143,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else: else:
is_special = True is_special = True
debug.log(f"Special token: {is_special}") debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"]) yield FinishReason("stop" if is_special else "length")
else: else:
if response.headers["content-type"].startswith("image/"): if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()])) base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from .OpenaiAPI import OpenaiAPI from .OpenaiAPI import OpenaiAPI
from .HuggingChat import HuggingChat from .HuggingChat import HuggingChat
from ...providers.types import Messages
class HuggingFaceAPI(OpenaiAPI): class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)" label = "HuggingFace (Inference API)"
@@ -11,6 +12,23 @@ class HuggingFaceAPI(OpenaiAPI):
working = True working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model default_vision_model = default_model
models = [
*HuggingChat.models @classmethod
] def get_models(cls, **kwargs):
HuggingChat.get_models()
cls.models = HuggingChat.text_models
cls.vision_models = HuggingChat.vision_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = None,
**kwargs
):
if api_base is None:
api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
async for chunk in super().create_async_generator(model, messages, api_base=api_base, **kwargs):
yield chunk

View File

@@ -73,10 +73,11 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
raise MissingAuthError('Add a "api_key"') raise MissingAuthError('Add a "api_key"')
if api_base is None: if api_base is None:
api_base = cls.api_base api_base = cls.api_base
if images is not None: if images is not None and messages:
if not model and hasattr(cls, "default_vision_model"): if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model model = cls.default_vision_model
messages[-1]["content"] = [ last_message = messages[-1].copy()
last_message["content"] = [
*[{ *[{
"type": "image_url", "type": "image_url",
"image_url": {"url": to_data_uri(image)} "image_url": {"url": to_data_uri(image)}
@@ -86,6 +87,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
"text": messages[-1]["content"] "text": messages[-1]["content"]
} }
] ]
messages[-1] = last_message
async with StreamSession( async with StreamSession(
proxy=proxy, proxy=proxy,
headers=cls.get_headers(stream, api_key, headers), headers=cls.get_headers(stream, api_key, headers),
@@ -117,9 +119,9 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
yield ToolCalls(choice["message"]["tool_calls"]) yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data: if "usage" in data:
yield Usage(**data["usage"]) yield Usage(**data["usage"])
finish = cls.read_finish_reason(choice) if "finish_reason" in choice and choice["finish_reason"] is not None:
if finish is not None: yield FinishReason(choice["finish_reason"])
yield finish return
else: else:
first = True first = True
async for line in response.iter_lines(): async for line in response.iter_lines():
@@ -137,16 +139,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
if delta: if delta:
first = False first = False
yield delta yield delta
finish = cls.read_finish_reason(choice) if "finish_reason" in choice and choice["finish_reason"] is not None:
if finish is not None: yield FinishReason(choice["finish_reason"])
yield finish
break break
@staticmethod
def read_finish_reason(choice: dict) -> Optional[FinishReason]:
if "finish_reason" in choice and choice["finish_reason"] is not None:
return FinishReason(choice["finish_reason"])
@classmethod @classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return { return {

View File

@@ -495,8 +495,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"headers": cls._headers, "headers": cls._headers,
"web_search": web_search, "web_search": web_search,
}) })
actions = ["variant", "continue"] if conversation.finish_reason == "max_tokens" else ["variant"] yield FinishReason(conversation.finish_reason)
yield FinishReason(conversation.finish_reason, actions=actions)
@classmethod @classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator: async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:

View File

@@ -376,6 +376,29 @@ body:not(.white) a:visited{
display: flex; display: flex;
} }
.message .reasoning_text.final:not(.hidden), .message .reasoning_title {
margin-bottom: var(--inner-gap);
padding-bottom: var(--inner-gap);
border-bottom: 1px solid var(--colour-3);
overflow: hidden;
}
.message .reasoning_text.final {
max-height: 1000px;
transition: max-height 0.25s ease-in;
}
.message .reasoning_text.final.hidden {
transition: max-height 0.15s ease-out;
max-height: 0;
display: block;
overflow: hidden;
}
.message .reasoning_title {
cursor: pointer;
}
.message .user i { .message .user i {
position: absolute; position: absolute;
bottom: -6px; bottom: -6px;

View File

@@ -35,6 +35,7 @@ let title_storage = {};
let parameters_storage = {}; let parameters_storage = {};
let finish_storage = {}; let finish_storage = {};
let usage_storage = {}; let usage_storage = {};
let reasoning_storage = {}
messageInput.addEventListener("blur", () => { messageInput.addEventListener("blur", () => {
window.scrollTo(0, 0); window.scrollTo(0, 0);
@@ -70,6 +71,17 @@ if (window.markdownit) {
} }
} }
function render_reasoning(reasoning, final = false) {
return `<div class="reasoning_body">
<div class="reasoning_title">
<strong>Reasoning <i class="fa-solid fa-brain"></i>:</strong> ${escapeHtml(reasoning.status)}
</div>
<div class="reasoning_text${final ? " final hidden" : ""}">
${markdown_render(reasoning.text)}
</div>
</div>`;
}
function filter_message(text) { function filter_message(text) {
return text.replaceAll( return text.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, "" /<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm, ""
@@ -169,7 +181,7 @@ const get_message_el = (el) => {
} }
const register_message_buttons = async () => { const register_message_buttons = async () => {
document.querySelectorAll(".message .content .provider").forEach(async (el) => { message_box.querySelectorAll(".message .content .provider").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
const provider_forms = document.querySelector(".provider_forms"); const provider_forms = document.querySelector(".provider_forms");
@@ -192,7 +204,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-xmark").forEach(async (el) => { message_box.querySelectorAll(".message .fa-xmark").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -203,7 +215,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-clipboard").forEach(async (el) => { message_box.querySelectorAll(".message .fa-clipboard").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -226,7 +238,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-file-export").forEach(async (el) => { message_box.querySelectorAll(".message .fa-file-export").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -244,7 +256,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-volume-high").forEach(async (el) => { message_box.querySelectorAll(".message .fa-volume-high").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -270,7 +282,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .regenerate_button").forEach(async (el) => { message_box.querySelectorAll(".message .regenerate_button").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -282,7 +294,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .continue_button").forEach(async (el) => { message_box.querySelectorAll(".message .continue_button").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -297,7 +309,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => { message_box.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -307,7 +319,7 @@ const register_message_buttons = async () => {
} }
}); });
document.querySelectorAll(".message .fa-print").forEach(async (el) => { message_box.querySelectorAll(".message .fa-print").forEach(async (el) => {
if (!("click" in el.dataset)) { if (!("click" in el.dataset)) {
el.dataset.click = "true"; el.dataset.click = "true";
el.addEventListener("click", async () => { el.addEventListener("click", async () => {
@@ -323,6 +335,16 @@ const register_message_buttons = async () => {
}) })
} }
}); });
message_box.querySelectorAll(".message .reasoning_title").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
let text_el = el.parentElement.querySelector(".reasoning_text");
text_el.classList[text_el.classList.contains("hidden") ? "remove" : "add"]("hidden");
})
}
});
} }
const delete_conversations = async () => { const delete_conversations = async () => {
@@ -469,7 +491,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
messages.forEach((message) => { messages.forEach((message) => {
message_copy = { ...message }; message_copy = { ...message };
if (last_message) { if (last_message) {
if (last_message["role"] == message["role"]) { if (last_message["role"] == message["role"] && message["role"] == "assistant") {
message_copy["content"] = last_message["content"] + message_copy["content"]; message_copy["content"] = last_message["content"] + message_copy["content"];
new_messages.pop(); new_messages.pop();
} }
@@ -515,6 +537,7 @@ const prepare_messages = (messages, message_index = -1, do_continue = false, do_
delete new_message.synthesize; delete new_message.synthesize;
delete new_message.finish; delete new_message.finish;
delete new_message.usage; delete new_message.usage;
delete new_message.reasoning;
delete new_message.conversation; delete new_message.conversation;
delete new_message.continue; delete new_message.continue;
// Append message to new messages // Append message to new messages
@@ -711,11 +734,21 @@ async function add_message_chunk(message, message_id, provider, scroll) {
} else if (message.type == "title") { } else if (message.type == "title") {
title_storage[message_id] = message.title; title_storage[message_id] = message.title;
} else if (message.type == "login") { } else if (message.type == "login") {
update_message(content_map, message_id, message.login, scroll); update_message(content_map, message_id, markdown_render(message.login), scroll);
} else if (message.type == "finish") { } else if (message.type == "finish") {
finish_storage[message_id] = message.finish; finish_storage[message_id] = message.finish;
} else if (message.type == "usage") { } else if (message.type == "usage") {
usage_storage[message_id] = message.usage; usage_storage[message_id] = message.usage;
} else if (message.type == "reasoning") {
if (!reasoning_storage[message_id]) {
reasoning_storage[message_id] = message;
reasoning_storage[message_id].text = "";
} else if (message.status) {
reasoning_storage[message_id].status = message.status;
} else if (message.token) {
reasoning_storage[message_id].text += message.token;
}
update_message(content_map, message_id, render_reasoning(reasoning_storage[message_id]), scroll);
} else if (message.type == "parameters") { } else if (message.type == "parameters") {
if (!parameters_storage[provider]) { if (!parameters_storage[provider]) {
parameters_storage[provider] = {}; parameters_storage[provider] = {};
@@ -846,6 +879,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
title_storage[message_id], title_storage[message_id],
finish_storage[message_id], finish_storage[message_id],
usage_storage[message_id], usage_storage[message_id],
reasoning_storage[message_id],
action=="continue" action=="continue"
); );
delete controller_storage[message_id]; delete controller_storage[message_id];
@@ -1042,6 +1076,7 @@ function merge_messages(message1, message2) {
const load_conversation = async (conversation_id, scroll=true) => { const load_conversation = async (conversation_id, scroll=true) => {
let conversation = await get_conversation(conversation_id); let conversation = await get_conversation(conversation_id);
let messages = conversation?.items || []; let messages = conversation?.items || [];
console.debug("Conversation:", conversation)
if (!conversation) { if (!conversation) {
return; return;
@@ -1098,11 +1133,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
let add_buttons = []; let add_buttons = [];
// Find buttons to add // Find buttons to add
actions = ["variant"] actions = ["variant"]
if (item.finish && item.finish.actions) {
actions = item.finish.actions
}
// Add continue button if possible // Add continue button if possible
if (item.role == "assistant" && !actions.includes("continue")) { if (item.role == "assistant") {
let reason = "stop"; let reason = "stop";
// Read finish reason from conversation // Read finish reason from conversation
if (item.finish && item.finish.reason) { if (item.finish && item.finish.reason) {
@@ -1167,7 +1199,10 @@ const load_conversation = async (conversation_id, scroll=true) => {
</div> </div>
<div class="content"> <div class="content">
${provider} ${provider}
<div class="content_inner">${markdown_render(buffer)}</div> <div class="content_inner">
${item.reasoning ? render_reasoning(item.reasoning, true): ""}
${markdown_render(buffer)}
</div>
<div class="count"> <div class="count">
${count_words_and_tokens(buffer, next_provider?.model, completion_tokens, prompt_tokens)} ${count_words_and_tokens(buffer, next_provider?.model, completion_tokens, prompt_tokens)}
${add_buttons.join("")} ${add_buttons.join("")}
@@ -1298,6 +1333,7 @@ const add_message = async (
title = null, title = null,
finish = null, finish = null,
usage = null, usage = null,
reasoning = null,
do_continue = false do_continue = false
) => { ) => {
const conversation = await get_conversation(conversation_id); const conversation = await get_conversation(conversation_id);
@@ -1329,6 +1365,9 @@ const add_message = async (
if (usage) { if (usage) {
new_message.usage = usage; new_message.usage = usage;
} }
if (reasoning) {
new_message.reasoning = reasoning;
}
if (do_continue) { if (do_continue) {
new_message.continue = true; new_message.continue = true;
} }
@@ -1604,23 +1643,24 @@ function count_words_and_tokens(text, model, completion_tokens, prompt_tokens) {
function update_message(content_map, message_id, content = null, scroll = true) { function update_message(content_map, message_id, content = null, scroll = true) {
content_map.update_timeouts.push(setTimeout(() => { content_map.update_timeouts.push(setTimeout(() => {
if (!content) content = message_storage[message_id]; if (!content) {
html = markdown_render(content); content = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null; let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) { for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element) const index = content.lastIndexOf(element)
if (index - element.length > lastIndex) { if (index - element.length > lastIndex) {
lastElement = element; lastElement = element;
lastIndex = index; lastIndex = index;
}
}
if (lastIndex) {
content = content.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
} }
} }
if (lastIndex) { content_map.inner.innerHTML = content;
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
if (error_storage[message_id]) { if (error_storage[message_id]) {
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${error_storage[message_id]}`); content_map.inner.innerHTML += markdown_render(`**An error occured:** ${error_storage[message_id]}`);
} }
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model); content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner); highlight(content_map.inner);
if (scroll) { if (scroll) {
@@ -2132,9 +2172,9 @@ async function read_response(response, message_id, provider, scroll) {
function get_api_key_by_provider(provider) { function get_api_key_by_provider(provider) {
let api_key = null; let api_key = null;
if (provider) { if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.id || null; api_key = document.querySelector(`.${provider}-api_key`)?.id || null;
if (api_key == null) { if (api_key == null) {
api_key = document.querySelector(`.${provider}-api_key`)?.id || null; api_key = document.getElementById(`${provider}-api_key`)?.id || null;
} }
if (api_key) { if (api_key) {
api_key = appStorage.getItem(api_key); api_key = appStorage.getItem(api_key);

View File

@@ -13,7 +13,7 @@ from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__ from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import IterListProvider from ...providers.retry_provider import IterListProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters
from ... import version, models from ... import version, models
from ... import ChatCompletion, get_model_and_provider from ... import ChatCompletion, get_model_and_provider
@@ -207,6 +207,8 @@ class Api:
yield self._format_json("finish", chunk.get_dict()) yield self._format_json("finish", chunk.get_dict())
elif isinstance(chunk, Usage): elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict()) yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
else: else:
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
if debug.logs: if debug.logs:
@@ -219,10 +221,15 @@ class Api:
if first: if first:
yield self.handle_provider(provider_handler, model) yield self.handle_provider(provider_handler, model)
def _format_json(self, response_type: str, content): def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None:
return {
'type': response_type,
response_type: content,
}
return { return {
'type': response_type, 'type': response_type,
response_type: content **kwargs
} }
def handle_provider(self, provider_handler, model): def handle_provider(self, provider_handler, model):

View File

@@ -309,7 +309,7 @@ class Backend_Api(Api):
return "Provider not found", 404 return "Provider not found", 404
return models return models
def _format_json(self, response_type: str, content) -> str: def _format_json(self, response_type: str, content = None, **kwargs) -> str:
""" """
Formats and returns a JSON response. Formats and returns a JSON response.
@@ -320,4 +320,4 @@ class Backend_Api(Api):
Returns: Returns:
str: A JSON formatted string. str: A JSON formatted string.
""" """
return json.dumps(super()._format_json(response_type, content)) + "\n" return json.dumps(super()._format_json(response_type, content, **kwargs)) + "\n"

View File

@@ -340,7 +340,8 @@ class ProviderModelMixin:
default_model: str = None default_model: str = None
models: list[str] = [] models: list[str] = []
model_aliases: dict[str, str] = {} model_aliases: dict[str, str] = {}
image_models: list = None image_models: list = []
vision_models: list = []
last_model: str = None last_model: str = None
@classmethod @classmethod

View File

@@ -89,9 +89,8 @@ class JsonMixin:
self.__dict__ = {} self.__dict__ = {}
class FinishReason(ResponseType, JsonMixin): class FinishReason(ResponseType, JsonMixin):
def __init__(self, reason: str, actions: list[str] = None) -> None: def __init__(self, reason: str) -> None:
self.reason = reason self.reason = reason
self.actions = actions
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""
@@ -121,6 +120,14 @@ class TitleGeneration(ResponseType):
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""
class Reasoning(ResponseType):
def __init__(self, token: str = None, status: str = None) -> None:
self.token = token
self.status = status
def __str__(self) -> str:
return "" if self.token is None else self.token
class Sources(ResponseType): class Sources(ResponseType):
def __init__(self, sources: list[dict[str, str]]) -> None: def __init__(self, sources: list[dict[str, str]]) -> None:
self.list = [] self.list = []

View File

@@ -25,7 +25,7 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse]
return return
text = await response.text() text = await response.text()
if message is None: if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else text
if message == "HTML content": if message == "HTML content":
if response.status == 520: if response.status == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"
@@ -46,7 +46,7 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
if response.ok: if response.ok:
return return
if message is None: if message is None:
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else response.text message = "HTML content" if response.headers.get("content-type", "").startswith("text/html") else response.text
if message == "HTML content": if message == "HTML content":
if response.status_code == 520: if response.status_code == 520:
message = "Unknown error (Cloudflare)" message = "Unknown error (Cloudflare)"