Support reasoning tokens by default

Add new default HuggingFace provider
Add format_image_prompt and get_last_user_message helper
Add stop_browser callable to get_nodriver function
Fix content type response in images route
This commit is contained in:
hlohaus
2025-01-31 17:36:48 +01:00
parent ce7e9b03a5
commit 89e096334d
41 changed files with 377 additions and 240 deletions

View File

@@ -13,7 +13,7 @@ from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri
from ..cookies import get_cookies_dir
from .helper import format_prompt
from .helper import format_prompt, format_image_prompt
from ..providers.response import FinishReason, JsonConversation, Reasoning
class Conversation(JsonConversation):
@@ -216,9 +216,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=headers) as session:
if model == "ImageGeneration2":
prompt = messages[-1]["content"]
data = {
"query": prompt,
"query": format_image_prompt(messages, prompt),
"agentMode": True
}
headers['content-type'] = 'text/plain;charset=UTF-8'
@@ -307,8 +306,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image_url_match = re.search(r'!\[.*?\]\((.*?)\)', text_to_yield)
if image_url_match:
image_url = image_url_match.group(1)
prompt = messages[-1]["content"]
yield ImageResponse(images=[image_url], alt=prompt)
yield ImageResponse(image_url, format_image_prompt(messages, prompt))
else:
if "<think>" in text_to_yield and "</think>" in chunk_text :
chunk_text = text_to_yield.split('<think>', 1)

View File

@@ -28,6 +28,7 @@ from ..providers.response import BaseConversation, JsonConversation, RequestLogi
from ..providers.asyncio import get_running_loop
from ..requests import get_nodriver
from ..image import ImageResponse, to_bytes, is_accepted_format
from .helper import get_last_user_message
from .. import debug
class Conversation(JsonConversation):
@@ -139,7 +140,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
else:
conversation_id = conversation.conversation_id
if prompt is None:
prompt = messages[-1]["content"]
prompt = get_last_user_message(messages)
debug.log(f"Copilot: Use conversation: {conversation_id}")
uploaded_images = []
@@ -206,7 +207,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}})
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
try:
page = await browser.get(url)
access_token = None
@@ -233,7 +234,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None, target: str
await page.close()
return access_token, cookies
finally:
browser.stop()
stop_browser()
def readHAR(url: str):
api_key = None

View File

@@ -9,7 +9,6 @@ class OIVSCode(OpenaiTemplate):
working = True
needs_auth = False
default_model = "gpt-4o-mini-2024-07-18"
default_model = "gpt-4o-mini"
default_vision_model = default_model
vision_models = [default_model, "gpt-4o-mini"]
model_aliases = {"gpt-4o-mini": "gpt-4o-mini-2024-07-18"}

View File

@@ -5,7 +5,7 @@ import json
from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status
from ..providers.response import Reasoning, FinishReason
from ..providers.response import FinishReason
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
API_URL = "https://www.perplexity.ai/socket.io/"
@@ -87,22 +87,7 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
continue
try:
data = json.loads(message[2:])[1]
new_content = data["output"][last_message:]
if "<think>" in new_content:
yield Reasoning(None, "thinking")
is_thinking = True
if "</think>" in new_content:
new_content = new_content.split("</think>", 1)
yield Reasoning(f"{new_content[0]}</think>")
yield Reasoning(None, "finished")
yield new_content[1]
is_thinking = False
elif is_thinking:
yield Reasoning(new_content)
else:
yield new_content
yield data["output"][last_message:]
last_message = len(data["output"])
if data["final"]:
yield FinishReason("stop")

View File

