Support reasoning tokens by default

Add new default HuggingFace provider
Add format_image_prompt and get_last_user_message helper
Add stop_browser callable to get_nodriver function
Fix content type response in images route
This commit is contained in:
hlohaus
2025-01-31 17:36:48 +01:00
parent ce7e9b03a5
commit 89e096334d
41 changed files with 377 additions and 240 deletions

View File

@@ -13,7 +13,7 @@ from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri from ..image import ImageResponse, to_data_uri
from ..cookies import get_cookies_dir from ..cookies import get_cookies_dir
from .helper import format_prompt from .helper import format_prompt, format_image_prompt
from ..providers.response import FinishReason, JsonConversation, Reasoning from ..providers.response import FinishReason, JsonConversation, Reasoning
class Conversation(JsonConversation): class Conversation(JsonConversation):
@@ -216,9 +216,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
if model == "ImageGeneration2": if model == "ImageGeneration2":
prompt = messages[-1]["content"]
data = { data = {
"query": prompt, "query": format_image_prompt(messages, prompt),
"agentMode": True "agentMode": True
} }
headers['content-type'] = 'text/plain;charset=UTF-8' headers['content-type'] = 'text/plain;charset=UTF-8'
@@ -307,8 +306,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image_url_match = re.search(r'!\[.*?\]\((.*?)\)', text_to_yield) image_url_match = re.search(r'!\[.*?\]\((.*?)\)', text_to_yield)
if image_url_match: if image_url_match:
image_url = image_url_match.group(1) image_url = image_url_match.group(1)
prompt = messages[-1]["content"] yield ImageResponse(image_url, format_image_prompt(messages, prompt))
yield ImageResponse(images=[image_url], alt=prompt)
else: else:
if "<think>" in text_to_yield and "</think>" in chunk_text : if "<think>" in text_to_yield and "</think>" in chunk_text :
chunk_text = text_to_yield.split('<think>', 1) chunk_text = text_to_yield.split('<think>', 1)

View File

@@ -28,6 +28,7 @@ from ..providers.response import BaseConversation, JsonConversation, RequestLogi
from ..providers.asyncio import get_running_loop from ..providers.asyncio import get_running_loop
from ..requests import get_nodriver from ..requests import get_nodriver
from ..image import ImageResponse, to_bytes, is_accepted_format from ..image import ImageResponse, to_bytes, is_accepted_format
from .helper import get_last_user_message
from .. import debug from .. import debug
class Conversation(JsonConversation): class Conversation(JsonConversation):
@@ -139,7 +140,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
else: else:
conversation_id = conversation.conversation_id conversation_id = conversation.conversation_id
if prompt is None: if prompt is None:
prompt = messages[-1]["content"] prompt = get_last_user_message(messages)
debug.log(f"Copilot: Use conversation: {conversation_id}") debug.log(f"Copilot: Use conversation: {conversation_id}")
uploaded_images = [] uploaded_images = []
@@ -206,7 +207,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}}) yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}})
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",): async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
browser = await get_nodriver(proxy=proxy, user_data_dir="copilot") browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
try: try:
page = await browser.get(url) page = await browser.get(url)
access_token = None access_token = None
@@ -233,7 +234,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None, target: str
await page.close() await page.close()
return access_token, cookies return access_token, cookies
finally: finally:
browser.stop() stop_browser()
def readHAR(url: str): def readHAR(url: str):
api_key = None api_key = None

View File

@@ -9,7 +9,6 @@ class OIVSCode(OpenaiTemplate):
working = True working = True
needs_auth = False needs_auth = False
default_model = "gpt-4o-mini-2024-07-18" default_model = "gpt-4o-mini"
default_vision_model = default_model default_vision_model = default_model
vision_models = [default_model, "gpt-4o-mini"] vision_models = [default_model, "gpt-4o-mini"]
model_aliases = {"gpt-4o-mini": "gpt-4o-mini-2024-07-18"}

View File

@@ -5,7 +5,7 @@ import json
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status from ..requests import StreamSession, raise_for_status
from ..providers.response import Reasoning, FinishReason from ..providers.response import FinishReason
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
API_URL = "https://www.perplexity.ai/socket.io/" API_URL = "https://www.perplexity.ai/socket.io/"
@@ -87,22 +87,7 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
continue continue
try: try:
data = json.loads(message[2:])[1] data = json.loads(message[2:])[1]
new_content = data["output"][last_message:] yield data["output"][last_message:]
if "<think>" in new_content:
yield Reasoning(None, "thinking")
is_thinking = True
if "</think>" in new_content:
new_content = new_content.split("</think>", 1)
yield Reasoning(f"{new_content[0]}</think>")
yield Reasoning(None, "finished")
yield new_content[1]
is_thinking = False
elif is_thinking:
yield Reasoning(new_content)
else:
yield new_content
last_message = len(data["output"]) last_message = len(data["output"])
if data["final"]: if data["final"]:
yield FinishReason("stop") yield FinishReason("stop")

View File

