Add websocket media streaming for OpenaiChat

Introduces the wss_media method to stream media updates via websocket in OpenaiChat, and updates logic to yield media as it becomes available. Also adds wait_media as a fallback polling method, tracks image generation tasks in Conversation, and fixes a bug in curl_cffi.py when deleting the 'autoping' key from kwargs.
This commit is contained in:
Ammar
2025-12-06 16:34:48 +02:00
parent 027d486b57
commit e5ca022142
2 changed files with 158 additions and 3 deletions

View File

@@ -10,7 +10,9 @@ import re
import time
import uuid
from copy import copy
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, Union, List, Any, AsyncGenerator, Set
from curl_cffi import AsyncSession
try:
import nodriver
@@ -341,7 +343,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
debug.error(e)
if download_urls:
# status = None, finished_successfully
if is_sediment and status is None:
if is_sediment and status != "finished_successfully":
return ImagePreview(download_urls, prompt, {"status": status, "headers": auth_result.headers})
else:
return ImageResponse(download_urls, prompt, {"status": status, "headers": auth_result.headers})
@@ -703,8 +705,157 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
await asyncio.sleep(5)
else:
break
if conversation.task and kwargs.get("wait_media", True):
async for _m in cls.wss_media(session, conversation, auth_result.headers, auth_result):
yield _m
# if kwargs.get("wait_media"):
# async for _m in cls.wait_media(session, conversation, headers, auth_result):
# yield _m
yield FinishReason(conversation.finish_reason)
@classmethod
async def wss_media(
cls,
_session,
conversation: Conversation,
headers: Dict[str, str],
auth_result: AuthResult,
timeout: Optional[int] = 20,
):
seen_assets: Set[str] = set()
async with AsyncSession(
timeout=timeout,
impersonate="chrome",
headers=headers,
cookies=auth_result.cookies
) as session:
response = await session.get(
"https://chatgpt.com/backend-api/celsius/ws/user",
headers=headers,
)
response.raise_for_status()
websocket_url = response.json().get("websocket_url")
started = False
wss = await session.ws_connect(websocket_url, timeout=3)
while not wss.closed:
try:
last_msg = await wss.recv_json(timeout=60 if not started else timeout)
except:
break
conversation_id = conversation.task.get("conversation_id")
message_id = conversation.task.get("message", {}).get("id")
if isinstance(last_msg, dict) and last_msg.get("type") == "conversation-update":
if last_msg.get("payload", {}).get("conversation_id") != conversation_id:
continue
message = last_msg.get("payload", {}).get("update_content", {}).get("message", {})
if message.get("id") != message_id:
continue
# if last_msg.get("payload", {}).get("update_type") == 'async-task-start':
# started = True
started = True
if last_msg.get("payload", {}).get("update_type") == 'async-task-update-message':
status = message.get("status")
parts = message.get("content").get("parts") or []
for part in parts:
if part.get("content_type") != "image_asset_pointer":
continue
asset = part.get("asset_pointer")
if not asset or asset in seen_assets:
continue
seen_assets.add(asset)
generated_images = await cls.get_generated_image(
_session,
auth_result,
asset,
conversation.prompt or "",
conversation.conversation_id,
status,
)
if generated_images is not None:
yield generated_images
if message.get("status") == "finished_successfully":
await wss.close()
return
@classmethod
async def wait_media(
cls,
session,
conversation,
headers: Dict[str, str],
auth_result: AuthResult,
poll_interval: int = 10,
timeout: Optional[int] = None,
) -> AsyncGenerator[Any, None]:
start_time = asyncio.get_event_loop().time()
seen_assets: Set[str] = set()
running = True
has_image_task = False
generation_started = False
while running:
if timeout is not None:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
return
# https://chatgpt.com/backend-api/tasks
async with session.get(
f"https://chatgpt.com/backend-api/conversation/{conversation.conversation_id}",
headers=headers,
) as response:
await raise_for_status(response)
data = await response.json()
mapping = data.get("mapping") or {}
if not mapping:
return
last_node = list(mapping.values())[-1] or {}
last_message = last_node.get("message") or {}
metadata = last_message.get("metadata") or {}
status = last_message.get("status")
image_task_id = metadata.get("image_gen_task_id")
if not has_image_task and not image_task_id:
return
if image_task_id and not has_image_task:
debug.log(f"OpenaiChat: Wait Task: {image_task_id}")
has_image_task = True
if status == "in_progress":
generation_started = True
elif generation_started and status == "finished_successfully":
running = False
if generation_started:
content = last_message.get("content") or {}
parts = content.get("parts") or []
for part in parts:
if part.get("content_type") != "image_asset_pointer":
continue
asset = part.get("asset_pointer")
if not asset or asset in seen_assets:
continue
seen_assets.add(asset)
generated_images = await cls.get_generated_image(
session,
auth_result,
asset,
conversation.prompt
or metadata.get("async_task_title")
or "",
conversation.conversation_id,
status,
)
if generated_images is not None:
yield generated_images
if generation_started and status == "finished_successfully":
return
await asyncio.sleep(poll_interval)
@classmethod
async def iter_messages_line(cls, session: StreamSession, auth_result: AuthResult, line: bytes,
fields: Conversation, sources: OpenAISources,
@@ -850,6 +1001,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if fields.parent_message_id is None:
fields.parent_message_id = v.get("message", {}).get("id")
fields.message_id = v.get("message", {}).get("id")
if m.get("status") == "finished_successfully" and m.get("metadata", {}).get("image_gen_task_id"):
fields.task = v
return
if "error" in line and line.get("error"):
raise RuntimeError(line.get("error"))
@@ -1046,6 +1199,7 @@ class Conversation(JsonConversation):
self.thoughts_summary = ""
self.prompt = None
self.generated_images: ImagePreview = None
self.task: dict = None
def get_cookies(

View File

@@ -148,7 +148,8 @@ if has_curl_cffi and has_curl_ws:
def __init__(self, session, url, **kwargs) -> None:
self.session: StreamSession = session
self.url: str = url
del kwargs["autoping"]
if "autoping" in kwargs:
del kwargs["autoping"]
self.options: dict = kwargs
async def __aenter__(self):