mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-24 13:07:53 +08:00
547 lines
22 KiB
Python
547 lines
22 KiB
Python
import os
|
|
import json
|
|
import base64
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import aiohttp
|
|
from aiohttp import ClientSession, ClientTimeout
|
|
|
|
from ...typing import AsyncResult, Messages, MediaListType
|
|
from ...errors import MissingAuthError
|
|
from ...image.copy_images import save_response_media
|
|
from ...image import to_bytes, is_data_an_media
|
|
from ...providers.response import Usage, ImageResponse, ToolCalls, Reasoning
|
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
|
|
from ..helper import get_connector, get_system_prompt, format_media_prompt
|
|
from ... import debug
|
|
|
|
def get_oauth_creds_path():
|
|
return Path.home() / ".gemini" / "oauth_creds.json"
|
|
|
|
class AuthManager(AuthFileMixin):
|
|
"""
|
|
Handles OAuth2 authentication and Google Code Assist API communication.
|
|
Manages token caching, refresh, and API calls.
|
|
|
|
Requires environment dict-like object with keys:
|
|
- GCP_SERVICE_ACCOUNT: JSON string with OAuth2 credentials, containing:
|
|
access_token, expiry_date (ms timestamp), refresh_token
|
|
- Optionally supports cache storage via a KV storage interface implementing:
|
|
get(key) -> value or None,
|
|
put(key, value, expiration_seconds),
|
|
delete(key)
|
|
"""
|
|
parent = "GeminiCLI"
|
|
|
|
OAUTH_REFRESH_URL = "https://oauth2.googleapis.com/token"
|
|
OAUTH_CLIENT_ID = "681255809395" + "-oo8ft2oprdrnp9e3aqf6av3hmdib135j" + ".apps.googleusercontent.com"
|
|
OAUTH_CLIENT_SECRET = "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"
|
|
TOKEN_BUFFER_TIME = 5 * 60 # seconds, 5 minutes
|
|
KV_TOKEN_KEY = "oauth_token_cache"
|
|
|
|
def __init__(self, env: Dict[str, Any]):
|
|
self.env = env
|
|
self._access_token: Optional[str] = None
|
|
self._expiry: Optional[float] = None # Unix timestamp in seconds
|
|
self._token_cache = {} # Example in-memory cache; replace with KV store for production
|
|
|
|
async def initialize_auth(self) -> None:
|
|
"""
|
|
Initialize authentication by using cached token, or refreshing if needed.
|
|
Raises RuntimeError if no valid token can be obtained.
|
|
"""
|
|
# Try cached token from KV store or in-memory cache
|
|
cached = await self._get_cached_token()
|
|
now = time.time()
|
|
if cached:
|
|
expires_at = cached["expiry_date"] / 1000 # ms to seconds
|
|
if expires_at - now > self.TOKEN_BUFFER_TIME:
|
|
self._access_token = cached["access_token"]
|
|
self._expiry = expires_at
|
|
return # Use cached token if valid
|
|
|
|
path = AuthManager.get_cache_file()
|
|
if not path.exists():
|
|
path = get_oauth_creds_path()
|
|
if path.exists():
|
|
try:
|
|
with path.open("r") as f:
|
|
creds = json.load(f)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to read OAuth credentials from {path}: {e}")
|
|
else:
|
|
# Parse credentials from environment
|
|
if "GCP_SERVICE_ACCOUNT" not in self.env:
|
|
raise RuntimeError("GCP_SERVICE_ACCOUNT environment variable not set.")
|
|
creds = json.loads(self.env["GCP_SERVICE_ACCOUNT"])
|
|
|
|
refresh_token = creds.get("refresh_token")
|
|
access_token = creds.get("access_token")
|
|
expiry_date = creds.get("expiry_date") # milliseconds since epoch
|
|
|
|
# Use original access token if still valid
|
|
if access_token and expiry_date:
|
|
expires_at = expiry_date / 1000
|
|
if expires_at - now > self.TOKEN_BUFFER_TIME:
|
|
self._access_token = access_token
|
|
self._expiry = expires_at
|
|
await self._cache_token(access_token, expiry_date)
|
|
return
|
|
|
|
# Otherwise, refresh token
|
|
if not refresh_token:
|
|
raise RuntimeError("No refresh token found in GCP_SERVICE_ACCOUNT.")
|
|
|
|
await self._refresh_and_cache_token(refresh_token)
|
|
|
|
async def _refresh_and_cache_token(self, refresh_token: str) -> None:
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
data = {
|
|
"client_id": self.OAUTH_CLIENT_ID,
|
|
"client_secret": self.OAUTH_CLIENT_SECRET,
|
|
"refresh_token": refresh_token,
|
|
"grant_type": "refresh_token",
|
|
}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(self.OAUTH_REFRESH_URL, data=data, headers=headers) as resp:
|
|
if resp.status != 200:
|
|
text = await resp.text()
|
|
raise RuntimeError(f"Token refresh failed: {text}")
|
|
resp_data = await resp.json()
|
|
access_token = resp_data.get("access_token")
|
|
expires_in = resp_data.get("expires_in", 3600) # seconds
|
|
|
|
if not access_token:
|
|
raise RuntimeError("No access_token in refresh response.")
|
|
|
|
self._access_token = access_token
|
|
self._expiry = time.time() + expires_in
|
|
|
|
expiry_date_ms = int(self._expiry * 1000) # milliseconds
|
|
|
|
await self._cache_token(access_token, expiry_date_ms)
|
|
|
|
async def _cache_token(self, access_token: str, expiry_date: int) -> None:
|
|
# Cache token in KV store or fallback to memory cache
|
|
token_data = {
|
|
"access_token": access_token,
|
|
"expiry_date": expiry_date,
|
|
"cached_at": int(time.time() * 1000), # ms
|
|
}
|
|
self._token_cache[self.KV_TOKEN_KEY] = token_data
|
|
|
|
async def _get_cached_token(self) -> Optional[Dict[str, Any]]:
|
|
# Return in-memory cached token if present and still valid
|
|
cached = self._token_cache.get(self.KV_TOKEN_KEY)
|
|
if cached:
|
|
expires_at = cached["expiry_date"] / 1000
|
|
if expires_at - time.time() > self.TOKEN_BUFFER_TIME:
|
|
return cached
|
|
return None
|
|
|
|
async def clear_token_cache(self) -> None:
|
|
self._access_token = None
|
|
self._expiry = None
|
|
|
|
def get_access_token(self) -> Optional[str]:
|
|
# Return current valid access token or None
|
|
if (
|
|
self._access_token is not None
|
|
and self._expiry is not None
|
|
and self._expiry - time.time() > self.TOKEN_BUFFER_TIME
|
|
):
|
|
return self._access_token
|
|
return None
|
|
|
|
async def call_endpoint(self, method: str, body: Dict[str, Any], is_retry=False) -> Any:
|
|
"""
|
|
Call Google Code Assist API endpoint with JSON body.
|
|
|
|
Automatically retries once on 401 Unauthorized by refreshing auth.
|
|
"""
|
|
if not self.get_access_token():
|
|
await self.initialize_auth()
|
|
|
|
url = f"https://cloudcode-pa.googleapis.com/v1internal:{method}"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.get_access_token()}",
|
|
}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(url, headers=headers, json=body) as resp:
|
|
if resp.status == 401 and not is_retry:
|
|
# Token likely expired, clear and retry once
|
|
await self.clear_token_cache()
|
|
await self.initialize_auth()
|
|
return await self.call_endpoint(method, body, is_retry=True)
|
|
elif not resp.ok:
|
|
text = await resp.text()
|
|
raise RuntimeError(f"API call failed with status {resp.status}: {text}")
|
|
|
|
return await resp.json()
|
|
|
|
class GeminiCLIProvider():
|
|
url = "https://cloud.google.com/code-assist"
|
|
api_base = "https://cloudcode-pa.googleapis.com/v1internal"
|
|
|
|
# Required for authentication and token management; Expects a compatible AuthManager instance
|
|
auth_manager: AuthManager
|
|
env: dict
|
|
|
|
def __init__(self, env: dict, auth_manager: AuthManager):
|
|
self.env = env
|
|
self.auth_manager = auth_manager
|
|
|
|
# Cache for discovered project ID
|
|
self._project_id: Optional[str] = None
|
|
|
|
async def discover_project_id(self) -> str:
|
|
if self.env.get("GEMINI_PROJECT_ID"):
|
|
return self.env["GEMINI_PROJECT_ID"]
|
|
if self._project_id:
|
|
return self._project_id
|
|
|
|
try:
|
|
load_response = await self.auth_manager.call_endpoint(
|
|
"loadCodeAssist",
|
|
{
|
|
"cloudaicompanionProject": "default-project",
|
|
"metadata": {"duetProject": "default-project"},
|
|
},
|
|
)
|
|
project = load_response.get("cloudaicompanionProject")
|
|
if project:
|
|
self._project_id = project
|
|
return project
|
|
raise RuntimeError(
|
|
"Project ID discovery failed - set GEMINI_PROJECT_ID in environment."
|
|
)
|
|
except Exception as e:
|
|
debug.error(f"Failed to discover project ID: {e}")
|
|
raise RuntimeError(
|
|
"Could not discover project ID. Ensure authentication or set GEMINI_PROJECT_ID."
|
|
)
|
|
|
|
@staticmethod
|
|
def _messages_to_gemini_format(messages: list, media: MediaListType) -> Dict[str, Any]:
|
|
format_messages = []
|
|
for msg in messages:
|
|
# Convert a ChatMessage dict to GeminiFormattedMessage dict
|
|
role = "model" if msg["role"] == "assistant" else "user"
|
|
|
|
# Handle tool role (OpenAI style)
|
|
if msg["role"] == "tool":
|
|
parts = [
|
|
{
|
|
"functionResponse": {
|
|
"name": msg.get("tool_call_id", "unknown_function"),
|
|
"response": {
|
|
"result": (
|
|
msg["content"]
|
|
if isinstance(msg["content"], str)
|
|
else json.dumps(msg["content"])
|
|
)
|
|
},
|
|
}
|
|
}
|
|
],
|
|
|
|
# Handle assistant messages with tool calls
|
|
elif msg["role"] == "assistant" and msg.get("tool_calls"):
|
|
parts = []
|
|
if isinstance(msg["content"], str) and msg["content"].strip():
|
|
parts.append({"text": msg["content"]})
|
|
for tool_call in msg["tool_calls"]:
|
|
if tool_call.get("type") == "function":
|
|
parts.append(
|
|
{
|
|
"functionCall": {
|
|
"name": tool_call["function"]["name"],
|
|
"args": json.loads(tool_call["function"]["arguments"]),
|
|
}
|
|
}
|
|
)
|
|
|
|
# Handle string content
|
|
elif isinstance(msg["content"], str):
|
|
parts = [{"text": msg["content"]}]
|
|
|
|
# Handle array content (possibly multimodal)
|
|
elif isinstance(msg["content"], list):
|
|
for content in msg["content"]:
|
|
ctype = content.get("type")
|
|
if ctype == "text":
|
|
parts.append({"text": content["text"]})
|
|
elif ctype == "image_url":
|
|
image_url = content.get("image_url", {}).get("url")
|
|
if not image_url:
|
|
continue
|
|
if image_url.startswith("data:"):
|
|
# Inline base64 data image
|
|
prefix, b64data = image_url.split(",", 1)
|
|
mime_type = prefix.split(":")[1].split(";")[0]
|
|
parts.append({"inlineData": {"mimeType": mime_type, "data": b64data}})
|
|
else:
|
|
parts.append(
|
|
{
|
|
"fileData": {
|
|
"mimeType": "image/jpeg", # Could improve by validation
|
|
"fileUri": image_url,
|
|
}
|
|
}
|
|
)
|
|
else:
|
|
parts = [{"text": str(msg["content"])}]
|
|
format_messages.append({"role": role, "parts": parts})
|
|
if media:
|
|
if not format_messages:
|
|
format_messages.append({"role": "user", "parts": []})
|
|
for media_data, filename in media:
|
|
if isinstance(media_data, str):
|
|
if not filename:
|
|
filename = media_data
|
|
extension = filename.split(".")[-1].replace("jpg", "jpeg")
|
|
format_messages[-1]["parts"].append(
|
|
{
|
|
"fileData": {
|
|
"mimeType": f"image/{extension}",
|
|
"fileUri": image_url,
|
|
}
|
|
}
|
|
)
|
|
else:
|
|
media_data = to_bytes(media_data)
|
|
format_messages[-1]["parts"].append({
|
|
"inlineData": {
|
|
"mimeType": is_data_an_media(media_data, filename),
|
|
"data": base64.b64encode(media_data).decode()
|
|
}
|
|
})
|
|
return format_messages
|
|
|
|
async def stream_content(
|
|
self,
|
|
model: str,
|
|
messages: Messages,
|
|
*,
|
|
proxy: Optional[str] = None,
|
|
thinking_budget: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
seed: Optional[int] = None,
|
|
response_format: Optional[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> AsyncGenerator:
|
|
await self.auth_manager.initialize_auth()
|
|
|
|
project_id = await self.discover_project_id()
|
|
|
|
# Convert messages to Gemini format
|
|
contents = self._messages_to_gemini_format([m for m in messages if m["role"] not in ["developer", "system"]], media=kwargs.get("media", None))
|
|
system_prompt = get_system_prompt(messages)
|
|
requestData = {}
|
|
if system_prompt:
|
|
requestData["system_instruction"] = {"parts": {"text": system_prompt}}
|
|
|
|
# Compose request body
|
|
req_body = {
|
|
"model": model,
|
|
"project": project_id,
|
|
"request": {
|
|
"contents": contents,
|
|
"generationConfig": {
|
|
"maxOutputTokens": max_tokens,
|
|
"temperature": temperature,
|
|
"topP": top_p,
|
|
"stop": stop,
|
|
"presencePenalty": presence_penalty,
|
|
"frequencyPenalty": frequency_penalty,
|
|
"seed": seed,
|
|
"responseMimeType": None if response_format is None else ("application/json" if response_format.get("type") == "json_object" else None),
|
|
"thinkingConfig": {
|
|
"thinkingBudget": thinking_budget,
|
|
"includeThoughts": True
|
|
} if thinking_budget else None,
|
|
},
|
|
"tools": tools or [],
|
|
"toolConfig": {
|
|
"functionCallingConfig": {
|
|
"mode": tool_choice.upper(),
|
|
"allowedFunctionNames": [tool["function"]["name"] for tool in tools]
|
|
}
|
|
} if tool_choice else None,
|
|
**requestData
|
|
},
|
|
}
|
|
|
|
# Remove None values recursively
|
|
def clean_none(d):
|
|
if isinstance(d, dict):
|
|
return {k: clean_none(v) for k, v in d.items() if v is not None}
|
|
if isinstance(d, list):
|
|
return [clean_none(x) for x in d if x is not None]
|
|
return d
|
|
|
|
req_body = clean_none(req_body)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.auth_manager.get_access_token()}",
|
|
}
|
|
|
|
url = f"{self.api_base}:streamGenerateContent?alt=sse"
|
|
|
|
# Streaming SSE parsing helper
|
|
async def parse_sse_stream(stream: aiohttp.StreamReader) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Parse SSE stream yielding parsed JSON objects"""
|
|
buffer = ""
|
|
object_buffer = ""
|
|
|
|
async for chunk_bytes in stream.iter_any():
|
|
chunk = chunk_bytes.decode()
|
|
buffer += chunk
|
|
lines = buffer.split("\n")
|
|
buffer = lines.pop() # Save last incomplete line back
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line == "":
|
|
# Empty line indicates end of SSE message -> parse object buffer
|
|
if object_buffer:
|
|
try:
|
|
yield json.loads(object_buffer)
|
|
except Exception as e:
|
|
debug.error(f"Error parsing SSE JSON: {e}")
|
|
object_buffer = ""
|
|
elif line.startswith("data: "):
|
|
object_buffer += line[6:]
|
|
|
|
# Final parse when stream ends
|
|
if object_buffer:
|
|
try:
|
|
yield json.loads(object_buffer)
|
|
except Exception as e:
|
|
debug.error(f"Error parsing final SSE JSON: {e}")
|
|
|
|
timeout = ClientTimeout(total=None) # No total timeout
|
|
connector = get_connector(None, proxy) # Customize connector as needed (supports proxy)
|
|
|
|
async with ClientSession(headers=headers, timeout=timeout, connector=connector) as session:
|
|
async with session.post(url, json=req_body) as resp:
|
|
if not resp.ok:
|
|
if resp.status == 401:
|
|
# Possibly token expired: try login retry logic, omitted here for brevity
|
|
raise MissingAuthError(f"Unauthorized (401) from Gemini API")
|
|
error_body = await resp.text()
|
|
raise RuntimeError(f"Gemini API error {resp.status}: {error_body}")
|
|
|
|
async for json_data in parse_sse_stream(resp.content):
|
|
# Process JSON data according to Gemini API structure
|
|
candidates = json_data.get("response", {}).get("candidates", [])
|
|
usage_metadata = json_data.get("response", {}).get("usageMetadata", {})
|
|
|
|
if not candidates:
|
|
continue
|
|
|
|
candidate = candidates[0]
|
|
content = candidate.get("content", {})
|
|
parts = content.get("parts", [])
|
|
|
|
tool_calls = []
|
|
|
|
for part in parts:
|
|
# Real thinking chunks
|
|
if part.get("thought") is True and "text" in part:
|
|
yield Reasoning(part["text"])
|
|
|
|
# Function calls from Gemini
|
|
elif "functionCall" in part:
|
|
tool_calls.append(part["functionCall"])
|
|
|
|
# Text content
|
|
elif "text" in part:
|
|
yield part["text"]
|
|
|
|
# Inline media data
|
|
elif "inlineData" in part:
|
|
# Media chunk - yield media asynchronously
|
|
async for media in save_response_media(part["inlineData"], format_media_prompt(messages)):
|
|
yield media
|
|
|
|
# File data (e.g. external image)
|
|
elif "fileData" in part:
|
|
# Just yield the file URI for now
|
|
file_data = part["fileData"]
|
|
yield ImageResponse(file_data.get("fileUri"))
|
|
|
|
if tool_calls:
|
|
yield ToolCalls(tool_calls)
|
|
if usage_metadata:
|
|
yield Usage(
|
|
promptTokens=usage_metadata.get("promptTokenCount", 0),
|
|
completionTokens=usage_metadata.get("candidatesTokenCount", 0),
|
|
)
|
|
|
|
class GeminiCLI(AsyncGeneratorProvider, ProviderModelMixin):
|
|
label = "Google Gemini CLI"
|
|
login_url = "https://github.com/GewoonJaap/gemini-cli-openai"
|
|
|
|
default_model = "gemini-2.5-pro"
|
|
models = [
|
|
"gemini-2.5-pro",
|
|
"gemini-2.5-flash",
|
|
"gemini-3-pro-preview"
|
|
]
|
|
|
|
working = True
|
|
supports_message_history = True
|
|
supports_system_message = True
|
|
needs_auth = True
|
|
active_by_default = True
|
|
|
|
auth_manager: AuthManager = None
|
|
|
|
@classmethod
|
|
def get_models(cls, **kwargs):
|
|
if cls.live == 0:
|
|
if cls.auth_manager is None:
|
|
cls.auth_manager = AuthManager(env=os.environ)
|
|
if cls.auth_manager.get_access_token() is not None:
|
|
cls.live += 1
|
|
return cls.models
|
|
|
|
@classmethod
|
|
async def create_async_generator(
|
|
cls,
|
|
model: str,
|
|
messages: Messages,
|
|
stream: bool = False,
|
|
media: MediaListType = None,
|
|
tools: Optional[list] = None,
|
|
**kwargs
|
|
) -> AsyncResult:
|
|
if cls.auth_manager is None:
|
|
cls.auth_manager = AuthManager(env=os.environ)
|
|
|
|
# Initialize Gemini CLI provider with auth manager and environment
|
|
provider = GeminiCLIProvider(env=os.environ, auth_manager=cls.auth_manager)
|
|
|
|
async for chunk in provider.stream_content(
|
|
model=model,
|
|
messages=messages,
|
|
stream=stream,
|
|
media=media,
|
|
tools=tools,
|
|
**kwargs
|
|
):
|
|
yield chunk |