@@ -7,7 +7,7 @@ from urllib.parse import quote_plus
from typing import Optional from typing import Optional
from aiohttp import ClientSession from aiohttp import ClientSession
from .helper import filter_none from .helper import filter_none, format_image_prompt
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri from ..image import to_data_uri
@@ -127,7 +127,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if model in cls.image_models: if model in cls.image_models:
yield await cls._generate_image( yield await cls._generate_image(
model=model, model=model,
prompt=messages[-1]["content"] if prompt is None else prompt, prompt=format_image_prompt(messages, prompt),
proxy=proxy, proxy=proxy,
width=width, width=width,
height=height, height=height,

View File

@@ -77,7 +77,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
except MissingRequirementsError: except MissingRequirementsError:
pass pass
if not cookies or "afUserId" not in cookies: if not cookies or "afUserId" not in cookies:
browser = await get_nodriver(proxy=proxy) browser, stop_browser = await get_nodriver(proxy=proxy)
try: try:
page = await browser.get(cls.url) page = await browser.get(cls.url)
await page.wait_for('[data-testid="user-profile-button"]', timeout=900) await page.wait_for('[data-testid="user-profile-button"]', timeout=900)
@@ -86,7 +86,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
cookies[c.name] = c.value cookies[c.name] = c.value
await page.close() await page.close()
finally: finally:
browser.stop() stop_browser()
async with StreamSession( async with StreamSession(
proxy=proxy, proxy=proxy,
impersonate="chrome", impersonate="chrome",

View File

@@ -9,6 +9,7 @@ from .deprecated import *
from .needs_auth import * from .needs_auth import *
from .not_working import * from .not_working import *
from .local import * from .local import *
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference
from .hf_space import HuggingSpace from .hf_space import HuggingSpace
from .mini_max import HailuoAI, MiniMax from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi from .template import OpenaiTemplate, BackendApi

View File

@@ -14,7 +14,7 @@ except ImportError:
has_curl_cffi = False has_curl_cffi = False
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ..helper import format_prompt from ..helper import format_prompt, format_image_prompt, get_last_user_message
from ...typing import AsyncResult, Messages, Cookies, ImagesType from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
from ...image import to_bytes from ...image import to_bytes
@@ -22,6 +22,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies from ...cookies import get_cookies
from .models import default_model, fallback_models, image_models, model_aliases
from ... import debug from ... import debug
class Conversation(JsonConversation): class Conversation(JsonConversation):
@@ -30,52 +31,14 @@ class Conversation(JsonConversation):
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin): class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
url = "https://huggingface.co/chat" url = "https://huggingface.co/chat"
working = True working = True
use_nodriver = True use_nodriver = True
supports_stream = True supports_stream = True
needs_auth = True needs_auth = True
default_model = default_model
default_model = "Qwen/Qwen2.5-72B-Instruct" model_aliases = model_aliases
default_image_model = "black-forest-labs/FLUX.1-dev" image_models = image_models
image_models = [
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
'Qwen/QwQ-32B-Preview',
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
'Qwen/Qwen2.5-Coder-32B-Instruct',
'meta-llama/Llama-3.2-11B-Vision-Instruct',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
] + image_models
model_aliases = {
### Chat ###
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"qwq-32b": "Qwen/QwQ-32B-Preview",
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
### Image ###
"flux-dev": "black-forest-labs/FLUX.1-dev",
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
### Used in other providers ###
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
"gemma-2-27b": "google/gemma-2-27b-it",
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
"qvq-72b": "Qwen/QVQ-72B-Preview",
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
}
@classmethod @classmethod
def get_models(cls): def get_models(cls):
@@ -94,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
cls.vision_models = [model["id"] for model in models if model["multimodal"]] cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e: except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}") debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
cls.models = [*cls.fallback_models] cls.models = [*fallback_models]
return cls.models return cls.models
@classmethod @classmethod
@@ -108,9 +71,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
headers=DEFAULT_HEADERS headers=DEFAULT_HEADERS
) )
return return
login_url = os.environ.get("G4F_LOGIN_URL") yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
if login_url:
yield RequestLogin(cls.__name__, login_url)
yield AuthResult( yield AuthResult(
**await get_args_from_nodriver( **await get_args_from_nodriver(
cls.url, cls.url,
@@ -143,6 +104,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
if model not in conversation.models: if model not in conversation.models:
conversationId = cls.create_conversation(session, model) conversationId = cls.create_conversation(session, model)
debug.log(f"Conversation created: {json.dumps(conversationId[8:] + '...')}")
messageId = cls.fetch_message_id(session, conversationId) messageId = cls.fetch_message_id(session, conversationId)
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId} conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation: if return_conversation:
@@ -151,9 +113,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
else: else:
conversationId = conversation.models[model]["conversationId"] conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId) conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
inputs = messages[-1]["content"] inputs = get_last_user_message(messages)
debug.log(f"Use: {json.dumps(conversation.models[model])}")
settings = { settings = {
"inputs": inputs, "inputs": inputs,
@@ -204,8 +164,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
break break
elif line["type"] == "file": elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}" url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
prompt = messages[-1]["content"] if prompt is None else prompt yield ImageResponse(url, format_image_prompt(messages, prompt), options={"cookies": auth_result.cookies})
yield ImageResponse(url, alt=prompt, options={"cookies": auth_result.cookies})
elif line["type"] == "webSearch" and "sources" in line: elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"]) sources = Sources(line["sources"])
elif line["type"] == "title": elif line["type"] == "title":
@@ -226,6 +185,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
response = session.post('https://huggingface.co/chat/conversation', json=json_data) response = session.post('https://huggingface.co/chat/conversation', json=json_data)
if response.status_code == 401: if response.status_code == 401:
raise MissingAuthError(response.text) raise MissingAuthError(response.text)
if response.status_code == 400:
raise ResponseError(f"{response.text}: Model: {model}")
raise_for_status(response) raise_for_status(response)
return response.json().get('conversationId') return response.json().get('conversationId')

View File

@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from ..template import OpenaiTemplate from ..template.OpenaiTemplate import OpenaiTemplate
from .HuggingChat import HuggingChat from .models import model_aliases
from ...providers.types import Messages from ...providers.types import Messages
from .HuggingChat import HuggingChat
from ... import debug from ... import debug
class HuggingFaceAPI(OpenaiTemplate): class HuggingFaceAPI(OpenaiTemplate):
@@ -16,13 +17,16 @@ class HuggingFaceAPI(OpenaiTemplate):
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model default_vision_model = default_model
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"] vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
model_aliases = HuggingChat.model_aliases model_aliases = model_aliases
@classmethod @classmethod
def get_models(cls, **kwargs): def get_models(cls, **kwargs):
if not cls.models: if not cls.models:
HuggingChat.get_models() HuggingChat.get_models()
cls.models = list(set(HuggingChat.text_models + cls.vision_models)) cls.models = HuggingChat.text_models.copy()
for model in cls.vision_models:
if model not in cls.models:
cls.models.append(model)
return cls.models return cls.models
@classmethod @classmethod

View File

@@ -11,37 +11,32 @@ from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason from ...providers.response import FinishReason
from ...image import ImageResponse from ...image import ImageResponse
from ..helper import format_image_prompt
from .models import default_model, default_image_model, model_aliases, fallback_models
from ... import debug from ... import debug
from .HuggingChat import HuggingChat class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co" url = "https://huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
working = True working = True
supports_message_history = True
default_model = HuggingChat.default_model default_model = default_model
default_image_model = HuggingChat.default_image_model default_image_model = default_image_model
model_aliases = HuggingChat.model_aliases model_aliases = model_aliases
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
]
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls) -> list[str]:
if not cls.models: if not cls.models:
models = fallback_models.copy()
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation" url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
models = [model["id"] for model in requests.get(url).json()] extra_models = [model["id"] for model in requests.get(url).json()]
models.extend(cls.extra_models) extra_models.sort()
models.sort() models.extend([model for model in extra_models if model not in models])
if not cls.image_models: if not cls.image_models:
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image" url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20] cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
cls.image_models.sort() cls.image_models.sort()
models.extend(cls.image_models) models.extend([model for model in cls.image_models if model not in models])
cls.models = list(set(models)) cls.models = models
return cls.models return cls.models
@classmethod @classmethod
@@ -85,7 +80,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
payload = None payload = None
if cls.get_models() and model in cls.image_models: if cls.get_models() and model in cls.image_models:
stream = False stream = False
prompt = messages[-1]["content"] if prompt is None else prompt prompt = format_image_prompt(messages, prompt)
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}} payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
else: else:
params = { params = {

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import random
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse
from ...errors import ModelNotSupportedError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI
from .HuggingFaceInference import HuggingFaceInference
from ... import debug
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
working = True
supports_message_history = True
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
cls.models = HuggingFaceInference.get_models()
cls.image_models = HuggingFaceInference.image_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
**kwargs
) -> AsyncResult:
if "api_key" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
try:
is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
if isinstance(chunk, (str, ImageResponse)):
is_started = True
yield chunk
if is_started:
return
except Exception as e:
if is_started:
raise e
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
if not cls.image_models:
cls.get_models()
if model in cls.image_models:
if "api_key" not in kwargs:
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):
yield chunk
else:
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk
return
try:
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
yield chunk
except ModelNotSupportedError:
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk

