mirror of
				https://github.com/xtekky/gpt4free.git
				synced 2025-10-26 17:50:22 +08:00 
			
		
		
		
	Use dynamtic aspect_ratio for image and video size
This commit is contained in:
		| @@ -15,6 +15,7 @@ from ..errors import ModelNotFoundError | ||||
| from ..requests.raise_for_status import raise_for_status | ||||
| from ..requests.aiohttp import get_connector | ||||
| from ..image.copy_images import save_response_media | ||||
| from ..image import use_aspect_ratio | ||||
| from ..providers.response import FinishReason, Usage, ToolCalls | ||||
| from .. import debug | ||||
|  | ||||
| @@ -139,8 +140,9 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         cache: bool = False, | ||||
|         # Image generation parameters | ||||
|         prompt: str = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         seed: Optional[int] = None, | ||||
|         nologo: bool = True, | ||||
|         private: bool = False, | ||||
| @@ -177,6 +179,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|                 model=model, | ||||
|                 prompt=format_image_prompt(messages, prompt), | ||||
|                 proxy=proxy, | ||||
|                 aspect_ratio=aspect_ratio, | ||||
|                 width=width, | ||||
|                 height=height, | ||||
|                 seed=seed, | ||||
| @@ -212,6 +215,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         model: str, | ||||
|         prompt: str, | ||||
|         proxy: str, | ||||
|         aspect_ratio: str, | ||||
|         width: int, | ||||
|         height: int, | ||||
|         seed: Optional[int], | ||||
| @@ -223,17 +227,17 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|     ) -> AsyncResult: | ||||
|         if not cache and seed is None: | ||||
|             seed = random.randint(9999, 99999999) | ||||
|         params = { | ||||
|             "seed": str(seed) if seed is not None else None, | ||||
|             "width": str(width), | ||||
|             "height": str(height), | ||||
|         params = use_aspect_ratio({ | ||||
|             "seed": seed, | ||||
|             "width": width, | ||||
|             "height": height, | ||||
|             "model": model, | ||||
|             "nologo": str(nologo).lower(), | ||||
|             "private": str(private).lower(), | ||||
|             "enhance": str(enhance).lower(), | ||||
|             "safe": str(safe).lower() | ||||
|         } | ||||
|         query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items() if v is not None) | ||||
|         }, aspect_ratio) | ||||
|         query = "&".join(f"{k}={quote_plus(str(v))}" for k, v in params.items() if v is not None) | ||||
|         url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}" | ||||
|         #yield ImagePreview(url, prompt) | ||||
|  | ||||
|   | ||||
| @@ -36,8 +36,9 @@ class PollinationsImage(PollinationsAI): | ||||
|         messages: Messages, | ||||
|         proxy: str = None, | ||||
|         prompt: str = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         seed: Optional[int] = None, | ||||
|         cache: bool = False, | ||||
|         nologo: bool = True, | ||||
| @@ -52,6 +53,7 @@ class PollinationsImage(PollinationsAI): | ||||
|             model=model, | ||||
|             prompt=format_image_prompt(messages, prompt), | ||||
|             proxy=proxy, | ||||
|             aspect_ratio=aspect_ratio, | ||||
|             width=width, | ||||
|             height=height, | ||||
|             seed=seed, | ||||
|   | ||||
| @@ -11,6 +11,7 @@ from ...errors import ModelNotSupportedError, ResponseError | ||||
| from ...requests import StreamSession, raise_for_status | ||||
| from ...providers.response import FinishReason, ImageResponse | ||||
| from ...image.copy_images import save_response_media | ||||
| from ...image import use_aspect_ratio | ||||
| from ..helper import format_image_prompt, get_last_user_message | ||||
| from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models | ||||
| from ... import debug | ||||
| @@ -78,8 +79,9 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         action: str = None, | ||||
|         extra_data: dict = {}, | ||||
|         seed: int = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = None, | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         **kwargs | ||||
|     ) -> AsyncResult: | ||||
|         try: | ||||
| @@ -99,14 +101,14 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         ) as session: | ||||
|             try: | ||||
|                 if model in provider_together_urls: | ||||
|                     data = { | ||||
|                     data = use_aspect_ratio({ | ||||
|                         "response_format": "url", | ||||
|                         "prompt": format_image_prompt(messages, prompt), | ||||
|                         "model": model, | ||||
|                         "width": width, | ||||
|                         "height": height, | ||||
|                         **extra_data | ||||
|                     } | ||||
|                     }, aspect_ratio) | ||||
|                     async with session.post(provider_together_urls[model], json=data) as response: | ||||
|                         if response.status == 404: | ||||
|                             raise ModelNotSupportedError(f"Model is not supported: {model}") | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import time | ||||
| import asyncio | ||||
| import random | ||||
| import requests | ||||
|  | ||||
| @@ -8,12 +10,13 @@ from ...requests import StreamSession, raise_for_status | ||||
| from ...errors import ModelNotSupportedError | ||||
| from ...providers.helper import format_image_prompt | ||||
| from ...providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin | ||||
| from ...providers.response import ProviderInfo, ImageResponse, VideoResponse | ||||
| from ...providers.response import ProviderInfo, ImageResponse, VideoResponse, Reasoning | ||||
| from ...image.copy_images import save_response_media | ||||
| from ...image import use_aspect_ratio | ||||
| from ... import debug | ||||
|  | ||||
| class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|     label = "HuggingFace (Image / Video Generation)" | ||||
|     label = "HuggingFace (Image/Video Generation)" | ||||
|     parent = "HuggingFace" | ||||
|     url = "https://huggingface.co" | ||||
|     working = True | ||||
| @@ -79,6 +82,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         prompt: str = None, | ||||
|         proxy: str = None, | ||||
|         timeout: int = 0, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         **kwargs | ||||
|     ): | ||||
|         provider_mapping = await cls.get_mapping(model, api_key) | ||||
| @@ -91,85 +95,94 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|             if key in ["replicate", "together", "hf-inference"] | ||||
|         } | ||||
|         provider_mapping = {**new_mapping, **provider_mapping} | ||||
|         last_response = None | ||||
|         for provider_key, provider in provider_mapping.items(): | ||||
|             yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"}) | ||||
|         async def generate(extra_data: dict, prompt: str): | ||||
|             last_response = None | ||||
|             for provider_key, provider in provider_mapping.items(): | ||||
|                 provider_info = ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})", "url": f"{cls.url}/{model}"}) | ||||
|  | ||||
|             api_base = f"https://router.huggingface.co/{provider_key}" | ||||
|             task = provider["task"] | ||||
|             provider_id = provider["providerId"] | ||||
|             if task not in cls.tasks: | ||||
|                 raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}") | ||||
|                 api_base = f"https://router.huggingface.co/{provider_key}" | ||||
|                 task = provider["task"] | ||||
|                 provider_id = provider["providerId"] | ||||
|                 if task not in cls.tasks: | ||||
|                     raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} task: {task}") | ||||
|  | ||||
|             prompt = format_image_prompt(messages, prompt) | ||||
|             if task == "text-to-video": | ||||
|                 extra_data = { | ||||
|                     "num_inference_steps": 20, | ||||
|                     "video_size": "landscape_16_9", | ||||
|                     **extra_data | ||||
|                 } | ||||
|             else: | ||||
|                 extra_data = { | ||||
|                     "width": 1024, | ||||
|                     "height": 1024, | ||||
|                     **extra_data | ||||
|                 } | ||||
|             if provider_key == "fal-ai": | ||||
|                 url = f"{api_base}/{provider_id}" | ||||
|                 data = { | ||||
|                     "prompt": prompt, | ||||
|                     "image_size": "square_hd", | ||||
|                     **extra_data | ||||
|                 } | ||||
|             elif provider_key == "replicate": | ||||
|                 url = f"{api_base}/v1/models/{provider_id}/prediction" | ||||
|                 data = { | ||||
|                     "input": { | ||||
|                 prompt = format_image_prompt(messages, prompt) | ||||
|                 if task == "text-to-video": | ||||
|                     extra_data = { | ||||
|                         "num_inference_steps": 20, | ||||
|                         "resolution": "480p", | ||||
|                         "aspect_ratio": aspect_ratio, | ||||
|                         **extra_data | ||||
|                     } | ||||
|                 else: | ||||
|                     extra_data = use_aspect_ratio(extra_data, aspect_ratio) | ||||
|                 if provider_key == "fal-ai": | ||||
|                     url = f"{api_base}/{provider_id}" | ||||
|                     data = { | ||||
|                         "prompt": prompt, | ||||
|                         "image_size": "square_hd", | ||||
|                         **extra_data | ||||
|                     } | ||||
|                 } | ||||
|             elif provider_key in ("hf-inference", "hf-free"): | ||||
|                 api_base = "https://api-inference.huggingface.co" | ||||
|                 url = f"{api_base}/models/{provider_id}" | ||||
|                 data = { | ||||
|                     "inputs": prompt, | ||||
|                     "parameters": { | ||||
|                         "seed": random.randint(0, 2**32), | ||||
|                 elif provider_key == "replicate": | ||||
|                     url = f"{api_base}/v1/models/{provider_id}/prediction" | ||||
|                     data = { | ||||
|                         "input": { | ||||
|                             "prompt": prompt, | ||||
|                             **extra_data | ||||
|                         } | ||||
|                     } | ||||
|                 elif provider_key in ("hf-inference", "hf-free"): | ||||
|                     api_base = "https://api-inference.huggingface.co" | ||||
|                     url = f"{api_base}/models/{provider_id}" | ||||
|                     data = { | ||||
|                         "inputs": prompt, | ||||
|                         "parameters": { | ||||
|                             "seed": random.randint(0, 2**32), | ||||
|                             **extra_data | ||||
|                         } | ||||
|                     } | ||||
|                 elif task == "text-to-image": | ||||
|                     url = f"{api_base}/v1/images/generations" | ||||
|                     data = { | ||||
|                         "response_format": "url", | ||||
|                         "prompt": prompt, | ||||
|                         "model": provider_id, | ||||
|                         **extra_data | ||||
|                     } | ||||
|                 } | ||||
|             elif task == "text-to-image": | ||||
|                 url = f"{api_base}/v1/images/generations" | ||||
|                 data = { | ||||
|                     "response_format": "url", | ||||
|                     "prompt": prompt, | ||||
|                     "model": provider_id, | ||||
|                     **extra_data | ||||
|                 } | ||||
|  | ||||
|             async with StreamSession( | ||||
|                 headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"}, | ||||
|                 proxy=proxy, | ||||
|                 timeout=timeout | ||||
|             ) as session: | ||||
|                 async with session.post(url, json=data) as response: | ||||
|                     if response.status in (400, 401, 402): | ||||
|                         last_response = response | ||||
|                         debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}") | ||||
|                         continue | ||||
|                     if response.status == 404: | ||||
|                         raise ModelNotSupportedError(f"Model is not supported: {model}") | ||||
|                     await raise_for_status(response) | ||||
|                     async for chunk in save_response_media(response, prompt): | ||||
|                         yield chunk | ||||
|                         return | ||||
|                     result = await response.json() | ||||
|                     if "video" in result: | ||||
|                         yield VideoResponse(result["video"]["url"], prompt) | ||||
|                     elif task == "text-to-image": | ||||
|                         yield ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt) | ||||
|                     elif task == "text-to-video": | ||||
|                         yield VideoResponse(result["output"], prompt) | ||||
|                     return | ||||
|         await raise_for_status(last_response) | ||||
|                 async with StreamSession( | ||||
|                     headers=headers if provider_key == "free" or api_key is None else {**headers, "Authorization": f"Bearer {api_key}"}, | ||||
|                     proxy=proxy, | ||||
|                     timeout=timeout | ||||
|                 ) as session: | ||||
|                     async with session.post(url, json=data) as response: | ||||
|                         if response.status in (400, 401, 402): | ||||
|                             last_response = response | ||||
|                             debug.error(f"{cls.__name__}: Error {response.status} with {provider_key} and {provider_id}") | ||||
|                             continue | ||||
|                         if response.status == 404: | ||||
|                             raise ModelNotSupportedError(f"Model is not supported: {model}") | ||||
|                         await raise_for_status(response) | ||||
|                         async for chunk in save_response_media(response, prompt): | ||||
|                             return provider_info, chunk | ||||
|                         result = await response.json() | ||||
|                         if "video" in result: | ||||
|                             return provider_info, VideoResponse(result["video"]["url"], prompt) | ||||
|                         elif task == "text-to-image": | ||||
|                             return provider_info, ImageResponse([item["url"] for item in result.get("images", result.get("data"))], prompt) | ||||
|                         elif task == "text-to-video": | ||||
|                             return provider_info, VideoResponse(result["output"], prompt) | ||||
|             await raise_for_status(last_response) | ||||
|  | ||||
|         background_tasks = set() | ||||
|         started = time.time() | ||||
|         task = asyncio.create_task(generate(extra_data, prompt)) | ||||
|         background_tasks.add(task) | ||||
|         task.add_done_callback(background_tasks.discard) | ||||
|         while background_tasks: | ||||
|             yield Reasoning(label="Generating", status=f"{time.time() - started:.2f}s") | ||||
|             await asyncio.sleep(0.2) | ||||
|         provider_info, media_response = await task | ||||
|         yield Reasoning(label="Finished", status=f"{time.time() - started:.2f}s") | ||||
|         yield provider_info | ||||
|         yield media_response | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import uuid | ||||
| from ...typing import AsyncResult, Messages | ||||
| from ...providers.response import ImageResponse, ImagePreview, JsonConversation, Reasoning | ||||
| from ...requests import StreamSession | ||||
| from ...image import use_aspect_ratio | ||||
| from ...errors import ResponseError | ||||
| from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin | ||||
| from ..helper import format_image_prompt | ||||
| @@ -56,8 +57,9 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         messages: Messages, | ||||
|         prompt: str = None, | ||||
|         proxy: str = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         guidance_scale: float = 3.5, | ||||
|         num_inference_steps: int = 28, | ||||
|         seed: int = 0, | ||||
| @@ -69,7 +71,8 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|     ) -> AsyncResult: | ||||
|         async with StreamSession(impersonate="chrome", proxy=proxy) as session: | ||||
|             prompt = format_image_prompt(messages, prompt) | ||||
|             data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] | ||||
|             data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio) | ||||
|             data = [prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps] | ||||
|             conversation = JsonConversation(zerogpu_token=api_key, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex) | ||||
|             if conversation.zerogpu_token is None: | ||||
|                 conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(cls.space, session, conversation, cookies) | ||||
|   | ||||
| @@ -37,8 +37,9 @@ class G4F(DeepseekAI_JanusPro7b): | ||||
|         messages: Messages, | ||||
|         proxy: str = None, | ||||
|         prompt: str = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         seed: int = None, | ||||
|         cookies: dict = None, | ||||
|         api_key: str = None, | ||||
| @@ -50,6 +51,7 @@ class G4F(DeepseekAI_JanusPro7b): | ||||
|                 model, messages, | ||||
|                 proxy=proxy, | ||||
|                 prompt=prompt, | ||||
|                 aspect_ratio=aspect_ratio, | ||||
|                 width=width, | ||||
|                 height=height, | ||||
|                 seed=seed, | ||||
|   | ||||
| @@ -5,6 +5,7 @@ from aiohttp import ClientSession | ||||
|  | ||||
| from ...typing import AsyncResult, Messages | ||||
| from ...providers.response import ImageResponse, ImagePreview | ||||
| from ...image import use_aspect_ratio | ||||
| from ...errors import ResponseError | ||||
| from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin | ||||
| from ..helper import format_image_prompt | ||||
| @@ -29,8 +30,9 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|         negative_prompt: str = None, | ||||
|         api_key: str = None,  | ||||
|         proxy: str = None, | ||||
|         width: int = 1024, | ||||
|         height: int = 1024, | ||||
|         aspect_ratio: str = "1:1", | ||||
|         width: int = None, | ||||
|         height: int = None, | ||||
|         guidance_scale: float = 4.5, | ||||
|         num_inference_steps: int = 50, | ||||
|         seed: int = 0, | ||||
| @@ -45,8 +47,9 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin): | ||||
|             headers["Authorization"] = f"Bearer {api_key}" | ||||
|         async with ClientSession(headers=headers) as session: | ||||
|             prompt = format_image_prompt(messages, prompt) | ||||
|             data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio) | ||||
|             data = { | ||||
|                 "data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] | ||||
|                 "data": [prompt, negative_prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps] | ||||
|             } | ||||
|             async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response: | ||||
|                 response.raise_for_status() | ||||
|   | ||||
| @@ -358,37 +358,48 @@ | ||||
|             gradient.classList.add('hidden'); | ||||
|  | ||||
|             const url = "https://image.pollinations.ai/feed"; | ||||
|             const eventSource = new EventSource(url); | ||||
|             const imageFeed = document.getElementById("image-feed"); | ||||
|             const images = [] | ||||
|             eventSource.onmessage = (event) => { | ||||
|                 const data = JSON.parse(event.data); | ||||
|                 if (data.nsfw || !data.nologo || data.width < 1024 || !data.imageURL || data.isChild) { | ||||
|                     return; | ||||
|                 } | ||||
|                 const lower = data.prompt.toLowerCase(); | ||||
|                 const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"]; | ||||
|                 for (i in tags) { | ||||
|                     if (lower.indexOf(tags[i]) != -1) { | ||||
|                         console.log("Skipping image with tag: " + tags[i]); | ||||
|                         console.debug("Skipping image:", data.imageURL); | ||||
|                         return; | ||||
|             let es = null; | ||||
|             function initES() { | ||||
|                 if (es == null || es.readyState == EventSource.CLOSED) { | ||||
|                     const eventSource = new EventSource(url); | ||||
|                     eventSource.onmessage = (event) => { | ||||
|                         const data = JSON.parse(event.data); | ||||
|                         if (data.nsfw || !data.nologo || data.width < 512 || !data.imageURL || data.isChild || data.status != "end_generating") { | ||||
|                             return; | ||||
|                         } | ||||
|                         const lower = data.prompt.toLowerCase(); | ||||
|                         const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"]; | ||||
|                         for (i in tags) { | ||||
|                             if (lower.indexOf(tags[i]) != -1) { | ||||
|                                 console.log("Skipping image with tag: " + tags[i]); | ||||
|                                 console.debug("Skipping image:", data.imageURL); | ||||
|                                 return; | ||||
|                             } | ||||
|                         } | ||||
|                         const landscape = window.innerWidth > window.innerHeight; | ||||
|                         if (landscape && data.width > data.height) { | ||||
|                             images.push(data.imageURL); | ||||
|                         } else if (!landscape && data.width < data.height) { | ||||
|                             images.push(data.imageURL); | ||||
|                         } | ||||
|                     }; | ||||
|                     eventSource.onerror = (event) => { | ||||
|                         eventSource.close(); | ||||
|                     } | ||||
|                     imageFeed.onerror = () => { | ||||
|                         imageFeed.classList.add("hidden"); | ||||
|                     }    | ||||
|                 } | ||||
|                 images.push(data.imageURL); | ||||
|             }; | ||||
|             eventSource.onerror = (event) => { | ||||
|                 eventSource.close(); | ||||
|             } | ||||
|             imageFeed.onerror = () => { | ||||
|                 imageFeed.classList.add("hidden"); | ||||
|             }    | ||||
|             initES(); | ||||
|             setInterval(() => { | ||||
|                 if (images.length > 0) { | ||||
|                     imageFeed.classList.remove("hidden"); | ||||
|                     imageFeed.src = images.shift(); | ||||
|                 } else if(imageFeed) { | ||||
|                     imageFeed.remove(); | ||||
|                     initES(); | ||||
|                 } | ||||
|             }, 7000); | ||||
|  | ||||
|   | ||||
| @@ -84,8 +84,8 @@ if (window.markdownit) { | ||||
|             .replaceAll('"></video>', '"></video>') | ||||
|             .replaceAll('<audio controls src="', '<audio controls src="') | ||||
|             .replaceAll('"></audio>', '"></audio>') | ||||
|             .replaceAll('<iframe type="text/html" src="', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="390" width="640" src="') | ||||
|             .replaceAll('"></iframe>', `?enablejsapi=1&origin=${new URL(location.href).origin}"></iframe>`) | ||||
|             .replaceAll('<iframe type="text/html" src="', '<iframe type="text/html" frameborder="0" allow="fullscreen" height="224" width="400" src="') | ||||
|             .replaceAll('"></iframe>', `?enablejsapi=1"></iframe>`) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -95,7 +95,7 @@ function render_reasoning(reasoning, final = false) { | ||||
|     </div>` : ""; | ||||
|     return `<div class="reasoning_body"> | ||||
|         <div class="reasoning_title"> | ||||
|            <strong>Reasoning <i class="brain">🧠</i>:</strong> ${escapeHtml(reasoning.status)} | ||||
|            <strong>${reasoning.label ? reasoning.label :'Reasoning <i class="brain">🧠</i>'}:</strong> ${escapeHtml(reasoning.status)} | ||||
|         </div> | ||||
|         ${inner_text} | ||||
|     </div>`; | ||||
| @@ -893,6 +893,8 @@ async function add_message_chunk(message, message_id, provider, scroll, finish_m | ||||
|             message_storage[message_id] = ""; | ||||
|         } else if (message.status) { | ||||
|             reasoning_storage[message_id].status = message.status; | ||||
|         } if (message.label) { | ||||
|             reasoning_storage[message_id].label = message.label; | ||||
|         } if (message.token) { | ||||
|             reasoning_storage[message_id].text += message.token; | ||||
|         } | ||||
| @@ -999,7 +1001,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi | ||||
|             content_map.inner.innerHTML = html; | ||||
|             highlight(content_map.inner); | ||||
|         } | ||||
|         if (message_storage[message_id]) { | ||||
|         if (message_storage[message_id] || reasoning_storage[message_id]) { | ||||
|             const message_provider = message_id in provider_storage ? provider_storage[message_id] : null; | ||||
|             let usage = {}; | ||||
|             if (usage_storage[message_id]) { | ||||
| @@ -1064,7 +1066,35 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi | ||||
|         } | ||||
|         // Reload conversation if no error | ||||
|         if (!error_storage[message_id] && reloadConversation) { | ||||
|             await safe_load_conversation(window.conversation_id, scroll); | ||||
|             if(await safe_load_conversation(window.conversation_id, scroll)) { | ||||
|                 const new_message = Array.from(document.querySelectorAll(".message")).at(-1); | ||||
|                 const new_media = new_message?.querySelector("audio, video, iframe"); | ||||
|                 if (new_media) { | ||||
|                     if (new_media.tagName == "IFRAME") { | ||||
|                         if (YT) { | ||||
|                             async function onPlayerReady(event) { | ||||
|                                 if (scroll) { | ||||
|                                     await lazy_scroll_to_bottom(); | ||||
|                                 } | ||||
|                                 event.target.setVolume(100); | ||||
|                                 event.target.playVideo(); | ||||
|                             } | ||||
|                             player = new YT.Player(new_media, { | ||||
|                                 events: { | ||||
|                                     'onReady': onPlayerReady, | ||||
|                                 } | ||||
|                             }); | ||||
|                         } | ||||
|                     } else { | ||||
|                         setTimeout(async () => { | ||||
|                             if (scroll) { | ||||
|                                 await lazy_scroll_to_bottom(); | ||||
|                             } | ||||
|                             new_media.play(); | ||||
|                         }, 2000); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         let cursorDiv = message_el.querySelector(".cursor"); | ||||
|         if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv); | ||||
| @@ -1121,6 +1151,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi | ||||
|             api_key: api_key, | ||||
|             api_base: api_base, | ||||
|             ignored: ignored, | ||||
|             aspect_ratio: window.innerHeight > window.innerWidth ? "9:16" : "16:9", | ||||
|             ...extra_parameters | ||||
|         }, Object.values(image_storage), message_id, scroll, finish_message); | ||||
|     } catch (e) { | ||||
| @@ -1494,7 +1525,7 @@ const load_conversation = async (conversation_id, scroll=true) => { | ||||
|     [...new Set(providers)].forEach(async (provider) => { | ||||
|         await load_provider_parameters(provider); | ||||
|     }); | ||||
|     register_message_buttons(); | ||||
|     await register_message_buttons(); | ||||
|     highlight(message_box); | ||||
|     regenerate_button.classList.remove("regenerate-hidden"); | ||||
|  | ||||
| @@ -1504,6 +1535,7 @@ const load_conversation = async (conversation_id, scroll=true) => { | ||||
|         setTimeout(() => { | ||||
|             message_box.scrollTop = message_box.scrollHeight; | ||||
|         }, 500); | ||||
|         return true; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @@ -1516,7 +1548,7 @@ async function safe_load_conversation(conversation_id, scroll=true) { | ||||
|         } | ||||
|     } | ||||
|     if (!is_running) { | ||||
|         load_conversation(conversation_id, scroll); | ||||
|         return await load_conversation(conversation_id, scroll); | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -9,7 +9,8 @@ import shutil | ||||
| import random | ||||
| import datetime | ||||
| import tempfile | ||||
| from flask import Flask, Response, request, jsonify, render_template, send_from_directory | ||||
| from flask import Flask, Response, redirect, request, jsonify, render_template, send_from_directory | ||||
| from werkzeug.exceptions import NotFound | ||||
| from typing import Generator | ||||
| from pathlib import Path | ||||
| from urllib.parse import quote_plus | ||||
| @@ -24,7 +25,7 @@ from ...tools.run_tools import iter_run_tools | ||||
| from ...errors import ProviderNotFoundError | ||||
| from ...image import is_allowed_extension | ||||
| from ...cookies import get_cookies_dir | ||||
| from ...image.copy_images import secure_filename | ||||
| from ...image.copy_images import secure_filename, get_source_url | ||||
| from ... import ChatCompletion | ||||
| from ... import models | ||||
| from .api import Api | ||||
| @@ -333,46 +334,18 @@ class Backend_Api(Api): | ||||
|         def get_media(bucket_id, filename, dirname: str = None): | ||||
|             bucket_dir = get_bucket_dir(secure_filename(bucket_id), secure_filename(dirname)) | ||||
|             media_dir = os.path.join(bucket_dir, "media") | ||||
|             if os.path.exists(media_dir): | ||||
|             try: | ||||
|                 return send_from_directory(os.path.abspath(media_dir), filename) | ||||
|             return "Not found", 404 | ||||
|             except NotFound: | ||||
|                 source_url = get_source_url(request.query_string.decode()) | ||||
|                 if source_url is not None: | ||||
|                     return redirect(source_url) | ||||
|                 raise | ||||
|  | ||||
|         @app.route('/files/<dirname>/<bucket_id>/media/<filename>', methods=['GET']) | ||||
|         def get_media_sub(dirname, bucket_id, filename): | ||||
|             return get_media(bucket_id, filename, dirname) | ||||
|  | ||||
|         @app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT']) | ||||
|         def upload_file(bucket_id, filename, dirname: str = None): | ||||
|             bucket_dir = secure_filename(bucket_id if dirname is None else dirname) | ||||
|             bucket_dir = get_bucket_dir(bucket_dir) | ||||
|             filename = secure_filename(filename) | ||||
|             bucket_path = Path(bucket_dir) | ||||
|             if dirname is not None: | ||||
|                 bucket_path = bucket_path / secure_filename(bucket_id) | ||||
|  | ||||
|             if not supports_filename(filename): | ||||
|                 return jsonify({"error": {"message": f"File type not allowed"}}), 400 | ||||
|  | ||||
|             if not bucket_path.exists(): | ||||
|                 bucket_path.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|             try: | ||||
|                 file_path = bucket_path / filename | ||||
|                 file_data = request.get_data() | ||||
|                 if not file_data: | ||||
|                     return jsonify({"error": {"message": "No file data received"}}), 400 | ||||
|  | ||||
|                 with file_path.open('wb') as f: | ||||
|                     f.write(file_data) | ||||
|  | ||||
|                 return jsonify({"message": f"File '{filename}' uploaded successfully to bucket '{bucket_id}'"}), 201 | ||||
|             except Exception as e: | ||||
|                 return jsonify({"error": {"message": f"Error uploading file: {str(e)}"}}), 500 | ||||
|      | ||||
|         @app.route('/backend-api/v2/files/<bucket_id>/<dirname>/<filename>', methods=['PUT']) | ||||
|         def upload_file_sub(bucket_id, filename, dirname): | ||||
|             return upload_file(bucket_id, filename, dirname) | ||||
|  | ||||
|         @app.route('/backend-api/v2/upload_cookies', methods=['POST']) | ||||
|         def upload_cookies(): | ||||
|             file = None | ||||
|   | ||||
| @@ -16,7 +16,7 @@ except ImportError: | ||||
| from ..typing import ImageType, Union, Image | ||||
| from ..errors import MissingRequirementsError | ||||
|  | ||||
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} | ||||
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'webm', 'svg', 'mp3', 'wav', 'mp4', 'flac', 'opus', 'ogg', 'mkv'} | ||||
|  | ||||
| EXTENSIONS_MAP: dict[str, str] = { | ||||
|     "image/png": "png", | ||||
| @@ -259,6 +259,27 @@ def to_input_audio(audio: ImageType, filename: str = None) -> str: | ||||
|         } | ||||
|     raise ValueError("Invalid input audio") | ||||
|  | ||||
| def use_aspect_ratio(extra_data: dict, aspect_ratio: str) -> Image: | ||||
|     if aspect_ratio == "1:1": | ||||
|         extra_data = { | ||||
|             "width": 1024, | ||||
|             "height": 1024, | ||||
|             **extra_data | ||||
|         } | ||||
|     elif aspect_ratio == "16:9": | ||||
|         extra_data = { | ||||
|             "width": 800, | ||||
|             "height": 512, | ||||
|             **extra_data | ||||
|         } | ||||
|     elif aspect_ratio == "9:16": | ||||
|         extra_data = { | ||||
|             "width": 512, | ||||
|             "height": 800, | ||||
|             **extra_data | ||||
|         } | ||||
|     return extra_data | ||||
|  | ||||
| class ImageDataResponse(): | ||||
|     def __init__( | ||||
|         self, | ||||
|   | ||||
| @@ -11,7 +11,7 @@ from aiohttp import ClientSession, ClientError | ||||
|  | ||||
| from ..typing import Optional, Cookies | ||||
| from ..requests.aiohttp import get_connector, StreamResponse | ||||
| from ..image import EXTENSIONS_MAP | ||||
| from ..image import EXTENSIONS_MAP, ALLOWED_EXTENSIONS | ||||
| from ..tools.files import get_bucket_dir | ||||
| from ..providers.response import ImageResponse, AudioResponse, VideoResponse | ||||
| from ..Provider.template import BackendApi | ||||
| @@ -21,9 +21,9 @@ from .. import debug | ||||
| # Directory for storing generated images | ||||
| images_dir = "./generated_images" | ||||
|  | ||||
| def get_media_extension(image: str) -> str: | ||||
|     """Extract image extension from URL or filename, default to .jpg""" | ||||
|     match = re.search(r"\.(jpe?g|png|webp|mp4|mp3|wav)[?$]", image, re.IGNORECASE) | ||||
| def get_media_extension(media: str) -> str: | ||||
|     """Extract media file extension from URL or filename""" | ||||
|     match = re.search(r"\.(jpe?g|png|gif|svg|webp|webm|mp4|mp3|wav|flac|opus|ogg|mkv)(?:\?|$)", media, re.IGNORECASE) | ||||
|     return f".{match.group(1).lower()}" if match else "" | ||||
|  | ||||
| def ensure_images_dir(): | ||||
| @@ -51,8 +51,10 @@ def secure_filename(filename: str) -> str: | ||||
|  | ||||
| async def save_response_media(response: StreamResponse, prompt: str): | ||||
|     content_type = response.headers["content-type"] | ||||
|     if content_type in EXTENSIONS_MAP or content_type.startswith("audio/"): | ||||
|     if content_type in EXTENSIONS_MAP or content_type.startswith("audio/") or content_type.startswith("video/"): | ||||
|         extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3") | ||||
|         if extension not in ALLOWED_EXTENSIONS: | ||||
|             raise ValueError(f"Unsupported media type: {content_type}") | ||||
|         bucket_id = str(uuid.uuid4()) | ||||
|         dirname = str(int(time.time())) | ||||
|         bucket_dir = get_bucket_dir(bucket_id, dirname) | ||||
| @@ -135,11 +137,14 @@ async def copy_media( | ||||
|                 if target is None and not os.path.splitext(target_path)[1]: | ||||
|                     with open(target_path, "rb") as f: | ||||
|                         file_header = f.read(12) | ||||
|                     detected_type = is_accepted_format(file_header) | ||||
|                     if detected_type: | ||||
|                         new_ext = f".{detected_type.split('/')[-1]}" | ||||
|                         os.rename(target_path, f"{target_path}{new_ext}") | ||||
|                         target_path = f"{target_path}{new_ext}" | ||||
|                     try: | ||||
|                         detected_type = is_accepted_format(file_header) | ||||
|                         if detected_type: | ||||
|                             new_ext = f".{detected_type.split('/')[-1]}" | ||||
|                             os.rename(target_path, f"{target_path}{new_ext}") | ||||
|                             target_path = f"{target_path}{new_ext}" | ||||
|                     except ValueError: | ||||
|                         pass | ||||
|  | ||||
|                 # Build URL with safe encoding | ||||
|                 url_filename = quote(os.path.basename(target_path)) | ||||
|   | ||||
| @@ -178,11 +178,13 @@ class Reasoning(ResponseType): | ||||
|     def __init__( | ||||
|             self, | ||||
|             token: Optional[str] = None, | ||||
|             label: Optional[str] = None, | ||||
|             status: Optional[str] = None, | ||||
|             is_thinking: Optional[str] = None | ||||
|         ) -> None: | ||||
|         """Initialize with token, status, and thinking state.""" | ||||
|         self.token = token | ||||
|         self.label = label | ||||
|         self.status = status | ||||
|         self.is_thinking = is_thinking | ||||
|  | ||||
| @@ -203,6 +205,8 @@ class Reasoning(ResponseType): | ||||
|  | ||||
|     def get_dict(self) -> Dict: | ||||
|         """Return a dictionary representation of the reasoning.""" | ||||
|         if self.label is not None: | ||||
|             return {"label": self.label, "status": self.status} | ||||
|         if self.is_thinking is None: | ||||
|             if self.status is None: | ||||
|                 return {"token": self.token} | ||||
| @@ -248,16 +252,22 @@ class YouTube(HiddenResponse): | ||||
|             for id in self.ids | ||||
|         ])) | ||||
|  | ||||
| class Audio(ResponseType): | ||||
|     def __init__(self, data: bytes) -> None: | ||||
| class AudioResponse(ResponseType): | ||||
|     def __init__(self, data: Union[bytes, str]) -> None: | ||||
|         """Initialize with audio data bytes.""" | ||||
|         self.data = data | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|     def to_uri(self) -> str: | ||||
|         if isinstance(self.data, str): | ||||
|             return self.data | ||||
|         """Return audio data as a base64-encoded data URI.""" | ||||
|         data_base64 = base64.b64encode(self.data).decode() | ||||
|         return f"data:audio/mpeg;base64,{data_base64}" | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         """Return audio as html element.""" | ||||
|         return f'<audio controls src="{self.to_uri()}"></audio>' | ||||
|  | ||||
| class BaseConversation(ResponseType): | ||||
|     def __str__(self) -> str: | ||||
|         """Return an empty string by default.""" | ||||
| @@ -282,7 +292,7 @@ class RequestLogin(HiddenResponse): | ||||
|         """Return formatted login link as a string.""" | ||||
|         return format_link(self.login_url, f"[Login to {self.label}]") + "\n\n" | ||||
|  | ||||
| class ImageResponse(ResponseType): | ||||
| class MediaResponse(ResponseType): | ||||
|     def __init__( | ||||
|         self, | ||||
|         images: Union[str, List[str]], | ||||
| @@ -294,10 +304,6 @@ class ImageResponse(ResponseType): | ||||
|         self.alt = alt | ||||
|         self.options = options | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         """Return images as markdown.""" | ||||
|         return format_images_markdown(self.images, self.alt, self.get("preview")) | ||||
|  | ||||
|     def get(self, key: str) -> any: | ||||
|         """Get an option value by key.""" | ||||
|         return self.options.get(key) | ||||
| @@ -306,6 +312,16 @@ class ImageResponse(ResponseType): | ||||
|         """Return images as a list.""" | ||||
|         return [self.images] if isinstance(self.images, str) else self.images | ||||
|  | ||||
| class ImageResponse(MediaResponse): | ||||
|     def __str__(self) -> str: | ||||
|         """Return images as markdown.""" | ||||
|         return format_images_markdown(self.images, self.alt, self.get("preview")) | ||||
|  | ||||
| class VideoResponse(MediaResponse): | ||||
|     def __str__(self) -> str: | ||||
|         """Return videos as html elements.""" | ||||
|         return "\n".join([f'<video controls src="{video}"></video>' for video in self.get_list()]) | ||||
|  | ||||
| class ImagePreview(ImageResponse): | ||||
|     def __str__(self) -> str: | ||||
|         """Return an empty string for preview.""" | ||||
|   | ||||
| @@ -4,7 +4,7 @@ import random | ||||
|  | ||||
| from ..typing import Type, List, CreateResult, Messages, AsyncResult | ||||
| from .types import BaseProvider, BaseRetryProvider, ProviderType | ||||
| from .response import ImageResponse, ProviderInfo | ||||
| from .response import MediaResponse, ProviderInfo | ||||
| from .. import debug | ||||
| from ..errors import RetryProviderError, RetryNoProviderError | ||||
|  | ||||
| @@ -59,7 +59,7 @@ class IterListProvider(BaseRetryProvider): | ||||
|                 for chunk in response: | ||||
|                     if chunk: | ||||
|                         yield chunk | ||||
|                         if isinstance(chunk, (str, ImageResponse)): | ||||
|                         if isinstance(chunk, (str, MediaResponse)): | ||||
|                             started = True | ||||
|                 if started: | ||||
|                     return | ||||
| @@ -94,7 +94,7 @@ class IterListProvider(BaseRetryProvider): | ||||
|                     async for chunk in response: | ||||
|                         if chunk: | ||||
|                             yield chunk | ||||
|                             if isinstance(chunk, (str, ImageResponse)): | ||||
|                             if isinstance(chunk, (str, MediaResponse)): | ||||
|                                 started = True | ||||
|                 elif response: | ||||
|                     response = await response | ||||
|   | ||||
| @@ -64,6 +64,7 @@ class StreamResponse: | ||||
|         inner: Response = await self.inner | ||||
|         self.inner = inner | ||||
|         self.url = inner.url | ||||
|         self.method = inner.request.method | ||||
|         self.request = inner.request | ||||
|         self.status: int = inner.status_code | ||||
|         self.reason: str = inner.reason | ||||
|   | ||||
| @@ -118,9 +118,10 @@ def supports_filename(filename: str): | ||||
|             return True | ||||
|     return False | ||||
|  | ||||
| def get_bucket_dir(bucket_id: str): | ||||
|     bucket_dir = os.path.join(get_cookies_dir(), "buckets", bucket_id) | ||||
|     return bucket_dir | ||||
| def get_bucket_dir(bucket_id: str, dirname: str = None): | ||||
|     if dirname is None: | ||||
|         return os.path.join(get_cookies_dir(), "buckets", bucket_id) | ||||
|     return os.path.join(get_cookies_dir(), "buckets", dirname, bucket_id) | ||||
|  | ||||
| def get_buckets(): | ||||
|     buckets_dir = os.path.join(get_cookies_dir(), "buckets") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 hlohaus
					hlohaus