mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-10-17 22:00:46 +08:00
Improve tools support in OpenaiTemplate and GeminiPro
Update models in DDG, PerplexityLabs, Gemini Fix issues with curl_cffi in new versions
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -66,3 +66,4 @@ bench.py
|
||||
to-reverse.txt
|
||||
g4f/Provider/OpenaiChat2.py
|
||||
generated_images/
|
||||
projects/windows/
|
@@ -1,5 +1,9 @@
|
||||
import unittest
|
||||
|
||||
import g4f.debug
|
||||
|
||||
g4f.debug.version_check = False
|
||||
|
||||
from .asyncio import *
|
||||
from .backend import *
|
||||
from .main import *
|
||||
|
@@ -6,8 +6,12 @@ from g4f.errors import VersionNotFoundError
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
class TestGetLastProvider(unittest.TestCase):
|
||||
|
||||
def test_get_latest_version(self):
|
||||
current_version = g4f.version.utils.current_version
|
||||
if current_version is not None:
|
||||
self.assertIsInstance(g4f.version.utils.current_version, str)
|
||||
try:
|
||||
self.assertIsInstance(g4f.version.utils.latest_version, str)
|
||||
except VersionNotFoundError:
|
||||
pass
|
@@ -42,7 +42,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"gpt-4": "gpt-4o-mini",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"mixtral-8x7b": "mistralai/Mistral-Small-24B-Instruct-2501",
|
||||
}
|
||||
|
||||
last_request_time = 0
|
||||
|
@@ -5,7 +5,8 @@ import json
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..requests import StreamSession, raise_for_status
|
||||
from ..providers.response import FinishReason
|
||||
from ..errors import ResponseError
|
||||
from ..providers.response import FinishReason, Sources
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
|
||||
API_URL = "https://www.perplexity.ai/socket.io/"
|
||||
@@ -15,10 +16,11 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://labs.perplexity.ai"
|
||||
working = True
|
||||
|
||||
default_model = "sonar-pro"
|
||||
default_model = "r1-1776"
|
||||
models = [
|
||||
"sonar",
|
||||
default_model,
|
||||
"sonar-pro",
|
||||
"sonar",
|
||||
"sonar-reasoning",
|
||||
"sonar-reasoning-pro",
|
||||
]
|
||||
@@ -32,19 +34,10 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0",
|
||||
"Accept": "*/*",
|
||||
"Accept-Language": "de,en-US;q=0.7,en;q=0.3",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Origin": cls.url,
|
||||
"Connection": "keep-alive",
|
||||
"Referer": f"{cls.url}/",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "same-site",
|
||||
"TE": "trailers",
|
||||
}
|
||||
async with StreamSession(headers=headers, proxies={"all": proxy}) as session:
|
||||
async with StreamSession(headers=headers, proxy=proxy, impersonate="chrome") as session:
|
||||
t = format(random.getrandbits(32), "08x")
|
||||
async with session.get(
|
||||
f"{API_URL}?EIO=4&transport=polling&t={t}"
|
||||
@@ -60,17 +53,22 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
) as response:
|
||||
await raise_for_status(response)
|
||||
assert await response.text() == "OK"
|
||||
async with session.get(
|
||||
f"{API_URL}?EIO=4&transport=polling&t={t}&sid={sid}",
|
||||
data=post_data
|
||||
) as response:
|
||||
await raise_for_status(response)
|
||||
assert (await response.text()).startswith("40")
|
||||
async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws:
|
||||
await ws.send_str("2probe")
|
||||
assert(await ws.receive_str() == "3probe")
|
||||
await ws.send_str("5")
|
||||
assert(await ws.receive_str())
|
||||
assert(await ws.receive_str() == "6")
|
||||
message_data = {
|
||||
"version": "2.16",
|
||||
"version": "2.18",
|
||||
"source": "default",
|
||||
"model": model,
|
||||
"messages": messages
|
||||
"messages": messages,
|
||||
}
|
||||
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
|
||||
last_message = 0
|
||||
@@ -82,12 +80,15 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
await ws.send_str("3")
|
||||
continue
|
||||
try:
|
||||
if last_message == 0 and model == cls.default_model:
|
||||
yield "<think>"
|
||||
data = json.loads(message[2:])[1]
|
||||
yield data["output"][last_message:]
|
||||
last_message = len(data["output"])
|
||||
if data["final"]:
|
||||
if data["citations"]:
|
||||
yield Sources(data["citations"])
|
||||
yield FinishReason("stop")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {message} - {e}")
|
||||
raise RuntimeError(f"Message: {message}") from e
|
||||
raise ResponseError(f"Message: {message}") from e
|
||||
|
@@ -124,7 +124,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
raise
|
||||
|
||||
if not cache and seed is None:
|
||||
seed = random.randint(0, 10000)
|
||||
seed = random.randint(1000, 999999)
|
||||
|
||||
if model in cls.image_models:
|
||||
async for chunk in cls._generate_image(
|
||||
@@ -182,7 +182,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
|
||||
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
|
||||
prefix = f"{model}_{seed}" if seed is not None else model
|
||||
url = f"{cls.image_api_endpoint}prompt/{prefix}_{quote_plus(prompt)}?{query}"
|
||||
yield ImagePreview(url, prompt)
|
||||
|
||||
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
|
||||
|
@@ -39,14 +39,15 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
default_model = default_model
|
||||
model_aliases = model_aliases
|
||||
image_models = image_models
|
||||
text_models = fallback_models
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
if not cls.models:
|
||||
try:
|
||||
text = requests.get(cls.url).text
|
||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
||||
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
|
||||
text = re.sub(r',parameters:{[^}]+?}', '', text)
|
||||
text = text.replace('void 0', 'null')
|
||||
def add_quotation_mark(match):
|
||||
return f'{match.group(1)}"{match.group(2)}":'
|
||||
@@ -56,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
cls.models = cls.text_models + cls.image_models
|
||||
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
|
||||
except Exception as e:
|
||||
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
|
||||
debug.error(f"{cls.__name__}: Error reading models: {type(e).__name__}: {e}")
|
||||
cls.models = [*fallback_models]
|
||||
return cls.models
|
||||
|
||||
|
@@ -10,8 +10,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
|
||||
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ImageResponse
|
||||
from ..helper import format_image_prompt
|
||||
from .models import default_model, default_image_model, model_aliases, fallback_models
|
||||
from ..helper import format_image_prompt, get_last_user_message
|
||||
from .models import default_model, default_image_model, model_aliases, fallback_models, image_models
|
||||
from ... import debug
|
||||
|
||||
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
@@ -22,6 +22,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
default_model = default_model
|
||||
default_image_model = default_image_model
|
||||
model_aliases = model_aliases
|
||||
image_models = image_models
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
@@ -29,14 +30,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
models = fallback_models.copy()
|
||||
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
if response.ok:
|
||||
extra_models = [model["id"] for model in response.json()]
|
||||
extra_models.sort()
|
||||
models.extend([model for model in extra_models if model not in models])
|
||||
if not cls.image_models:
|
||||
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
if response.ok:
|
||||
cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
|
||||
cls.image_models.sort()
|
||||
models.extend([model for model in cls.image_models if model not in models])
|
||||
@@ -57,6 +57,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt: str = None,
|
||||
action: str = None,
|
||||
extra_data: dict = {},
|
||||
seed: int = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
try:
|
||||
@@ -104,7 +105,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if pipeline_tag == "text-to-image":
|
||||
stream = False
|
||||
inputs = format_image_prompt(messages, prompt)
|
||||
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
|
||||
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **extra_data}}
|
||||
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
||||
model_type = None
|
||||
if "config" in model_data and "model_type" in model_data["config"]:
|
||||
@@ -116,11 +117,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if len(messages) > 6:
|
||||
messages = messages[:3] + messages[-3:]
|
||||
else:
|
||||
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
|
||||
messages = [m for m in messages if m["role"] == "system"] + [get_last_user_message(messages)]
|
||||
inputs = get_inputs(messages, model_data, model_type, do_continue)
|
||||
debug.log(f"New len: {len(inputs)}")
|
||||
if model_type == "gpt2" and max_tokens >= 1024:
|
||||
params["max_new_tokens"] = 512
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||
else:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
|
||||
|
@@ -48,7 +48,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
except Exception as e:
|
||||
if is_started:
|
||||
raise e
|
||||
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
|
||||
if not cls.image_models:
|
||||
cls.get_models()
|
||||
if model in cls.image_models:
|
||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import uuid
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import urllib.parse
|
||||
|
||||
@@ -88,7 +88,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
if seed is None:
|
||||
seed = int(time.time())
|
||||
seed = random.randint(1000, 999999)
|
||||
|
||||
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
|
||||
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
||||
|
@@ -32,9 +32,7 @@ class CopilotAccount(AsyncAuthedProvider, Copilot):
|
||||
except NoValidHarFileError as h:
|
||||
debug.log(f"Copilot: {h}")
|
||||
if has_nodriver:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield RequestLogin(cls.label, login_url)
|
||||
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||
Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||
else:
|
||||
raise h
|
||||
|
@@ -65,7 +65,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
default_image_model = default_model
|
||||
default_vision_model = default_model
|
||||
image_models = [default_image_model]
|
||||
models = [default_model, "gemini-1.5-flash", "gemini-1.5-pro"]
|
||||
models = [default_model, "gemini-2.0"]
|
||||
|
||||
synthesize_content_type = "audio/vnd.wav"
|
||||
|
||||
@@ -179,7 +179,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
|
||||
content = response_part[4][0][1][0]
|
||||
except (ValueError, KeyError, TypeError, IndexError) as e:
|
||||
debug.log(f"{cls.__name__}:{e.__class__.__name__}:{e}")
|
||||
debug.error(f"{cls.__name__} {type(e).__name__}: {e}")
|
||||
continue
|
||||
match = re.search(r'\[Imagen of (.*?)\]', content)
|
||||
if match:
|
||||
|
@@ -51,7 +51,9 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
]
|
||||
cls.models.sort()
|
||||
except Exception as e:
|
||||
debug.log(e)
|
||||
debug.error(e)
|
||||
if api_key is not None:
|
||||
raise MissingAuthError("Invalid API key")
|
||||
return cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@@ -111,7 +113,17 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"topK": kwargs.get("top_k"),
|
||||
},
|
||||
"tools": [{
|
||||
"functionDeclarations": tools
|
||||
"function_declarations": [{
|
||||
"name": tool["function"]["name"],
|
||||
"description": tool["function"]["description"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {key: {
|
||||
"type": value["type"],
|
||||
"description": value["title"]
|
||||
} for key, value in tool["function"]["parameters"]["properties"].items()}
|
||||
},
|
||||
} for tool in tools]
|
||||
}] if tools else None
|
||||
}
|
||||
system_prompt = "\n".join(
|
||||
|
@@ -322,8 +322,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
try:
|
||||
image_requests = await cls.upload_images(session, auth_result, images) if images else None
|
||||
except Exception as e:
|
||||
debug.log("OpenaiChat: Upload image failed")
|
||||
debug.log(f"{e.__class__.__name__}: {e}")
|
||||
debug.error("OpenaiChat: Upload image failed")
|
||||
debug.error(e)
|
||||
model = cls.get_model(model)
|
||||
if conversation is None:
|
||||
conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
|
||||
@@ -360,12 +360,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
# if auth_result.arkose_token is None:
|
||||
# raise MissingAuthError("No arkose token found in .har file")
|
||||
if "proofofwork" in chat_requirements:
|
||||
if getattr(auth_result, "proof_token", None) is None:
|
||||
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
|
||||
user_agent = getattr(auth_result, "headers", {}).get("user-agent")
|
||||
proof_token = getattr(auth_result, "proof_token", None)
|
||||
if proof_token is None:
|
||||
auth_result.proof_token = get_config(user_agent)
|
||||
proofofwork = generate_proof_token(
|
||||
**chat_requirements["proofofwork"],
|
||||
user_agent=getattr(auth_result, "headers", {}).get("user-agent"),
|
||||
proof_token=getattr(auth_result, "proof_token", None)
|
||||
user_agent=user_agent,
|
||||
proof_token=proof_token
|
||||
)
|
||||
[debug.log(text) for text in (
|
||||
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
|
||||
@@ -425,8 +427,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
) as response:
|
||||
cls._update_request_args(auth_result, session)
|
||||
if response.status == 403:
|
||||
auth_result.proof_token = None
|
||||
cls.request_config.proof_token = None
|
||||
raise MissingAuthError("Access token is not valid")
|
||||
await raise_for_status(response)
|
||||
buffer = u""
|
||||
async for line in response.iter_lines():
|
||||
@@ -469,14 +471,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
break
|
||||
yield Parameters(**{
|
||||
"action": "continue" if conversation.finish_reason == "max_tokens" else "variant",
|
||||
"conversation": conversation.get_dict(),
|
||||
"proof_token": cls.request_config.proof_token,
|
||||
"cookies": cls._cookies,
|
||||
"headers": cls._headers,
|
||||
"web_search": web_search,
|
||||
})
|
||||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
|
@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
if cls.sort_models:
|
||||
cls.models.sort()
|
||||
except Exception as e:
|
||||
debug.log(e)
|
||||
debug.error(e)
|
||||
return cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
prompt: str = None,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
tools: Optional[list] = None,
|
||||
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "", "reasoning_effort", "logit_bias"],
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
@@ -112,6 +112,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
}
|
||||
]
|
||||
messages[-1] = last_message
|
||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
model=model,
|
||||
@@ -120,7 +121,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**extra_parameters,
|
||||
**extra_data
|
||||
)
|
||||
if api_endpoint is None:
|
||||
|
@@ -588,7 +588,7 @@ class Api:
|
||||
target=target)
|
||||
debug.log(f"Image copied from {source_url}")
|
||||
except Exception as e:
|
||||
debug.log(f"{type(e).__name__}: Download failed: {source_url}\n{e}")
|
||||
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
|
||||
return RedirectResponse(url=source_url)
|
||||
if not os.path.isfile(target):
|
||||
return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND)
|
||||
|
@@ -18,7 +18,7 @@ from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import to_sync_generator
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
@@ -103,14 +103,14 @@ def iter_response(
|
||||
|
||||
idx += 1
|
||||
if usage is None:
|
||||
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
|
||||
usage = Usage(completion_tokens=idx, total_tokens=idx)
|
||||
|
||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||
|
||||
if stream:
|
||||
yield ChatCompletionChunk.model_construct(
|
||||
None, finish_reason, completion_id, int(time.time()),
|
||||
usage=usage.get_dict()
|
||||
usage=usage
|
||||
)
|
||||
else:
|
||||
if response_format is not None and "type" in response_format:
|
||||
@@ -118,7 +118,8 @@ def iter_response(
|
||||
content = filter_json(content)
|
||||
yield ChatCompletion.model_construct(
|
||||
content, finish_reason, completion_id, int(time.time()),
|
||||
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
|
||||
usage=UsageModel.model_construct(**usage.get_dict()),
|
||||
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
|
||||
)
|
||||
|
||||
# Synchronous iter_append_model_and_provider function
|
||||
@@ -186,7 +187,7 @@ async def async_iter_response(
|
||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||
|
||||
if usage is None:
|
||||
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
|
||||
usage = Usage(completion_tokens=idx, total_tokens=idx)
|
||||
|
||||
if stream:
|
||||
yield ChatCompletionChunk.model_construct(
|
||||
@@ -199,7 +200,8 @@ async def async_iter_response(
|
||||
content = filter_json(content)
|
||||
yield ChatCompletion.model_construct(
|
||||
content, finish_reason, completion_id, int(time.time()),
|
||||
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
|
||||
usage=UsageModel.model_construct(**usage.get_dict()),
|
||||
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
|
||||
)
|
||||
finally:
|
||||
await safe_aclose(response)
|
||||
@@ -363,7 +365,7 @@ class Images:
|
||||
break
|
||||
except Exception as e:
|
||||
error = e
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
||||
@@ -458,7 +460,7 @@ class Images:
|
||||
break
|
||||
except Exception as e:
|
||||
error = e
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
||||
@@ -583,7 +585,7 @@ class AsyncCompletions:
|
||||
messages: Messages,
|
||||
model: str,
|
||||
**kwargs
|
||||
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
|
||||
) -> AsyncIterator[ChatCompletionChunk]:
|
||||
return self.create(messages, model, stream=True, **kwargs)
|
||||
|
||||
class AsyncImages(Images):
|
||||
|
@@ -5,9 +5,6 @@ from time import time
|
||||
|
||||
from .helper import filter_none
|
||||
|
||||
ToolCalls = Optional[List[Dict[str, Any]]]
|
||||
Usage = Optional[Dict[str, int]]
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, Field
|
||||
except ImportError:
|
||||
@@ -29,6 +26,40 @@ class BaseModel(BaseModel):
|
||||
return super().model_construct(**data)
|
||||
return cls.construct(**data)
|
||||
|
||||
class UsageModel(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: Optional[Dict[str, Any]]
|
||||
completion_tokens_details: Optional[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, prompt_tokens=0, completion_tokens=0, total_tokens=0, prompt_tokens_details=None, completion_tokens_details=None, **kwargs):
|
||||
return super().model_construct(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
class ToolFunctionModel(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
class ToolCallModel(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
function: ToolFunctionModel
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, function=None, **kwargs):
|
||||
return super().model_construct(
|
||||
**kwargs,
|
||||
function=ToolFunctionModel.model_construct(**function),
|
||||
)
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
@@ -36,7 +67,7 @@ class ChatCompletionChunk(BaseModel):
|
||||
model: str
|
||||
provider: Optional[str]
|
||||
choices: List[ChatCompletionDeltaChoice]
|
||||
usage: Usage
|
||||
usage: UsageModel
|
||||
|
||||
@classmethod
|
||||
def model_construct(
|
||||
@@ -45,7 +76,7 @@ class ChatCompletionChunk(BaseModel):
|
||||
finish_reason: str,
|
||||
completion_id: str = None,
|
||||
created: int = None,
|
||||
usage: Usage = None
|
||||
usage: UsageModel = None
|
||||
):
|
||||
return super().model_construct(
|
||||
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
||||
@@ -63,10 +94,10 @@ class ChatCompletionChunk(BaseModel):
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: ToolCalls
|
||||
tool_calls: list[ToolCallModel] = None
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, content: str, tool_calls: ToolCalls = None):
|
||||
def model_construct(cls, content: str, tool_calls: list = None):
|
||||
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
@@ -85,11 +116,7 @@ class ChatCompletion(BaseModel):
|
||||
model: str
|
||||
provider: Optional[str]
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: Usage = Field(default={
|
||||
"prompt_tokens": 0, #prompt_tokens,
|
||||
"completion_tokens": 0, #completion_tokens,
|
||||
"total_tokens": 0, #prompt_tokens + completion_tokens,
|
||||
})
|
||||
usage: UsageModel
|
||||
|
||||
@classmethod
|
||||
def model_construct(
|
||||
@@ -98,8 +125,8 @@ class ChatCompletion(BaseModel):
|
||||
finish_reason: str,
|
||||
completion_id: str = None,
|
||||
created: int = None,
|
||||
tool_calls: ToolCalls = None,
|
||||
usage: Usage = None
|
||||
tool_calls: list[ToolCallModel] = None,
|
||||
usage: UsageModel = None
|
||||
):
|
||||
return super().model_construct(
|
||||
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from .providers.types import ProviderType
|
||||
|
||||
logging: bool = False
|
||||
@@ -8,6 +9,9 @@ version: str = None
|
||||
log_handler: callable = print
|
||||
logs: list = []
|
||||
|
||||
def log(text):
|
||||
def log(text, file = None):
|
||||
if logging:
|
||||
log_handler(text)
|
||||
log_handler(text, file=file)
|
||||
|
||||
def error(error, name: str = None):
|
||||
log(error if isinstance(error, str) else f"{type(error).__name__ if name is None else name}: {error}", file=sys.stderr)
|
@@ -181,7 +181,12 @@
|
||||
<script>
|
||||
(async () => {
|
||||
const isIframe = window.self !== window.top;
|
||||
const backendUrl = "{{backend_url}}";
|
||||
let url = new URL(window.location.href)
|
||||
if (isIframe && backendUrl) {
|
||||
window.location.replace(url.search ? `${backendUrl}?${url.search}` : backendUrl);
|
||||
return;
|
||||
}
|
||||
let params = new URLSearchParams(url.search);
|
||||
if (params.get("__sign")) {
|
||||
localStorage.setItem("zerogpu_token", params.get("__sign"));
|
||||
@@ -232,12 +237,11 @@
|
||||
import * as hub from "@huggingface/hub";
|
||||
import { init } from "@huggingface/space-header";
|
||||
|
||||
const isIframe = window.self !== window.top;
|
||||
const button = document.querySelector('form a.button');
|
||||
if (isIframe) {
|
||||
button.classList.remove('hidden');
|
||||
} else {
|
||||
init("roxky/g4f-space");
|
||||
init("roxky/g4f-space-new");
|
||||
}
|
||||
|
||||
const form = document.querySelector("form");
|
||||
@@ -282,13 +286,11 @@
|
||||
const cache_id = Math.floor(Math.random() * max);
|
||||
let prompt;
|
||||
if (cache_id % 2 == 0) {
|
||||
prompt = `
|
||||
Today is ${new Date().toJSON().slice(0, 10)}.
|
||||
prompt = `Today is ${new Date().toJSON().slice(0, 10)}.
|
||||
Create a single-page HTML screensaver reflecting the current season (based on the date).
|
||||
Avoid using any text.`;
|
||||
} else {
|
||||
prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
|
||||
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
|
||||
}
|
||||
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
|
||||
const text = await response.text()
|
||||
|
@@ -293,7 +293,7 @@
|
||||
</div>
|
||||
<div class="field">
|
||||
<select name="model" id="model">
|
||||
<option value="">Model: Default</option>
|
||||
<option value="" selected="selected">Model: Default</option>
|
||||
<option value="gpt-4">gpt-4</option>
|
||||
<option value="gpt-4o">gpt-4o</option>
|
||||
<option value="gpt-4o-mini">gpt-4o-mini</option>
|
||||
|
@@ -72,13 +72,16 @@ class Api:
|
||||
|
||||
@staticmethod
|
||||
def get_version() -> dict:
|
||||
current_version = None
|
||||
latest_version = None
|
||||
try:
|
||||
current_version = version.utils.current_version
|
||||
latest_version = version.utils.latest_version
|
||||
except VersionNotFoundError:
|
||||
current_version = None
|
||||
pass
|
||||
return {
|
||||
"version": current_version,
|
||||
"latest_version": version.utils.latest_version,
|
||||
"latest_version": latest_version,
|
||||
}
|
||||
|
||||
def serve_images(self, name):
|
||||
@@ -137,10 +140,10 @@ class Api:
|
||||
}
|
||||
|
||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
|
||||
def decorated_log(text: str):
|
||||
def decorated_log(text: str, file = None):
|
||||
debug.logs.append(text)
|
||||
if debug.logging:
|
||||
debug.log_handler(text)
|
||||
debug.log_handler(text, file)
|
||||
debug.log = decorated_log
|
||||
proxy = os.environ.get("G4F_PROXY")
|
||||
provider = kwargs.get("provider")
|
||||
@@ -154,6 +157,7 @@ class Api:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
debug.error(e)
|
||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||
return
|
||||
if not isinstance(provider_handler, BaseRetryProvider):
|
||||
@@ -183,6 +187,7 @@ class Api:
|
||||
yield self._format_json("conversation_id", conversation_id)
|
||||
elif isinstance(chunk, Exception):
|
||||
logger.exception(chunk)
|
||||
debug.error(e)
|
||||
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
|
||||
elif isinstance(chunk, PreviewResponse):
|
||||
yield self._format_json("preview", chunk.to_string())
|
||||
@@ -215,19 +220,18 @@ class Api:
|
||||
yield self._format_json(chunk.type, **chunk.get_dict())
|
||||
else:
|
||||
yield self._format_json("content", str(chunk))
|
||||
if debug.logs:
|
||||
for log in debug.logs:
|
||||
yield self._format_json("log", str(log))
|
||||
debug.logs = []
|
||||
yield from self._yield_logs()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
if debug.logging:
|
||||
debug.log_handler(get_error_message(e))
|
||||
debug.error(e)
|
||||
yield from self._yield_logs()
|
||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||
|
||||
def _yield_logs(self):
|
||||
if debug.logs:
|
||||
for log in debug.logs:
|
||||
yield self._format_json("log", str(log))
|
||||
yield self._format_json("log", log)
|
||||
debug.logs = []
|
||||
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
|
||||
|
||||
def _format_json(self, response_type: str, content = None, **kwargs):
|
||||
if content is not None and isinstance(response_type, str):
|
||||
|
@@ -76,7 +76,7 @@ class Backend_Api(Api):
|
||||
@app.route('/', methods=['GET'])
|
||||
@limiter.exempt
|
||||
def home():
|
||||
return render_template('demo.html')
|
||||
return render_template('demo.html', backend_url=os.environ.get("G4F_BACKEND_URL", ""))
|
||||
else:
|
||||
@app.route('/', methods=['GET'])
|
||||
def home():
|
||||
|
@@ -123,7 +123,7 @@ async def copy_images(
|
||||
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
|
||||
|
||||
except (ClientError, IOError, OSError) as e:
|
||||
debug.log(f"Image processing failed: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"Image processing failed: {type(e).__name__}: {e}")
|
||||
if target_path and os.path.exists(target_path):
|
||||
os.unlink(target_path)
|
||||
return get_source_url(image, image)
|
||||
|
@@ -243,7 +243,7 @@ llama_3_3_70b = Model(
|
||||
mixtral_8x7b = Model(
|
||||
name = "mixtral-8x7b",
|
||||
base_provider = "Mistral",
|
||||
best_provider = IterListProvider([DDG, Jmuz])
|
||||
best_provider = Jmuz
|
||||
)
|
||||
mixtral_8x22b = Model(
|
||||
name = "mixtral-8x22b",
|
||||
@@ -300,7 +300,7 @@ wizardlm_2_8x22b = Model(
|
||||
### Google DeepMind ###
|
||||
# gemini
|
||||
gemini = Model(
|
||||
name = 'gemini',
|
||||
name = 'gemini-2.0',
|
||||
base_provider = 'Google',
|
||||
best_provider = Gemini
|
||||
)
|
||||
@@ -316,13 +316,13 @@ gemini_exp = Model(
|
||||
gemini_1_5_flash = Model(
|
||||
name = 'gemini-1.5-flash',
|
||||
base_provider = 'Google DeepMind',
|
||||
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
|
||||
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
|
||||
)
|
||||
|
||||
gemini_1_5_pro = Model(
|
||||
name = 'gemini-1.5-pro',
|
||||
base_provider = 'Google DeepMind',
|
||||
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
|
||||
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
|
||||
)
|
||||
|
||||
# gemini-2.0
|
||||
@@ -713,6 +713,7 @@ class ModelUtils:
|
||||
|
||||
### Google ###
|
||||
### Gemini
|
||||
"gemini": gemini,
|
||||
gemini.name: gemini,
|
||||
gemini_exp.name: gemini_exp,
|
||||
gemini_1_5_pro.name: gemini_1_5_pro,
|
||||
@@ -812,7 +813,6 @@ class ModelUtils:
|
||||
|
||||
|
||||
demo_models = {
|
||||
gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]],
|
||||
"default": [llama_3_2_11b, [HuggingFace]],
|
||||
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
|
||||
qvq_72b.name: [qvq_72b, [HuggingSpace]],
|
||||
|
@@ -409,6 +409,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
def get_cache_file(cls) -> Path:
|
||||
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
|
||||
@classmethod
|
||||
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
||||
if auth_result is not None:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
elif cache_file.exists():
|
||||
cache_file.unlink()
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
@@ -416,35 +424,25 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
auth_result = AuthResult()
|
||||
auth_result: AuthResult = None
|
||||
cache_file = cls.get_cache_file()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
for chunk in auth_result:
|
||||
if hasattr(chunk, "get_dict"):
|
||||
auth_result = chunk
|
||||
else:
|
||||
yield chunk
|
||||
raise MissingAuthError
|
||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
except (MissingAuthError, NoValidHarFileError):
|
||||
auth_result = cls.on_auth(**kwargs)
|
||||
for chunk in auth_result:
|
||||
if hasattr(chunk, "get_dict"):
|
||||
if isinstance(chunk, AuthResult):
|
||||
auth_result = chunk
|
||||
else:
|
||||
yield chunk
|
||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
finally:
|
||||
if hasattr(auth_result, "get_dict"):
|
||||
data = auth_result.get_dict()
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(data))
|
||||
elif cache_file.exists():
|
||||
cache_file.unlink()
|
||||
cls.write_cache_file(cache_file, auth_result)
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@@ -453,19 +451,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
auth_result: AuthResult = None
|
||||
cache_file = cls.get_cache_file()
|
||||
try:
|
||||
auth_result = AuthResult()
|
||||
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
auth_result = AuthResult(**json.load(f))
|
||||
else:
|
||||
auth_result = cls.on_auth_async(**kwargs)
|
||||
async for chunk in auth_result:
|
||||
if hasattr(chunk, "get_dict"):
|
||||
auth_result = chunk
|
||||
else:
|
||||
yield chunk
|
||||
raise MissingAuthError
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
@@ -474,16 +467,16 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
|
||||
cache_file.unlink()
|
||||
auth_result = cls.on_auth_async(**kwargs)
|
||||
async for chunk in auth_result:
|
||||
if hasattr(chunk, "get_dict"):
|
||||
if isinstance(chunk, AuthResult):
|
||||
auth_result = chunk
|
||||
else:
|
||||
yield chunk
|
||||
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
|
||||
async for chunk in response:
|
||||
if cache_file is not None:
|
||||
cls.write_cache_file(cache_file, auth_result)
|
||||
cache_file = None
|
||||
yield chunk
|
||||
finally:
|
||||
if hasattr(auth_result, "get_dict"):
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps(auth_result.get_dict()))
|
||||
elif cache_file.exists():
|
||||
cache_file.unlink()
|
||||
if cache_file is not None:
|
||||
cls.write_cache_file(cache_file, auth_result)
|
@@ -19,7 +19,6 @@ def quote_url(url: str) -> str:
|
||||
|
||||
def quote_title(title: str) -> str:
|
||||
if title:
|
||||
title = title.strip()
|
||||
title = " ".join(title.split())
|
||||
return title.replace('[', '').replace(']', '')
|
||||
return ""
|
||||
@@ -154,6 +153,7 @@ class Sources(ResponseType):
|
||||
self.add_source(source)
|
||||
|
||||
def add_source(self, source: dict[str, str]):
|
||||
source = source if isinstance(source, dict) else {"url": source}
|
||||
url = source.get("url", source.get("link", None))
|
||||
if url is not None:
|
||||
url = re.sub(r"[&?]utm_source=.+", "", url)
|
||||
|
@@ -65,7 +65,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
@@ -105,7 +105,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
debug.error(name=f"{provider.__name__} {type(e).__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
|
@@ -2,12 +2,12 @@ from __future__ import annotations
|
||||
|
||||
from curl_cffi.requests import AsyncSession, Response
|
||||
try:
|
||||
from curl_cffi.requests import CurlMime
|
||||
from curl_cffi import CurlMime
|
||||
has_curl_mime = True
|
||||
except ImportError:
|
||||
has_curl_mime = False
|
||||
try:
|
||||
from curl_cffi.requests import CurlWsFlag
|
||||
from curl_cffi import CurlWsFlag
|
||||
has_curl_ws = True
|
||||
except ImportError:
|
||||
has_curl_ws = False
|
||||
@@ -73,7 +73,7 @@ class StreamSession(AsyncSession):
|
||||
def request(
|
||||
self, method: str, url: str, ssl = None, **kwargs
|
||||
) -> StreamResponse:
|
||||
if isinstance(kwargs.get("data"), CurlMime):
|
||||
if kwargs.get("data") and isinstance(kwargs.get("data"), CurlMime):
|
||||
kwargs["multipart"] = kwargs.pop("data")
|
||||
"""Create and return a StreamResponse object for the given HTTP request."""
|
||||
return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))
|
||||
@@ -100,12 +100,12 @@ if has_curl_mime:
|
||||
else:
|
||||
class FormData():
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]")
|
||||
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U curl_cffi")
|
||||
|
||||
class WebSocket():
|
||||
def __init__(self, session, url, **kwargs) -> None:
|
||||
if not has_curl_ws:
|
||||
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]")
|
||||
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U curl_cffi")
|
||||
self.session: StreamSession = session
|
||||
self.url: str = url
|
||||
del kwargs["autoping"]
|
||||
@@ -116,11 +116,13 @@ class WebSocket():
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.inner.aclose()
|
||||
await self.inner.aclose() if hasattr(self.inner, "aclose") else await self.inner.close()
|
||||
|
||||
async def receive_str(self, **kwargs) -> str:
|
||||
bytes, _ = await self.inner.arecv()
|
||||
method = self.inner.arecv if hasattr(self.inner, "arecv") else self.inner.recv
|
||||
bytes, _ = await method()
|
||||
return bytes.decode(errors="ignore")
|
||||
|
||||
async def send_str(self, data: str):
|
||||
await self.inner.asend(data.encode(), CurlWsFlag.TEXT)
|
||||
method = self.inner.asend if hasattr(self.inner, "asend") else self.inner.send
|
||||
await method(data.encode(), CurlWsFlag.TEXT)
|
71
g4f/tools/pydantic_ai.py
Normal file
71
g4f/tools/pydantic_ai.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pydantic_ai.models import Model, KnownModelName, infer_model
|
||||
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
|
||||
|
||||
from ..client import AsyncClient
|
||||
|
||||
@dataclass(init=False)
|
||||
class AIModel(OpenAIModel):
|
||||
"""A model that uses the G4F API."""
|
||||
|
||||
client: AsyncClient = field(repr=False)
|
||||
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
||||
|
||||
_model_name: str = field(repr=False)
|
||||
_provider: str = field(repr=False)
|
||||
_system: Optional[str] = field(repr=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str | None = None,
|
||||
*,
|
||||
system_prompt_role: OpenAISystemPromptRole | None = None,
|
||||
system: str | None = 'openai',
|
||||
**kwargs
|
||||
):
|
||||
"""Initialize an AI model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the AI model to use. List of model names available
|
||||
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
||||
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
||||
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
||||
In the future, this may be inferred from the model name.
|
||||
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
||||
customize the `base_url` and `api_key` to use a different provider.
|
||||
"""
|
||||
self._model_name = model_name
|
||||
self._provider = provider
|
||||
self.client = AsyncClient(provider=provider, **kwargs)
|
||||
self.system_prompt_role = system_prompt_role
|
||||
self._system = system
|
||||
|
||||
def name(self) -> str:
|
||||
if self._provider:
|
||||
return f'g4f:{self._provider}:{self._model_name}'
|
||||
return f'g4f:{self._model_name}'
|
||||
|
||||
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
|
||||
if isinstance(model, Model):
|
||||
return model
|
||||
if model.startswith("g4f:"):
|
||||
model = model[4:]
|
||||
if ":" in model:
|
||||
provider, model = model.split(":", 1)
|
||||
return AIModel(model, provider=provider, api_key=api_key)
|
||||
return AIModel(model)
|
||||
return infer_model(model)
|
||||
|
||||
def apply_patch(api_key: str | None = None):
|
||||
import pydantic_ai.models
|
||||
import pydantic_ai.models.openai
|
||||
|
||||
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
|
||||
pydantic_ai.models.AIModel = AIModel
|
||||
pydantic_ai.models.openai.NOT_GIVEN = None
|
@@ -44,7 +44,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
|
||||
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
||||
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
|
||||
except Exception as e:
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
# Keep web_search in kwargs for provider native support
|
||||
pass
|
||||
|
||||
@@ -82,6 +82,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
|
||||
has_bucket = True
|
||||
message["content"] = new_message_content
|
||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||
if "\nSource: " in messages[-1]["content"]:
|
||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||
create_function = provider.get_async_create_function()
|
||||
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
|
||||
@@ -149,7 +150,7 @@ def iter_run_tools(
|
||||
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
|
||||
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
|
||||
except Exception as e:
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
# Keep web_search in kwargs for provider native support
|
||||
pass
|
||||
|
||||
@@ -192,7 +193,8 @@ def iter_run_tools(
|
||||
has_bucket = True
|
||||
message["content"] = new_message_content
|
||||
if has_bucket and isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] += BUCKET_INSTRUCTIONS
|
||||
if "\nSource: " in messages[-1]["content"]:
|
||||
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
|
||||
|
||||
thinking_start_time = 0
|
||||
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
|
||||
|
@@ -237,7 +237,7 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
|
||||
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
|
||||
if raise_search_exceptions:
|
||||
raise e
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
return prompt
|
||||
|
||||
def spacy_get_keywords(text: str):
|
||||
|
@@ -44,8 +44,9 @@ def get_github_version(repo: str) -> str:
|
||||
VersionNotFoundError: If there is an error in fetching the version from GitHub.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
|
||||
return response["tag_name"]
|
||||
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
|
||||
response.raise_for_status()
|
||||
return response.json()["tag_name"]
|
||||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||
|
||||
|
Reference in New Issue
Block a user