46
g4f/Provider/hf/models.py Normal file
View File

@@ -0,0 +1,46 @@
default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
'Qwen/QwQ-32B-Preview',
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
'Qwen/Qwen2.5-Coder-32B-Instruct',
'meta-llama/Llama-3.2-11B-Vision-Instruct',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
] + image_models
model_aliases = {
### Chat ###
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"qwq-32b": "Qwen/QwQ-32B-Preview",
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
### Image ###
"flux": "black-forest-labs/FLUX.1-dev",
"flux-dev": "black-forest-labs/FLUX.1-dev",
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
### Used in other providers ###
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
"gemma-2-27b": "google/gemma-2-27b-it",
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
"qvq-72b": "Qwen/QVQ-72B-Preview",
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
}
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
]

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview from ...image import ImageResponse, ImagePreview
from ...errors import ResponseError from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-dev.hf.space" url = "https://black-forest-labs-flux-1-dev.hf.space"
@@ -44,7 +45,7 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None: if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt prompt = format_image_prompt(messages, prompt)
data = { data = {
"data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] "data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
} }

View File

@@ -8,6 +8,7 @@ from ...image import ImageResponse
from ...errors import ResponseError from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-schnell.hf.space" url = "https://black-forest-labs-flux-1-schnell.hf.space"
@@ -42,7 +43,7 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
height = max(32, height - (height % 8)) height = max(32, height - (height % 8))
if prompt is None: if prompt is None:
prompt = messages[-1]["content"] prompt = format_image_prompt(messages)
payload = { payload = {
"data": [ "data": [

View File

@@ -6,7 +6,7 @@ from aiohttp import ClientSession, FormData
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ...requests import raise_for_status from ...requests import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt from ..helper import format_prompt, get_last_user_message
from ...providers.response import JsonConversation, TitleGeneration from ...providers.response import JsonConversation, TitleGeneration
class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin): class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
@@ -58,7 +58,7 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
) as session: ) as session:
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"]) system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
messages = [message for message in messages if message["role"] != "system"] messages = [message for message in messages if message["role"] != "system"]
inputs = format_prompt(messages) if conversation is None else messages[-1]["content"] inputs = format_prompt(messages) if conversation is None else get_last_user_message(messages)
if conversation is None or conversation.model != model or conversation.preprompt != system_prompt: if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
data = {"model": model, "preprompt": system_prompt} data = {"model": model, "preprompt": system_prompt}
async with session.post(cls.conversation_url, json=data, proxy=proxy) as response: async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
@@ -78,7 +78,6 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
data = node["data"] data = node["data"]
message_id = data[data[data[data[0]["messages"]][-1]]["id"]] message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
data = FormData() data = FormData()
inputs = messages[-1]["content"]
data.add_field( data.add_field(
"data", "data",
json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}), json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),

View File

