Improve tools support in OpenaiTemplate and GeminiPro

Update models in DDG, PerplexityLabs, Gemini
Fix issues with curl_cffi in new versions
This commit is contained in:
hlohaus
2025-02-21 04:36:54 +01:00
parent c3ed6d0f8f
commit e53483d85b
33 changed files with 300 additions and 172 deletions

1
.gitignore vendored
View File

@@ -66,3 +66,4 @@ bench.py
to-reverse.txt
g4f/Provider/OpenaiChat2.py
generated_images/
projects/windows/

View File

@@ -1,5 +1,9 @@
import unittest
import g4f.debug
g4f.debug.version_check = False
from .asyncio import *
from .backend import *
from .main import *

View File

@@ -6,8 +6,12 @@ from g4f.errors import VersionNotFoundError
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestGetLastProvider(unittest.TestCase):
def test_get_latest_version(self):
current_version = g4f.version.utils.current_version
if current_version is not None:
self.assertIsInstance(g4f.version.utils.current_version, str)
try:
self.assertIsInstance(g4f.version.utils.latest_version, str)
except VersionNotFoundError:
pass

View File

@@ -42,7 +42,7 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
"gpt-4": "gpt-4o-mini",
"claude-3-haiku": "claude-3-haiku-20240307",
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"mixtral-8x7b": "mistralai/Mistral-Small-24B-Instruct-2501",
}
last_request_time = 0

View File

@@ -5,7 +5,8 @@ import json
from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status
from ..providers.response import FinishReason
from ..errors import ResponseError
from ..providers.response import FinishReason, Sources
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
API_URL = "https://www.perplexity.ai/socket.io/"
@@ -15,10 +16,11 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://labs.perplexity.ai"
working = True
default_model = "sonar-pro"
default_model = "r1-1776"
models = [
"sonar",
default_model,
"sonar-pro",
"sonar",
"sonar-reasoning",
"sonar-reasoning-pro",
]
@@ -32,19 +34,10 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
headers = {
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0",
"Accept": "*/*",
"Accept-Language": "de,en-US;q=0.7,en;q=0.3",
"Accept-Encoding": "gzip, deflate, br",
"Origin": cls.url,
"Connection": "keep-alive",
"Referer": f"{cls.url}/",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "same-site",
"TE": "trailers",
}
async with StreamSession(headers=headers, proxies={"all": proxy}) as session:
async with StreamSession(headers=headers, proxy=proxy, impersonate="chrome") as session:
t = format(random.getrandbits(32), "08x")
async with session.get(
f"{API_URL}?EIO=4&transport=polling&t={t}"
@@ -60,17 +53,22 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
) as response:
await raise_for_status(response)
assert await response.text() == "OK"
async with session.get(
f"{API_URL}?EIO=4&transport=polling&t={t}&sid={sid}",
data=post_data
) as response:
await raise_for_status(response)
assert (await response.text()).startswith("40")
async with session.ws_connect(f"{WS_URL}?EIO=4&transport=websocket&sid={sid}", autoping=False) as ws:
await ws.send_str("2probe")
assert(await ws.receive_str() == "3probe")
await ws.send_str("5")
assert(await ws.receive_str())
assert(await ws.receive_str() == "6")
message_data = {
"version": "2.16",
"version": "2.18",
"source": "default",
"model": model,
"messages": messages
"messages": messages,
}
await ws.send_str("42" + json.dumps(["perplexity_labs", message_data]))
last_message = 0
@@ -82,12 +80,15 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
await ws.send_str("3")
continue
try:
if last_message == 0 and model == cls.default_model:
yield "<think>"
data = json.loads(message[2:])[1]
yield data["output"][last_message:]
last_message = len(data["output"])
if data["final"]:
if data["citations"]:
yield Sources(data["citations"])
yield FinishReason("stop")
break
except Exception as e:
print(f"Error processing message: {message} - {e}")
raise RuntimeError(f"Message: {message}") from e
raise ResponseError(f"Message: {message}") from e

