mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-24 13:07:53 +08:00
Add websocket media streaming for OpenaiChat
Introduces the wss_media method to stream media updates via websocket in OpenaiChat, and updates logic to yield media as it becomes available. Also adds wait_media as a fallback polling method, tracks image generation tasks in Conversation, and fixes a bug in curl_cffi.py when deleting the 'autoping' key from kwargs.
This commit is contained in:
@@ -10,7 +10,9 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
from copy import copy
|
||||
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any
|
||||
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any, AsyncGenerator, Set
|
||||
|
||||
from curl_cffi import AsyncSession
|
||||
|
||||
try:
|
||||
import nodriver
|
||||
@@ -341,7 +343,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
debug.error(e)
|
||||
if download_urls:
|
||||
# status = None, finished_successfully
|
||||
if is_sediment and status is None:
|
||||
if is_sediment and status != "finished_successfully":
|
||||
return ImagePreview(download_urls, prompt, {"status": status, "headers": auth_result.headers})
|
||||
else:
|
||||
return ImageResponse(download_urls, prompt, {"status": status, "headers": auth_result.headers})
|
||||
@@ -703,8 +705,157 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
if conversation.task and kwargs.get("wait_media", True):
|
||||
async for _m in cls.wss_media(session, conversation, auth_result.headers, auth_result):
|
||||
yield _m
|
||||
# if kwargs.get("wait_media"):
|
||||
# async for _m in cls.wait_media(session, conversation, headers, auth_result):
|
||||
# yield _m
|
||||
|
||||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
async def wss_media(
|
||||
cls,
|
||||
_session,
|
||||
conversation: Conversation,
|
||||
headers: Dict[str, str],
|
||||
auth_result: AuthResult,
|
||||
timeout: Optional[int] = 20,
|
||||
):
|
||||
seen_assets: Set[str] = set()
|
||||
async with AsyncSession(
|
||||
timeout=timeout,
|
||||
impersonate="chrome",
|
||||
headers=headers,
|
||||
cookies=auth_result.cookies
|
||||
) as session:
|
||||
response = await session.get(
|
||||
"https://chatgpt.com/backend-api/celsius/ws/user",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
websocket_url = response.json().get("websocket_url")
|
||||
started = False
|
||||
wss = await session.ws_connect(websocket_url, timeout=3)
|
||||
while not wss.closed:
|
||||
try:
|
||||
last_msg = await wss.recv_json(timeout=60 if not started else timeout)
|
||||
except:
|
||||
break
|
||||
conversation_id = conversation.task.get("conversation_id")
|
||||
message_id = conversation.task.get("message", {}).get("id")
|
||||
if isinstance(last_msg, dict) and last_msg.get("type") == "conversation-update":
|
||||
if last_msg.get("payload", {}).get("conversation_id") != conversation_id:
|
||||
continue
|
||||
|
||||
message = last_msg.get("payload", {}).get("update_content", {}).get("message", {})
|
||||
if message.get("id") != message_id:
|
||||
continue
|
||||
|
||||
# if last_msg.get("payload", {}).get("update_type") == 'async-task-start':
|
||||
# started = True
|
||||
started = True
|
||||
if last_msg.get("payload", {}).get("update_type") == 'async-task-update-message':
|
||||
|
||||
status = message.get("status")
|
||||
parts = message.get("content").get("parts") or []
|
||||
for part in parts:
|
||||
if part.get("content_type") != "image_asset_pointer":
|
||||
continue
|
||||
asset = part.get("asset_pointer")
|
||||
if not asset or asset in seen_assets:
|
||||
continue
|
||||
seen_assets.add(asset)
|
||||
generated_images = await cls.get_generated_image(
|
||||
_session,
|
||||
auth_result,
|
||||
asset,
|
||||
conversation.prompt or "",
|
||||
conversation.conversation_id,
|
||||
status,
|
||||
)
|
||||
if generated_images is not None:
|
||||
yield generated_images
|
||||
if message.get("status") == "finished_successfully":
|
||||
await wss.close()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def wait_media(
|
||||
cls,
|
||||
session,
|
||||
conversation,
|
||||
headers: Dict[str, str],
|
||||
auth_result: AuthResult,
|
||||
poll_interval: int = 10,
|
||||
timeout: Optional[int] = None,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
seen_assets: Set[str] = set()
|
||||
running = True
|
||||
has_image_task = False
|
||||
generation_started = False
|
||||
|
||||
while running:
|
||||
if timeout is not None:
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
if elapsed > timeout:
|
||||
return
|
||||
# https://chatgpt.com/backend-api/tasks
|
||||
async with session.get(
|
||||
f"https://chatgpt.com/backend-api/conversation/{conversation.conversation_id}",
|
||||
headers=headers,
|
||||
) as response:
|
||||
await raise_for_status(response)
|
||||
data = await response.json()
|
||||
|
||||
mapping = data.get("mapping") or {}
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
last_node = list(mapping.values())[-1] or {}
|
||||
last_message = last_node.get("message") or {}
|
||||
metadata = last_message.get("metadata") or {}
|
||||
status = last_message.get("status")
|
||||
image_task_id = metadata.get("image_gen_task_id")
|
||||
if not has_image_task and not image_task_id:
|
||||
return
|
||||
|
||||
if image_task_id and not has_image_task:
|
||||
debug.log(f"OpenaiChat: Wait Task: {image_task_id}")
|
||||
has_image_task = True
|
||||
if status == "in_progress":
|
||||
generation_started = True
|
||||
elif generation_started and status == "finished_successfully":
|
||||
running = False
|
||||
if generation_started:
|
||||
content = last_message.get("content") or {}
|
||||
parts = content.get("parts") or []
|
||||
for part in parts:
|
||||
if part.get("content_type") != "image_asset_pointer":
|
||||
continue
|
||||
asset = part.get("asset_pointer")
|
||||
if not asset or asset in seen_assets:
|
||||
continue
|
||||
seen_assets.add(asset)
|
||||
generated_images = await cls.get_generated_image(
|
||||
session,
|
||||
auth_result,
|
||||
asset,
|
||||
conversation.prompt
|
||||
or metadata.get("async_task_title")
|
||||
or "",
|
||||
conversation.conversation_id,
|
||||
status,
|
||||
)
|
||||
if generated_images is not None:
|
||||
yield generated_images
|
||||
if generation_started and status == "finished_successfully":
|
||||
return
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes,
|
||||
fields: Conversation, sources: OpenAISources,
|
||||
@@ -850,6 +1001,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
||||
if fields.parent_message_id is None:
|
||||
fields.parent_message_id = v.get("message", {}).get("id")
|
||||
fields.message_id = v.get("message", {}).get("id")
|
||||
if m.get("status") == "finished_successfully" and m.get("metadata", {}).get("image_gen_task_id"):
|
||||
fields.task = v
|
||||
return
|
||||
if "error" in line and line.get("error"):
|
||||
raise RuntimeError(line.get("error"))
|
||||
@@ -1046,6 +1199,7 @@ class Conversation(JsonConversation):
|
||||
self.thoughts_summary = ""
|
||||
self.prompt = None
|
||||
self.generated_images: ImagePreview = None
|
||||
self.task: dict = None
|
||||
|
||||
|
||||
def get_cookies(
|
||||
|
||||
@@ -148,7 +148,8 @@ if has_curl_cffi and has_curl_ws:
|
||||
def __init__(self, session, url, **kwargs) -> None:
|
||||
self.session: StreamSession = session
|
||||
self.url: str = url
|
||||
del kwargs["autoping"]
|
||||
if "autoping" in kwargs:
|
||||
del kwargs["autoping"]
|
||||
self.options: dict = kwargs
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
Reference in New Issue
Block a user