@@ -8,8 +8,8 @@ import urllib.parse
from ...typing import AsyncResult, Messages, Cookies from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation, ImageResponse from ...providers.response import JsonConversation, ImageResponse, Notification
from ...requests.aiohttp import StreamSession, StreamResponse from ...requests.aiohttp import StreamSession, StreamResponse
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies from ...cookies import get_cookies
@@ -38,7 +38,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
"headers": { "headers": {
"content-type": "application/json", "content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token, "x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.uuid, "x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer, "referer": cls.referer,
}, },
"json": {"data":[None,prompt,42,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash}, "json": {"data":[None,prompt,42,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash},
@@ -48,7 +48,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
"headers": { "headers": {
"content-type": "application/json", "content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token, "x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.uuid, "x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer, "referer": cls.referer,
}, },
"json": {"data":[prompt,1234,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash}, "json": {"data":[prompt,1234,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash},
@@ -82,33 +82,14 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
method = "image" method = "image"
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
prompt = messages[-1]["content"] if prompt is None else prompt prompt = format_image_prompt(messages, prompt)
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash") session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
async with StreamSession(proxy=proxy, impersonate="chrome") as session: async with StreamSession(proxy=proxy, impersonate="chrome") as session:
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash") session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
user_uuid = None if conversation is None else getattr(conversation, "user_uuid", None) zerogpu_uuid, zerogpu_token = await get_zerogpu_token(session, conversation, cookies)
zerogpu_token = "[object Object]"
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
if cookies:
# Get current UTC time + 10 minutes
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
encoded_dt = urllib.parse.quote(dt)
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
zerogpu_token = (await response.json())
zerogpu_token = zerogpu_token["token"]
if user_uuid is None:
async with session.get(cls.url, cookies=cookies) as response:
match = re.search(r"&quot;token&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_token = match.group(1)
match = re.search(r"&quot;sessionUuid&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
user_uuid = match.group(1)
if conversation is None or not hasattr(conversation, "session_hash"): if conversation is None or not hasattr(conversation, "session_hash"):
conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, uuid=user_uuid) conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, zerogpu_uuid=zerogpu_uuid)
conversation.zerogpu_token = zerogpu_token conversation.zerogpu_token = zerogpu_token
if return_conversation: if return_conversation:
yield conversation yield conversation
@@ -124,7 +105,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
try: try:
json_data = json.loads(decoded_line[6:]) json_data = json.loads(decoded_line[6:])
if json_data.get('msg') == 'log': if json_data.get('msg') == 'log':
debug.log(json_data["log"]) yield Notification(json_data["log"])
if json_data.get('msg') == 'process_generating': if json_data.get('msg') == 'process_generating':
if 'output' in json_data and 'data' in json_data['output']: if 'output' in json_data and 'data' in json_data['output']:
@@ -141,4 +122,27 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
break break
except json.JSONDecodeError: except json.JSONDecodeError:
debug.log("Could not parse JSON:", decoded_line) debug.log("Could not parse JSON:", decoded_line)
async def get_zerogpu_token(session: StreamSession, conversation: JsonConversation, cookies: Cookies = None):
zerogpu_uuid = None if conversation is None else getattr(conversation, "zerogpu_uuid", None)
zerogpu_token = "[object Object]"
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
if zerogpu_uuid is None:
async with session.get(Janus_Pro_7B.url, cookies=cookies) as response:
match = re.search(r"&quot;token&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_token = match.group(1)
match = re.search(r"&quot;sessionUuid&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_uuid = match.group(1)
if cookies:
# Get current UTC time + 10 minutes
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
encoded_dt = urllib.parse.quote(dt)
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
zerogpu_token = (await response.json())
zerogpu_token = zerogpu_token["token"]
return zerogpu_uuid, zerogpu_token

View File

@@ -8,6 +8,7 @@ from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt from ..helper import format_prompt
from ...providers.response import JsonConversation, Reasoning from ...providers.response import JsonConversation, Reasoning
from ..helper import get_last_user_message
from ... import debug from ... import debug
class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin): class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
@@ -41,7 +42,7 @@ class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation: if return_conversation:
yield JsonConversation(session_hash=session_hash) yield JsonConversation(session_hash=session_hash)
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"] prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
headers = { headers = {
'accept': '*/*', 'accept': '*/*',

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview from ...image import ImageResponse, ImagePreview
from ...errors import ResponseError from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin): class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://stabilityai-stable-diffusion-3-5-large.hf.space" url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
@@ -42,7 +43,7 @@ class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None: if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt prompt = format_image_prompt(messages, prompt)
data = { data = {
"data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] "data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
} }

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse from ...image import ImageResponse
from ...errors import ResponseError from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ..helper import format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin): class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
@@ -37,10 +38,7 @@ class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult: ) -> AsyncResult:
width = max(32, width - (width % 8)) width = max(32, width - (width % 8))
height = max(32, height - (height % 8)) height = max(32, height - (height % 8))
prompt = format_image_prompt(messages, prompt)
if prompt is None:
prompt = messages[-1]["content"]
payload = { payload = {
"data": [ "data": [
prompt, prompt,

View File

@@ -10,6 +10,7 @@ from ..base_provider import AsyncAuthedProvider, ProviderModelMixin, format_prom
from ..mini_max.crypt import CallbackResults, get_browser_callback, generate_yy_header, get_body_to_yy from ..mini_max.crypt import CallbackResults, get_browser_callback, generate_yy_header, get_body_to_yy
from ...requests import get_args_from_nodriver, raise_for_status from ...requests import get_args_from_nodriver, raise_for_status
from ...providers.response import AuthResult, JsonConversation, RequestLogin, TitleGeneration from ...providers.response import AuthResult, JsonConversation, RequestLogin, TitleGeneration
from ..helper import get_last_user_message
from ... import debug from ... import debug
class Conversation(JsonConversation): class Conversation(JsonConversation):
@@ -62,7 +63,7 @@ class HailuoAI(AsyncAuthedProvider, ProviderModelMixin):
conversation = None conversation = None
form_data = { form_data = {
"characterID": 1 if conversation is None else getattr(conversation, "characterID", 1), "characterID": 1 if conversation is None else getattr(conversation, "characterID", 1),
"msgContent": format_prompt(messages) if conversation is None else messages[-1]["content"], "msgContent": format_prompt(messages) if conversation is None else get_last_user_message(messages),
"chatID": 0 if conversation is None else getattr(conversation, "chatID", 0), "chatID": 0 if conversation is None else getattr(conversation, "chatID", 0),
"searchMode": 0 "searchMode": 0
} }

View File

@@ -98,7 +98,7 @@ class Anthropic(OpenaiAPI):
"text": messages[-1]["content"] "text": messages[-1]["content"]
} }
] ]
system = "\n".join([message for message in messages if message.get("role") == "system"]) system = "\n".join([message["content"] for message in messages if message.get("role") == "system"])
if system: if system:
messages = [message for message in messages if message.get("role") != "system"] messages = [message for message in messages if message.get("role") != "system"]
else: else:

View File

@@ -6,6 +6,7 @@ from ...errors import MissingAuthError
from ...typing import AsyncResult, Messages, Cookies from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..bing.create_images import create_images, create_session from ..bing.create_images import create_images, create_session
from ..helper import format_image_prompt
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin): class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
label = "Microsoft Designer in Bing" label = "Microsoft Designer in Bing"
@@ -35,7 +36,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
session = BingCreateImages(cookies, proxy, api_key) session = BingCreateImages(cookies, proxy, api_key)
yield await session.generate(messages[-1]["content"] if prompt is None else prompt) yield await session.generate(format_image_prompt(messages, prompt))
async def generate(self, prompt: str) -> ImageResponse: async def generate(self, prompt: str) -> ImageResponse:
""" """

View File

@@ -5,6 +5,7 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...image import ImageResponse from ...image import ImageResponse
from ..template import OpenaiTemplate from ..template import OpenaiTemplate
from ..helper import format_image_prompt
class DeepInfra(OpenaiTemplate): class DeepInfra(OpenaiTemplate):
url = "https://deepinfra.com" url = "https://deepinfra.com"
@@ -55,7 +56,7 @@ class DeepInfra(OpenaiTemplate):
) -> AsyncResult: ) -> AsyncResult:
if model in cls.get_image_models(): if model in cls.get_image_models():
yield cls.create_async_image( yield cls.create_async_image(
messages[-1]["content"] if prompt is None else prompt, format_image_prompt(messages, prompt),
model, model,
**kwargs **kwargs
) )

View File

@@ -31,7 +31,7 @@ try:
challenge = self._get_pow_challenge() challenge = self._get_pow_challenge()
pow_response = self.pow_solver.solve_challenge(challenge) pow_response = self.pow_solver.solve_challenge(challenge)
headers = self._get_headers(pow_response) headers = self._get_headers(pow_response)
response = requests.request( response = requests.request(
method=method, method=method,
url=url, url=url,

View File

@@ -25,6 +25,7 @@ from ...requests.aiohttp import get_connector
from ...requests import get_nodriver from ...requests import get_nodriver
from ...errors import MissingAuthError from ...errors import MissingAuthError
from ...image import ImageResponse, to_bytes from ...image import ImageResponse, to_bytes
from ..helper import get_last_user_message
from ... import debug from ... import debug
REQUEST_HEADERS = { REQUEST_HEADERS = {
@@ -78,7 +79,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
if debug.logging: if debug.logging:
print("Skip nodriver login in Gemini provider") print("Skip nodriver login in Gemini provider")
return return
browser = await get_nodriver(proxy=proxy, user_data_dir="gemini") browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
try: try:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
@@ -91,7 +92,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
await page.close() await page.close()
cls._cookies = cookies cls._cookies = cookies
finally: finally:
browser.stop() stop_browser()
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@@ -107,7 +108,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
language: str = "en", language: str = "en",
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"] prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True) cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy) base_connector = get_connector(connector, proxy)

View File

@@ -7,7 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConv
from ...typing import AsyncResult, Messages, Cookies from ...typing import AsyncResult, Messages, Cookies
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector from ...requests.aiohttp import get_connector
from ...providers.helper import format_prompt from ...providers.helper import format_prompt, get_last_user_message
from ...cookies import get_cookies from ...cookies import get_cookies
class Conversation(BaseConversation): class Conversation(BaseConversation):
@@ -78,7 +78,7 @@ class GithubCopilot(AsyncGeneratorProvider, ProviderModelMixin):
conversation_id = (await response.json()).get("thread_id") conversation_id = (await response.json()).get("thread_id")
if return_conversation: if return_conversation:
yield Conversation(conversation_id) yield Conversation(conversation_id)
content = messages[-1]["content"] content = get_last_user_message(messages)
else: else:
content = format_prompt(messages) content = format_prompt(messages)
json_data = { json_data = {

View File

@@ -14,7 +14,7 @@ from ...requests.aiohttp import get_connector
from ...requests import get_nodriver from ...requests import get_nodriver
from ..Copilot import get_headers, get_har_files from ..Copilot import get_headers, get_har_files
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_random_hex from ..helper import get_random_hex, format_image_prompt
from ... import debug from ... import debug
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin): class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
@@ -39,7 +39,7 @@ class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
image_size = "1024x1024" image_size = "1024x1024"
if model != cls.default_image_model and model in cls.image_models: if model != cls.default_image_model and model in cls.image_models:
image_size = model image_size = model
yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy) yield await cls.generate(format_image_prompt(messages, prompt), image_size, proxy)
@classmethod @classmethod
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse: async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
@@ -143,7 +143,7 @@ def readHAR(url: str) -> tuple[str, str]:
return api_key, user_agent return api_key, user_agent
async def get_access_token_and_user_agent(url: str, proxy: str = None): async def get_access_token_and_user_agent(url: str, proxy: str = None):
browser = await get_nodriver(proxy=proxy, user_data_dir="designer") browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="designer")
try: try:
page = await browser.get(url) page = await browser.get(url)
user_agent = await page.evaluate("navigator.userAgent") user_agent = await page.evaluate("navigator.userAgent")
@@ -168,4 +168,4 @@ async def get_access_token_and_user_agent(url: str, proxy: str = None):
await page.close() await page.close()
return access_token, user_agent return access_token, user_agent
finally: finally:
browser.stop() stop_browser()

View File

@@ -98,7 +98,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
default_model = "auto" default_model = "auto"
default_image_model = "dall-e-3" default_image_model = "dall-e-3"
image_models = [default_image_model] image_models = [default_image_model]
text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1", "o1-preview", "o1-mini"] text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "o1", "o1-preview", "o1-mini"]
vision_models = text_models vision_models = text_models
models = text_models + image_models models = text_models + image_models
synthesize_content_type = "audio/mpeg" synthesize_content_type = "audio/mpeg"
@@ -598,7 +598,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
@classmethod @classmethod
async def nodriver_auth(cls, proxy: str = None): async def nodriver_auth(cls, proxy: str = None):
browser = await get_nodriver(proxy=proxy) browser, stop_browser = await get_nodriver(proxy=proxy)
try: try:
page = browser.main_tab page = browser.main_tab
def on_request(event: nodriver.cdp.network.RequestWillBeSent): def on_request(event: nodriver.cdp.network.RequestWillBeSent):
@@ -648,7 +648,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent) cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
cls._set_api_key(cls._api_key) cls._set_api_key(cls._api_key)
finally: finally:
browser.stop() stop_browser()
@staticmethod @staticmethod
def get_default_headers() -> Dict[str, str]: def get_default_headers() -> Dict[str, str]:

View File

@@ -13,9 +13,6 @@ from .GigaChat import GigaChat
from .GithubCopilot import GithubCopilot from .GithubCopilot import GithubCopilot
from .GlhfChat import GlhfChat from .GlhfChat import GlhfChat
from .Groq import Groq from .Groq import Groq
from .HuggingChat import HuggingChat
from .HuggingFace import HuggingFace
from .HuggingFaceAPI import HuggingFaceAPI
from .MetaAI import MetaAI from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount from .MetaAIAccount import MetaAIAccount
from .MicrosoftDesigner import MicrosoftDesigner from .MicrosoftDesigner import MicrosoftDesigner

View File

@@ -8,11 +8,11 @@ from urllib.parse import quote_plus
from ...typing import Messages, AsyncResult from ...typing import Messages, AsyncResult
from ...requests import StreamSession from ...requests import StreamSession
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import ProviderInfo, JsonConversation, PreviewResponse, SynthesizeData, TitleGeneration, RequestLogin from ...providers.response import *
from ...providers.response import Parameters, FinishReason, Usage, Reasoning from ...image import get_image_extension
from ...errors import ModelNotSupportedError from ...errors import ModelNotSupportedError
from ..needs_auth.OpenaiAccount import OpenaiAccount from ..needs_auth.OpenaiAccount import OpenaiAccount
from ..needs_auth.HuggingChat import HuggingChat from ..hf.HuggingChat import HuggingChat
from ... import debug from ... import debug
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
@@ -98,8 +98,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
yield PreviewResponse(data[data_type]) yield PreviewResponse(data[data_type])
elif data_type == "content": elif data_type == "content":
def on_image(match): def on_image(match):
extension = match.group(3).split(".")[-1].split("?")[0] extension = get_image_extension(match.group(3))
extension = "" if not extension or len(extension) > 4 else f".{extension}"
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}" filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}"
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}" download_url = f"/download/{filename}?url={cls.url}{match.group(3)}"
return f"[![{match.group(1)}]({download_url})](/images/{filename})" return f"[![{match.group(1)}]({download_url})](/images/{filename})"
@@ -119,6 +118,6 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
elif data_type == "finish": elif data_type == "finish":
yield FinishReason(data[data_type]["reason"]) yield FinishReason(data[data_type]["reason"])
elif data_type == "log": elif data_type == "log":
debug.log(data[data_type]) yield DebugResponse.from_dict(data[data_type])
else: else:
debug.log(f"Unknown data: ({data_type}) {data}") yield DebugResponse.from_dict(data)

View File

@@ -4,11 +4,11 @@ import json
import time import time
import requests import requests
from ..helper import filter_none from ..helper import filter_none, format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, Reasoning, ImageResponse from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri from ...image import to_data_uri
from ... import debug from ... import debug
@@ -82,7 +82,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
# Proxy for image generation feature # Proxy for image generation feature
if model and model in cls.image_models: if model and model in cls.image_models:
data = { data = {
"prompt": messages[-1]["content"] if prompt is None else prompt, "prompt": format_image_prompt(messages, prompt),
"model": model, "model": model,
} }
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response: async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
@@ -154,17 +154,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
delta = delta.lstrip() delta = delta.lstrip()
if delta: if delta:
first = False first = False
if is_thinking: yield delta
if "</think>" in delta:
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
is_thinking = 0
else:
yield Reasoning(delta)
elif "<think>" in delta:
is_thinking = time.time()
yield Reasoning(None, "Is thinking...")
else:
yield delta
if "usage" in data and data["usage"]: if "usage" in data and data["usage"]:
yield Usage(**data["usage"]) yield Usage(**data["usage"])
if "finish_reason" in choice and choice["finish_reason"] is not None: if "finish_reason" in choice and choice["finish_reason"] is not None:

View File

@@ -6,6 +6,7 @@ import uvicorn
import secrets import secrets
import os import os
import shutil import shutil
import time
from email.utils import formatdate from email.utils import formatdate
import os.path import os.path
import hashlib import hashlib
@@ -538,6 +539,10 @@ class Api:
response_data = provider_handler.synthesize({**request.query_params}) response_data = provider_handler.synthesize({**request.query_params})
content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream") content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
return StreamingResponse(response_data, media_type=content_type) return StreamingResponse(response_data, media_type=content_type)
@self.app.get("/json/{filename}")
async def get_json(filename, request: Request):
return ""
@self.app.get("/images/{filename}", response_class=FileResponse, responses={ @self.app.get("/images/{filename}", response_class=FileResponse, responses={
HTTP_200_OK: {"content": {"image/*": {}}}, HTTP_200_OK: {"content": {"image/*": {}}},
@@ -550,15 +555,18 @@ class Api:
stat_result.st_size = 0 stat_result.st_size = 0
if os.path.isfile(target): if os.path.isfile(target):
stat_result.st_size = os.stat(target).st_size stat_result.st_size = os.stat(target).st_size
stat_result.st_mtime = int(f"{filename.split('_')[0]}") stat_result.st_mtime = int(f"{filename.split('_')[0]}") if filename.startswith("1") else 0
headers = {
"cache-control": "public, max-age=31536000",
"content-type": f"image/{ext.replace('jpg', 'jepg')}",
"content-length": str(stat_result.st_size),
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"',
}
response = FileResponse( response = FileResponse(
target, target,
media_type=f"image/{ext.replace('jpg', 'jepg')}", headers=headers,
headers={ filename=filename,
"content-length": str(stat_result.st_size),
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"'
},
) )
try: try:
if_none_match = request.headers["if-none-match"] if_none_match = request.headers["if-none-match"]

View File

@@ -8,11 +8,12 @@ try:
except ImportError as e: except ImportError as e:
import_error = e import_error = e
def get_gui_app(demo: bool = False): def get_gui_app(demo: bool = False, api: bool = False):
if import_error is not None: if import_error is not None:
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}') raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
app = create_app() app = create_app()
app.demo = demo app.demo = demo
app.api = api
site = Website(app) site = Website(app)
for route in site.routes: for route in site.routes:

View File

@@ -13,8 +13,8 @@ from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__ from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import BaseRetryProvider from ...providers.retry_provider import BaseRetryProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning, PreviewResponse from ...providers.helper import format_image_prompt
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters, ProviderInfo from ...providers.response import *
from ... import version, models from ... import version, models
from ... import ChatCompletion, get_model_and_provider from ... import ChatCompletion, get_model_and_provider
from ... import debug from ... import debug
@@ -183,13 +183,14 @@ class Api:
logger.exception(chunk) logger.exception(chunk)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__) yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, (PreviewResponse, ImagePreview)): elif isinstance(chunk, (PreviewResponse, ImagePreview)):
yield self._format_json("preview", chunk.to_string()) yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
elif isinstance(chunk, ImageResponse): elif isinstance(chunk, ImageResponse):
images = chunk images = chunk
if download_images: if download_images or chunk.get("cookies"):
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy)) alt = format_image_prompt(kwargs.get("messages"))
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy, alt))
images = ImageResponse(images, chunk.alt) images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images)) yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData): elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.get_dict()) yield self._format_json("synthesize", chunk.get_dict())
elif isinstance(chunk, TitleGeneration): elif isinstance(chunk, TitleGeneration):
@@ -203,7 +204,11 @@ class Api:
elif isinstance(chunk, Usage): elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict()) yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning): elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status) yield self._format_json("reasoning", token=chunk.token, status=chunk.status, is_thinking=chunk.is_thinking)
elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.get_dict())
elif isinstance(chunk, Notification):
yield self._format_json("notification", chunk.message)
else: else:
yield self._format_json("content", str(chunk)) yield self._format_json("content", str(chunk))
if debug.logs: if debug.logs:
@@ -219,6 +224,15 @@ class Api:
yield self._format_json('error', type(e).__name__, message=get_error_message(e)) yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _format_json(self, response_type: str, content = None, **kwargs): def _format_json(self, response_type: str, content = None, **kwargs):
# Make sure it get be formated as JSON
if content is not None and not isinstance(content, (str, dict)):
content = str(content)
kwargs = {
key: value
if value is isinstance(value, (str, dict))
else str(value)
for key, value in kwargs.items()
if isinstance(key, str)}
if content is not None: if content is not None:
return { return {
'type': response_type, 'type': response_type,

View File

@@ -7,6 +7,8 @@ import time
import uuid import uuid
import base64 import base64
import asyncio import asyncio
import hashlib
from urllib.parse import quote_plus
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from aiohttp import ClientSession, ClientError from aiohttp import ClientSession, ClientError
@@ -239,10 +241,16 @@ def to_data_uri(image: ImageType) -> str:
def ensure_images_dir(): def ensure_images_dir():
os.makedirs(images_dir, exist_ok=True) os.makedirs(images_dir, exist_ok=True)
def get_image_extension(image: str) -> str:
if match := re.search(r"(\.(?:jpe?g|png|webp))[$?&]", image):
return match.group(1)
return ".jpg"
async def copy_images( async def copy_images(
images: list[str], images: list[str],
cookies: Optional[Cookies] = None, cookies: Optional[Cookies] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
alt: str = None,
add_url: bool = True, add_url: bool = True,
target: str = None, target: str = None,
ssl: bool = None ssl: bool = None
@@ -256,7 +264,10 @@ async def copy_images(
) as session: ) as session:
async def copy_image(image: str, target: str = None) -> str: async def copy_image(image: str, target: str = None) -> str:
if target is None or len(images) > 1: if target is None or len(images) > 1:
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") hash = hashlib.sha256(image.encode()).hexdigest()
target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4())
target = f"{int(time.time())}_{target}{get_image_extension(image)}"
target = os.path.join(images_dir, target)
try: try:
if image.startswith("data:"): if image.startswith("data:"):
with open(target, "wb") as f: with open(target, "wb") as f:

View File

@@ -69,9 +69,13 @@ def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator
loop.close() loop.close()
# Helper function to convert a synchronous iterator to an async iterator # Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator: Iterator) -> AsyncIterator: async def to_async_iterator(iterator) -> AsyncIterator:
try: if hasattr(iterator, '__aiter__'):
async for item in iterator: async for item in iterator:
yield item yield item
return
try:
for item in iterator:
yield item
except TypeError: except TypeError:
yield await iterator yield await iterator

View File

@@ -27,6 +27,23 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
return formatted return formatted
return f"{formatted}\nAssistant:" return f"{formatted}\nAssistant:"
def get_last_user_message(messages: Messages) -> str:
user_messages = []
last_message = None if len(messages) == 0 else messages[-1]
while last_message is not None and messages:
last_message = messages.pop()
if last_message["role"] == "user":
if isinstance(last_message["content"], str):
user_messages.append(last_message["content"].strip())
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])
def format_image_prompt(messages, prompt: str = None) -> str:
if prompt is None:
return get_last_user_message(messages)
return prompt
def format_prompt_max_length(messages: Messages, max_lenght: int) -> str: def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
prompt = format_prompt(messages) prompt = format_prompt(messages)
start = len(prompt) start = len(prompt)

