Fix response type of reasoning in UI

This commit is contained in:
hlohaus
2025-02-01 12:15:46 +01:00
parent 6b8e6adc9d
commit 797b17833a
11 changed files with 63 additions and 151 deletions

View File

@@ -9,7 +9,7 @@ import urllib.parse
from ...typing import AsyncResult, Messages, Cookies from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, format_image_prompt from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation, ImageResponse, Notification from ...providers.response import JsonConversation, ImageResponse, DebugResponse
from ...requests.aiohttp import StreamSession, StreamResponse from ...requests.aiohttp import StreamSession, StreamResponse
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies from ...cookies import get_cookies
@@ -105,7 +105,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
try: try:
json_data = json.loads(decoded_line[6:]) json_data = json.loads(decoded_line[6:])
if json_data.get('msg') == 'log': if json_data.get('msg') == 'log':
yield Notification(json_data["log"]) yield DebugResponse(log=json_data["log"])
if json_data.get('msg') == 'process_generating': if json_data.get('msg') == 'process_generating':
if 'output' in json_data and 'data' in json_data['output']: if 'output' in json_data and 'data' in json_data['output']:

View File

@@ -7,24 +7,26 @@ from typing import AsyncIterator
import asyncio import asyncio
from ..base_provider import AsyncAuthedProvider from ..base_provider import AsyncAuthedProvider
from ...requests import get_args_from_nodriver from ...providers.helper import get_last_user_message
from ... import requests
from ...errors import MissingAuthError
from ...requests import get_args_from_nodriver, get_nodriver
from ...providers.response import AuthResult, RequestLogin, Reasoning, JsonConversation, FinishReason from ...providers.response import AuthResult, RequestLogin, Reasoning, JsonConversation, FinishReason
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ... import debug
try: try:
from curl_cffi import requests from curl_cffi import requests
from dsk.api import DeepSeekAPI, AuthenticationError, DeepSeekPOW from dsk.api import DeepSeekAPI, AuthenticationError, DeepSeekPOW
class DeepSeekAPIArgs(DeepSeekAPI): class DeepSeekAPIArgs(DeepSeekAPI):
def __init__(self, args: dict): def __init__(self, args: dict):
args.pop("headers")
self.auth_token = args.pop("api_key") self.auth_token = args.pop("api_key")
if not self.auth_token or not isinstance(self.auth_token, str): if not self.auth_token or not isinstance(self.auth_token, str):
raise AuthenticationError("Invalid auth token provided") raise AuthenticationError("Invalid auth token provided")
self.args = args self.args = args
self.pow_solver = DeepSeekPOW() self.pow_solver = DeepSeekPOW()
def _make_request(self, method: str, endpoint: str, json_data: dict, pow_required: bool = False): def _make_request(self, method: str, endpoint: str, json_data: dict, pow_required: bool = False, **kwargs):
url = f"{self.BASE_URL}{endpoint}" url = f"{self.BASE_URL}{endpoint}"
headers = self._get_headers() headers = self._get_headers()
if pow_required: if pow_required:
@@ -36,12 +38,15 @@ try:
method=method, method=method,
url=url, url=url,
json=json_data, **{ json=json_data, **{
"headers":headers, **self.args,
"impersonate":'chrome', "headers": {**headers, **self.args["headers"]},
"timeout":None, "timeout":None,
**self.args },
} **kwargs
) )
if response.status_code == 403:
raise MissingAuthError()
response.raise_for_status()
return response.json() return response.json()
except ImportError: except ImportError:
pass pass
@@ -55,6 +60,8 @@ class DeepSeekAPI(AsyncAuthedProvider):
@classmethod @classmethod
async def on_auth_async(cls, proxy: str = None, **kwargs) -> AsyncIterator: async def on_auth_async(cls, proxy: str = None, **kwargs) -> AsyncIterator:
if not hasattr(cls, "browser"):
cls.browser, cls.stop_browser = await get_nodriver()
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "") yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
async def callback(page): async def callback(page):
while True: while True:
@@ -62,7 +69,7 @@ class DeepSeekAPI(AsyncAuthedProvider):
cls._access_token = json.loads(await page.evaluate("localStorage.getItem('userToken')") or "{}").get("value") cls._access_token = json.loads(await page.evaluate("localStorage.getItem('userToken')") or "{}").get("value")
if cls._access_token: if cls._access_token:
break break
args = await get_args_from_nodriver(cls.url, proxy, callback=callback) args = await get_args_from_nodriver(cls.url, proxy, callback=callback, browser=cls.browser)
yield AuthResult( yield AuthResult(
api_key=cls._access_token, api_key=cls._access_token,
**args **args
@@ -88,7 +95,7 @@ class DeepSeekAPI(AsyncAuthedProvider):
is_thinking = 0 is_thinking = 0
for chunk in api.chat_completion( for chunk in api.chat_completion(
conversation.chat_id, conversation.chat_id,
messages[-1]["content"], get_last_user_message(messages),
thinking_enabled=True thinking_enabled=True
): ):
if chunk['type'] == 'thinking': if chunk['type'] == 'thinking':
@@ -100,6 +107,7 @@ class DeepSeekAPI(AsyncAuthedProvider):
if is_thinking: if is_thinking:
yield Reasoning(None, f"Thought for {time.time() - is_thinking:.2f}s") yield Reasoning(None, f"Thought for {time.time() - is_thinking:.2f}s")
is_thinking = 0 is_thinking = 0
if chunk['content']:
yield chunk['content'] yield chunk['content']
if chunk['finish_reason']: if chunk['finish_reason']:
yield FinishReason(chunk['finish_reason']) yield FinishReason(chunk['finish_reason'])

View File

@@ -1,61 +1,15 @@
from __future__ import annotations from __future__ import annotations
import re
import json import json
import time
from urllib.parse import quote_plus
from ...typing import Messages, AsyncResult from ...typing import Messages, AsyncResult
from ...requests import StreamSession from ...requests import StreamSession
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import * from ...providers.response import RawResponse
from ...image import get_image_extension
from ...errors import ModelNotSupportedError
from ..needs_auth.OpenaiAccount import OpenaiAccount
from ..hf.HuggingChat import HuggingChat
from ... import debug from ... import debug
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
ssl = False ssl = None
models = [
*OpenaiAccount.get_models(),
*HuggingChat.get_models(),
"flux",
"flux-pro",
"MiniMax-01",
"Microsoft Copilot",
]
@classmethod
def get_model(cls, model: str):
if "MiniMax" in model:
model = "MiniMax"
elif "Copilot" in model:
model = "Copilot"
elif "FLUX" in model:
model = f"flux-{model.split('-')[-1]}"
elif "flux" in model:
model = model.split(' ')[-1]
elif model in OpenaiAccount.get_models():
pass
elif model in HuggingChat.get_models():
pass
else:
raise ModelNotSupportedError(f"Model: {model}")
return model
@classmethod
def get_provider(cls, model: str):
if model.startswith("MiniMax"):
return "HailuoAI"
elif model == "Copilot":
return "CopilotAccount"
elif model in OpenaiAccount.get_models():
return "OpenaiAccount"
elif model in HuggingChat.get_models():
return "HuggingChat"
return None
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@@ -63,61 +17,16 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
model: str, model: str,
messages: Messages, messages: Messages,
api_key: str = None, api_key: str = None,
proxy: str = None,
timeout: int = 0,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
debug.log(f"{__name__}: {api_key}") debug.log(f"{cls.__name__}: {api_key}")
async with StreamSession( async with StreamSession(
proxy=proxy,
headers={"Accept": "text/event-stream"}, headers={"Accept": "text/event-stream"},
timeout=timeout
) as session: ) as session:
model = cls.get_model(model)
provider = cls.get_provider(model)
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={ async with session.post(f"{cls.url}/backend-api/v2/conversation", json={
"model": model, "model": model,
"messages": messages, "messages": messages,
"provider": provider,
**kwargs **kwargs
}, ssl=cls.ssl) as response: }, ssl=cls.ssl) as response:
async for line in response.iter_lines(): async for line in response.iter_lines():
data = json.loads(line) yield RawResponse(**json.loads(line))
data_type = data.pop("type")
if data_type == "provider":
yield ProviderInfo(**data[data_type])
provider = data[data_type]["name"]
elif data_type == "conversation":
yield JsonConversation(**data[data_type][provider] if provider in data[data_type] else data[data_type][""])
elif data_type == "conversation_id":
pass
elif data_type == "message":
yield Exception(data)
elif data_type == "preview":
yield PreviewResponse(data[data_type])
elif data_type == "content":
def on_image(match):
extension = get_image_extension(match.group(3))
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}"
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}"
return f"[![{match.group(1)}]({download_url})](/images/{filename})"
yield re.sub(r'\[\!\[(.+?)\]\(([^)]+?)\)\]\(([^)]+?)\)', on_image, data["content"])
elif data_type =="synthesize":
yield SynthesizeData(**data[data_type])
elif data_type == "parameters":
yield Parameters(**data[data_type])
elif data_type == "usage":
yield Usage(**data[data_type])
elif data_type == "reasoning":
yield Reasoning(**data)
elif data_type == "login":
pass
elif data_type == "title":
yield TitleGeneration(data[data_type])
elif data_type == "finish":
yield FinishReason(data[data_type]["reason"])
elif data_type == "log":
yield DebugResponse.from_dict(data[data_type])
else:
yield DebugResponse.from_dict(data)