@@ -7,7 +7,7 @@ from urllib.parse import quote_plus
from typing import Optional
from aiohttp import ClientSession
from .helper import filter_none
from .helper import filter_none, format_image_prompt
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages, ImagesType
from ..image import to_data_uri
@@ -127,7 +127,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if model in cls.image_models:
yield await cls._generate_image(
model=model,
prompt=messages[-1]["content"] if prompt is None else prompt,
prompt=format_image_prompt(messages, prompt),
proxy=proxy,
width=width,
height=height,

View File

@@ -77,7 +77,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
except MissingRequirementsError:
pass
if not cookies or "afUserId" not in cookies:
browser = await get_nodriver(proxy=proxy)
browser, stop_browser = await get_nodriver(proxy=proxy)
try:
page = await browser.get(cls.url)
await page.wait_for('[data-testid="user-profile-button"]', timeout=900)
@@ -86,7 +86,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
cookies[c.name] = c.value
await page.close()
finally:
browser.stop()
stop_browser()
async with StreamSession(
proxy=proxy,
impersonate="chrome",

View File

@@ -9,6 +9,7 @@ from .deprecated import *
from .needs_auth import *
from .not_working import *
from .local import *
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference
from .hf_space import HuggingSpace
from .mini_max import HailuoAI, MiniMax
from .template import OpenaiTemplate, BackendApi

View File

@@ -14,7 +14,7 @@ except ImportError:
has_curl_cffi = False
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ..helper import format_prompt
from ..helper import format_prompt, format_image_prompt, get_last_user_message
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
from ...image import to_bytes
@@ -22,6 +22,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
from ...requests.raise_for_status import raise_for_status
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
from ...cookies import get_cookies
from .models import default_model, fallback_models, image_models, model_aliases
from ... import debug
class Conversation(JsonConversation):
@@ -35,47 +36,9 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
use_nodriver = True
supports_stream = True
needs_auth = True
default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
'Qwen/QwQ-32B-Preview',
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
'Qwen/Qwen2.5-Coder-32B-Instruct',
'meta-llama/Llama-3.2-11B-Vision-Instruct',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
] + image_models
model_aliases = {
### Chat ###
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"qwq-32b": "Qwen/QwQ-32B-Preview",
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
### Image ###
"flux-dev": "black-forest-labs/FLUX.1-dev",
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
### Used in other providers ###
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
"gemma-2-27b": "google/gemma-2-27b-it",
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
"qvq-72b": "Qwen/QVQ-72B-Preview",
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
}
default_model = default_model
model_aliases = model_aliases
image_models = image_models
@classmethod
def get_models(cls):
@@ -94,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
cls.models = [*cls.fallback_models]
cls.models = [*fallback_models]
return cls.models
@classmethod
@@ -108,9 +71,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
headers=DEFAULT_HEADERS
)
return
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield RequestLogin(cls.__name__, login_url)
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
yield AuthResult(
**await get_args_from_nodriver(
cls.url,
@@ -143,6 +104,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
if model not in conversation.models:
conversationId = cls.create_conversation(session, model)
debug.log(f"Conversation created: {json.dumps(conversationId[8:] + '...')}")
messageId = cls.fetch_message_id(session, conversationId)
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation:
@@ -151,9 +113,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
else:
conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
inputs = messages[-1]["content"]
debug.log(f"Use: {json.dumps(conversation.models[model])}")
inputs = get_last_user_message(messages)
settings = {
"inputs": inputs,
@@ -204,8 +164,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
break
elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
prompt = messages[-1]["content"] if prompt is None else prompt
yield ImageResponse(url, alt=prompt, options={"cookies": auth_result.cookies})
yield ImageResponse(url, format_image_prompt(messages, prompt), options={"cookies": auth_result.cookies})
elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"])
elif line["type"] == "title":
@@ -226,6 +185,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
response = session.post('https://huggingface.co/chat/conversation', json=json_data)
if response.status_code == 401:
raise MissingAuthError(response.text)
if response.status_code == 400:
raise ResponseError(f"{response.text}: Model: {model}")
raise_for_status(response)
return response.json().get('conversationId')

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from ..template import OpenaiTemplate
from .HuggingChat import HuggingChat
from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases
from ...providers.types import Messages
from .HuggingChat import HuggingChat
from ... import debug
class HuggingFaceAPI(OpenaiTemplate):
@@ -16,13 +17,16 @@ class HuggingFaceAPI(OpenaiTemplate):
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
model_aliases = HuggingChat.model_aliases
model_aliases = model_aliases
@classmethod
def get_models(cls, **kwargs):
if not cls.models:
HuggingChat.get_models()
cls.models = list(set(HuggingChat.text_models + cls.vision_models))
cls.models = HuggingChat.text_models.copy()
for model in cls.vision_models:
if model not in cls.models:
cls.models.append(model)
return cls.models
@classmethod

View File

@@ -11,37 +11,32 @@ from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason
from ...image import ImageResponse
from ..helper import format_image_prompt
from .models import default_model, default_image_model, model_aliases, fallback_models
from ... import debug
from .HuggingChat import HuggingChat
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
working = True
supports_message_history = True
default_model = HuggingChat.default_model
default_image_model = HuggingChat.default_image_model
model_aliases = HuggingChat.model_aliases
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
]
default_model = default_model
default_image_model = default_image_model
model_aliases = model_aliases
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
models = fallback_models.copy()
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
models = [model["id"] for model in requests.get(url).json()]
models.extend(cls.extra_models)
models.sort()
extra_models = [model["id"] for model in requests.get(url).json()]
extra_models.sort()
models.extend([model for model in extra_models if model not in models])
if not cls.image_models:
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
cls.image_models.sort()
models.extend(cls.image_models)
cls.models = list(set(models))
models.extend([model for model in cls.image_models if model not in models])
cls.models = models
return cls.models
@classmethod
@@ -85,7 +80,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
payload = None
if cls.get_models() and model in cls.image_models:
stream = False
prompt = messages[-1]["content"] if prompt is None else prompt
prompt = format_image_prompt(messages, prompt)
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
else:
params = {

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import random
from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse
from ...errors import ModelNotSupportedError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .HuggingChat import HuggingChat
from .HuggingFaceAPI import HuggingFaceAPI
from .HuggingFaceInference import HuggingFaceInference
from ... import debug
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
working = True
supports_message_history = True
@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
cls.models = HuggingFaceInference.get_models()
cls.image_models = HuggingFaceInference.image_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
**kwargs
) -> AsyncResult:
if "api_key" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
try:
is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
if isinstance(chunk, (str, ImageResponse)):
is_started = True
yield chunk
if is_started:
return
except Exception as e:
if is_started:
raise e
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
if not cls.image_models:
cls.get_models()
if model in cls.image_models:
if "api_key" not in kwargs:
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):
yield chunk
else:
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk
return
try:
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
yield chunk
except ModelNotSupportedError:
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
yield chunk

46
g4f/Provider/hf/models.py Normal file
View File

@@ -0,0 +1,46 @@
default_model = "Qwen/Qwen2.5-72B-Instruct"
default_image_model = "black-forest-labs/FLUX.1-dev"
image_models = [
default_image_model,
"black-forest-labs/FLUX.1-schnell",
]
fallback_models = [
default_model,
'meta-llama/Llama-3.3-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
'Qwen/QwQ-32B-Preview',
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
'Qwen/Qwen2.5-Coder-32B-Instruct',
'meta-llama/Llama-3.2-11B-Vision-Instruct',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
] + image_models
model_aliases = {
### Chat ###
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"qwq-32b": "Qwen/QwQ-32B-Preview",
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
### Image ###
"flux": "black-forest-labs/FLUX.1-dev",
"flux-dev": "black-forest-labs/FLUX.1-dev",
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
### Used in other providers ###
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
"gemma-2-27b": "google/gemma-2-27b-it",
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
"qvq-72b": "Qwen/QVQ-72B-Preview",
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
}
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
]

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-dev.hf.space"
@@ -44,7 +45,7 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt
prompt = format_image_prompt(messages, prompt)
data = {
"data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
}

View File

@@ -8,6 +8,7 @@ from ...image import ImageResponse
from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-schnell.hf.space"
@@ -42,7 +43,7 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
height = max(32, height - (height % 8))
if prompt is None:
prompt = messages[-1]["content"]
prompt = format_image_prompt(messages)
payload = {
"data": [

View File

@@ -6,7 +6,7 @@ from aiohttp import ClientSession, FormData
from ...typing import AsyncResult, Messages
from ...requests import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt
from ..helper import format_prompt, get_last_user_message
from ...providers.response import JsonConversation, TitleGeneration
class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
@@ -58,7 +58,7 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
) as session:
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
messages = [message for message in messages if message["role"] != "system"]
inputs = format_prompt(messages) if conversation is None else messages[-1]["content"]
inputs = format_prompt(messages) if conversation is None else get_last_user_message(messages)
if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
data = {"model": model, "preprompt": system_prompt}
async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
@@ -78,7 +78,6 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
data = node["data"]
message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
data = FormData()
inputs = messages[-1]["content"]
data.add_field(
"data",
json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),

View File

@@ -8,8 +8,8 @@ import urllib.parse
from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt
from ...providers.response import JsonConversation, ImageResponse
from ..helper import format_prompt, format_image_prompt
from ...providers.response import JsonConversation, ImageResponse, Notification
from ...requests.aiohttp import StreamSession, StreamResponse
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies
@@ -38,7 +38,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
"headers": {
"content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.uuid,
"x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer,
},
"json": {"data":[None,prompt,42,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash},
@@ -48,7 +48,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
"headers": {
"content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.uuid,
"x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer,
},
"json": {"data":[prompt,1234,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash},
@@ -82,33 +82,14 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
method = "image"
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
prompt = messages[-1]["content"] if prompt is None else prompt
prompt = format_image_prompt(messages, prompt)
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
user_uuid = None if conversation is None else getattr(conversation, "user_uuid", None)
zerogpu_token = "[object Object]"
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
if cookies:
# Get current UTC time + 10 minutes
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
encoded_dt = urllib.parse.quote(dt)
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
zerogpu_token = (await response.json())
zerogpu_token = zerogpu_token["token"]
if user_uuid is None:
async with session.get(cls.url, cookies=cookies) as response:
match = re.search(r"&quot;token&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_token = match.group(1)
match = re.search(r"&quot;sessionUuid&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
user_uuid = match.group(1)
zerogpu_uuid, zerogpu_token = await get_zerogpu_token(session, conversation, cookies)
if conversation is None or not hasattr(conversation, "session_hash"):
conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, uuid=user_uuid)
conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, zerogpu_uuid=zerogpu_uuid)
conversation.zerogpu_token = zerogpu_token
if return_conversation:
yield conversation
@@ -124,7 +105,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
try:
json_data = json.loads(decoded_line[6:])
if json_data.get('msg') == 'log':
debug.log(json_data["log"])
yield Notification(json_data["log"])
if json_data.get('msg') == 'process_generating':
if 'output' in json_data and 'data' in json_data['output']:
@@ -142,3 +123,26 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
except json.JSONDecodeError:
debug.log("Could not parse JSON:", decoded_line)
async def get_zerogpu_token(session: StreamSession, conversation: JsonConversation, cookies: Cookies = None):
zerogpu_uuid = None if conversation is None else getattr(conversation, "zerogpu_uuid", None)
zerogpu_token = "[object Object]"
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
if zerogpu_uuid is None:
async with session.get(Janus_Pro_7B.url, cookies=cookies) as response:
match = re.search(r"&quot;token&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_token = match.group(1)
match = re.search(r"&quot;sessionUuid&quot;:&quot;([^&]+?)&quot;", await response.text())
if match:
zerogpu_uuid = match.group(1)
if cookies:
# Get current UTC time + 10 minutes
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
encoded_dt = urllib.parse.quote(dt)
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
zerogpu_token = (await response.json())
zerogpu_token = zerogpu_token["token"]
return zerogpu_uuid, zerogpu_token

View File

@@ -8,6 +8,7 @@ from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt
from ...providers.response import JsonConversation, Reasoning
from ..helper import get_last_user_message
from ... import debug
class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
@@ -41,7 +42,7 @@ class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
if return_conversation:
yield JsonConversation(session_hash=session_hash)
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
headers = {
'accept': '*/*',

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse, ImagePreview
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
@@ -42,7 +43,7 @@ class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt
prompt = format_image_prompt(messages, prompt)
data = {
"data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
}

View File

@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
from ...image import ImageResponse
from ...errors import ResponseError
from ...requests.raise_for_status import raise_for_status
from ..helper import format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
@@ -37,10 +38,7 @@ class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
width = max(32, width - (width % 8))
height = max(32, height - (height % 8))
if prompt is None:
prompt = messages[-1]["content"]
prompt = format_image_prompt(messages, prompt)
payload = {
"data": [
prompt,

View File

@@ -10,6 +10,7 @@ from ..base_provider import AsyncAuthedProvider, ProviderModelMixin, format_prom
from ..mini_max.crypt import CallbackResults, get_browser_callback, generate_yy_header, get_body_to_yy
from ...requests import get_args_from_nodriver, raise_for_status
from ...providers.response import AuthResult, JsonConversation, RequestLogin, TitleGeneration
from ..helper import get_last_user_message
from ... import debug
class Conversation(JsonConversation):
@@ -62,7 +63,7 @@ class HailuoAI(AsyncAuthedProvider, ProviderModelMixin):
conversation = None
form_data = {
"characterID": 1 if conversation is None else getattr(conversation, "characterID", 1),
"msgContent": format_prompt(messages) if conversation is None else messages[-1]["content"],
"msgContent": format_prompt(messages) if conversation is None else get_last_user_message(messages),
"chatID": 0 if conversation is None else getattr(conversation, "chatID", 0),
"searchMode": 0
}

View File

@@ -98,7 +98,7 @@ class Anthropic(OpenaiAPI):
"text": messages[-1]["content"]
}
]
system = "\n".join([message for message in messages if message.get("role") == "system"])
system = "\n".join([message["content"] for message in messages if message.get("role") == "system"])
if system:
messages = [message for message in messages if message.get("role") != "system"]
else:

View File

@@ -6,6 +6,7 @@ from ...errors import MissingAuthError
from ...typing import AsyncResult, Messages, Cookies
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..bing.create_images import create_images, create_session
from ..helper import format_image_prompt
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
label = "Microsoft Designer in Bing"
@@ -35,7 +36,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
session = BingCreateImages(cookies, proxy, api_key)
yield await session.generate(messages[-1]["content"] if prompt is None else prompt)
yield await session.generate(format_image_prompt(messages, prompt))
async def generate(self, prompt: str) -> ImageResponse:
"""

View File

@@ -5,6 +5,7 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...image import ImageResponse
from ..template import OpenaiTemplate
from ..helper import format_image_prompt
class DeepInfra(OpenaiTemplate):
url = "https://deepinfra.com"
@@ -55,7 +56,7 @@ class DeepInfra(OpenaiTemplate):
) -> AsyncResult:
if model in cls.get_image_models():
yield cls.create_async_image(
messages[-1]["content"] if prompt is None else prompt,
format_image_prompt(messages, prompt),
model,
**kwargs
)

View File

@@ -25,6 +25,7 @@ from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ...errors import MissingAuthError
from ...image import ImageResponse, to_bytes
from ..helper import get_last_user_message
from ... import debug
REQUEST_HEADERS = {
@@ -78,7 +79,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
if debug.logging:
print("Skip nodriver login in Gemini provider")
return
browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
try:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
@@ -91,7 +92,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
await page.close()
cls._cookies = cookies
finally:
browser.stop()
stop_browser()
@classmethod
async def create_async_generator(
@@ -107,7 +108,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
language: str = "en",
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy)

View File

@@ -7,7 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConv
from ...typing import AsyncResult, Messages, Cookies
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...providers.helper import format_prompt
from ...providers.helper import format_prompt, get_last_user_message
from ...cookies import get_cookies
class Conversation(BaseConversation):
@@ -78,7 +78,7 @@ class GithubCopilot(AsyncGeneratorProvider, ProviderModelMixin):
conversation_id = (await response.json()).get("thread_id")
if return_conversation:
yield Conversation(conversation_id)
content = messages[-1]["content"]
content = get_last_user_message(messages)
else:
content = format_prompt(messages)
json_data = {

View File

@@ -14,7 +14,7 @@ from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ..Copilot import get_headers, get_har_files
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_random_hex
from ..helper import get_random_hex, format_image_prompt
from ... import debug
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
@@ -39,7 +39,7 @@ class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
image_size = "1024x1024"
if model != cls.default_image_model and model in cls.image_models:
image_size = model
yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy)
yield await cls.generate(format_image_prompt(messages, prompt), image_size, proxy)
@classmethod
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
@@ -143,7 +143,7 @@ def readHAR(url: str) -> tuple[str, str]:
return api_key, user_agent
async def get_access_token_and_user_agent(url: str, proxy: str = None):
browser = await get_nodriver(proxy=proxy, user_data_dir="designer")
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="designer")
try:
page = await browser.get(url)
user_agent = await page.evaluate("navigator.userAgent")
@@ -168,4 +168,4 @@ async def get_access_token_and_user_agent(url: str, proxy: str = None):
await page.close()
return access_token, user_agent
finally:
browser.stop()
stop_browser()

View File

@@ -98,7 +98,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
default_model = "auto"
default_image_model = "dall-e-3"
image_models = [default_image_model]
text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1", "o1-preview", "o1-mini"]
text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "o1", "o1-preview", "o1-mini"]
vision_models = text_models
models = text_models + image_models
synthesize_content_type = "audio/mpeg"
@@ -598,7 +598,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
@classmethod
async def nodriver_auth(cls, proxy: str = None):
browser = await get_nodriver(proxy=proxy)
browser, stop_browser = await get_nodriver(proxy=proxy)
try:
page = browser.main_tab
def on_request(event: nodriver.cdp.network.RequestWillBeSent):
@@ -648,7 +648,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
cls._set_api_key(cls._api_key)
finally:
browser.stop()
stop_browser()
@staticmethod
def get_default_headers() -> Dict[str, str]:

View File

@@ -13,9 +13,6 @@ from .GigaChat import GigaChat
from .GithubCopilot import GithubCopilot
from .GlhfChat import GlhfChat
from .Groq import Groq
from .HuggingChat import HuggingChat
from .HuggingFace import HuggingFace
from .HuggingFaceAPI import HuggingFaceAPI
from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount
from .MicrosoftDesigner import MicrosoftDesigner

View File

@@ -8,11 +8,11 @@ from urllib.parse import quote_plus
from ...typing import Messages, AsyncResult
from ...requests import StreamSession
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...providers.response import ProviderInfo, JsonConversation, PreviewResponse, SynthesizeData, TitleGeneration, RequestLogin
from ...providers.response import Parameters, FinishReason, Usage, Reasoning
from ...providers.response import *
from ...image import get_image_extension
from ...errors import ModelNotSupportedError
from ..needs_auth.OpenaiAccount import OpenaiAccount
from ..needs_auth.HuggingChat import HuggingChat
from ..hf.HuggingChat import HuggingChat
from ... import debug
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
@@ -98,8 +98,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
yield PreviewResponse(data[data_type])
elif data_type == "content":
def on_image(match):
extension = match.group(3).split(".")[-1].split("?")[0]
extension = "" if not extension or len(extension) > 4 else f".{extension}"
extension = get_image_extension(match.group(3))
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}"
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}"
return f"[![{match.group(1)}]({download_url})](/images/{filename})"
@@ -119,6 +118,6 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
elif data_type == "finish":
yield FinishReason(data[data_type]["reason"])
elif data_type == "log":
debug.log(data[data_type])
yield DebugResponse.from_dict(data[data_type])
else:
debug.log(f"Unknown data: ({data_type}) {data}")
yield DebugResponse.from_dict(data)

View File

@@ -4,11 +4,11 @@ import json
import time
import requests
from ..helper import filter_none
from ..helper import filter_none, format_image_prompt
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage, Reasoning, ImageResponse
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
from ... import debug
@@ -82,7 +82,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
# Proxy for image generation feature
if model and model in cls.image_models:
data = {
"prompt": messages[-1]["content"] if prompt is None else prompt,
"prompt": format_image_prompt(messages, prompt),
"model": model,
}
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
@@ -154,17 +154,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
delta = delta.lstrip()
if delta:
first = False
if is_thinking:
if "</think>" in delta:
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
is_thinking = 0
else:
yield Reasoning(delta)
elif "<think>" in delta:
is_thinking = time.time()
yield Reasoning(None, "Is thinking...")
else:
yield delta
yield delta
if "usage" in data and data["usage"]:
yield Usage(**data["usage"])
if "finish_reason" in choice and choice["finish_reason"] is not None:

View File

@@ -6,6 +6,7 @@ import uvicorn
import secrets
import os
import shutil
import time
from email.utils import formatdate
import os.path
import hashlib
@@ -539,6 +540,10 @@ class Api:
content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
return StreamingResponse(response_data, media_type=content_type)
@self.app.get("/json/{filename}")
async def get_json(filename, request: Request):
return ""
@self.app.get("/images/{filename}", response_class=FileResponse, responses={
HTTP_200_OK: {"content": {"image/*": {}}},
HTTP_404_NOT_FOUND: {}
@@ -550,15 +555,18 @@ class Api:
stat_result.st_size = 0
if os.path.isfile(target):
stat_result.st_size = os.stat(target).st_size
stat_result.st_mtime = int(f"{filename.split('_')[0]}")
stat_result.st_mtime = int(f"{filename.split('_')[0]}") if filename.startswith("1") else 0
headers = {
"cache-control": "public, max-age=31536000",
"content-type": f"image/{ext.replace('jpg', 'jepg')}",
"content-length": str(stat_result.st_size),
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"',
}
response = FileResponse(
target,
media_type=f"image/{ext.replace('jpg', 'jepg')}",
headers={
"content-length": str(stat_result.st_size),
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"'
},
headers=headers,
filename=filename,
)
try:
if_none_match = request.headers["if-none-match"]

View File

@@ -8,11 +8,12 @@ try:
except ImportError as e:
import_error = e
def get_gui_app(demo: bool = False):
def get_gui_app(demo: bool = False, api: bool = False):
if import_error is not None:
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
app = create_app()
app.demo = demo
app.api = api
site = Website(app)
for route in site.routes:

View File

@@ -13,8 +13,8 @@ from ...tools.run_tools import iter_run_tools
from ...Provider import ProviderUtils, __providers__
from ...providers.base_provider import ProviderModelMixin
from ...providers.retry_provider import BaseRetryProvider
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning, PreviewResponse
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters, ProviderInfo
from ...providers.helper import format_image_prompt
from ...providers.response import *
from ... import version, models
from ... import ChatCompletion, get_model_and_provider
from ... import debug
@@ -183,13 +183,14 @@ class Api:
logger.exception(chunk)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, (PreviewResponse, ImagePreview)):
yield self._format_json("preview", chunk.to_string())
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
elif isinstance(chunk, ImageResponse):
images = chunk
if download_images:
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy))
if download_images or chunk.get("cookies"):
alt = format_image_prompt(kwargs.get("messages"))
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy, alt))
images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images))
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.get_dict())
elif isinstance(chunk, TitleGeneration):
@@ -203,7 +204,11 @@ class Api:
elif isinstance(chunk, Usage):
yield self._format_json("usage", chunk.get_dict())
elif isinstance(chunk, Reasoning):
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
yield self._format_json("reasoning", token=chunk.token, status=chunk.status, is_thinking=chunk.is_thinking)
elif isinstance(chunk, DebugResponse):
yield self._format_json("log", chunk.get_dict())
elif isinstance(chunk, Notification):
yield self._format_json("notification", chunk.message)
else:
yield self._format_json("content", str(chunk))
if debug.logs:
@@ -219,6 +224,15 @@ class Api:
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _format_json(self, response_type: str, content = None, **kwargs):
# Make sure it get be formated as JSON
if content is not None and not isinstance(content, (str, dict)):
content = str(content)
kwargs = {
key: value
if value is isinstance(value, (str, dict))
else str(value)
for key, value in kwargs.items()
if isinstance(key, str)}
if content is not None:
return {
'type': response_type,

View File

@@ -7,6 +7,8 @@ import time
import uuid
import base64
import asyncio
import hashlib
from urllib.parse import quote_plus
from io import BytesIO
from pathlib import Path
from aiohttp import ClientSession, ClientError
@@ -239,10 +241,16 @@ def to_data_uri(image: ImageType) -> str:
def ensure_images_dir():
os.makedirs(images_dir, exist_ok=True)
def get_image_extension(image: str) -> str:
if match := re.search(r"(\.(?:jpe?g|png|webp))[$?&]", image):
return match.group(1)
return ".jpg"
async def copy_images(
images: list[str],
cookies: Optional[Cookies] = None,
proxy: Optional[str] = None,
alt: str = None,
add_url: bool = True,
target: str = None,
ssl: bool = None
@@ -256,7 +264,10 @@ async def copy_images(
) as session:
async def copy_image(image: str, target: str = None) -> str:
if target is None or len(images) > 1:
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
hash = hashlib.sha256(image.encode()).hexdigest()
target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4())
target = f"{int(time.time())}_{target}{get_image_extension(image)}"
target = os.path.join(images_dir, target)
try:
if image.startswith("data:"):
with open(target, "wb") as f:

View File

@@ -69,9 +69,13 @@ def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator
loop.close()
# Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
try:
async def to_async_iterator(iterator) -> AsyncIterator:
if hasattr(iterator, '__aiter__'):
async for item in iterator:
yield item
return
try:
for item in iterator:
yield item
except TypeError:
yield await iterator

View File

@@ -27,6 +27,23 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
return formatted
return f"{formatted}\nAssistant:"
def get_last_user_message(messages: Messages) -> str:
user_messages = []
last_message = None if len(messages) == 0 else messages[-1]
while last_message is not None and messages:
last_message = messages.pop()
if last_message["role"] == "user":
if isinstance(last_message["content"], str):
user_messages.append(last_message["content"].strip())
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])
def format_image_prompt(messages, prompt: str = None) -> str:
if prompt is None:
return get_last_user_message(messages)
return prompt
def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
prompt = format_prompt(messages)
start = len(prompt)

View File

@@ -88,44 +88,61 @@ class JsonMixin:
def reset(self):
self.__dict__ = {}
class FinishReason(ResponseType, JsonMixin):
class HiddenResponse(ResponseType):
def __str__(self) -> str:
return ""
class FinishReason(JsonMixin, HiddenResponse):
def __init__(self, reason: str) -> None:
self.reason = reason
def __str__(self) -> str:
return ""
class ToolCalls(ResponseType):
class ToolCalls(HiddenResponse):
def __init__(self, list: list):
self.list = list
def __str__(self) -> str:
return ""
def get_list(self) -> list:
return self.list
class Usage(ResponseType, JsonMixin):
def __str__(self) -> str:
return ""
class Usage(JsonMixin, HiddenResponse):
pass
class AuthResult(JsonMixin):
def __str__(self) -> str:
return ""
class AuthResult(JsonMixin, HiddenResponse):
pass
class TitleGeneration(ResponseType):
class TitleGeneration(HiddenResponse):
def __init__(self, title: str) -> None:
self.title = title
class DebugResponse(JsonMixin, HiddenResponse):
@classmethod
def from_dict(cls, data: dict) -> None:
return cls(**data)
@classmethod
def from_str(cls, data: str) -> None:
return cls(error=data)
class Notification(ResponseType):
def __init__(self, message: str) -> None:
self.message = message
def __str__(self) -> str:
return ""
return f"{self.message}\n"
class Reasoning(ResponseType):
def __init__(self, token: str = None, status: str = None) -> None:
def __init__(
self,
token: str = None,
status: str = None,
is_thinking: str = None
) -> None:
self.token = token
self.status = status
self.is_thinking = is_thinking
def __str__(self) -> str:
if self.is_thinking is not None:
return self.is_thinking
return f"{self.status}\n" if self.token is None else self.token
class Sources(ResponseType):
@@ -154,14 +171,11 @@ class BaseConversation(ResponseType):
class JsonConversation(BaseConversation, JsonMixin):
pass
class SynthesizeData(ResponseType, JsonMixin):
class SynthesizeData(HiddenResponse, JsonMixin):
def __init__(self, provider: str, data: dict):
self.provider = provider
self.data = data
def __str__(self) -> str:
return ""
class RequestLogin(ResponseType):
def __init__(self, label: str, login_url: str) -> None:
self.label = label
@@ -197,13 +211,10 @@ class ImagePreview(ImageResponse):
def to_string(self):
return super().__str__()
class PreviewResponse(ResponseType):
class PreviewResponse(HiddenResponse):
def __init__(self, data: str):
self.data = data
def __str__(self):
return ""
def to_string(self):
return self.data
@@ -211,6 +222,5 @@ class Parameters(ResponseType, JsonMixin):
def __str__(self):
return ""
class ProviderInfo(ResponseType, JsonMixin):
def __str__(self):
return ""
class ProviderInfo(JsonMixin, HiddenResponse):
pass

View File

@@ -87,7 +87,7 @@ async def get_args_from_nodriver(
callback: callable = None,
cookies: Cookies = None
) -> dict:
browser = await get_nodriver(proxy=proxy, timeout=timeout)
browser, stop_browser = await get_nodriver(proxy=proxy, timeout=timeout)
try:
if debug.logging:
print(f"Open nodriver with url: {url}")
@@ -117,7 +117,7 @@ async def get_args_from_nodriver(
"proxy": proxy,
}
finally:
browser.stop()
stop_browser()
def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
if cookies is None:
@@ -170,11 +170,10 @@ async def get_nodriver(
browser = util.get_registered_instances().pop()
else:
raise
stop = browser.stop
def on_stop():
try:
stop()
if browser.connection:
browser.stop()
finally:
lock_file.unlink(missing_ok=True)
browser.stop = on_stop
return browser
return browser, on_stop

View File

@@ -3,12 +3,14 @@ from __future__ import annotations
import re
import json
import asyncio
import time
from pathlib import Path
from typing import Optional, Callable, AsyncIterator
from ..typing import Messages
from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator
from ..providers.response import Reasoning
from ..providers.types import ProviderType
from ..cookies import get_cookies_dir
from .web_search import do_search, get_search_message
@@ -147,4 +149,26 @@ def iter_run_tools(
if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS
return iter_callback(model=model, messages=messages, provider=provider, **kwargs)
is_thinking = 0
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
if not isinstance(chunk, str):
yield chunk
continue
if "<think>" in chunk:
chunk = chunk.split("<think>", 1)
yield chunk[0]
yield Reasoning(is_thinking="<think>")
yield Reasoning(chunk[1])
yield Reasoning(None, "Is thinking...")
is_thinking = time.time()
if "</think>" in chunk:
chunk = chunk.split("</think>", 1)
yield Reasoning(chunk[0])
yield Reasoning(is_thinking="</think>")
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
yield chunk[1]
is_thinking = 0
elif is_thinking:
yield Reasoning(chunk)
else:
yield chunk

View File

@@ -192,6 +192,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
return SearchResults(formatted_results, used_words)
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
if instructions and instructions in prompt:
return prompt # We have already added search results
if query is None:
query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")