mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-09-27 04:36:17 +08:00
253 lines
9.1 KiB
Python
253 lines
9.1 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import asyncio
|
|
import uuid
|
|
from typing import Optional, Dict, Union
|
|
from .stubs import IQwenOAuth2Client, ErrorDataDict
|
|
from pathlib import Path
|
|
import threading
|
|
|
|
from ..base_provider import AuthFileMixin
|
|
from ... import debug
|
|
|
|
QWEN_DIR = ".qwen"
|
|
QWEN_CREDENTIAL_FILENAME = "oauth_creds.json"
|
|
QWEN_LOCK_FILENAME = "oauth_creds.lock"
|
|
TOKEN_REFRESH_BUFFER_MS = 30 * 1000
|
|
LOCK_TIMEOUT_MS = 10000
|
|
CACHE_CHECK_INTERVAL_MS = 1000
|
|
|
|
|
|
def isErrorResponse(
|
|
response: Union[Dict, ErrorDataDict]
|
|
) -> bool:
|
|
return "error" in response
|
|
|
|
|
|
class TokenError:
|
|
REFRESH_FAILED = "REFRESH_FAILED"
|
|
NO_REFRESH_TOKEN = "NO_REFRESH_TOKEN"
|
|
LOCK_TIMEOUT = "LOCK_TIMEOUT"
|
|
FILE_ACCESS_ERROR = "FILE_ACCESS_ERROR"
|
|
NETWORK_ERROR = "NETWORK_ERROR"
|
|
|
|
|
|
class TokenManagerError(Exception):
|
|
def __init__(self, type_: str, message: str, original_error: Optional[Exception] = None):
|
|
super().__init__(message)
|
|
self.type = type_
|
|
self.original_error = original_error
|
|
|
|
|
|
class SharedTokenManager(AuthFileMixin):
|
|
parent = "QwenCode"
|
|
_instance: Optional["SharedTokenManager"] = None
|
|
_lock = threading.Lock()
|
|
|
|
def __init__(self):
|
|
self.memory_cache = {
|
|
"credentials": None,
|
|
"file_mod_time": 0,
|
|
"last_check": 0,
|
|
}
|
|
self.refresh_promise = None
|
|
|
|
@classmethod
|
|
def getInstance(cls):
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def getCredentialFilePath(self):
|
|
path = Path(os.path.expanduser(f"~/{QWEN_DIR}/{QWEN_CREDENTIAL_FILENAME}"))
|
|
if path.is_file():
|
|
return path
|
|
return SharedTokenManager.get_cache_file()
|
|
|
|
def getLockFilePath(self):
|
|
return Path(os.path.expanduser(f"~/{QWEN_DIR}/{QWEN_LOCK_FILENAME}"))
|
|
|
|
def setLockConfig(self, config: dict):
|
|
# Optional: allow lock config override
|
|
pass
|
|
|
|
def registerCleanupHandlers(self):
|
|
import atexit
|
|
|
|
def cleanup():
|
|
try:
|
|
lock_path = self.getLockFilePath()
|
|
lock_path.unlink()
|
|
except:
|
|
pass
|
|
|
|
atexit.register(cleanup)
|
|
|
|
async def getValidCredentials(self, qwen_client: IQwenOAuth2Client, force_refresh: bool = False):
|
|
try:
|
|
self.checkAndReloadIfNeeded()
|
|
|
|
if (
|
|
self.memory_cache["credentials"]
|
|
and not force_refresh
|
|
and self.isTokenValid(self.memory_cache["credentials"])
|
|
):
|
|
return self.memory_cache["credentials"]
|
|
|
|
if self.refresh_promise:
|
|
return await self.refresh_promise
|
|
|
|
self.refresh_promise = asyncio.create_task(self.performTokenRefresh(qwen_client, force_refresh))
|
|
credentials = await self.refresh_promise
|
|
self.refresh_promise = None
|
|
return credentials
|
|
except Exception as e:
|
|
if isinstance(e, TokenManagerError):
|
|
raise
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, str(e), e) from e
|
|
|
|
def checkAndReloadIfNeeded(self):
|
|
now = int(time.time() * 1000)
|
|
if now - self.memory_cache["last_check"] < CACHE_CHECK_INTERVAL_MS:
|
|
return
|
|
self.memory_cache["last_check"] = now
|
|
|
|
try:
|
|
file_path = self.getCredentialFilePath()
|
|
stat = file_path.stat()
|
|
file_mod_time = int(stat.st_mtime * 1000)
|
|
if file_mod_time > self.memory_cache["file_mod_time"]:
|
|
self.reloadCredentialsFromFile()
|
|
self.memory_cache["file_mod_time"] = file_mod_time
|
|
except FileNotFoundError:
|
|
self.memory_cache["file_mod_time"] = 0
|
|
except Exception as e:
|
|
self.memory_cache["credentials"] = None
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, str(e), e)
|
|
|
|
def reloadCredentialsFromFile(self):
|
|
file_path = self.getCredentialFilePath()
|
|
debug.log(f"Reloading credentials from {file_path}")
|
|
try:
|
|
with open(file_path, "r") as fs:
|
|
data = json.load(fs)
|
|
credentials = self.validateCredentials(data)
|
|
self.memory_cache["credentials"] = credentials
|
|
except FileNotFoundError as e:
|
|
self.memory_cache["credentials"] = None
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, "Credentials file not found", e) from e
|
|
except json.JSONDecodeError as e:
|
|
self.memory_cache["credentials"] = None
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, "Invalid JSON format", e) from e
|
|
except Exception as e:
|
|
self.memory_cache["credentials"] = None
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, str(e), e) from e
|
|
|
|
def validateCredentials(self, data):
|
|
if not data or not isinstance(data, dict):
|
|
raise ValueError("Invalid credentials format")
|
|
for field in ["access_token", "refresh_token", "token_type"]:
|
|
if field not in data or not isinstance(data[field], str):
|
|
raise ValueError(f"Invalid credentials: missing {field}")
|
|
if "expiry_date" not in data or not isinstance(data["expiry_date"], (int, float)):
|
|
raise ValueError("Invalid credentials: missing expiry_date")
|
|
return data
|
|
|
|
async def performTokenRefresh(self, qwen_client: IQwenOAuth2Client, force_refresh: bool):
|
|
lock_path = self.getLockFilePath()
|
|
try:
|
|
if self.memory_cache["credentials"] is None:
|
|
self.reloadCredentialsFromFile()
|
|
qwen_client.setCredentials(self.memory_cache["credentials"])
|
|
current_credentials = qwen_client.getCredentials()
|
|
if not current_credentials.get("refresh_token"):
|
|
raise TokenManagerError(TokenError.NO_REFRESH_TOKEN, "No refresh token")
|
|
await self.acquireLock(lock_path)
|
|
|
|
self.checkAndReloadIfNeeded()
|
|
|
|
if (
|
|
not force_refresh
|
|
and self.memory_cache["credentials"]
|
|
and self.isTokenValid(self.memory_cache["credentials"])
|
|
):
|
|
qwen_client.setCredentials(self.memory_cache["credentials"])
|
|
return self.memory_cache["credentials"]
|
|
|
|
response = await qwen_client.refreshAccessToken()
|
|
if not response or isErrorResponse(response):
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, str(response))
|
|
token_data = response
|
|
if "access_token" not in token_data:
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, "No access_token returned")
|
|
|
|
credentials = {
|
|
"access_token": token_data["access_token"],
|
|
"token_type": token_data["token_type"],
|
|
"refresh_token": token_data.get("refresh_token", current_credentials.get("refresh_token")),
|
|
"resource_url": token_data.get("resource_url"),
|
|
"expiry_date": int(time.time() * 1000) + token_data.get("expires_in", 0) * 1000,
|
|
}
|
|
self.memory_cache["credentials"] = credentials
|
|
qwen_client.setCredentials(credentials)
|
|
|
|
await self.saveCredentialsToFile(credentials)
|
|
return credentials
|
|
except Exception as e:
|
|
if isinstance(e, TokenManagerError):
|
|
raise
|
|
raise
|
|
|
|
finally:
|
|
await self.releaseLock(lock_path)
|
|
|
|
async def acquireLock(self, lock_path: Path):
|
|
max_attempts = 50
|
|
attempt_interval = 200 # ms
|
|
lock_id = str(uuid.uuid4())
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
|
for _ in range(max_attempts):
|
|
try:
|
|
with open(lock_path, "w") as f:
|
|
f.write(lock_id)
|
|
return
|
|
except:
|
|
try:
|
|
stat = os.stat(str(lock_path))
|
|
lock_age = int(time.time() * 1000) - int(stat.st_mtime * 1000)
|
|
if lock_age > LOCK_TIMEOUT_MS:
|
|
try:
|
|
await os.unlink(str(lock_path))
|
|
except:
|
|
pass
|
|
except:
|
|
pass
|
|
await asyncio.sleep(attempt_interval / 1000)
|
|
raise TokenManagerError(TokenError.LOCK_TIMEOUT, "Failed to acquire lock")
|
|
|
|
async def releaseLock(self, lock_path: Path):
|
|
try:
|
|
await os.unlink(str(lock_path))
|
|
except:
|
|
pass
|
|
|
|
async def saveCredentialsToFile(self, credentials: dict):
|
|
file_path = self.getCredentialFilePath()
|
|
os.makedirs(file_path.parent, exist_ok=True)
|
|
with open(file_path, "w") as f:
|
|
f.write(json.dumps(credentials, indent=2))
|
|
stat = os.stat(str(file_path))
|
|
self.memory_cache["file_mod_time"] = int(stat.st_mtime * 1000)
|
|
|
|
def isTokenValid(self, credentials: dict) -> bool:
|
|
expiry_date = credentials.get("expiry_date")
|
|
if not expiry_date:
|
|
return False
|
|
return time.time() * 1000 < expiry_date - TOKEN_REFRESH_BUFFER_MS
|
|
|
|
def getCurrentCredentials(self):
|
|
return self.memory_cache["credentials"]
|