View File

@@ -581,7 +581,7 @@ class Api:
source_url = str(request.query_params).split("url=", 1) source_url = str(request.query_params).split("url=", 1)
if len(source_url) > 1: if len(source_url) > 1:
source_url = source_url[1] source_url = source_url[1]
source_url = source_url.replace("%2F", "/").replace("%3A", ":").replace("%3F", "?") source_url = source_url.replace("%2F", "/").replace("%3A", ":").replace("%3F", "?").replace("%3D", "=")
if source_url.startswith("https://"): if source_url.startswith("https://"):
await copy_images( await copy_images(
[source_url], [source_url],

View File

@@ -779,7 +779,7 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m
} else if (message.type == "reasoning") { } else if (message.type == "reasoning") {
if (!reasoning_storage[message_id]) { if (!reasoning_storage[message_id]) {
reasoning_storage[message_id] = message; reasoning_storage[message_id] = message;
reasoning_storage[message_id].text = ""; reasoning_storage[message_id].text = message.token || "";
} else if (message.status) { } else if (message.status) {
reasoning_storage[message_id].status = message.status; reasoning_storage[message_id].status = message.status;
} else if (message.token) { } else if (message.token) {

View File

@@ -187,8 +187,8 @@ class Api:
elif isinstance(chunk, ImageResponse): elif isinstance(chunk, ImageResponse):
images = chunk images = chunk
if download_images or chunk.get("cookies"): if download_images or chunk.get("cookies"):
alt = format_image_prompt(kwargs.get("messages")) chunk.alt = chunk.alt or format_image_prompt(kwargs.get("messages"))
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy, alt)) images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy=proxy, alt=chunk.alt))
images = ImageResponse(images, chunk.alt) images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt) yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData): elif isinstance(chunk, SynthesizeData):
@@ -204,11 +204,9 @@ class Api:
elif isinstance(chunk, Usage): elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict()) yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning): elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status, is_thinking=chunk.is_thinking) yield self._format_json("reasoning", chunk.get_dict())
elif isinstance(chunk, DebugResponse): elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.get_dict()) yield self._format_json("log", chunk.log)
elif isinstance(chunk, Notification):
yield self._format_json("notification", chunk.message)
else: else:
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
if debug.logs: if debug.logs:
@@ -224,15 +222,6 @@ class Api:
yield self._format_json('error', type(e).__name__, message=get_error_message(e)) yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _format_json(self, response_type: str, content = None, **kwargs): def _format_json(self, response_type: str, content = None, **kwargs):
# Make sure it get be formated as JSON
if content is not None and not isinstance(content, (str, dict)):
content = str(content)
kwargs = {
key: value
if value is isinstance(value, (str, dict))
else str(value)
for key, value in kwargs.items()
if isinstance(key, str)}
if content is not None: if content is not None:
return { return {
'type': response_type, 'type': response_type,

View File

@@ -156,7 +156,7 @@ class Backend_Api(Api):
if has_flask_limiter and app.demo: if has_flask_limiter and app.demo:
@app.route('/backend-api/v2/conversation', methods=['POST']) @app.route('/backend-api/v2/conversation', methods=['POST'])
@limiter.limit("4 per minute") # 1 request in 15 seconds @limiter.limit("2 per minute")
def _handle_conversation(): def _handle_conversation():
limiter.check() limiter.check()
return handle_conversation() return handle_conversation()
@@ -270,7 +270,8 @@ class Backend_Api(Api):
response = iter_run_tools(ChatCompletion.create, **parameters) response = iter_run_tools(ChatCompletion.create, **parameters)
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
with cache_file.open("w") as f: with cache_file.open("w") as f:
f.write(response) for chunk in response:
f.write(str(chunk))
else: else:
response = iter_run_tools(ChatCompletion.create, **parameters) response = iter_run_tools(ChatCompletion.create, **parameters)

View File

@@ -242,13 +242,15 @@ def ensure_images_dir():
os.makedirs(images_dir, exist_ok=True) os.makedirs(images_dir, exist_ok=True)
def get_image_extension(image: str) -> str: def get_image_extension(image: str) -> str:
if match := re.search(r"(\.(?:jpe?g|png|webp))[$?&]", image): match = re.search(r"\.(?:jpe?g|png|webp)", image)
return match.group(1) if match:
return match.group(0)
return ".jpg" return ".jpg"
async def copy_images( async def copy_images(
images: list[str], images: list[str],
cookies: Optional[Cookies] = None, cookies: Optional[Cookies] = None,
headers: Optional[dict] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
alt: str = None, alt: str = None,
add_url: bool = True, add_url: bool = True,
@@ -260,7 +262,8 @@ async def copy_images(
ensure_images_dir() ensure_images_dir()
async with ClientSession( async with ClientSession(
connector=get_connector(proxy=proxy), connector=get_connector(proxy=proxy),
cookies=cookies cookies=cookies,
headers=headers,
) as session: ) as session:
async def copy_image(image: str, target: str = None) -> str: async def copy_image(image: str, target: str = None) -> str:
if target is None or len(images) > 1: if target is None or len(images) > 1:

View File

@@ -88,6 +88,9 @@ class JsonMixin:
def reset(self): def reset(self):
self.__dict__ = {} self.__dict__ = {}
class RawResponse(ResponseType, JsonMixin):
pass
class HiddenResponse(ResponseType): class HiddenResponse(ResponseType):
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""
@@ -113,21 +116,9 @@ class TitleGeneration(HiddenResponse):
def __init__(self, title: str) -> None: def __init__(self, title: str) -> None:
self.title = title self.title = title
class DebugResponse(JsonMixin, HiddenResponse): class DebugResponse(HiddenResponse):
@classmethod def __init__(self, log: str) -> None:
def from_dict(cls, data: dict) -> None: self.log = log
return cls(**data)
@classmethod
def from_str(cls, data: str) -> None:
return cls(error=data)
class Notification(ResponseType):
def __init__(self, message: str) -> None:
self.message = message
def __str__(self) -> str:
return f"{self.message}\n"
class Reasoning(ResponseType): class Reasoning(ResponseType):
def __init__( def __init__(
@@ -149,6 +140,13 @@ class Reasoning(ResponseType):
return f"{self.status}\n" return f"{self.status}\n"
return "" return ""
def get_dict(self):
if self.is_thinking is None:
if self.status is None:
return {"token": self.token}
{"token": self.token, "status": self.status}
return {"token": self.token, "status": self.status, "is_thinking": self.is_thinking}
class Sources(ResponseType): class Sources(ResponseType):
def __init__(self, sources: list[dict[str, str]]) -> None: def __init__(self, sources: list[dict[str, str]]) -> None:
self.list = [] self.list = []

View File

@@ -28,6 +28,7 @@ try:
from nodriver import Browser, Tab, util from nodriver import Browser, Tab, util
has_nodriver = True has_nodriver = True
except ImportError: except ImportError:
from typing import Type as Browser
from typing import Type as Tab from typing import Type as Tab
has_nodriver = False has_nodriver = False
try: try:
@@ -85,9 +86,14 @@ async def get_args_from_nodriver(
timeout: int = 120, timeout: int = 120,
wait_for: str = None, wait_for: str = None,
callback: callable = None, callback: callable = None,
cookies: Cookies = None cookies: Cookies = None,
browser: Browser = None
) -> dict: ) -> dict:
if browser is None:
browser, stop_browser = await get_nodriver(proxy=proxy, timeout=timeout) browser, stop_browser = await get_nodriver(proxy=proxy, timeout=timeout)
else:
def stop_browser():
...
try: try:
if debug.logging: if debug.logging:
print(f"Open nodriver with url: {url}") print(f"Open nodriver with url: {url}")

View File

@@ -157,15 +157,13 @@ def iter_run_tools(
if "<think>" in chunk: if "<think>" in chunk:
chunk = chunk.split("<think>", 1) chunk = chunk.split("<think>", 1)
yield chunk[0] yield chunk[0]
yield Reasoning(is_thinking="<think>") yield Reasoning(None, "Is thinking...", is_thinking="<think>")
yield Reasoning(chunk[1]) yield Reasoning(chunk[1])
yield Reasoning(None, "Is thinking...")
is_thinking = time.time() is_thinking = time.time()
if "</think>" in chunk: if "</think>" in chunk:
chunk = chunk.split("</think>", 1) chunk = chunk.split("</think>", 1)
yield Reasoning(chunk[0]) yield Reasoning(chunk[0])
yield Reasoning(is_thinking="</think>") yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds", is_thinking="</think>")
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
yield chunk[1] yield chunk[1]
is_thinking = 0 is_thinking = 0
elif is_thinking: elif is_thinking: