mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-28 02:21:39 +08:00
Support reasoning tokens by default
Add new default HuggingFace provider Add format_image_prompt and get_last_user_message helper Add stop_browser callable to get_nodriver function Fix content type response in images route
This commit is contained in:
@@ -13,7 +13,7 @@ from ..requests.raise_for_status import raise_for_status
|
|||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..image import ImageResponse, to_data_uri
|
from ..image import ImageResponse, to_data_uri
|
||||||
from ..cookies import get_cookies_dir
|
from ..cookies import get_cookies_dir
|
||||||
from .helper import format_prompt
|
from .helper import format_prompt, format_image_prompt
|
||||||
from ..providers.response import FinishReason, JsonConversation, Reasoning
|
from ..providers.response import FinishReason, JsonConversation, Reasoning
|
||||||
|
|
||||||
class Conversation(JsonConversation):
|
class Conversation(JsonConversation):
|
||||||
@@ -216,9 +216,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
if model == "ImageGeneration2":
|
if model == "ImageGeneration2":
|
||||||
prompt = messages[-1]["content"]
|
|
||||||
data = {
|
data = {
|
||||||
"query": prompt,
|
"query": format_image_prompt(messages, prompt),
|
||||||
"agentMode": True
|
"agentMode": True
|
||||||
}
|
}
|
||||||
headers['content-type'] = 'text/plain;charset=UTF-8'
|
headers['content-type'] = 'text/plain;charset=UTF-8'
|
||||||
@@ -307,8 +306,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
image_url_match = re.search(r'!\[.*?\]\((.*?)\)', text_to_yield)
|
image_url_match = re.search(r'!\[.*?\]\((.*?)\)', text_to_yield)
|
||||||
if image_url_match:
|
if image_url_match:
|
||||||
image_url = image_url_match.group(1)
|
image_url = image_url_match.group(1)
|
||||||
prompt = messages[-1]["content"]
|
yield ImageResponse(image_url, format_image_prompt(messages, prompt))
|
||||||
yield ImageResponse(images=[image_url], alt=prompt)
|
|
||||||
else:
|
else:
|
||||||
if "<think>" in text_to_yield and "</think>" in chunk_text :
|
if "<think>" in text_to_yield and "</think>" in chunk_text :
|
||||||
chunk_text = text_to_yield.split('<think>', 1)
|
chunk_text = text_to_yield.split('<think>', 1)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from ..providers.response import BaseConversation, JsonConversation, RequestLogi
|
|||||||
from ..providers.asyncio import get_running_loop
|
from ..providers.asyncio import get_running_loop
|
||||||
from ..requests import get_nodriver
|
from ..requests import get_nodriver
|
||||||
from ..image import ImageResponse, to_bytes, is_accepted_format
|
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||||
|
from .helper import get_last_user_message
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
class Conversation(JsonConversation):
|
class Conversation(JsonConversation):
|
||||||
@@ -139,7 +140,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||||||
else:
|
else:
|
||||||
conversation_id = conversation.conversation_id
|
conversation_id = conversation.conversation_id
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = messages[-1]["content"]
|
prompt = get_last_user_message(messages)
|
||||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||||
|
|
||||||
uploaded_images = []
|
uploaded_images = []
|
||||||
@@ -206,7 +207,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||||||
yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}})
|
yield Parameters(**{"cookies": {c.name: c.value for c in session.cookies.jar}})
|
||||||
|
|
||||||
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
|
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
|
||||||
browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
|
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
|
||||||
try:
|
try:
|
||||||
page = await browser.get(url)
|
page = await browser.get(url)
|
||||||
access_token = None
|
access_token = None
|
||||||
@@ -233,7 +234,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None, target: str
|
|||||||
await page.close()
|
await page.close()
|
||||||
return access_token, cookies
|
return access_token, cookies
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
|
|
||||||
def readHAR(url: str):
|
def readHAR(url: str):
|
||||||
api_key = None
|
api_key = None
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ class OIVSCode(OpenaiTemplate):
|
|||||||
working = True
|
working = True
|
||||||
needs_auth = False
|
needs_auth = False
|
||||||
|
|
||||||
default_model = "gpt-4o-mini-2024-07-18"
|
default_model = "gpt-4o-mini"
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
vision_models = [default_model, "gpt-4o-mini"]
|
vision_models = [default_model, "gpt-4o-mini"]
|
||||||
model_aliases = {"gpt-4o-mini": "gpt-4o-mini-2024-07-18"}
|
|
||||||
@@ -5,7 +5,7 @@ import json
|
|||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from ..requests import StreamSession, raise_for_status
|
from ..requests import StreamSession, raise_for_status
|
||||||
from ..providers.response import Reasoning, FinishReason
|
from ..providers.response import FinishReason
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
|
||||||
API_URL = "https://www.perplexity.ai/socket.io/"
|
API_URL = "https://www.perplexity.ai/socket.io/"
|
||||||
@@ -87,22 +87,7 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
data = json.loads(message[2:])[1]
|
data = json.loads(message[2:])[1]
|
||||||
new_content = data["output"][last_message:]
|
yield data["output"][last_message:]
|
||||||
|
|
||||||
if "<think>" in new_content:
|
|
||||||
yield Reasoning(None, "thinking")
|
|
||||||
is_thinking = True
|
|
||||||
if "</think>" in new_content:
|
|
||||||
new_content = new_content.split("</think>", 1)
|
|
||||||
yield Reasoning(f"{new_content[0]}</think>")
|
|
||||||
yield Reasoning(None, "finished")
|
|
||||||
yield new_content[1]
|
|
||||||
is_thinking = False
|
|
||||||
elif is_thinking:
|
|
||||||
yield Reasoning(new_content)
|
|
||||||
else:
|
|
||||||
yield new_content
|
|
||||||
|
|
||||||
last_message = len(data["output"])
|
last_message = len(data["output"])
|
||||||
if data["final"]:
|
if data["final"]:
|
||||||
yield FinishReason("stop")
|
yield FinishReason("stop")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from urllib.parse import quote_plus
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from .helper import filter_none
|
from .helper import filter_none, format_image_prompt
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..typing import AsyncResult, Messages, ImagesType
|
from ..typing import AsyncResult, Messages, ImagesType
|
||||||
from ..image import to_data_uri
|
from ..image import to_data_uri
|
||||||
@@ -127,7 +127,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if model in cls.image_models:
|
if model in cls.image_models:
|
||||||
yield await cls._generate_image(
|
yield await cls._generate_image(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=messages[-1]["content"] if prompt is None else prompt,
|
prompt=format_image_prompt(messages, prompt),
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
except MissingRequirementsError:
|
except MissingRequirementsError:
|
||||||
pass
|
pass
|
||||||
if not cookies or "afUserId" not in cookies:
|
if not cookies or "afUserId" not in cookies:
|
||||||
browser = await get_nodriver(proxy=proxy)
|
browser, stop_browser = await get_nodriver(proxy=proxy)
|
||||||
try:
|
try:
|
||||||
page = await browser.get(cls.url)
|
page = await browser.get(cls.url)
|
||||||
await page.wait_for('[data-testid="user-profile-button"]', timeout=900)
|
await page.wait_for('[data-testid="user-profile-button"]', timeout=900)
|
||||||
@@ -86,7 +86,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
cookies[c.name] = c.value
|
cookies[c.name] = c.value
|
||||||
await page.close()
|
await page.close()
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
async with StreamSession(
|
async with StreamSession(
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
impersonate="chrome",
|
impersonate="chrome",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .deprecated import *
|
|||||||
from .needs_auth import *
|
from .needs_auth import *
|
||||||
from .not_working import *
|
from .not_working import *
|
||||||
from .local import *
|
from .local import *
|
||||||
|
from .hf import HuggingFace, HuggingChat, HuggingFaceAPI, HuggingFaceInference
|
||||||
from .hf_space import HuggingSpace
|
from .hf_space import HuggingSpace
|
||||||
from .mini_max import HailuoAI, MiniMax
|
from .mini_max import HailuoAI, MiniMax
|
||||||
from .template import OpenaiTemplate, BackendApi
|
from .template import OpenaiTemplate, BackendApi
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ except ImportError:
|
|||||||
has_curl_cffi = False
|
has_curl_cffi = False
|
||||||
|
|
||||||
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
||||||
from ..helper import format_prompt
|
from ..helper import format_prompt, format_image_prompt, get_last_user_message
|
||||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||||
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
|
||||||
from ...image import to_bytes
|
from ...image import to_bytes
|
||||||
@@ -22,6 +22,7 @@ from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
|
|||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
|
from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
|
||||||
from ...cookies import get_cookies
|
from ...cookies import get_cookies
|
||||||
|
from .models import default_model, fallback_models, image_models, model_aliases
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class Conversation(JsonConversation):
|
class Conversation(JsonConversation):
|
||||||
@@ -30,52 +31,14 @@ class Conversation(JsonConversation):
|
|||||||
|
|
||||||
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
url = "https://huggingface.co/chat"
|
url = "https://huggingface.co/chat"
|
||||||
|
|
||||||
working = True
|
working = True
|
||||||
use_nodriver = True
|
use_nodriver = True
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
needs_auth = True
|
needs_auth = True
|
||||||
|
default_model = default_model
|
||||||
default_model = "Qwen/Qwen2.5-72B-Instruct"
|
model_aliases = model_aliases
|
||||||
default_image_model = "black-forest-labs/FLUX.1-dev"
|
image_models = image_models
|
||||||
image_models = [
|
|
||||||
default_image_model,
|
|
||||||
"black-forest-labs/FLUX.1-schnell",
|
|
||||||
]
|
|
||||||
fallback_models = [
|
|
||||||
default_model,
|
|
||||||
'meta-llama/Llama-3.3-70B-Instruct',
|
|
||||||
'CohereForAI/c4ai-command-r-plus-08-2024',
|
|
||||||
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
|
|
||||||
'Qwen/QwQ-32B-Preview',
|
|
||||||
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
|
|
||||||
'Qwen/Qwen2.5-Coder-32B-Instruct',
|
|
||||||
'meta-llama/Llama-3.2-11B-Vision-Instruct',
|
|
||||||
'mistralai/Mistral-Nemo-Instruct-2407',
|
|
||||||
'microsoft/Phi-3.5-mini-instruct',
|
|
||||||
] + image_models
|
|
||||||
model_aliases = {
|
|
||||||
### Chat ###
|
|
||||||
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
||||||
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
|
|
||||||
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
|
|
||||||
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
"qwq-32b": "Qwen/QwQ-32B-Preview",
|
|
||||||
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
|
||||||
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
||||||
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
|
||||||
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
|
|
||||||
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
|
|
||||||
### Image ###
|
|
||||||
"flux-dev": "black-forest-labs/FLUX.1-dev",
|
|
||||||
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
|
|
||||||
### Used in other providers ###
|
|
||||||
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
|
|
||||||
"gemma-2-27b": "google/gemma-2-27b-it",
|
|
||||||
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
|
|
||||||
"qvq-72b": "Qwen/QVQ-72B-Preview",
|
|
||||||
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls):
|
def get_models(cls):
|
||||||
@@ -94,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
|
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
|
||||||
cls.models = [*cls.fallback_models]
|
cls.models = [*fallback_models]
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -108,9 +71,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
headers=DEFAULT_HEADERS
|
headers=DEFAULT_HEADERS
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
|
||||||
if login_url:
|
|
||||||
yield RequestLogin(cls.__name__, login_url)
|
|
||||||
yield AuthResult(
|
yield AuthResult(
|
||||||
**await get_args_from_nodriver(
|
**await get_args_from_nodriver(
|
||||||
cls.url,
|
cls.url,
|
||||||
@@ -143,6 +104,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
if model not in conversation.models:
|
if model not in conversation.models:
|
||||||
conversationId = cls.create_conversation(session, model)
|
conversationId = cls.create_conversation(session, model)
|
||||||
|
debug.log(f"Conversation created: {json.dumps(conversationId[8:] + '...')}")
|
||||||
messageId = cls.fetch_message_id(session, conversationId)
|
messageId = cls.fetch_message_id(session, conversationId)
|
||||||
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
|
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
|
||||||
if return_conversation:
|
if return_conversation:
|
||||||
@@ -151,9 +113,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
else:
|
else:
|
||||||
conversationId = conversation.models[model]["conversationId"]
|
conversationId = conversation.models[model]["conversationId"]
|
||||||
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
|
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
|
||||||
inputs = messages[-1]["content"]
|
inputs = get_last_user_message(messages)
|
||||||
|
|
||||||
debug.log(f"Use: {json.dumps(conversation.models[model])}")
|
|
||||||
|
|
||||||
settings = {
|
settings = {
|
||||||
"inputs": inputs,
|
"inputs": inputs,
|
||||||
@@ -204,8 +164,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
break
|
break
|
||||||
elif line["type"] == "file":
|
elif line["type"] == "file":
|
||||||
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
|
url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
|
||||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
yield ImageResponse(url, format_image_prompt(messages, prompt), options={"cookies": auth_result.cookies})
|
||||||
yield ImageResponse(url, alt=prompt, options={"cookies": auth_result.cookies})
|
|
||||||
elif line["type"] == "webSearch" and "sources" in line:
|
elif line["type"] == "webSearch" and "sources" in line:
|
||||||
sources = Sources(line["sources"])
|
sources = Sources(line["sources"])
|
||||||
elif line["type"] == "title":
|
elif line["type"] == "title":
|
||||||
@@ -226,6 +185,8 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
response = session.post('https://huggingface.co/chat/conversation', json=json_data)
|
response = session.post('https://huggingface.co/chat/conversation', json=json_data)
|
||||||
if response.status_code == 401:
|
if response.status_code == 401:
|
||||||
raise MissingAuthError(response.text)
|
raise MissingAuthError(response.text)
|
||||||
|
if response.status_code == 400:
|
||||||
|
raise ResponseError(f"{response.text}: Model: {model}")
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
return response.json().get('conversationId')
|
return response.json().get('conversationId')
|
||||||
|
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ..template import OpenaiTemplate
|
from ..template.OpenaiTemplate import OpenaiTemplate
|
||||||
from .HuggingChat import HuggingChat
|
from .models import model_aliases
|
||||||
from ...providers.types import Messages
|
from ...providers.types import Messages
|
||||||
|
from .HuggingChat import HuggingChat
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class HuggingFaceAPI(OpenaiTemplate):
|
class HuggingFaceAPI(OpenaiTemplate):
|
||||||
@@ -16,13 +17,16 @@ class HuggingFaceAPI(OpenaiTemplate):
|
|||||||
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
|
vision_models = [default_vision_model, "Qwen/Qwen2-VL-7B-Instruct"]
|
||||||
model_aliases = HuggingChat.model_aliases
|
model_aliases = model_aliases
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, **kwargs):
|
def get_models(cls, **kwargs):
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
HuggingChat.get_models()
|
HuggingChat.get_models()
|
||||||
cls.models = list(set(HuggingChat.text_models + cls.vision_models))
|
cls.models = HuggingChat.text_models.copy()
|
||||||
|
for model in cls.vision_models:
|
||||||
|
if model not in cls.models:
|
||||||
|
cls.models.append(model)
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -11,37 +11,32 @@ from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
|
|||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason
|
from ...providers.response import FinishReason
|
||||||
from ...image import ImageResponse
|
from ...image import ImageResponse
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
from .models import default_model, default_image_model, model_aliases, fallback_models
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
from .HuggingChat import HuggingChat
|
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
|
||||||
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|
||||||
url = "https://huggingface.co"
|
url = "https://huggingface.co"
|
||||||
login_url = "https://huggingface.co/settings/tokens"
|
|
||||||
working = True
|
working = True
|
||||||
supports_message_history = True
|
|
||||||
default_model = HuggingChat.default_model
|
default_model = default_model
|
||||||
default_image_model = HuggingChat.default_image_model
|
default_image_model = default_image_model
|
||||||
model_aliases = HuggingChat.model_aliases
|
model_aliases = model_aliases
|
||||||
extra_models = [
|
|
||||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
|
||||||
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
|
||||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls) -> list[str]:
|
def get_models(cls) -> list[str]:
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
|
models = fallback_models.copy()
|
||||||
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
|
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
|
||||||
models = [model["id"] for model in requests.get(url).json()]
|
extra_models = [model["id"] for model in requests.get(url).json()]
|
||||||
models.extend(cls.extra_models)
|
extra_models.sort()
|
||||||
models.sort()
|
models.extend([model for model in extra_models if model not in models])
|
||||||
if not cls.image_models:
|
if not cls.image_models:
|
||||||
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
||||||
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
|
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
|
||||||
cls.image_models.sort()
|
cls.image_models.sort()
|
||||||
models.extend(cls.image_models)
|
models.extend([model for model in cls.image_models if model not in models])
|
||||||
cls.models = list(set(models))
|
cls.models = models
|
||||||
return cls.models
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -85,7 +80,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
payload = None
|
payload = None
|
||||||
if cls.get_models() and model in cls.image_models:
|
if cls.get_models() and model in cls.image_models:
|
||||||
stream = False
|
stream = False
|
||||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
prompt = format_image_prompt(messages, prompt)
|
||||||
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
|
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
|
||||||
else:
|
else:
|
||||||
params = {
|
params = {
|
||||||
62
g4f/Provider/hf/__init__.py
Normal file
62
g4f/Provider/hf/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from ...typing import AsyncResult, Messages
|
||||||
|
from ...providers.response import ImageResponse
|
||||||
|
from ...errors import ModelNotSupportedError
|
||||||
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
from .HuggingChat import HuggingChat
|
||||||
|
from .HuggingFaceAPI import HuggingFaceAPI
|
||||||
|
from .HuggingFaceInference import HuggingFaceInference
|
||||||
|
from ... import debug
|
||||||
|
|
||||||
|
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
url = "https://huggingface.co"
|
||||||
|
login_url = "https://huggingface.co/settings/tokens"
|
||||||
|
working = True
|
||||||
|
supports_message_history = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_models(cls) -> list[str]:
|
||||||
|
if not cls.models:
|
||||||
|
cls.models = HuggingFaceInference.get_models()
|
||||||
|
cls.image_models = HuggingFaceInference.image_models
|
||||||
|
return cls.models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
if "api_key" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
|
||||||
|
try:
|
||||||
|
is_started = False
|
||||||
|
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
|
||||||
|
if isinstance(chunk, (str, ImageResponse)):
|
||||||
|
is_started = True
|
||||||
|
yield chunk
|
||||||
|
if is_started:
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
if is_started:
|
||||||
|
raise e
|
||||||
|
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
|
||||||
|
if not cls.image_models:
|
||||||
|
cls.get_models()
|
||||||
|
if model in cls.image_models:
|
||||||
|
if "api_key" not in kwargs:
|
||||||
|
async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
except ModelNotSupportedError:
|
||||||
|
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
|
||||||
|
yield chunk
|
||||||
46
g4f/Provider/hf/models.py
Normal file
46
g4f/Provider/hf/models.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
default_model = "Qwen/Qwen2.5-72B-Instruct"
|
||||||
|
default_image_model = "black-forest-labs/FLUX.1-dev"
|
||||||
|
image_models = [
|
||||||
|
default_image_model,
|
||||||
|
"black-forest-labs/FLUX.1-schnell",
|
||||||
|
]
|
||||||
|
fallback_models = [
|
||||||
|
default_model,
|
||||||
|
'meta-llama/Llama-3.3-70B-Instruct',
|
||||||
|
'CohereForAI/c4ai-command-r-plus-08-2024',
|
||||||
|
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
|
||||||
|
'Qwen/QwQ-32B-Preview',
|
||||||
|
'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF',
|
||||||
|
'Qwen/Qwen2.5-Coder-32B-Instruct',
|
||||||
|
'meta-llama/Llama-3.2-11B-Vision-Instruct',
|
||||||
|
'mistralai/Mistral-Nemo-Instruct-2407',
|
||||||
|
'microsoft/Phi-3.5-mini-instruct',
|
||||||
|
] + image_models
|
||||||
|
model_aliases = {
|
||||||
|
### Chat ###
|
||||||
|
"qwen-2.5-72b": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
|
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
|
||||||
|
"command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024",
|
||||||
|
"deepseek-r1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
||||||
|
"qwq-32b": "Qwen/QwQ-32B-Preview",
|
||||||
|
"nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
||||||
|
"qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||||
|
"llama-3.2-11b": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
|
||||||
|
### Image ###
|
||||||
|
"flux": "black-forest-labs/FLUX.1-dev",
|
||||||
|
"flux-dev": "black-forest-labs/FLUX.1-dev",
|
||||||
|
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
|
||||||
|
### Used in other providers ###
|
||||||
|
"qwen-2-vl-7b": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
"gemma-2-27b": "google/gemma-2-27b-it",
|
||||||
|
"qwen-2-72b": "Qwen/Qwen2-72B-Instruct",
|
||||||
|
"qvq-72b": "Qwen/QVQ-72B-Preview",
|
||||||
|
"sd-3.5": "stabilityai/stable-diffusion-3.5-large",
|
||||||
|
}
|
||||||
|
extra_models = [
|
||||||
|
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
||||||
|
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||||
|
]
|
||||||
@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
|
|||||||
from ...image import ImageResponse, ImagePreview
|
from ...image import ImageResponse, ImagePreview
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
|
||||||
class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
|
class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://black-forest-labs-flux-1-dev.hf.space"
|
url = "https://black-forest-labs-flux-1-dev.hf.space"
|
||||||
@@ -44,7 +45,7 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
prompt = format_image_prompt(messages, prompt)
|
||||||
data = {
|
data = {
|
||||||
"data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
|
"data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from ...image import ImageResponse
|
|||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
|
||||||
class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://black-forest-labs-flux-1-schnell.hf.space"
|
url = "https://black-forest-labs-flux-1-schnell.hf.space"
|
||||||
@@ -42,7 +43,7 @@ class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
height = max(32, height - (height % 8))
|
height = max(32, height - (height % 8))
|
||||||
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = messages[-1]["content"]
|
prompt = format_image_prompt(messages)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"data": [
|
"data": [
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from aiohttp import ClientSession, FormData
|
|||||||
from ...typing import AsyncResult, Messages
|
from ...typing import AsyncResult, Messages
|
||||||
from ...requests import raise_for_status
|
from ...requests import raise_for_status
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt
|
from ..helper import format_prompt, get_last_user_message
|
||||||
from ...providers.response import JsonConversation, TitleGeneration
|
from ...providers.response import JsonConversation, TitleGeneration
|
||||||
|
|
||||||
class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
|
class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
@@ -58,7 +58,7 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
) as session:
|
) as session:
|
||||||
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
|
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
|
||||||
messages = [message for message in messages if message["role"] != "system"]
|
messages = [message for message in messages if message["role"] != "system"]
|
||||||
inputs = format_prompt(messages) if conversation is None else messages[-1]["content"]
|
inputs = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||||
if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
|
if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
|
||||||
data = {"model": model, "preprompt": system_prompt}
|
data = {"model": model, "preprompt": system_prompt}
|
||||||
async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
|
async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
|
||||||
@@ -78,7 +78,6 @@ class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
data = node["data"]
|
data = node["data"]
|
||||||
message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
|
message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
|
||||||
data = FormData()
|
data = FormData()
|
||||||
inputs = messages[-1]["content"]
|
|
||||||
data.add_field(
|
data.add_field(
|
||||||
"data",
|
"data",
|
||||||
json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),
|
json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import urllib.parse
|
|||||||
|
|
||||||
from ...typing import AsyncResult, Messages, Cookies
|
from ...typing import AsyncResult, Messages, Cookies
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt
|
from ..helper import format_prompt, format_image_prompt
|
||||||
from ...providers.response import JsonConversation, ImageResponse
|
from ...providers.response import JsonConversation, ImageResponse, Notification
|
||||||
from ...requests.aiohttp import StreamSession, StreamResponse
|
from ...requests.aiohttp import StreamSession, StreamResponse
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...cookies import get_cookies
|
from ...cookies import get_cookies
|
||||||
@@ -38,7 +38,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"headers": {
|
"headers": {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"x-zerogpu-token": conversation.zerogpu_token,
|
"x-zerogpu-token": conversation.zerogpu_token,
|
||||||
"x-zerogpu-uuid": conversation.uuid,
|
"x-zerogpu-uuid": conversation.zerogpu_uuid,
|
||||||
"referer": cls.referer,
|
"referer": cls.referer,
|
||||||
},
|
},
|
||||||
"json": {"data":[None,prompt,42,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash},
|
"json": {"data":[None,prompt,42,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash},
|
||||||
@@ -48,7 +48,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
"headers": {
|
"headers": {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"x-zerogpu-token": conversation.zerogpu_token,
|
"x-zerogpu-token": conversation.zerogpu_token,
|
||||||
"x-zerogpu-uuid": conversation.uuid,
|
"x-zerogpu-uuid": conversation.zerogpu_uuid,
|
||||||
"referer": cls.referer,
|
"referer": cls.referer,
|
||||||
},
|
},
|
||||||
"json": {"data":[prompt,1234,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash},
|
"json": {"data":[prompt,1234,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash},
|
||||||
@@ -82,33 +82,14 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
method = "image"
|
method = "image"
|
||||||
|
|
||||||
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
||||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
prompt = format_image_prompt(messages, prompt)
|
||||||
|
|
||||||
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
||||||
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
||||||
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
||||||
user_uuid = None if conversation is None else getattr(conversation, "user_uuid", None)
|
zerogpu_uuid, zerogpu_token = await get_zerogpu_token(session, conversation, cookies)
|
||||||
zerogpu_token = "[object Object]"
|
|
||||||
|
|
||||||
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
|
|
||||||
if cookies:
|
|
||||||
# Get current UTC time + 10 minutes
|
|
||||||
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
|
|
||||||
encoded_dt = urllib.parse.quote(dt)
|
|
||||||
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
|
|
||||||
zerogpu_token = (await response.json())
|
|
||||||
zerogpu_token = zerogpu_token["token"]
|
|
||||||
if user_uuid is None:
|
|
||||||
async with session.get(cls.url, cookies=cookies) as response:
|
|
||||||
match = re.search(r""token":"([^&]+?)"", await response.text())
|
|
||||||
if match:
|
|
||||||
zerogpu_token = match.group(1)
|
|
||||||
match = re.search(r""sessionUuid":"([^&]+?)"", await response.text())
|
|
||||||
if match:
|
|
||||||
user_uuid = match.group(1)
|
|
||||||
|
|
||||||
if conversation is None or not hasattr(conversation, "session_hash"):
|
if conversation is None or not hasattr(conversation, "session_hash"):
|
||||||
conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, uuid=user_uuid)
|
conversation = JsonConversation(session_hash=session_hash, zerogpu_token=zerogpu_token, zerogpu_uuid=zerogpu_uuid)
|
||||||
conversation.zerogpu_token = zerogpu_token
|
conversation.zerogpu_token = zerogpu_token
|
||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield conversation
|
yield conversation
|
||||||
@@ -124,7 +105,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
try:
|
try:
|
||||||
json_data = json.loads(decoded_line[6:])
|
json_data = json.loads(decoded_line[6:])
|
||||||
if json_data.get('msg') == 'log':
|
if json_data.get('msg') == 'log':
|
||||||
debug.log(json_data["log"])
|
yield Notification(json_data["log"])
|
||||||
|
|
||||||
if json_data.get('msg') == 'process_generating':
|
if json_data.get('msg') == 'process_generating':
|
||||||
if 'output' in json_data and 'data' in json_data['output']:
|
if 'output' in json_data and 'data' in json_data['output']:
|
||||||
@@ -141,4 +122,27 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
break
|
break
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
debug.log("Could not parse JSON:", decoded_line)
|
debug.log("Could not parse JSON:", decoded_line)
|
||||||
|
|
||||||
|
async def get_zerogpu_token(session: StreamSession, conversation: JsonConversation, cookies: Cookies = None):
|
||||||
|
zerogpu_uuid = None if conversation is None else getattr(conversation, "zerogpu_uuid", None)
|
||||||
|
zerogpu_token = "[object Object]"
|
||||||
|
|
||||||
|
cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
|
||||||
|
if zerogpu_uuid is None:
|
||||||
|
async with session.get(Janus_Pro_7B.url, cookies=cookies) as response:
|
||||||
|
match = re.search(r""token":"([^&]+?)"", await response.text())
|
||||||
|
if match:
|
||||||
|
zerogpu_token = match.group(1)
|
||||||
|
match = re.search(r""sessionUuid":"([^&]+?)"", await response.text())
|
||||||
|
if match:
|
||||||
|
zerogpu_uuid = match.group(1)
|
||||||
|
if cookies:
|
||||||
|
# Get current UTC time + 10 minutes
|
||||||
|
dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
|
||||||
|
encoded_dt = urllib.parse.quote(dt)
|
||||||
|
async with session.get(f"https://huggingface.co/api/spaces/deepseek-ai/Janus-Pro-7B/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
|
||||||
|
zerogpu_token = (await response.json())
|
||||||
|
zerogpu_token = zerogpu_token["token"]
|
||||||
|
|
||||||
|
return zerogpu_uuid, zerogpu_token
|
||||||
@@ -8,6 +8,7 @@ from ...typing import AsyncResult, Messages
|
|||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt
|
from ..helper import format_prompt
|
||||||
from ...providers.response import JsonConversation, Reasoning
|
from ...providers.response import JsonConversation, Reasoning
|
||||||
|
from ..helper import get_last_user_message
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
|
class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
@@ -41,7 +42,7 @@ class Qwen_Qwen_2_5M_Demo(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield JsonConversation(session_hash=session_hash)
|
yield JsonConversation(session_hash=session_hash)
|
||||||
|
|
||||||
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
|
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'accept': '*/*',
|
'accept': '*/*',
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
|
|||||||
from ...image import ImageResponse, ImagePreview
|
from ...image import ImageResponse, ImagePreview
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
|
||||||
class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
|
class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
|
url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
|
||||||
@@ -42,7 +43,7 @@ class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
prompt = format_image_prompt(messages, prompt)
|
||||||
data = {
|
data = {
|
||||||
"data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
|
"data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
|
|||||||
from ...image import ImageResponse
|
from ...image import ImageResponse
|
||||||
from ...errors import ResponseError
|
from ...errors import ResponseError
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
|
from ..helper import format_image_prompt
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
|
|
||||||
class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
@@ -37,10 +38,7 @@ class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
width = max(32, width - (width % 8))
|
width = max(32, width - (width % 8))
|
||||||
height = max(32, height - (height % 8))
|
height = max(32, height - (height % 8))
|
||||||
|
prompt = format_image_prompt(messages, prompt)
|
||||||
if prompt is None:
|
|
||||||
prompt = messages[-1]["content"]
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"data": [
|
"data": [
|
||||||
prompt,
|
prompt,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from ..base_provider import AsyncAuthedProvider, ProviderModelMixin, format_prom
|
|||||||
from ..mini_max.crypt import CallbackResults, get_browser_callback, generate_yy_header, get_body_to_yy
|
from ..mini_max.crypt import CallbackResults, get_browser_callback, generate_yy_header, get_body_to_yy
|
||||||
from ...requests import get_args_from_nodriver, raise_for_status
|
from ...requests import get_args_from_nodriver, raise_for_status
|
||||||
from ...providers.response import AuthResult, JsonConversation, RequestLogin, TitleGeneration
|
from ...providers.response import AuthResult, JsonConversation, RequestLogin, TitleGeneration
|
||||||
|
from ..helper import get_last_user_message
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class Conversation(JsonConversation):
|
class Conversation(JsonConversation):
|
||||||
@@ -62,7 +63,7 @@ class HailuoAI(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
conversation = None
|
conversation = None
|
||||||
form_data = {
|
form_data = {
|
||||||
"characterID": 1 if conversation is None else getattr(conversation, "characterID", 1),
|
"characterID": 1 if conversation is None else getattr(conversation, "characterID", 1),
|
||||||
"msgContent": format_prompt(messages) if conversation is None else messages[-1]["content"],
|
"msgContent": format_prompt(messages) if conversation is None else get_last_user_message(messages),
|
||||||
"chatID": 0 if conversation is None else getattr(conversation, "chatID", 0),
|
"chatID": 0 if conversation is None else getattr(conversation, "chatID", 0),
|
||||||
"searchMode": 0
|
"searchMode": 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class Anthropic(OpenaiAPI):
|
|||||||
"text": messages[-1]["content"]
|
"text": messages[-1]["content"]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
system = "\n".join([message for message in messages if message.get("role") == "system"])
|
system = "\n".join([message["content"] for message in messages if message.get("role") == "system"])
|
||||||
if system:
|
if system:
|
||||||
messages = [message for message in messages if message.get("role") != "system"]
|
messages = [message for message in messages if message.get("role") != "system"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from ...errors import MissingAuthError
|
|||||||
from ...typing import AsyncResult, Messages, Cookies
|
from ...typing import AsyncResult, Messages, Cookies
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..bing.create_images import create_images, create_session
|
from ..bing.create_images import create_images, create_session
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
|
||||||
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
label = "Microsoft Designer in Bing"
|
label = "Microsoft Designer in Bing"
|
||||||
@@ -35,7 +36,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
session = BingCreateImages(cookies, proxy, api_key)
|
session = BingCreateImages(cookies, proxy, api_key)
|
||||||
yield await session.generate(messages[-1]["content"] if prompt is None else prompt)
|
yield await session.generate(format_image_prompt(messages, prompt))
|
||||||
|
|
||||||
async def generate(self, prompt: str) -> ImageResponse:
|
async def generate(self, prompt: str) -> ImageResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from ...typing import AsyncResult, Messages
|
|||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...image import ImageResponse
|
from ...image import ImageResponse
|
||||||
from ..template import OpenaiTemplate
|
from ..template import OpenaiTemplate
|
||||||
|
from ..helper import format_image_prompt
|
||||||
|
|
||||||
class DeepInfra(OpenaiTemplate):
|
class DeepInfra(OpenaiTemplate):
|
||||||
url = "https://deepinfra.com"
|
url = "https://deepinfra.com"
|
||||||
@@ -55,7 +56,7 @@ class DeepInfra(OpenaiTemplate):
|
|||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if model in cls.get_image_models():
|
if model in cls.get_image_models():
|
||||||
yield cls.create_async_image(
|
yield cls.create_async_image(
|
||||||
messages[-1]["content"] if prompt is None else prompt,
|
format_image_prompt(messages, prompt),
|
||||||
model,
|
model,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ try:
|
|||||||
challenge = self._get_pow_challenge()
|
challenge = self._get_pow_challenge()
|
||||||
pow_response = self.pow_solver.solve_challenge(challenge)
|
pow_response = self.pow_solver.solve_challenge(challenge)
|
||||||
headers = self._get_headers(pow_response)
|
headers = self._get_headers(pow_response)
|
||||||
|
|
||||||
response = requests.request(
|
response = requests.request(
|
||||||
method=method,
|
method=method,
|
||||||
url=url,
|
url=url,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from ...requests.aiohttp import get_connector
|
|||||||
from ...requests import get_nodriver
|
from ...requests import get_nodriver
|
||||||
from ...errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
from ...image import ImageResponse, to_bytes
|
from ...image import ImageResponse, to_bytes
|
||||||
|
from ..helper import get_last_user_message
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
REQUEST_HEADERS = {
|
REQUEST_HEADERS = {
|
||||||
@@ -78,7 +79,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if debug.logging:
|
if debug.logging:
|
||||||
print("Skip nodriver login in Gemini provider")
|
print("Skip nodriver login in Gemini provider")
|
||||||
return
|
return
|
||||||
browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
|
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
|
||||||
try:
|
try:
|
||||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||||
if login_url:
|
if login_url:
|
||||||
@@ -91,7 +92,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
await page.close()
|
await page.close()
|
||||||
cls._cookies = cookies
|
cls._cookies = cookies
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
@@ -107,7 +108,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
language: str = "en",
|
language: str = "en",
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
|
prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
|
||||||
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
|
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
|
||||||
base_connector = get_connector(connector, proxy)
|
base_connector = get_connector(connector, proxy)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConv
|
|||||||
from ...typing import AsyncResult, Messages, Cookies
|
from ...typing import AsyncResult, Messages, Cookies
|
||||||
from ...requests.raise_for_status import raise_for_status
|
from ...requests.raise_for_status import raise_for_status
|
||||||
from ...requests.aiohttp import get_connector
|
from ...requests.aiohttp import get_connector
|
||||||
from ...providers.helper import format_prompt
|
from ...providers.helper import format_prompt, get_last_user_message
|
||||||
from ...cookies import get_cookies
|
from ...cookies import get_cookies
|
||||||
|
|
||||||
class Conversation(BaseConversation):
|
class Conversation(BaseConversation):
|
||||||
@@ -78,7 +78,7 @@ class GithubCopilot(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
conversation_id = (await response.json()).get("thread_id")
|
conversation_id = (await response.json()).get("thread_id")
|
||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield Conversation(conversation_id)
|
yield Conversation(conversation_id)
|
||||||
content = messages[-1]["content"]
|
content = get_last_user_message(messages)
|
||||||
else:
|
else:
|
||||||
content = format_prompt(messages)
|
content = format_prompt(messages)
|
||||||
json_data = {
|
json_data = {
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from ...requests.aiohttp import get_connector
|
|||||||
from ...requests import get_nodriver
|
from ...requests import get_nodriver
|
||||||
from ..Copilot import get_headers, get_har_files
|
from ..Copilot import get_headers, get_har_files
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import get_random_hex
|
from ..helper import get_random_hex, format_image_prompt
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
@@ -39,7 +39,7 @@ class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
image_size = "1024x1024"
|
image_size = "1024x1024"
|
||||||
if model != cls.default_image_model and model in cls.image_models:
|
if model != cls.default_image_model and model in cls.image_models:
|
||||||
image_size = model
|
image_size = model
|
||||||
yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy)
|
yield await cls.generate(format_image_prompt(messages, prompt), image_size, proxy)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
|
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
|
||||||
@@ -143,7 +143,7 @@ def readHAR(url: str) -> tuple[str, str]:
|
|||||||
return api_key, user_agent
|
return api_key, user_agent
|
||||||
|
|
||||||
async def get_access_token_and_user_agent(url: str, proxy: str = None):
|
async def get_access_token_and_user_agent(url: str, proxy: str = None):
|
||||||
browser = await get_nodriver(proxy=proxy, user_data_dir="designer")
|
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="designer")
|
||||||
try:
|
try:
|
||||||
page = await browser.get(url)
|
page = await browser.get(url)
|
||||||
user_agent = await page.evaluate("navigator.userAgent")
|
user_agent = await page.evaluate("navigator.userAgent")
|
||||||
@@ -168,4 +168,4 @@ async def get_access_token_and_user_agent(url: str, proxy: str = None):
|
|||||||
await page.close()
|
await page.close()
|
||||||
return access_token, user_agent
|
return access_token, user_agent
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
@@ -98,7 +98,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
default_model = "auto"
|
default_model = "auto"
|
||||||
default_image_model = "dall-e-3"
|
default_image_model = "dall-e-3"
|
||||||
image_models = [default_image_model]
|
image_models = [default_image_model]
|
||||||
text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1", "o1-preview", "o1-mini"]
|
text_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "o1", "o1-preview", "o1-mini"]
|
||||||
vision_models = text_models
|
vision_models = text_models
|
||||||
models = text_models + image_models
|
models = text_models + image_models
|
||||||
synthesize_content_type = "audio/mpeg"
|
synthesize_content_type = "audio/mpeg"
|
||||||
@@ -598,7 +598,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def nodriver_auth(cls, proxy: str = None):
|
async def nodriver_auth(cls, proxy: str = None):
|
||||||
browser = await get_nodriver(proxy=proxy)
|
browser, stop_browser = await get_nodriver(proxy=proxy)
|
||||||
try:
|
try:
|
||||||
page = browser.main_tab
|
page = browser.main_tab
|
||||||
def on_request(event: nodriver.cdp.network.RequestWillBeSent):
|
def on_request(event: nodriver.cdp.network.RequestWillBeSent):
|
||||||
@@ -648,7 +648,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||||||
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
|
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
|
||||||
cls._set_api_key(cls._api_key)
|
cls._set_api_key(cls._api_key)
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_headers() -> Dict[str, str]:
|
def get_default_headers() -> Dict[str, str]:
|
||||||
|
|||||||
@@ -13,9 +13,6 @@ from .GigaChat import GigaChat
|
|||||||
from .GithubCopilot import GithubCopilot
|
from .GithubCopilot import GithubCopilot
|
||||||
from .GlhfChat import GlhfChat
|
from .GlhfChat import GlhfChat
|
||||||
from .Groq import Groq
|
from .Groq import Groq
|
||||||
from .HuggingChat import HuggingChat
|
|
||||||
from .HuggingFace import HuggingFace
|
|
||||||
from .HuggingFaceAPI import HuggingFaceAPI
|
|
||||||
from .MetaAI import MetaAI
|
from .MetaAI import MetaAI
|
||||||
from .MetaAIAccount import MetaAIAccount
|
from .MetaAIAccount import MetaAIAccount
|
||||||
from .MicrosoftDesigner import MicrosoftDesigner
|
from .MicrosoftDesigner import MicrosoftDesigner
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ from urllib.parse import quote_plus
|
|||||||
from ...typing import Messages, AsyncResult
|
from ...typing import Messages, AsyncResult
|
||||||
from ...requests import StreamSession
|
from ...requests import StreamSession
|
||||||
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ...providers.response import ProviderInfo, JsonConversation, PreviewResponse, SynthesizeData, TitleGeneration, RequestLogin
|
from ...providers.response import *
|
||||||
from ...providers.response import Parameters, FinishReason, Usage, Reasoning
|
from ...image import get_image_extension
|
||||||
from ...errors import ModelNotSupportedError
|
from ...errors import ModelNotSupportedError
|
||||||
from ..needs_auth.OpenaiAccount import OpenaiAccount
|
from ..needs_auth.OpenaiAccount import OpenaiAccount
|
||||||
from ..needs_auth.HuggingChat import HuggingChat
|
from ..hf.HuggingChat import HuggingChat
|
||||||
from ... import debug
|
from ... import debug
|
||||||
|
|
||||||
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
@@ -98,8 +98,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
yield PreviewResponse(data[data_type])
|
yield PreviewResponse(data[data_type])
|
||||||
elif data_type == "content":
|
elif data_type == "content":
|
||||||
def on_image(match):
|
def on_image(match):
|
||||||
extension = match.group(3).split(".")[-1].split("?")[0]
|
extension = get_image_extension(match.group(3))
|
||||||
extension = "" if not extension or len(extension) > 4 else f".{extension}"
|
|
||||||
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}"
|
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}"
|
||||||
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}"
|
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}"
|
||||||
return f"[](/images/{filename})"
|
return f"[](/images/{filename})"
|
||||||
@@ -119,6 +118,6 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
elif data_type == "finish":
|
elif data_type == "finish":
|
||||||
yield FinishReason(data[data_type]["reason"])
|
yield FinishReason(data[data_type]["reason"])
|
||||||
elif data_type == "log":
|
elif data_type == "log":
|
||||||
debug.log(data[data_type])
|
yield DebugResponse.from_dict(data[data_type])
|
||||||
else:
|
else:
|
||||||
debug.log(f"Unknown data: ({data_type}) {data}")
|
yield DebugResponse.from_dict(data)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import json
|
|||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..helper import filter_none
|
from ..helper import filter_none, format_image_prompt
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
||||||
from ...requests import StreamSession, raise_for_status
|
from ...requests import StreamSession, raise_for_status
|
||||||
from ...providers.response import FinishReason, ToolCalls, Usage, Reasoning, ImageResponse
|
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse
|
||||||
from ...errors import MissingAuthError, ResponseError
|
from ...errors import MissingAuthError, ResponseError
|
||||||
from ...image import to_data_uri
|
from ...image import to_data_uri
|
||||||
from ... import debug
|
from ... import debug
|
||||||
@@ -82,7 +82,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
# Proxy for image generation feature
|
# Proxy for image generation feature
|
||||||
if model and model in cls.image_models:
|
if model and model in cls.image_models:
|
||||||
data = {
|
data = {
|
||||||
"prompt": messages[-1]["content"] if prompt is None else prompt,
|
"prompt": format_image_prompt(messages, prompt),
|
||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.ssl) as response:
|
||||||
@@ -154,17 +154,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||||||
delta = delta.lstrip()
|
delta = delta.lstrip()
|
||||||
if delta:
|
if delta:
|
||||||
first = False
|
first = False
|
||||||
if is_thinking:
|
yield delta
|
||||||
if "</think>" in delta:
|
|
||||||
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
|
|
||||||
is_thinking = 0
|
|
||||||
else:
|
|
||||||
yield Reasoning(delta)
|
|
||||||
elif "<think>" in delta:
|
|
||||||
is_thinking = time.time()
|
|
||||||
yield Reasoning(None, "Is thinking...")
|
|
||||||
else:
|
|
||||||
yield delta
|
|
||||||
if "usage" in data and data["usage"]:
|
if "usage" in data and data["usage"]:
|
||||||
yield Usage(**data["usage"])
|
yield Usage(**data["usage"])
|
||||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import uvicorn
|
|||||||
import secrets
|
import secrets
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from email.utils import formatdate
|
from email.utils import formatdate
|
||||||
import os.path
|
import os.path
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -538,6 +539,10 @@ class Api:
|
|||||||
response_data = provider_handler.synthesize({**request.query_params})
|
response_data = provider_handler.synthesize({**request.query_params})
|
||||||
content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
|
content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
|
||||||
return StreamingResponse(response_data, media_type=content_type)
|
return StreamingResponse(response_data, media_type=content_type)
|
||||||
|
|
||||||
|
@self.app.get("/json/{filename}")
|
||||||
|
async def get_json(filename, request: Request):
|
||||||
|
return ""
|
||||||
|
|
||||||
@self.app.get("/images/{filename}", response_class=FileResponse, responses={
|
@self.app.get("/images/{filename}", response_class=FileResponse, responses={
|
||||||
HTTP_200_OK: {"content": {"image/*": {}}},
|
HTTP_200_OK: {"content": {"image/*": {}}},
|
||||||
@@ -550,15 +555,18 @@ class Api:
|
|||||||
stat_result.st_size = 0
|
stat_result.st_size = 0
|
||||||
if os.path.isfile(target):
|
if os.path.isfile(target):
|
||||||
stat_result.st_size = os.stat(target).st_size
|
stat_result.st_size = os.stat(target).st_size
|
||||||
stat_result.st_mtime = int(f"{filename.split('_')[0]}")
|
stat_result.st_mtime = int(f"{filename.split('_')[0]}") if filename.startswith("1") else 0
|
||||||
|
headers = {
|
||||||
|
"cache-control": "public, max-age=31536000",
|
||||||
|
"content-type": f"image/{ext.replace('jpg', 'jepg')}",
|
||||||
|
"content-length": str(stat_result.st_size),
|
||||||
|
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
|
||||||
|
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"',
|
||||||
|
}
|
||||||
response = FileResponse(
|
response = FileResponse(
|
||||||
target,
|
target,
|
||||||
media_type=f"image/{ext.replace('jpg', 'jepg')}",
|
headers=headers,
|
||||||
headers={
|
filename=filename,
|
||||||
"content-length": str(stat_result.st_size),
|
|
||||||
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
|
|
||||||
"etag": f'"{hashlib.md5(filename.encode()).hexdigest()}"'
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if_none_match = request.headers["if-none-match"]
|
if_none_match = request.headers["if-none-match"]
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import_error = e
|
import_error = e
|
||||||
|
|
||||||
def get_gui_app(demo: bool = False):
|
def get_gui_app(demo: bool = False, api: bool = False):
|
||||||
if import_error is not None:
|
if import_error is not None:
|
||||||
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
|
raise MissingRequirementsError(f'Install "gui" requirements | pip install -U g4f[gui]\n{import_error}')
|
||||||
app = create_app()
|
app = create_app()
|
||||||
app.demo = demo
|
app.demo = demo
|
||||||
|
app.api = api
|
||||||
|
|
||||||
site = Website(app)
|
site = Website(app)
|
||||||
for route in site.routes:
|
for route in site.routes:
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from ...tools.run_tools import iter_run_tools
|
|||||||
from ...Provider import ProviderUtils, __providers__
|
from ...Provider import ProviderUtils, __providers__
|
||||||
from ...providers.base_provider import ProviderModelMixin
|
from ...providers.base_provider import ProviderModelMixin
|
||||||
from ...providers.retry_provider import BaseRetryProvider
|
from ...providers.retry_provider import BaseRetryProvider
|
||||||
from ...providers.response import BaseConversation, JsonConversation, FinishReason, Usage, Reasoning, PreviewResponse
|
from ...providers.helper import format_image_prompt
|
||||||
from ...providers.response import SynthesizeData, TitleGeneration, RequestLogin, Parameters, ProviderInfo
|
from ...providers.response import *
|
||||||
from ... import version, models
|
from ... import version, models
|
||||||
from ... import ChatCompletion, get_model_and_provider
|
from ... import ChatCompletion, get_model_and_provider
|
||||||
from ... import debug
|
from ... import debug
|
||||||
@@ -183,13 +183,14 @@ class Api:
|
|||||||
logger.exception(chunk)
|
logger.exception(chunk)
|
||||||
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
||||||
elif isinstance(chunk, (PreviewResponse, ImagePreview)):
|
elif isinstance(chunk, (PreviewResponse, ImagePreview)):
|
||||||
yield self._format_json("preview", chunk.to_string())
|
yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
|
||||||
elif isinstance(chunk, ImageResponse):
|
elif isinstance(chunk, ImageResponse):
|
||||||
images = chunk
|
images = chunk
|
||||||
if download_images:
|
if download_images or chunk.get("cookies"):
|
||||||
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy))
|
alt = format_image_prompt(kwargs.get("messages"))
|
||||||
|
images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy, alt))
|
||||||
images = ImageResponse(images, chunk.alt)
|
images = ImageResponse(images, chunk.alt)
|
||||||
yield self._format_json("content", str(images))
|
yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
|
||||||
elif isinstance(chunk, SynthesizeData):
|
elif isinstance(chunk, SynthesizeData):
|
||||||
yield self._format_json("synthesize", chunk.get_dict())
|
yield self._format_json("synthesize", chunk.get_dict())
|
||||||
elif isinstance(chunk, TitleGeneration):
|
elif isinstance(chunk, TitleGeneration):
|
||||||
@@ -203,7 +204,11 @@ class Api:
|
|||||||
elif isinstance(chunk, Usage):
|
elif isinstance(chunk, Usage):
|
||||||
yield self._format_json("usage", chunk.get_dict())
|
yield self._format_json("usage", chunk.get_dict())
|
||||||
elif isinstance(chunk, Reasoning):
|
elif isinstance(chunk, Reasoning):
|
||||||
yield self._format_json("reasoning", token=chunk.token, status=chunk.status)
|
yield self._format_json("reasoning", token=chunk.token, status=chunk.status, is_thinking=chunk.is_thinking)
|
||||||
|
elif isinstance(chunk, DebugResponse):
|
||||||
|
yield self._format_json("log", chunk.get_dict())
|
||||||
|
elif isinstance(chunk, Notification):
|
||||||
|
yield self._format_json("notification", chunk.message)
|
||||||
else:
|
else:
|
||||||
yield self._format_json("content", str(chunk))
|
yield self._format_json("content", str(chunk))
|
||||||
if debug.logs:
|
if debug.logs:
|
||||||
@@ -219,6 +224,15 @@ class Api:
|
|||||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||||
|
|
||||||
def _format_json(self, response_type: str, content = None, **kwargs):
|
def _format_json(self, response_type: str, content = None, **kwargs):
|
||||||
|
# Make sure it get be formated as JSON
|
||||||
|
if content is not None and not isinstance(content, (str, dict)):
|
||||||
|
content = str(content)
|
||||||
|
kwargs = {
|
||||||
|
key: value
|
||||||
|
if value is isinstance(value, (str, dict))
|
||||||
|
else str(value)
|
||||||
|
for key, value in kwargs.items()
|
||||||
|
if isinstance(key, str)}
|
||||||
if content is not None:
|
if content is not None:
|
||||||
return {
|
return {
|
||||||
'type': response_type,
|
'type': response_type,
|
||||||
|
|||||||
13
g4f/image.py
13
g4f/image.py
@@ -7,6 +7,8 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
from urllib.parse import quote_plus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from aiohttp import ClientSession, ClientError
|
from aiohttp import ClientSession, ClientError
|
||||||
@@ -239,10 +241,16 @@ def to_data_uri(image: ImageType) -> str:
|
|||||||
def ensure_images_dir():
|
def ensure_images_dir():
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def get_image_extension(image: str) -> str:
|
||||||
|
if match := re.search(r"(\.(?:jpe?g|png|webp))[$?&]", image):
|
||||||
|
return match.group(1)
|
||||||
|
return ".jpg"
|
||||||
|
|
||||||
async def copy_images(
|
async def copy_images(
|
||||||
images: list[str],
|
images: list[str],
|
||||||
cookies: Optional[Cookies] = None,
|
cookies: Optional[Cookies] = None,
|
||||||
proxy: Optional[str] = None,
|
proxy: Optional[str] = None,
|
||||||
|
alt: str = None,
|
||||||
add_url: bool = True,
|
add_url: bool = True,
|
||||||
target: str = None,
|
target: str = None,
|
||||||
ssl: bool = None
|
ssl: bool = None
|
||||||
@@ -256,7 +264,10 @@ async def copy_images(
|
|||||||
) as session:
|
) as session:
|
||||||
async def copy_image(image: str, target: str = None) -> str:
|
async def copy_image(image: str, target: str = None) -> str:
|
||||||
if target is None or len(images) > 1:
|
if target is None or len(images) > 1:
|
||||||
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
|
hash = hashlib.sha256(image.encode()).hexdigest()
|
||||||
|
target = f"{quote_plus('+'.join(alt.split()[:10])[:100], '')}_{hash}" if alt else str(uuid.uuid4())
|
||||||
|
target = f"{int(time.time())}_{target}{get_image_extension(image)}"
|
||||||
|
target = os.path.join(images_dir, target)
|
||||||
try:
|
try:
|
||||||
if image.startswith("data:"):
|
if image.startswith("data:"):
|
||||||
with open(target, "wb") as f:
|
with open(target, "wb") as f:
|
||||||
|
|||||||
@@ -69,9 +69,13 @@ def to_sync_generator(generator: AsyncIterator, stream: bool = True) -> Iterator
|
|||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
# Helper function to convert a synchronous iterator to an async iterator
|
# Helper function to convert a synchronous iterator to an async iterator
|
||||||
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
|
async def to_async_iterator(iterator) -> AsyncIterator:
|
||||||
try:
|
if hasattr(iterator, '__aiter__'):
|
||||||
async for item in iterator:
|
async for item in iterator:
|
||||||
yield item
|
yield item
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
for item in iterator:
|
||||||
|
yield item
|
||||||
except TypeError:
|
except TypeError:
|
||||||
yield await iterator
|
yield await iterator
|
||||||
@@ -27,6 +27,23 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
|
|||||||
return formatted
|
return formatted
|
||||||
return f"{formatted}\nAssistant:"
|
return f"{formatted}\nAssistant:"
|
||||||
|
|
||||||
|
def get_last_user_message(messages: Messages) -> str:
|
||||||
|
user_messages = []
|
||||||
|
last_message = None if len(messages) == 0 else messages[-1]
|
||||||
|
while last_message is not None and messages:
|
||||||
|
last_message = messages.pop()
|
||||||
|
if last_message["role"] == "user":
|
||||||
|
if isinstance(last_message["content"], str):
|
||||||
|
user_messages.append(last_message["content"].strip())
|
||||||
|
else:
|
||||||
|
return "\n".join(user_messages[::-1])
|
||||||
|
return "\n".join(user_messages[::-1])
|
||||||
|
|
||||||
|
def format_image_prompt(messages, prompt: str = None) -> str:
|
||||||
|
if prompt is None:
|
||||||
|
return get_last_user_message(messages)
|
||||||
|
return prompt
|
||||||
|
|
||||||
def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
|
def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
|
||||||
prompt = format_prompt(messages)
|
prompt = format_prompt(messages)
|
||||||
start = len(prompt)
|
start = len(prompt)
|
||||||
|
|||||||
@@ -88,44 +88,61 @@ class JsonMixin:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.__dict__ = {}
|
self.__dict__ = {}
|
||||||
|
|
||||||
class FinishReason(ResponseType, JsonMixin):
|
class HiddenResponse(ResponseType):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
class FinishReason(JsonMixin, HiddenResponse):
|
||||||
def __init__(self, reason: str) -> None:
|
def __init__(self, reason: str) -> None:
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
|
|
||||||
def __str__(self) -> str:
|
class ToolCalls(HiddenResponse):
|
||||||
return ""
|
|
||||||
|
|
||||||
class ToolCalls(ResponseType):
|
|
||||||
def __init__(self, list: list):
|
def __init__(self, list: list):
|
||||||
self.list = list
|
self.list = list
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def get_list(self) -> list:
|
def get_list(self) -> list:
|
||||||
return self.list
|
return self.list
|
||||||
|
|
||||||
class Usage(ResponseType, JsonMixin):
|
class Usage(JsonMixin, HiddenResponse):
|
||||||
def __str__(self) -> str:
|
pass
|
||||||
return ""
|
|
||||||
|
|
||||||
class AuthResult(JsonMixin):
|
class AuthResult(JsonMixin, HiddenResponse):
|
||||||
def __str__(self) -> str:
|
pass
|
||||||
return ""
|
|
||||||
|
|
||||||
class TitleGeneration(ResponseType):
|
class TitleGeneration(HiddenResponse):
|
||||||
def __init__(self, title: str) -> None:
|
def __init__(self, title: str) -> None:
|
||||||
self.title = title
|
self.title = title
|
||||||
|
|
||||||
|
class DebugResponse(JsonMixin, HiddenResponse):
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> None:
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, data: str) -> None:
|
||||||
|
return cls(error=data)
|
||||||
|
|
||||||
|
class Notification(ResponseType):
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
self.message = message
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return ""
|
return f"{self.message}\n"
|
||||||
|
|
||||||
class Reasoning(ResponseType):
|
class Reasoning(ResponseType):
|
||||||
def __init__(self, token: str = None, status: str = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str = None,
|
||||||
|
status: str = None,
|
||||||
|
is_thinking: str = None
|
||||||
|
) -> None:
|
||||||
self.token = token
|
self.token = token
|
||||||
self.status = status
|
self.status = status
|
||||||
|
self.is_thinking = is_thinking
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
if self.is_thinking is not None:
|
||||||
|
return self.is_thinking
|
||||||
return f"{self.status}\n" if self.token is None else self.token
|
return f"{self.status}\n" if self.token is None else self.token
|
||||||
|
|
||||||
class Sources(ResponseType):
|
class Sources(ResponseType):
|
||||||
@@ -154,14 +171,11 @@ class BaseConversation(ResponseType):
|
|||||||
class JsonConversation(BaseConversation, JsonMixin):
|
class JsonConversation(BaseConversation, JsonMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SynthesizeData(ResponseType, JsonMixin):
|
class SynthesizeData(HiddenResponse, JsonMixin):
|
||||||
def __init__(self, provider: str, data: dict):
|
def __init__(self, provider: str, data: dict):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
class RequestLogin(ResponseType):
|
class RequestLogin(ResponseType):
|
||||||
def __init__(self, label: str, login_url: str) -> None:
|
def __init__(self, label: str, login_url: str) -> None:
|
||||||
self.label = label
|
self.label = label
|
||||||
@@ -197,13 +211,10 @@ class ImagePreview(ImageResponse):
|
|||||||
def to_string(self):
|
def to_string(self):
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
class PreviewResponse(ResponseType):
|
class PreviewResponse(HiddenResponse):
|
||||||
def __init__(self, data: str):
|
def __init__(self, data: str):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
@@ -211,6 +222,5 @@ class Parameters(ResponseType, JsonMixin):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
class ProviderInfo(ResponseType, JsonMixin):
|
class ProviderInfo(JsonMixin, HiddenResponse):
|
||||||
def __str__(self):
|
pass
|
||||||
return ""
|
|
||||||
@@ -87,7 +87,7 @@ async def get_args_from_nodriver(
|
|||||||
callback: callable = None,
|
callback: callable = None,
|
||||||
cookies: Cookies = None
|
cookies: Cookies = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
browser = await get_nodriver(proxy=proxy, timeout=timeout)
|
browser, stop_browser = await get_nodriver(proxy=proxy, timeout=timeout)
|
||||||
try:
|
try:
|
||||||
if debug.logging:
|
if debug.logging:
|
||||||
print(f"Open nodriver with url: {url}")
|
print(f"Open nodriver with url: {url}")
|
||||||
@@ -117,7 +117,7 @@ async def get_args_from_nodriver(
|
|||||||
"proxy": proxy,
|
"proxy": proxy,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
browser.stop()
|
stop_browser()
|
||||||
|
|
||||||
def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
|
def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
|
||||||
if cookies is None:
|
if cookies is None:
|
||||||
@@ -170,11 +170,10 @@ async def get_nodriver(
|
|||||||
browser = util.get_registered_instances().pop()
|
browser = util.get_registered_instances().pop()
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
stop = browser.stop
|
|
||||||
def on_stop():
|
def on_stop():
|
||||||
try:
|
try:
|
||||||
stop()
|
if browser.connection:
|
||||||
|
browser.stop()
|
||||||
finally:
|
finally:
|
||||||
lock_file.unlink(missing_ok=True)
|
lock_file.unlink(missing_ok=True)
|
||||||
browser.stop = on_stop
|
return browser, on_stop
|
||||||
return browser
|
|
||||||
@@ -3,12 +3,14 @@ from __future__ import annotations
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Callable, AsyncIterator
|
from typing import Optional, Callable, AsyncIterator
|
||||||
|
|
||||||
from ..typing import Messages
|
from ..typing import Messages
|
||||||
from ..providers.helper import filter_none
|
from ..providers.helper import filter_none
|
||||||
from ..providers.asyncio import to_async_iterator
|
from ..providers.asyncio import to_async_iterator
|
||||||
|
from ..providers.response import Reasoning
|
||||||
from ..providers.types import ProviderType
|
from ..providers.types import ProviderType
|
||||||
from ..cookies import get_cookies_dir
|
from ..cookies import get_cookies_dir
|
||||||
from .web_search import do_search, get_search_message
|
from .web_search import do_search, get_search_message
|
||||||
@@ -147,4 +149,26 @@ def iter_run_tools(
|
|||||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||||
|
|
||||||
return iter_callback(model=model, messages=messages, provider=provider, **kwargs)
|
is_thinking = 0
|
||||||
|
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
|
||||||
|
if not isinstance(chunk, str):
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
if "<think>" in chunk:
|
||||||
|
chunk = chunk.split("<think>", 1)
|
||||||
|
yield chunk[0]
|
||||||
|
yield Reasoning(is_thinking="<think>")
|
||||||
|
yield Reasoning(chunk[1])
|
||||||
|
yield Reasoning(None, "Is thinking...")
|
||||||
|
is_thinking = time.time()
|
||||||
|
if "</think>" in chunk:
|
||||||
|
chunk = chunk.split("</think>", 1)
|
||||||
|
yield Reasoning(chunk[0])
|
||||||
|
yield Reasoning(is_thinking="</think>")
|
||||||
|
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
|
||||||
|
yield chunk[1]
|
||||||
|
is_thinking = 0
|
||||||
|
elif is_thinking:
|
||||||
|
yield Reasoning(chunk)
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
|||||||
@@ -192,6 +192,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
|
|||||||
return SearchResults(formatted_results, used_words)
|
return SearchResults(formatted_results, used_words)
|
||||||
|
|
||||||
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
|
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
|
||||||
|
if instructions and instructions in prompt:
|
||||||
|
return prompt # We have already added search results
|
||||||
if query is None:
|
if query is None:
|
||||||
query = spacy_get_keywords(prompt)
|
query = spacy_get_keywords(prompt)
|
||||||
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")
|
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode(errors="ignore")
|
||||||
|
|||||||
Reference in New Issue
Block a user