View File

@@ -88,44 +88,61 @@ class JsonMixin:
def reset(self): def reset(self):
self.__dict__ = {} self.__dict__ = {}
class FinishReason(ResponseType, JsonMixin): class HiddenResponse(ResponseType):
def __str__(self) -> str:
return ""
class FinishReason(JsonMixin, HiddenResponse):
def __init__(self, reason: str) -> None: def __init__(self, reason: str) -> None:
self.reason = reason self.reason = reason
def __str__(self) -> str: class ToolCalls(HiddenResponse):
return ""
class ToolCalls(ResponseType):
def __init__(self, list: list): def __init__(self, list: list):
self.list = list self.list = list
def __str__(self) -> str:
return ""
def get_list(self) -> list: def get_list(self) -> list:
return self.list return self.list
class Usage(ResponseType, JsonMixin): class Usage(JsonMixin, HiddenResponse):
def __str__(self) -> str: pass
return ""
class AuthResult(JsonMixin): class AuthResult(JsonMixin, HiddenResponse):
def __str__(self) -> str: pass
return ""
class TitleGeneration(ResponseType): class TitleGeneration(HiddenResponse):
def __init__(self, title: str) -> None: def __init__(self, title: str) -> None:
self.title = title self.title = title
class DebugResponse(JsonMixin, HiddenResponse):
@classmethod
def from_dict(cls, data: dict) -> None:
return cls(**data)
@classmethod
def from_str(cls, data: str) -> None:
return cls(error=data)
class Notification(ResponseType):
def __init__(self, message: str) -> None:
self.message = message
def __str__(self) -> str: def __str__(self) -> str:
return "" return f"{self.message}\n"
class Reasoning(ResponseType): class Reasoning(ResponseType):
def __init__(self, token: str = None, status: str = None) -> None: def __init__(
self,
token: str = None,
status: str = None,
is_thinking: str = None
) -> None:
self.token = token self.token = token
self.status = status self.status = status
self.is_thinking = is_thinking
def __str__(self) -> str: def __str__(self) -> str:
if self.is_thinking is not None:
return self.is_thinking
return f"{self.status}\n" if self.token is None else self.token return f"{self.status}\n" if self.token is None else self.token
class Sources(ResponseType): class Sources(ResponseType):
@@ -154,14 +171,11 @@ class BaseConversation(ResponseType):
class JsonConversation(BaseConversation, JsonMixin): class JsonConversation(BaseConversation, JsonMixin):
pass pass
class SynthesizeData(ResponseType, JsonMixin): class SynthesizeData(HiddenResponse, JsonMixin):
def __init__(self, provider: str, data: dict): def __init__(self, provider: str, data: dict):
self.provider = provider self.provider = provider
self.data = data self.data = data
def __str__(self) -> str:
return ""
class RequestLogin(ResponseType): class RequestLogin(ResponseType):
def __init__(self, label: str, login_url: str) -> None: def __init__(self, label: str, login_url: str) -> None:
self.label = label self.label = label
@@ -197,13 +211,10 @@ class ImagePreview(ImageResponse):
def to_string(self): def to_string(self):
return super().__str__() return super().__str__()
class PreviewResponse(ResponseType): class PreviewResponse(HiddenResponse):
def __init__(self, data: str): def __init__(self, data: str):
self.data = data self.data = data
def __str__(self):
return ""
def to_string(self): def to_string(self):
return self.data return self.data
@@ -211,6 +222,5 @@ class Parameters(ResponseType, JsonMixin):
def __str__(self): def __str__(self):
return "" return ""
class ProviderInfo(ResponseType, JsonMixin): class ProviderInfo(JsonMixin, HiddenResponse):
def __str__(self): pass
return ""