View File

@@ -124,7 +124,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
raise
if not cache and seed is None:
seed = random.randint(0, 10000)
seed = random.randint(1000, 999999)
if model in cls.image_models:
async for chunk in cls._generate_image(
@@ -182,7 +182,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
}
params = {k: v for k, v in params.items() if v is not None}
query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
prefix = f"{model}_{seed}" if seed is not None else model
url = f"{cls.image_api_endpoint}prompt/{prefix}_{quote_plus(prompt)}?{query}"
yield ImagePreview(url, prompt)
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:

View File

@@ -39,14 +39,15 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
default_model = default_model
model_aliases = model_aliases
image_models = image_models
text_models = fallback_models
@classmethod
def get_models(cls):
if not cls.models:
try:
text = requests.get(cls.url).text
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
text = re.sub(r',parameters:{[^}]+?}', '', text)
text = text.replace('void 0', 'null')
def add_quotation_mark(match):
return f'{match.group(1)}"{match.group(2)}":'
@@ -56,7 +57,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
debug.error(f"{cls.__name__}: Error reading models: {type(e).__name__}: {e}")
cls.models = [*fallback_models]
return cls.models

View File

@@ -10,8 +10,8 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_p
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse
from ..helper import format_image_prompt
from .models import default_model, default_image_model, model_aliases, fallback_models
from ..helper import format_image_prompt, get_last_user_message
from .models import default_model, default_image_model, model_aliases, fallback_models, image_models
from ... import debug
class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
@@ -22,6 +22,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
default_model = default_model
default_image_model = default_image_model
model_aliases = model_aliases
image_models = image_models
@classmethod
def get_models(cls) -> list[str]:
@@ -29,14 +30,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
models = fallback_models.copy()
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
response = requests.get(url)
response.raise_for_status()
if response.ok:
extra_models = [model["id"] for model in response.json()]
extra_models.sort()
models.extend([model for model in extra_models if model not in models])
if not cls.image_models:
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
response = requests.get(url)
response.raise_for_status()
if response.ok:
cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
cls.image_models.sort()
models.extend([model for model in cls.image_models if model not in models])
@@ -57,6 +57,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None,
action: str = None,
extra_data: dict = {},
seed: int = None,
**kwargs
) -> AsyncResult:
try:
@@ -104,7 +105,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
if pipeline_tag == "text-to-image":
stream = False
inputs = format_image_prompt(messages, prompt)
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **extra_data}}
elif pipeline_tag in ("text-generation", "image-text-to-text"):
model_type = None
if "config" in model_data and "model_type" in model_data["config"]:
@@ -116,11 +117,13 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
else:
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
messages = [m for m in messages if m["role"] == "system"] + [get_last_user_message(messages)]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512
if seed is not None:
params["seed"] = seed
payload = {"inputs": inputs, "parameters": params, "stream": stream}
else:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")

View File

@@ -48,7 +48,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
except Exception as e:
if is_started:
raise e
debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
debug.error(f"{cls.__name__} {type(e).__name__}; {e}")
if not cls.image_models:
cls.get_models()
if model in cls.image_models:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import uuid
import re
import time
import random
from datetime import datetime, timezone, timedelta
import urllib.parse
@@ -88,7 +88,7 @@ class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
prompt = format_image_prompt(messages, prompt)
if seed is None:
seed = int(time.time())
seed = random.randint(1000, 999999)
session_hash = generate_session_hash() if conversation is None else getattr(conversation, "session_hash")
async with StreamSession(proxy=proxy, impersonate="chrome") as session:

View File

@@ -32,9 +32,7 @@ class CopilotAccount(AsyncAuthedProvider, Copilot):
except NoValidHarFileError as h:
debug.log(f"Copilot: {h}")
if has_nodriver:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield RequestLogin(cls.label, login_url)
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
Copilot._access_token, Copilot._cookies = await get_access_token_and_cookies(cls.url, proxy)
else:
raise h