View File

@@ -87,7 +87,7 @@ async def get_args_from_nodriver(
callback: callable = None, callback: callable = None,
cookies: Cookies = None cookies: Cookies = None
) -> dict: ) -> dict:
browser = await get_nodriver(proxy=proxy, timeout=timeout) browser, stop_browser = await get_nodriver(proxy=proxy, timeout=timeout)
try: try:
if debug.logging: if debug.logging:
print(f"Open nodriver with url: {url}") print(f"Open nodriver with url: {url}")
@@ -117,7 +117,7 @@ async def get_args_from_nodriver(
"proxy": proxy, "proxy": proxy,
} }
finally: finally:
browser.stop() stop_browser()
def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies: def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
if cookies is None: if cookies is None:
@@ -170,11 +170,10 @@ async def get_nodriver(
browser = util.get_registered_instances().pop() browser = util.get_registered_instances().pop()
else: else:
raise raise
stop = browser.stop
def on_stop(): def on_stop():
try: try:
stop() if browser.connection:
browser.stop()
finally: finally:
lock_file.unlink(missing_ok=True) lock_file.unlink(missing_ok=True)
browser.stop = on_stop return browser, on_stop
return browser

View File

@@ -3,12 +3,14 @@ from __future__ import annotations
import re import re
import json import json
import asyncio import asyncio
import time
from pathlib import Path from pathlib import Path
from typing import Optional, Callable, AsyncIterator from typing import Optional, Callable, AsyncIterator
from ..typing import Messages from ..typing import Messages
from ..providers.helper import filter_none from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator from ..providers.asyncio import to_async_iterator
from ..providers.response import Reasoning
from ..providers.types import ProviderType from ..providers.types import ProviderType
from ..cookies import get_cookies_dir from ..cookies import get_cookies_dir
from .web_search import do_search, get_search_message from .web_search import do_search, get_search_message
@@ -147,4 +149,26 @@ def iter_run_tools(
if has_bucket and isinstance(messages[-1]["content"], str): if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS messages[-1]["content"] += BUCKET_INSTRUCTIONS
return iter_callback(model=model, messages=messages, provider=provider, **kwargs) is_thinking = 0
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
if not isinstance(chunk, str):
yield chunk
continue
if "<think>" in chunk:
chunk = chunk.split("<think>", 1)
yield chunk[0]
yield Reasoning(is_thinking="<think>")
yield Reasoning(chunk[1])
yield Reasoning(None, "Is thinking...")
is_thinking = time.time()
if "</think>" in chunk:
chunk = chunk.split("</think>", 1)
yield Reasoning(chunk[0])
yield Reasoning(is_thinking="</think>")
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
yield chunk[1]
is_thinking = 0
elif is_thinking:
yield Reasoning(chunk)
else:
yield chunk

View File

@@ -192,6 +192,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
return SearchResults(formatted_results, used_words) return SearchResults(formatted_results, used_words)
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str: async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
if instructions and instructions in prompt:
return prompt # We have already added search results
if query is None: if query is None:
query = spacy_get_keywords(prompt) query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore") json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")