View File

@@ -65,7 +65,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
default_image_model = default_model
default_vision_model = default_model
image_models = [default_image_model]
models = [default_model, "gemini-1.5-flash", "gemini-1.5-pro"]
models = [default_model, "gemini-2.0"]
synthesize_content_type = "audio/vnd.wav"
@@ -179,7 +179,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
content = response_part[4][0][1][0]
except (ValueError, KeyError, TypeError, IndexError) as e:
debug.log(f"{cls.__name__}:{e.__class__.__name__}:{e}")
debug.error(f"{cls.__name__} {type(e).__name__}: {e}")
continue
match = re.search(r'\[Imagen of (.*?)\]', content)
if match:

View File

@@ -51,7 +51,9 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
]
cls.models.sort()
except Exception as e:
debug.log(e)
debug.error(e)
if api_key is not None:
raise MissingAuthError("Invalid API key")
return cls.fallback_models
return cls.models
@@ -111,7 +113,17 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
"topK": kwargs.get("top_k"),
},
"tools": [{
"functionDeclarations": tools
"function_declarations": [{
"name": tool["function"]["name"],
"description": tool["function"]["description"],
"parameters": {
"type": "object",
"properties": {key: {
"type": value["type"],
"description": value["title"]
} for key, value in tool["function"]["parameters"]["properties"].items()}
},
} for tool in tools]
}] if tools else None
}
system_prompt = "\n".join(

View File

@@ -322,8 +322,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
try:
image_requests = await cls.upload_images(session, auth_result, images) if images else None
except Exception as e:
debug.log("OpenaiChat: Upload image failed")
debug.log(f"{e.__class__.__name__}: {e}")
debug.error("OpenaiChat: Upload image failed")
debug.error(e)
model = cls.get_model(model)
if conversation is None:
conversation = Conversation(conversation_id, str(uuid.uuid4()), getattr(auth_result, "cookies", {}).get("oai-did"))
@@ -360,12 +360,14 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
# if auth_result.arkose_token is None:
# raise MissingAuthError("No arkose token found in .har file")
if "proofofwork" in chat_requirements:
if getattr(auth_result, "proof_token", None) is None:
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
user_agent = getattr(auth_result, "headers", {}).get("user-agent")
proof_token = getattr(auth_result, "proof_token", None)
if proof_token is None:
auth_result.proof_token = get_config(user_agent)
proofofwork = generate_proof_token(
**chat_requirements["proofofwork"],
user_agent=getattr(auth_result, "headers", {}).get("user-agent"),
proof_token=getattr(auth_result, "proof_token", None)
user_agent=user_agent,
proof_token=proof_token
)
[debug.log(text) for text in (
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
@@ -425,8 +427,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
) as response:
cls._update_request_args(auth_result, session)
if response.status == 403:
auth_result.proof_token = None
cls.request_config.proof_token = None
raise MissingAuthError("Access token is not valid")
await raise_for_status(response)
buffer = u""
async for line in response.iter_lines():
@@ -469,14 +471,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await asyncio.sleep(5)
else:
break
yield Parameters(**{
"action": "continue" if conversation.finish_reason == "max_tokens" else "variant",
"conversation": conversation.get_dict(),
"proof_token": cls.request_config.proof_token,
"cookies": cls._cookies,
"headers": cls._headers,
"web_search": web_search,
})
yield FinishReason(conversation.finish_reason)
@classmethod

View File

@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
if cls.sort_models:
cls.models.sort()
except Exception as e:
debug.log(e)
debug.error(e)
return cls.fallback_models
return cls.models
@@ -65,7 +65,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
prompt: str = None,
headers: dict = None,
impersonate: str = None,
tools: Optional[list] = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "", "reasoning_effort", "logit_bias"],
extra_data: dict = {},
**kwargs
) -> AsyncResult:
@@ -112,6 +112,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
}
]
messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
data = filter_none(
messages=messages,
model=model,
@@ -120,7 +121,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
top_p=top_p,
stop=stop,
stream=stream,
tools=tools,
**extra_parameters,
**extra_data
)
if api_endpoint is None:

View File

@@ -588,7 +588,7 @@ class Api:
target=target)
debug.log(f"Image copied from {source_url}")
except Exception as e:
debug.log(f"{type(e).__name__}: Download failed: {source_url}\n{e}")
debug.error(f"Download failed: {source_url}\n{type(e).__name__}: {e}")
return RedirectResponse(url=source_url)
if not os.path.isfile(target):
return ErrorResponse.from_message("File not found", HTTP_404_NOT_FOUND)

View File

@@ -18,7 +18,7 @@ from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import to_sync_generator
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..tools.run_tools import async_iter_run_tools, iter_run_tools
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
@@ -103,14 +103,14 @@ def iter_response(
idx += 1
if usage is None:
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
usage = Usage(completion_tokens=idx, total_tokens=idx)
finish_reason = "stop" if finish_reason is None else finish_reason
if stream:
yield ChatCompletionChunk.model_construct(
None, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict()
usage=usage
)
else:
if response_format is not None and "type" in response_format:
@@ -118,7 +118,8 @@ def iter_response(
content = filter_json(content)
yield ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
usage=UsageModel.model_construct(**usage.get_dict()),
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
)
# Synchronous iter_append_model_and_provider function
@@ -186,7 +187,7 @@ async def async_iter_response(
finish_reason = "stop" if finish_reason is None else finish_reason
if usage is None:
usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
usage = Usage(completion_tokens=idx, total_tokens=idx)
if stream:
yield ChatCompletionChunk.model_construct(
@@ -199,7 +200,8 @@ async def async_iter_response(
content = filter_json(content)
yield ChatCompletion.model_construct(
content, finish_reason, completion_id, int(time.time()),
usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
usage=UsageModel.model_construct(**usage.get_dict()),
**filter_none(tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]) if tool_calls is not None else {}
)
finally:
await safe_aclose(response)
@@ -363,7 +365,7 @@ class Images:
break
except Exception as e:
error = e
debug.log(f"Image provider {provider.__name__}: {e}")
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
@@ -458,7 +460,7 @@ class Images:
break
except Exception as e:
error = e
debug.log(f"Image provider {provider.__name__}: {e}")
debug.error(e, name=f"{provider.__name__} {type(e).__name__}")
else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
@@ -583,7 +585,7 @@ class AsyncCompletions:
messages: Messages,
model: str,
**kwargs
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
) -> AsyncIterator[ChatCompletionChunk]:
return self.create(messages, model, stream=True, **kwargs)
class AsyncImages(Images):

View File

@@ -5,9 +5,6 @@ from time import time
from .helper import filter_none
ToolCalls = Optional[List[Dict[str, Any]]]
Usage = Optional[Dict[str, int]]
try:
from pydantic import BaseModel, Field
except ImportError:
@@ -29,6 +26,40 @@ class BaseModel(BaseModel):
return super().model_construct(**data)
return cls.construct(**data)
class UsageModel(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]]
completion_tokens_details: Optional[Dict[str, Any]]
@classmethod
def model_construct(cls, prompt_tokens=0, completion_tokens=0, total_tokens=0, prompt_tokens_details=None, completion_tokens_details=None, **kwargs):
return super().model_construct(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
**kwargs
)
class ToolFunctionModel(BaseModel):
name: str
arguments: str
class ToolCallModel(BaseModel):
id: str
type: str
function: ToolFunctionModel
@classmethod
def model_construct(cls, function=None, **kwargs):
return super().model_construct(
**kwargs,
function=ToolFunctionModel.model_construct(**function),
)
class ChatCompletionChunk(BaseModel):
id: str
object: str
@@ -36,7 +67,7 @@ class ChatCompletionChunk(BaseModel):
model: str
provider: Optional[str]
choices: List[ChatCompletionDeltaChoice]
usage: Usage
usage: UsageModel
@classmethod
def model_construct(
@@ -45,7 +76,7 @@ class ChatCompletionChunk(BaseModel):
finish_reason: str,
completion_id: str = None,
created: int = None,
usage: Usage = None
usage: UsageModel = None
):
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,
@@ -63,10 +94,10 @@ class ChatCompletionChunk(BaseModel):
class ChatCompletionMessage(BaseModel):
role: str
content: str
tool_calls: ToolCalls
tool_calls: list[ToolCallModel] = None
@classmethod
def model_construct(cls, content: str, tool_calls: ToolCalls = None):
def model_construct(cls, content: str, tool_calls: list = None):
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
class ChatCompletionChoice(BaseModel):
@@ -85,11 +116,7 @@ class ChatCompletion(BaseModel):
model: str
provider: Optional[str]
choices: List[ChatCompletionChoice]
usage: Usage = Field(default={
"prompt_tokens": 0, #prompt_tokens,
"completion_tokens": 0, #completion_tokens,
"total_tokens": 0, #prompt_tokens + completion_tokens,
})
usage: UsageModel
@classmethod
def model_construct(
@@ -98,8 +125,8 @@ class ChatCompletion(BaseModel):
finish_reason: str,
completion_id: str = None,
created: int = None,
tool_calls: ToolCalls = None,
usage: Usage = None
tool_calls: list[ToolCallModel] = None,
usage: UsageModel = None
):
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,

View File

@@ -1,3 +1,4 @@
import sys
from .providers.types import ProviderType
logging: bool = False
@@ -8,6 +9,9 @@ version: str = None
log_handler: callable = print
logs: list = []
def log(text):
def log(text, file = None):
if logging:
log_handler(text)
log_handler(text, file=file)
def error(error, name: str = None):
log(error if isinstance(error, str) else f"{type(error).__name__ if name is None else name}: {error}", file=sys.stderr)

View File

@@ -181,7 +181,12 @@
<script>
(async () => {
const isIframe = window.self !== window.top;
const backendUrl = "{{backend_url}}";
let url = new URL(window.location.href)
if (isIframe && backendUrl) {
window.location.replace(url.search ? `${backendUrl}?${url.search}` : backendUrl);
return;
}
let params = new URLSearchParams(url.search);
if (params.get("__sign")) {
localStorage.setItem("zerogpu_token", params.get("__sign"));
@@ -232,12 +237,11 @@
import * as hub from "@huggingface/hub";
import { init } from "@huggingface/space-header";
const isIframe = window.self !== window.top;
const button = document.querySelector('form a.button');
if (isIframe) {
button.classList.remove('hidden');
} else {
init("roxky/g4f-space");
init("roxky/g4f-space-new");
}
const form = document.querySelector("form");
@@ -282,13 +286,11 @@
const cache_id = Math.floor(Math.random() * max);
let prompt;
if (cache_id % 2 == 0) {
prompt = `
Today is ${new Date().toJSON().slice(0, 10)}.
prompt = `Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date).
Avoid using any text.`;
} else {
prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
}
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
const text = await response.text()

View File

@@ -293,7 +293,7 @@
</div>
<div class="field">
<select name="model" id="model">
<option value="">Model: Default</option>
<option value="" selected="selected">Model: Default</option>
<option value="gpt-4">gpt-4</option>
<option value="gpt-4o">gpt-4o</option>
<option value="gpt-4o-mini">gpt-4o-mini</option>

View File

@@ -72,13 +72,16 @@ class Api:
@staticmethod
def get_version() -> dict:
current_version = None
latest_version = None
try:
current_version = version.utils.current_version
latest_version = version.utils.latest_version
except VersionNotFoundError:
current_version = None
pass
return {
"version": current_version,
"latest_version": version.utils.latest_version,
"latest_version": latest_version,
}
def serve_images(self, name):
@@ -137,10 +140,10 @@ class Api:
}
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
def decorated_log(text: str):
def decorated_log(text: str, file = None):
debug.logs.append(text)
if debug.logging:
debug.log_handler(text)
debug.log_handler(text, file)
debug.log = decorated_log
proxy = os.environ.get("G4F_PROXY")
provider = kwargs.get("provider")
@@ -154,6 +157,7 @@ class Api:
)
except Exception as e:
logger.exception(e)
debug.error(e)
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
return
if not isinstance(provider_handler, BaseRetryProvider):
@@ -183,6 +187,7 @@ class Api:
yield self._format_json("conversation_id", conversation_id)
elif isinstance(chunk, Exception):
logger.exception(chunk)
debug.error(e)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, PreviewResponse):
yield self._format_json("preview", chunk.to_string())
@@ -215,19 +220,18 @@ class Api:
yield self._format_json(chunk.type, **chunk.get_dict())
else:
yield self._format_json("content", str(chunk))
if debug.logs:
for log in debug.logs:
yield self._format_json("log", str(log))
debug.logs = []
yield from self._yield_logs()
except Exception as e:
logger.exception(e)
if debug.logging:
debug.log_handler(get_error_message(e))
debug.error(e)
yield from self._yield_logs()
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _yield_logs(self):
if debug.logs:
for log in debug.logs:
yield self._format_json("log", str(log))
yield self._format_json("log", log)
debug.logs = []
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
def _format_json(self, response_type: str, content = None, **kwargs):
if content is not None and isinstance(response_type, str):

View File

@@ -76,7 +76,7 @@ class Backend_Api(Api):
@app.route('/', methods=['GET'])
@limiter.exempt
def home():
return render_template('demo.html')
return render_template('demo.html', backend_url=os.environ.get("G4F_BACKEND_URL", ""))
else:
@app.route('/', methods=['GET'])
def home():

View File

@@ -123,7 +123,7 @@ async def copy_images(
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
except (ClientError, IOError, OSError) as e:
debug.log(f"Image processing failed: {e.__class__.__name__}: {e}")
debug.error(f"Image processing failed: {type(e).__name__}: {e}")
if target_path and os.path.exists(target_path):
os.unlink(target_path)
return get_source_url(image, image)

View File

@@ -243,7 +243,7 @@ llama_3_3_70b = Model(
mixtral_8x7b = Model(
name = "mixtral-8x7b",
base_provider = "Mistral",
best_provider = IterListProvider([DDG, Jmuz])
best_provider = Jmuz
)
mixtral_8x22b = Model(
name = "mixtral-8x22b",
@@ -300,7 +300,7 @@ wizardlm_2_8x22b = Model(
### Google DeepMind ###
# gemini
gemini = Model(
name = 'gemini',
name = 'gemini-2.0',
base_provider = 'Google',
best_provider = Gemini
)
@@ -316,13 +316,13 @@ gemini_exp = Model(
gemini_1_5_flash = Model(
name = 'gemini-1.5-flash',
base_provider = 'Google DeepMind',
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
)
gemini_1_5_pro = Model(
name = 'gemini-1.5-pro',
base_provider = 'Google DeepMind',
best_provider = IterListProvider([Blackbox, Jmuz, Gemini, GeminiPro, Liaobots])
best_provider = IterListProvider([Blackbox, Jmuz, GeminiPro, Liaobots])
)
# gemini-2.0
@@ -713,6 +713,7 @@ class ModelUtils:
### Google ###
### Gemini
"gemini": gemini,
gemini.name: gemini,
gemini_exp.name: gemini_exp,
gemini_1_5_pro.name: gemini_1_5_pro,
@@ -812,7 +813,6 @@ class ModelUtils:
demo_models = {
gpt_4o.name: [gpt_4o, [PollinationsAI, Blackbox]],
"default": [llama_3_2_11b, [HuggingFace]],
qwen_2_vl_7b.name: [qwen_2_vl_7b, [HuggingFaceAPI]],
qvq_72b.name: [qvq_72b, [HuggingSpace]],

View File

@@ -409,6 +409,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
def get_cache_file(cls) -> Path:
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@classmethod
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
elif cache_file.exists():
cache_file.unlink()
@classmethod
def create_completion(
cls,
@@ -416,35 +424,25 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = AuthResult()
auth_result: AuthResult = None
cache_file = cls.get_cache_file()
try:
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = cls.on_auth(**kwargs)
for chunk in auth_result:
if hasattr(chunk, "get_dict"):
auth_result = chunk
else:
yield chunk
raise MissingAuthError
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError):
auth_result = cls.on_auth(**kwargs)
for chunk in auth_result:
if hasattr(chunk, "get_dict"):
if isinstance(chunk, AuthResult):
auth_result = chunk
else:
yield chunk
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
finally:
if hasattr(auth_result, "get_dict"):
data = auth_result.get_dict()
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(data))
elif cache_file.exists():
cache_file.unlink()
cls.write_cache_file(cache_file, auth_result)
@classmethod
async def create_async_generator(
@@ -453,19 +451,14 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
messages: Messages,
**kwargs
) -> AsyncResult:
auth_result: AuthResult = None
cache_file = cls.get_cache_file()
try:
auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f))
else:
auth_result = cls.on_auth_async(**kwargs)
async for chunk in auth_result:
if hasattr(chunk, "get_dict"):
auth_result = chunk
else:
yield chunk
raise MissingAuthError
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
yield chunk
@@ -474,16 +467,16 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
cache_file.unlink()
auth_result = cls.on_auth_async(**kwargs)
async for chunk in auth_result:
if hasattr(chunk, "get_dict"):
if isinstance(chunk, AuthResult):
auth_result = chunk
else:
yield chunk
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response:
if cache_file is not None:
cls.write_cache_file(cache_file, auth_result)
cache_file = None
yield chunk
finally:
if hasattr(auth_result, "get_dict"):
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
elif cache_file.exists():
cache_file.unlink()
if cache_file is not None:
cls.write_cache_file(cache_file, auth_result)

View File

@@ -19,7 +19,6 @@ def quote_url(url: str) -> str:
def quote_title(title: str) -> str:
if title:
title = title.strip()
title = " ".join(title.split())
return title.replace('[', '').replace(']', '')
return ""
@@ -154,6 +153,7 @@ class Sources(ResponseType):
self.add_source(source)
def add_source(self, source: dict[str, str]):
source = source if isinstance(source, dict) else {"url": source}
url = source.get("url", source.get("link", None))
if url is not None:
url = re.sub(r"[&?]utm_source=.+", "", url)

View File

@@ -65,7 +65,7 @@ class IterListProvider(BaseRetryProvider):
return
except Exception as e:
exceptions[provider.__name__] = e
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
if started:
raise e
yield e
@@ -105,7 +105,7 @@ class IterListProvider(BaseRetryProvider):
return
except Exception as e:
exceptions[provider.__name__] = e
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
debug.error(name=f"{provider.__name__} {type(e).__name__}: {e}")
if started:
raise e
yield e

View File

@@ -2,12 +2,12 @@ from __future__ import annotations
from curl_cffi.requests import AsyncSession, Response
try:
from curl_cffi.requests import CurlMime
from curl_cffi import CurlMime
has_curl_mime = True
except ImportError:
has_curl_mime = False
try:
from curl_cffi.requests import CurlWsFlag
from curl_cffi import CurlWsFlag
has_curl_ws = True
except ImportError:
has_curl_ws = False
@@ -73,7 +73,7 @@ class StreamSession(AsyncSession):
def request(
self, method: str, url: str, ssl = None, **kwargs
) -> StreamResponse:
if isinstance(kwargs.get("data"), CurlMime):
if kwargs.get("data") and isinstance(kwargs.get("data"), CurlMime):
kwargs["multipart"] = kwargs.pop("data")
"""Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))
@@ -100,12 +100,12 @@ if has_curl_mime:
else:
class FormData():
def __init__(self) -> None:
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]")
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U curl_cffi")
class WebSocket():
def __init__(self, session, url, **kwargs) -> None:
if not has_curl_ws:
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]")
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U curl_cffi")
self.session: StreamSession = session
self.url: str = url
del kwargs["autoping"]
@@ -116,11 +116,13 @@ class WebSocket():
return self
async def __aexit__(self, *args):
await self.inner.aclose()
await self.inner.aclose() if hasattr(self.inner, "aclose") else await self.inner.close()
async def receive_str(self, **kwargs) -> str:
bytes, _ = await self.inner.arecv()
method = self.inner.arecv if hasattr(self.inner, "arecv") else self.inner.recv
bytes, _ = await method()
return bytes.decode(errors="ignore")
async def send_str(self, data: str):
await self.inner.asend(data.encode(), CurlWsFlag.TEXT)
method = self.inner.asend if hasattr(self.inner, "asend") else self.inner.send
await method(data.encode(), CurlWsFlag.TEXT)

71
g4f/tools/pydantic_ai.py Normal file
View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Optional
from functools import partial
from dataclasses import dataclass, field
from pydantic_ai.models import Model, KnownModelName, infer_model
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
from ..client import AsyncClient
@dataclass(init=False)
class AIModel(OpenAIModel):
"""A model that uses the G4F API."""
client: AsyncClient = field(repr=False)
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
_model_name: str = field(repr=False)
_provider: str = field(repr=False)
_system: Optional[str] = field(repr=False)
def __init__(
self,
model_name: str,
provider: str | None = None,
*,
system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai',
**kwargs
):
"""Initialize an AI model.
Args:
model_name: The name of the AI model to use. List of model names available
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
In the future, this may be inferred from the model name.
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
customize the `base_url` and `api_key` to use a different provider.
"""
self._model_name = model_name
self._provider = provider
self.client = AsyncClient(provider=provider, **kwargs)
self.system_prompt_role = system_prompt_role
self._system = system
def name(self) -> str:
if self._provider:
return f'g4f:{self._provider}:{self._model_name}'
return f'g4f:{self._model_name}'
def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
if isinstance(model, Model):
return model
if model.startswith("g4f:"):
model = model[4:]
if ":" in model:
provider, model = model.split(":", 1)
return AIModel(model, provider=provider, api_key=api_key)
return AIModel(model)
return infer_model(model)
def apply_patch(api_key: str | None = None):
import pydantic_ai.models
import pydantic_ai.models.openai
pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
pydantic_ai.models.AIModel = AIModel
pydantic_ai.models.openai.NOT_GIVEN = None

View File

@@ -44,7 +44,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support
pass
@@ -82,6 +82,7 @@ async def async_iter_run_tools(provider: ProviderType, model: str, messages, too
has_bucket = True
message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str):
if "\nSource: " in messages[-1]["content"]:
messages[-1]["content"] += BUCKET_INSTRUCTIONS
create_function = provider.get_async_create_function()
response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
@@ -149,7 +150,7 @@ def iter_run_tools(
web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Keep web_search in kwargs for provider native support
pass
@@ -192,7 +193,8 @@ def iter_run_tools(
has_bucket = True
message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS
if "\nSource: " in messages[-1]["content"]:
messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
thinking_start_time = 0
for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):

View File

@@ -237,7 +237,7 @@ def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) ->
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
if raise_search_exceptions:
raise e
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
return prompt
def spacy_get_keywords(text: str):

View File

@@ -44,8 +44,9 @@ def get_github_version(repo: str) -> str:
VersionNotFoundError: If there is an error in fetching the version from GitHub.
"""
try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
return response["tag_name"]
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
response.raise_for_status()
return response.json()["tag_name"]
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")