Add Edge as Browser for nodriver

Fix for RetryProviders doesn't retry
Add retry and continue for DuckDuckGo provider
Add cache for Cloudflare provider
Add cache for prompts on gui home
Add scroll to bottom checkbox in gui
Improve prompts on home gui
Fix response content type in api for files
This commit is contained in:
Heiner Lohaus
2025-01-05 17:02:15 +01:00
parent 9fd4e3c755
commit 12c413fd2e
16 changed files with 364 additions and 139 deletions

View File

@@ -2,12 +2,14 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from pathlib import Path
from ..typing import AsyncResult, Messages, Cookies from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
from ..providers.response import FinishReason from ..providers.response import FinishReason
from ..cookies import get_cookies_dir
from ..errors import ResponseStatusError, ModelNotFoundError from ..errors import ResponseStatusError, ModelNotFoundError
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin): class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
@@ -19,7 +21,7 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True
default_model = "@cf/meta/llama-3.1-8b-instruct" default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
model_aliases = { model_aliases = {
"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16", "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8", "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
@@ -33,6 +35,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
} }
_args: dict = None _args: dict = None
@classmethod
def get_cache_file(cls) -> Path:
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@classmethod @classmethod
def get_models(cls) -> str: def get_models(cls) -> str:
if not cls.models: if not cls.models:
@@ -67,7 +73,11 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
timeout: int = 300, timeout: int = 300,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
cache_file = cls.get_cache_file()
if cls._args is None: if cls._args is None:
if cache_file.exists():
with cache_file.open("r") as f:
cls._args = json.load(f)
if has_nodriver: if has_nodriver:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
else: else:
@@ -93,6 +103,8 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response) await raise_for_status(response)
except ResponseStatusError: except ResponseStatusError:
cls._args = None cls._args = None
if cache_file.exists():
cache_file.unlink()
raise raise
reason = None reason = None
async for line in response.iter_lines(): async for line in response.iter_lines():
@@ -110,3 +122,6 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
continue continue
if reason is not None: if reason is not None:
yield FinishReason(reason) yield FinishReason(reason)
with cache_file.open("w") as f:
json.dump(cls._args, f)

View File

@@ -1,14 +1,18 @@
from __future__ import annotations from __future__ import annotations
from aiohttp import ClientSession, ClientTimeout, ClientError import asyncio
from aiohttp import ClientSession, ClientTimeout, ClientError, ClientResponseError
import json import json
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation
from .helper import format_prompt from ..providers.response import FinishReason
from .. import debug
class Conversation(BaseConversation): class Conversation(BaseConversation):
vqd: str = None vqd: str = None
message_history: Messages = [] message_history: Messages = []
cookies: dict = {}
def __init__(self, model: str): def __init__(self, model: str):
self.model = model self.model = model
@@ -65,20 +69,24 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
conversation: Conversation = None, conversation: Conversation = None,
return_conversation: bool = False, return_conversation: bool = False,
proxy: str = None, proxy: str = None,
headers: dict = {
"Content-Type": "application/json",
},
cookies: dict = None,
max_retries: int = 3,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
headers = { if cookies is None and conversation is not None:
"Content-Type": "application/json", cookies = conversation.cookies
} async with ClientSession(headers=headers, cookies=cookies, timeout=ClientTimeout(total=30)) as session:
async with ClientSession(headers=headers, timeout=ClientTimeout(total=30)) as session:
# Fetch VQD token # Fetch VQD token
if conversation is None: if conversation is None:
conversation = Conversation(model) conversation = Conversation(model)
conversation.cookies = session.cookie_jar
if conversation.vqd is None:
conversation.vqd = await cls.fetch_vqd(session) conversation.vqd = await cls.fetch_vqd(session)
headers["x-vqd-4"] = conversation.vqd if conversation.vqd is not None:
headers["x-vqd-4"] = conversation.vqd
if return_conversation: if return_conversation:
yield conversation yield conversation
@@ -97,15 +105,33 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.api_endpoint, headers=headers, json=payload, proxy=proxy) as response: async with session.post(cls.api_endpoint, headers=headers, json=payload, proxy=proxy) as response:
conversation.vqd = response.headers.get("x-vqd-4") conversation.vqd = response.headers.get("x-vqd-4")
response.raise_for_status() response.raise_for_status()
reason = None
async for line in response.content: async for line in response.content:
line = line.decode("utf-8").strip() line = line.decode("utf-8").strip()
if line.startswith("data:"): if line.startswith("data:"):
try: try:
message = json.loads(line[5:].strip()) message = json.loads(line[5:].strip())
if "message" in message: if "message" in message and message["message"]:
yield message["message"] yield message["message"]
reason = "max_tokens"
elif message.get("message") == '':
reason = "stop"
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
if reason is not None:
yield FinishReason(reason)
except ClientResponseError as e:
if e.code in (400, 429) and max_retries > 0:
debug.log(f"Retry: max_retries={max_retries}, wait={512 - max_retries * 48}: {e}")
await asyncio.sleep(512 - max_retries * 48)
is_started = False
async for chunk in cls.create_async_generator(model, messages, conversation, return_conversation, max_retries=max_retries-1, **kwargs):
if chunk:
yield chunk
is_started = True
if is_started:
return
raise e
except ClientError as e: except ClientError as e:
raise Exception(f"HTTP ClientError occurred: {e}") raise Exception(f"HTTP ClientError occurred: {e}")
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@@ -137,7 +137,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
else: else:
is_special = True is_special = True
debug.log(f"Special token: {is_special}") debug.log(f"Special token: {is_special}")
yield FinishReason("stop" if is_special else "max_tokens", actions=["variant"] if is_special else ["continue", "variant"]) yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
else: else:
if response.headers["content-type"].startswith("image/"): if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()])) base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))

View File

@@ -105,11 +105,11 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
_expires: int = None _expires: int = None
@classmethod @classmethod
async def on_auth_async(cls, **kwargs) -> AuthResult: async def on_auth_async(cls, **kwargs) -> AsyncIterator:
if cls.needs_auth: if cls.needs_auth:
async for _ in cls.login(): async for chunk in cls.login():
pass yield chunk
return AuthResult( yield AuthResult(
api_key=cls._api_key, api_key=cls._api_key,
cookies=cls._cookies or RequestConfig.cookies or {}, cookies=cls._cookies or RequestConfig.cookies or {},
headers=cls._headers or RequestConfig.headers or cls.get_default_headers(), headers=cls._headers or RequestConfig.headers or cls.get_default_headers(),
@@ -174,7 +174,8 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"use_case": "multimodal" "use_case": "multimodal"
} }
# Post the image data to the service and get the image data # Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=auth_result.headers) as response: headers = auth_result.headers if hasattr(auth_result, "headers") else None
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
cls._update_request_args(auth_result, session) cls._update_request_args(auth_result, session)
await raise_for_status(response, "Create file failed") await raise_for_status(response, "Create file failed")
image_data = { image_data = {
@@ -360,7 +361,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
f"{cls.url}/backend-anon/sentinel/chat-requirements" f"{cls.url}/backend-anon/sentinel/chat-requirements"
if cls._api_key is None else if cls._api_key is None else
f"{cls.url}/backend-api/sentinel/chat-requirements", f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"p": None if auth_result.proof_token is None else get_requirements_token(auth_result.proof_token)}, json={"p": None if not getattr(auth_result, "proof_token") else get_requirements_token(auth_result.proof_token)},
headers=cls._headers headers=cls._headers
) as response: ) as response:
if response.status == 401: if response.status == 401:
@@ -386,7 +387,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
proofofwork = generate_proof_token( proofofwork = generate_proof_token(
**chat_requirements["proofofwork"], **chat_requirements["proofofwork"],
user_agent=auth_result.headers.get("user-agent"), user_agent=auth_result.headers.get("user-agent"),
proof_token=auth_result.proof_token proof_token=getattr(auth_result, "proof_token")
) )
[debug.log(text) for text in ( [debug.log(text) for text in (
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}", #f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",

View File

@@ -41,7 +41,7 @@ from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthErr
from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__ from g4f.Provider import ProviderType, ProviderUtils, __providers__
from g4f.gui import get_gui_app from g4f.gui import get_gui_app
from g4f.tools.files import supports_filename, get_streaming from g4f.tools.files import supports_filename, get_async_streaming
from .stubs import ( from .stubs import (
ChatCompletionsConfig, ImageGenerationConfig, ChatCompletionsConfig, ImageGenerationConfig,
ProviderResponseModel, ModelResponseModel, ProviderResponseModel, ModelResponseModel,
@@ -436,7 +436,8 @@ class Api:
event_stream = "text/event-stream" in request.headers.get("accept", "") event_stream = "text/event-stream" in request.headers.get("accept", "")
if not os.path.isdir(bucket_dir): if not os.path.isdir(bucket_dir):
return ErrorResponse.from_message("Bucket dir not found", 404) return ErrorResponse.from_message("Bucket dir not found", 404)
return StreamingResponse(get_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream), media_type="text/plain") return StreamingResponse(get_async_streaming(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream),
media_type="text/event-stream" if event_stream else "text/plain")
@self.app.post("/v1/files/{bucket_id}", responses={ @self.app.post("/v1/files/{bucket_id}", responses={
HTTP_200_OK: {"model": UploadResponseModel} HTTP_200_OK: {"model": UploadResponseModel}

View File

@@ -103,17 +103,29 @@
z-index: -1; z-index: -1;
} }
iframe.stream { .stream-widget {
max-height: 0; max-height: 0;
transition: max-height 0.15s ease-out; transition: max-height 0.15s ease-out;
color: var(--colour-5);
overflow: scroll;
text-align: left;
} }
iframe.stream.show { .stream-widget.show {
max-height: 1000px; max-height: 1000px;
height: 1000px; height: 1000px;
transition: max-height 0.25s ease-in; transition: max-height 0.25s ease-in;
background: rgba(255,255,255,0.7); background: rgba(255,255,255,0.7);
border-top: 2px solid rgba(255,255,255,0.5); border-top: 2px solid rgba(255,255,255,0.5);
padding: 20px;
}
.stream-widget img {
max-width: 320px;
}
#stream-container {
width: 100%;
} }
.description { .description {
@@ -207,32 +219,87 @@
<p>Powered by the G4F framework</p> <p>Powered by the G4F framework</p>
</div> </div>
<iframe id="stream-widget" class="stream" frameborder="0"></iframe> <iframe class="stream-widget" frameborder="0"></iframe>
</div> </div>
<script> <script>
const iframe = document.getElementById('stream-widget');"" const iframe = document.querySelector('.stream-widget');
let search = (navigator.language == "de" ? "news in deutschland" : navigator.language == "en" ? "world news" : navigator.language); const rand_idx = Math.floor(Math.random() * 9)
if (Math.floor(Math.random() * 6) % 2 == 0) { if (rand_idx < 3) {
search = "xtekky/gpt4free releases"; search = "xtekky/gpt4free releases";
} else if (rand_idx < 5) {
search = "developer news";
} else {
search = (navigator.language == "de" ? "news in deutsch" : navigator.language == "en" ? "world news" : `news in ${navigator.language}`);
} }
const url = "/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=" + search; const summary_prompt = "Give a summary of the provided text in ```markdown``` format. Add maybe one or more images.";
const url = `/backend-api/v2/create?prompt=${summary_prompt}&stream=1&web_search=${search}`;
iframe.src = url; iframe.src = url;
setTimeout(()=>iframe.classList.add('show'), 3000); const message = "Loading...";
setTimeout(()=>{
iframe.classList.add('show');
const iframeDocument = iframe.contentDocument || iframe.contentWindow?.document;
if (iframeDocument) {
const iframeBody = iframeDocument.querySelector("body");
if (iframeBody) {
iframeBody.innerHTML = message + iframeBody.innerHTML;
}
} else {
iframe.parentElement.removeChild(iframe);
}
}, 1000);
function filterMarkdown(text, allowedTypes = null, defaultValue = null) {
const match = text.match(/```(.+)\n(?<code>[\s\S]+?)(\n```|$)/);
if (match) {
const [, type, code] = match;
if (!allowedTypes || allowedTypes.includes(type)) {
return code;
}
}
return defaultValue;
}
let scroll_to_bottom_callback = () => {
const i = document.querySelector(".stream-widget");
if (!i.contentWindow || !i.contentDocument) {
return;
}
clientHeight = i.contentDocument.body.scrollHeight;
i.contentWindow.scrollTo(0, clientHeight);
if (clientHeight - i.contentWindow.scrollY < 2 * clientHeight) {
setTimeout(scroll_to_bottom_callback, 1000);
}
};
setTimeout(scroll_to_bottom_callback, 1000);
iframe.onload = () => { iframe.onload = () => {
const iframeDocument = iframe.contentDocument || iframe.contentWindow.document; const iframeDocument = iframe.contentDocument || iframe.contentWindow.document;
const iframeBody = iframeDocument.querySelector("body");
const iframeContent = iframeDocument.querySelector("pre"); const iframeContent = iframeDocument.querySelector("pre");
let iframeText = iframeContent.innerHTML;
const markdown = window.markdownit(); const markdown = window.markdownit();
iframeBody.innerHTML = markdown.render(iframeContent.innerHTML); const iframeContainer = document.querySelector(".container");
iframe.remove()
if (iframeText.indexOf('"error"') < 0) {
iframeContainer.innerHTML += `<div class="stream-widget show">${markdown.render(filterMarkdown(iframeText, "markdown", iframeText))}</div>`;
}
scroll_to_bottom_callback = () => null;
} }
(async () => { (async () => {
const prompt = ` const today = new Date().toJSON().slice(0, 10);
const max = 100;
const cache_id = Math.floor(Math.random() * max);
let prompt;
if (cache_id % 2 == 0) {
prompt = `
Today is ${new Date().toJSON().slice(0, 10)}. Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date). Create a single-page HTML screensaver reflecting the current season (based on the date).
For example, if it's Spring, it might use floral patterns or pastel colors. Avoid using any text.`;
Avoid using any text. Consider a subtle animation or transition effect.`; } else {
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html`) prompt = `Create a single-page HTML screensaver. Avoid using any text.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
}
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html&cache=${cache_id}`);
const text = await response.text() const text = await response.text()
background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`; background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`;
const gradient = document.querySelector('.gradient'); const gradient = document.querySelector('.gradient');

View File

@@ -239,7 +239,8 @@
<button class="hide-input"> <button class="hide-input">
<i class="fa-solid fa-angles-down"></i> <i class="fa-solid fa-angles-down"></i>
</button> </button>
<span class="text"></span> <input type="checkbox" id="agree" name="agree" value="yes" checked>
<label for="agree" class="text" onclick="this.innerText='';">Scroll to bottom</label>
</div> </div>
<div class="stop_generating stop_generating-hidden"> <div class="stop_generating stop_generating-hidden">
<button id="cancelButton"> <button id="cancelButton">

View File

@@ -516,7 +516,11 @@ body:not(.white) a:visited{
padding: 6px 6px; padding: 6px 6px;
} }
#input-count .text { input-count .text {
min-width: 12px
}
#input-count .text, #input-count input {
padding: 0 4px; padding: 0 4px;
} }
@@ -793,7 +797,7 @@ select {
appearance: none; appearance: none;
width: 100%; width: 100%;
height: 20px; height: 20px;
background: var(--accent); background: var(--colour-2);
outline: none; outline: none;
transition: opacity .2s; transition: opacity .2s;
border-radius: 10px; border-radius: 10px;
@@ -859,11 +863,18 @@ select:hover,
font-size: 15px; font-size: 15px;
width: 100%; width: 100%;
color: var(--colour-3); color: var(--colour-3);
min-height: 49px;
height: 59px; height: 59px;
outline: none; outline: none;
padding: var(--inner-gap) var(--section-gap); padding: var(--inner-gap) var(--section-gap);
resize: vertical; resize: vertical;
min-height: 59px;
transition: max-height 0.15s ease-out;
}
#systemPrompt:focus {
min-height: 200px;
max-height: 1000px;
transition: max-height 0.25s ease-in;
} }
.pswp { .pswp {
@@ -929,6 +940,9 @@ select:hover,
body:not(.white) .gradient{ body:not(.white) .gradient{
display: block; display: block;
} }
.settings .label, form .label, .settings label, form label {
min-width: 200px;
}
} }
.input-box { .input-box {
@@ -1354,7 +1368,6 @@ form .field.saved .fa-xmark {
.settings .label, form .label, .settings label, form label { .settings .label, form .label, .settings label, form label {
font-size: 15px; font-size: 15px;
margin-left: var(--inner-gap); margin-left: var(--inner-gap);
min-width: 200px;
} }
.settings .label, form .label { .settings .label, form .label {

View File

@@ -511,7 +511,9 @@ const prepare_messages = (messages, message_index = -1, do_continue = false) =>
// Include only not regenerated messages // Include only not regenerated messages
if (new_message && !new_message.regenerate) { if (new_message && !new_message.regenerate) {
// Remove generated images from history // Remove generated images from history
new_message.content = filter_message(new_message.content); if (new_message.content) {
new_message.content = filter_message(new_message.content);
}
// Remove internal fields // Remove internal fields
delete new_message.provider; delete new_message.provider;
delete new_message.synthesize; delete new_message.synthesize;
@@ -658,7 +660,7 @@ async function load_provider_parameters(provider) {
} }
} }
async function add_message_chunk(message, message_id, provider) { async function add_message_chunk(message, message_id, provider, scroll) {
content_map = content_storage[message_id]; content_map = content_storage[message_id];
if (message.type == "conversation") { if (message.type == "conversation") {
const conversation = await get_conversation(window.conversation_id); const conversation = await get_conversation(window.conversation_id);
@@ -698,7 +700,7 @@ async function add_message_chunk(message, message_id, provider) {
content_map.inner.innerHTML = markdown_render(message.preview); content_map.inner.innerHTML = markdown_render(message.preview);
} else if (message.type == "content") { } else if (message.type == "content") {
message_storage[message_id] += message.content; message_storage[message_id] += message.content;
update_message(content_map, message_id); update_message(content_map, message_id, null, scroll);
content_map.inner.style.height = ""; content_map.inner.style.height = "";
} else if (message.type == "log") { } else if (message.type == "log") {
let p = document.createElement("p"); let p = document.createElement("p");
@@ -709,9 +711,7 @@ async function add_message_chunk(message, message_id, provider) {
} else if (message.type == "title") { } else if (message.type == "title") {
title_storage[message_id] = message.title; title_storage[message_id] = message.title;
} else if (message.type == "login") { } else if (message.type == "login") {
update_message(content_map, message_id, message.login); update_message(content_map, message_id, message.login, scroll);
} else if (message.type == "login") {
update_message(content_map, message_id, message.login);
} else if (message.type == "finish") { } else if (message.type == "finish") {
finish_storage[message_id] = message.finish; finish_storage[message_id] = message.finish;
} else if (message.type == "parameters") { } else if (message.type == "parameters") {
@@ -734,8 +734,11 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
messages = prepare_messages(conversation.items, message_index, action=="continue"); messages = prepare_messages(conversation.items, message_index, action=="continue");
message_storage[message_id] = ""; message_storage[message_id] = "";
stop_generating.classList.remove("stop_generating-hidden"); stop_generating.classList.remove("stop_generating-hidden");
const scroll = true;
if (message_index == -1) { if (message_index > 0 && message_index + 1 < messages.length) {
scroll = false;
}
if (scroll) {
await lazy_scroll_to_bottom(); await lazy_scroll_to_bottom();
} }
@@ -780,7 +783,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
update_timeouts: [], update_timeouts: [],
message_index: message_index, message_index: message_index,
} }
if (message_index == -1) { if (scroll) {
await lazy_scroll_to_bottom(); await lazy_scroll_to_bottom();
} }
try { try {
@@ -801,7 +804,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
download_images: download_images, download_images: download_images,
api_key: api_key, api_key: api_key,
ignored: ignored, ignored: ignored,
}, files, message_id); }, files, message_id, scroll);
content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId)); content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId));
content_map.update_timeouts = []; content_map.update_timeouts = [];
if (!error_storage[message_id]) { if (!error_storage[message_id]) {
@@ -836,12 +839,12 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
); );
delete message_storage[message_id]; delete message_storage[message_id];
if (!error_storage[message_id]) { if (!error_storage[message_id]) {
await safe_load_conversation(window.conversation_id, message_index == -1); await safe_load_conversation(window.conversation_id, scroll);
} }
} }
let cursorDiv = message_el.querySelector(".cursor"); let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv); if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
if (message_index == -1) { if (scroll) {
await lazy_scroll_to_bottom(); await lazy_scroll_to_bottom();
} }
await safe_remove_cancel_button(); await safe_remove_cancel_button();
@@ -856,7 +859,7 @@ async function scroll_to_bottom() {
} }
async function lazy_scroll_to_bottom() { async function lazy_scroll_to_bottom() {
if (message_box.scrollHeight - message_box.scrollTop < 2 * message_box.clientHeight) { if (document.querySelector("#input-count input").checked) {
await scroll_to_bottom(); await scroll_to_bottom();
} }
} }
@@ -1013,6 +1016,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
if (newContent.startsWith("```")) { if (newContent.startsWith("```")) {
const index = str.indexOf("\n"); const index = str.indexOf("\n");
newContent = newContent.substring(index); newContent = newContent.substring(index);
} else if (newContent.startsWith("...")) {
newContent = " " + newContent.substring(3);
} }
if (newContent.startsWith(lastLine)) { if (newContent.startsWith(lastLine)) {
newContent = newContent.substring(lastLine.length); newContent = newContent.substring(lastLine.length);
@@ -1054,7 +1059,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
if (item.finish && item.finish.actions) { if (item.finish && item.finish.actions) {
actions = item.finish.actions actions = item.finish.actions
} }
if (!("continue" in actions)) { if (item.role == "assistant" && !actions.includes("continue")) {
let reason = "stop"; let reason = "stop";
// Read finish reason from conversation // Read finish reason from conversation
if (item.finish && item.finish.reason) { if (item.finish && item.finish.reason) {
@@ -1067,7 +1072,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
reason = "error"; reason = "error";
// Has an even number of start or end code tags // Has an even number of start or end code tags
} else if (buffer.split("```").length - 1 % 2 === 1) { } else if (buffer.split("```").length - 1 % 2 === 1) {
reason = "error"; reason = "length";
// Has a end token at the end // Has a end token at the end
} else if (lastLine.endsWith("```") || lastLine.endsWith(".") || lastLine.endsWith("?") || lastLine.endsWith("!") } else if (lastLine.endsWith("```") || lastLine.endsWith(".") || lastLine.endsWith("?") || lastLine.endsWith("!")
|| lastLine.endsWith('"') || lastLine.endsWith("'") || lastLine.endsWith(")") || lastLine.endsWith('"') || lastLine.endsWith("'") || lastLine.endsWith(")")
@@ -1152,7 +1157,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
highlight(message_box); highlight(message_box);
regenerate_button.classList.remove("regenerate-hidden"); regenerate_button.classList.remove("regenerate-hidden");
if (scroll) { if (document.querySelector("#input-count input").checked) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "smooth" }); message_box.scrollTo({ top: message_box.scrollHeight, behavior: "smooth" });
setTimeout(() => { setTimeout(() => {
@@ -1517,7 +1522,7 @@ function count_words_and_tokens(text, model) {
return `(${count_words(text)} words, ${count_chars(text)} chars, ${count_tokens(model, text)} tokens)`; return `(${count_words(text)} words, ${count_chars(text)} chars, ${count_tokens(model, text)} tokens)`;
} }
function update_message(content_map, message_id, content = null) { function update_message(content_map, message_id, content = null, scroll = true) {
content_map.update_timeouts.push(setTimeout(() => { content_map.update_timeouts.push(setTimeout(() => {
if (!content) content = message_storage[message_id]; if (!content) content = message_storage[message_id];
html = markdown_render(content); html = markdown_render(content);
@@ -1538,7 +1543,7 @@ function update_message(content_map, message_id, content = null) {
content_map.inner.innerHTML = html; content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model); content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner); highlight(content_map.inner);
if (content_map.message_index == -1) { if (scroll) {
lazy_scroll_to_bottom(); lazy_scroll_to_bottom();
} }
content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId)); content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId));
@@ -1890,7 +1895,7 @@ fileInput.addEventListener('change', async (event) => {
fileInput.value = ""; fileInput.value = "";
inputCount.innerText = `${count} Conversations were imported successfully`; inputCount.innerText = `${count} Conversations were imported successfully`;
} else { } else {
is_cookie_file = false; is_cookie_file = data.api_key;
if (Array.isArray(data)) { if (Array.isArray(data)) {
data.forEach((item) => { data.forEach((item) => {
if (item.domain && item.name && item.value) { if (item.domain && item.name && item.value) {
@@ -1927,7 +1932,7 @@ function get_selected_model() {
} }
} }
async function api(ressource, args=null, files=null, message_id=null) { async function api(ressource, args=null, files=null, message_id=null, scroll=true) {
let api_key; let api_key;
if (ressource == "models" && args) { if (ressource == "models" && args) {
api_key = get_api_key_by_provider(args); api_key = get_api_key_by_provider(args);
@@ -1957,7 +1962,7 @@ async function api(ressource, args=null, files=null, message_id=null) {
headers: headers, headers: headers,
body: body, body: body,
}); });
return read_response(response, message_id, args.provider || null); return read_response(response, message_id, args.provider || null, scroll);
} }
response = await fetch(url, {headers: headers}); response = await fetch(url, {headers: headers});
if (response.status == 200) { if (response.status == 200) {
@@ -1966,7 +1971,7 @@ async function api(ressource, args=null, files=null, message_id=null) {
console.error(response); console.error(response);
} }
async function read_response(response, message_id, provider) { async function read_response(response, message_id, provider, scroll) {
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
let buffer = "" let buffer = ""
while (true) { while (true) {
@@ -1979,7 +1984,7 @@ async function read_response(response, message_id, provider) {
continue; continue;
} }
try { try {
add_message_chunk(JSON.parse(buffer + line), message_id, provider); add_message_chunk(JSON.parse(buffer + line), message_id, provider, scroll);
buffer = ""; buffer = "";
} catch { } catch {
buffer += line buffer += line
@@ -2106,6 +2111,7 @@ if (SpeechRecognition) {
recognition.maxAlternatives = 1; recognition.maxAlternatives = 1;
let startValue; let startValue;
let buffer;
let lastDebounceTranscript; let lastDebounceTranscript;
recognition.onstart = function() { recognition.onstart = function() {
microLabel.classList.add("recognition"); microLabel.classList.add("recognition");
@@ -2114,6 +2120,7 @@ if (SpeechRecognition) {
messageInput.readOnly = true; messageInput.readOnly = true;
}; };
recognition.onend = function() { recognition.onend = function() {
messageInput.value = `${startValue ? startValue + "\n" : ""}${buffer}`;
messageInput.readOnly = false; messageInput.readOnly = false;
messageInput.focus(); messageInput.focus();
}; };
@@ -2131,18 +2138,17 @@ if (SpeechRecognition) {
lastDebounceTranscript = transcript; lastDebounceTranscript = transcript;
} }
if (transcript) { if (transcript) {
messageInput.value = `${startValue ? startValue+"\n" : ""}${transcript.trim()}`; inputCount.innerText = transcript;
if (isFinal) { if (isFinal) {
startValue = messageInput.value; buffer = `${buffer ? buffer + "\n" : ""}${transcript.trim()}`;
} }
messageInput.style.height = messageInput.scrollHeight + "px";
messageInput.scrollTop = messageInput.scrollHeight;
} }
}; };
microLabel.addEventListener("click", (e) => { microLabel.addEventListener("click", (e) => {
if (microLabel.classList.contains("recognition")) { if (microLabel.classList.contains("recognition")) {
recognition.stop(); recognition.stop();
messageInput.value = `${startValue ? startValue + "\n" : ""}${buffer}`;
microLabel.classList.remove("recognition"); microLabel.classList.remove("recognition");
} else { } else {
const lang = document.getElementById("recognition-language")?.value; const lang = document.getElementById("recognition-language")?.value;

View File

@@ -9,6 +9,8 @@ import shutil
from flask import Flask, Response, request, jsonify from flask import Flask, Response, request, jsonify
from typing import Generator from typing import Generator
from pathlib import Path from pathlib import Path
from urllib.parse import quote_plus
from hashlib import sha256
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image from ...image import is_allowed_extension, to_image
@@ -123,15 +125,30 @@ class Backend_Api(Api):
"type": "function" "type": "function"
}) })
do_filter_markdown = request.args.get("filter_markdown") do_filter_markdown = request.args.get("filter_markdown")
response = iter_run_tools( cache_id = request.args.get('cache')
ChatCompletion.create, parameters = {
model=request.args.get("model"), "model": request.args.get("model"),
messages=[{"role": "user", "content": request.args.get("prompt")}], "messages": [{"role": "user", "content": request.args.get("prompt")}],
provider=request.args.get("provider", None), "provider": request.args.get("provider", None),
stream=not do_filter_markdown, "stream": not do_filter_markdown and not cache_id,
ignore_stream=not request.args.get("stream"), "ignore_stream": not request.args.get("stream"),
tool_calls=tool_calls, "tool_calls": tool_calls,
) }
if cache_id:
cache_id = sha256(cache_id.encode() + json.dumps(parameters, sort_keys=True).encode()).hexdigest()
cache_dir = Path(get_cookies_dir()) / ".scrape_cache" / "create"
cache_file = cache_dir / f"{quote_plus(request.args.get('prompt').strip()[:20])}.{cache_id}.txt"
if cache_file.exists():
with cache_file.open("r") as f:
response = f.read()
else:
response = iter_run_tools(ChatCompletion.create, **parameters)
cache_dir.mkdir(parents=True, exist_ok=True)
with cache_file.open("w") as f:
f.write(response)
else:
response = iter_run_tools(ChatCompletion.create, **parameters)
if do_filter_markdown: if do_filter_markdown:
return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain') return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain')
def cast_str(): def cast_str():

View File

@@ -269,7 +269,7 @@ class AsyncProvider(AbstractProvider):
def get_async_create_function(cls) -> callable: def get_async_create_function(cls) -> callable:
return cls.create_async return cls.create_async
class AsyncGeneratorProvider(AsyncProvider): class AsyncGeneratorProvider(AbstractProvider):
""" """
Provides asynchronous generator functionality for streaming results. Provides asynchronous generator functionality for streaming results.
""" """
@@ -395,6 +395,10 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
def get_async_create_function(cls) -> callable: def get_async_create_function(cls) -> callable:
return cls.create_async_generator return cls.create_async_generator
@classmethod
def get_cache_file(cls) -> Path:
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@classmethod @classmethod
def create_completion( def create_completion(
cls, cls,
@@ -404,18 +408,24 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
) -> CreateResult: ) -> CreateResult:
try: try:
auth_result = AuthResult() auth_result = AuthResult()
cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json" cache_file = cls.get_cache_file()
if cache_file.exists(): if cache_file.exists():
with cache_file.open("r") as f: with cache_file.open("r") as f:
auth_result = AuthResult(**json.load(f)) auth_result = AuthResult(**json.load(f))
else: else:
auth_result = cls.on_auth(**kwargs) auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) if hasattr(auth_result, "_iter__"):
for chunk in auth_result:
if isinstance(chunk, AsyncResult):
auth_result = chunk
else:
yield chunk
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError): except (MissingAuthError, NoValidHarFileError):
if cache_file.exists(): if cache_file.exists():
cache_file.unlink() cache_file.unlink()
auth_result = cls.on_auth(**kwargs) auth_result = cls.on_auth(**kwargs)
return to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
finally: finally:
cache_file.parent.mkdir(parents=True, exist_ok=True) cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict())) cache_file.write_text(json.dumps(auth_result.get_dict()))
@@ -434,6 +444,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider):
auth_result = AuthResult(**json.load(f)) auth_result = AuthResult(**json.load(f))
else: else:
auth_result = await cls.on_auth_async(**kwargs) auth_result = await cls.on_auth_async(**kwargs)
if hasattr(auth_result, "_aiter__"):
async for chunk in auth_result:
if isinstance(chunk, AsyncResult):
auth_result = chunk
else:
yield chunk
response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
async for chunk in response: async for chunk in response:
yield chunk yield chunk

View File

@@ -19,7 +19,9 @@ def quote_url(url: str) -> str:
def quote_title(title: str) -> str: def quote_title(title: str) -> str:
if title: if title:
return title.replace("\n", "").replace('"', '') title = title.strip()
title = " ".join(title.split())
return title.replace('[', '').replace(']', '')
return "" return ""
def format_link(url: str, title: str = None) -> str: def format_link(url: str, title: str = None) -> str:

View File

@@ -58,10 +58,11 @@ class IterListProvider(BaseRetryProvider):
for chunk in response: for chunk in response:
if chunk: if chunk:
yield chunk yield chunk
started = True if isinstance(chunk, str):
started = True
elif response: elif response:
yield response yield response
started = True return
if started: if started:
return return
except Exception as e: except Exception as e:
@@ -93,7 +94,8 @@ class IterListProvider(BaseRetryProvider):
async for chunk in response: async for chunk in response:
if chunk: if chunk:
yield chunk yield chunk
started = True if isinstance(chunk, str):
started = True
elif response: elif response:
response = await response response = await response
if response: if response:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import os
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Iterator from typing import Iterator
from http.cookies import Morsel from http.cookies import Morsel
@@ -20,6 +21,7 @@ except ImportError:
try: try:
import nodriver import nodriver
from nodriver.cdp.network import CookieParam from nodriver.cdp.network import CookieParam
from nodriver.core.config import find_chrome_executable
from nodriver import Browser from nodriver import Browser
has_nodriver = True has_nodriver = True
except ImportError: except ImportError:
@@ -95,6 +97,8 @@ async def get_args_from_nodriver(
cookies[c.name] = c.value cookies[c.name] = c.value
user_agent = await page.evaluate("window.navigator.userAgent") user_agent = await page.evaluate("window.navigator.userAgent")
await page.wait_for("body:not(.no-js)", timeout=timeout) await page.wait_for("body:not(.no-js)", timeout=timeout)
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
cookies[c.name] = c.value
await page.close() await page.close()
browser.stop() browser.stop()
return { return {
@@ -114,13 +118,21 @@ def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
for cookie in response.cookies.jar: for cookie in response.cookies.jar:
cookies[cookie.name] = cookie.value cookies[cookie.name] = cookie.value
async def get_nodriver(proxy: str = None, user_data_dir = "nodriver", **kwargs)-> Browser: async def get_nodriver(proxy: str = None, user_data_dir = "nodriver", browser_executable_path=None, **kwargs)-> Browser:
if not has_nodriver: if not has_nodriver:
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver') raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None user_data_dir = user_config_dir(f"g4f-{user_data_dir}") if has_platformdirs else None
if browser_executable_path is None:
try:
browser_executable_path = find_chrome_executable()
except FileNotFoundError:
# Default to Edge if Chrome is not found
if os.path.exists("C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe"):
browser_executable_path = "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe"
debug.log(f"Open nodriver with user_dir: {user_data_dir}") debug.log(f"Open nodriver with user_dir: {user_data_dir}")
return await nodriver.start( return await nodriver.start(
user_data_dir=user_data_dir, user_data_dir=user_data_dir,
browser_args=None if proxy is None else [f"--proxy-server={proxy}"], browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
browser_executable_path=browser_executable_path,
**kwargs **kwargs
) )

View File

@@ -471,57 +471,88 @@ async def download_urls(
await asyncio.sleep(delay) await asyncio.sleep(delay)
new_urls = next_urls new_urls = next_urls
def get_downloads_urls(bucket_dir: Path, delete_files: bool = False) -> Iterator[str]:
download_file = bucket_dir / DOWNLOADS_FILE
if download_file.exists():
with download_file.open('r') as f:
data = json.load(f)
if delete_files:
download_file.unlink()
if isinstance(data, list):
for item in data:
if "url" in item:
yield item["url"]
def read_and_download_urls(bucket_dir: Path, event_stream: bool = False) -> Iterator[str]:
urls = get_downloads_urls(bucket_dir)
if urls:
count = 0
with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f:
for filename in to_sync_generator(download_urls(bucket_dir, urls)):
f.write(f"{filename}\n")
if event_stream:
count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
async def async_read_and_download_urls(bucket_dir: Path, event_stream: bool = False) -> Iterator[str]:
urls = get_downloads_urls(bucket_dir)
if urls:
count = 0
with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f:
async for filename in download_urls(bucket_dir, urls):
f.write(f"{filename}\n")
if event_stream:
count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
def stream_chunks(bucket_dir: Path, delete_files: bool = False, refine_chunks_with_spacy: bool = False, event_stream: bool = False) -> Iterator[str]:
size = 0
if refine_chunks_with_spacy:
for chunk in stream_read_parts_and_refine(bucket_dir, delete_files):
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "refine", "size": size})}\n\n'
else:
yield chunk
else:
streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir), delete_files)
streaming = cache_stream(streaming, bucket_dir)
for chunk in streaming:
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "load", "size": size})}\n\n'
else:
yield chunk
files_txt = os.path.join(bucket_dir, FILE_LIST)
if delete_files and os.path.exists(files_txt):
for filename in get_filenames(bucket_dir):
if os.path.exists(os.path.join(bucket_dir, filename)):
os.remove(os.path.join(bucket_dir, filename))
os.remove(files_txt)
if event_stream:
yield f'data: {json.dumps({"action": "delete_files"})}\n\n'
if event_stream:
yield f'data: {json.dumps({"action": "done", "size": size})}\n\n'
def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]: def get_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]:
bucket_dir = Path(bucket_dir) bucket_dir = Path(bucket_dir)
bucket_dir.mkdir(parents=True, exist_ok=True) bucket_dir.mkdir(parents=True, exist_ok=True)
try: try:
download_file = bucket_dir / DOWNLOADS_FILE yield from read_and_download_urls(bucket_dir, event_stream)
if download_file.exists(): yield from stream_chunks(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream)
urls = [] except Exception as e:
with download_file.open('r') as f: if event_stream:
data = json.load(f) yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n'
download_file.unlink() raise e
if isinstance(data, list):
for item in data: async def get_async_streaming(bucket_dir: str, delete_files = False, refine_chunks_with_spacy = False, event_stream: bool = False) -> Iterator[str]:
if "url" in item: bucket_dir = Path(bucket_dir)
urls.append(item["url"]) bucket_dir.mkdir(parents=True, exist_ok=True)
if urls: try:
count = 0 async for chunk in async_read_and_download_urls(bucket_dir, event_stream):
with open(os.path.join(bucket_dir, FILE_LIST), 'w') as f: yield chunk
for filename in to_sync_generator(download_urls(bucket_dir, urls)): for chunk in stream_chunks(bucket_dir, delete_files, refine_chunks_with_spacy, event_stream):
f.write(f"{filename}\n") yield chunk
if event_stream:
count += 1
yield f'data: {json.dumps({"action": "download", "count": count})}\n\n'
if refine_chunks_with_spacy:
size = 0
for chunk in stream_read_parts_and_refine(bucket_dir, delete_files):
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "refine", "size": size})}\n\n'
else:
yield chunk
else:
streaming = stream_read_files(bucket_dir, get_filenames(bucket_dir), delete_files)
streaming = cache_stream(streaming, bucket_dir)
size = 0
for chunk in streaming:
if event_stream:
size += len(chunk)
yield f'data: {json.dumps({"action": "load", "size": size})}\n\n'
else:
yield chunk
files_txt = os.path.join(bucket_dir, FILE_LIST)
if delete_files and os.path.exists(files_txt):
for filename in get_filenames(bucket_dir):
if os.path.exists(os.path.join(bucket_dir, filename)):
os.remove(os.path.join(bucket_dir, filename))
os.remove(files_txt)
if event_stream:
yield f'data: {json.dumps({"action": "delete_files"})}\n\n'
if event_stream:
yield f'data: {json.dumps({"action": "done", "size": size})}\n\n'
except Exception as e: except Exception as e:
if event_stream: if event_stream:
yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n' yield f'data: {json.dumps({"error": {"message": str(e)}})}\n\n'

View File

@@ -24,6 +24,7 @@ except:
from typing import Iterator from typing import Iterator
from ..cookies import get_cookies_dir from ..cookies import get_cookies_dir
from ..providers.response import format_link
from ..errors import MissingRequirementsError from ..errors import MissingRequirementsError
from .. import debug from .. import debug
@@ -66,7 +67,7 @@ class SearchResultEntry():
def set_text(self, text: str): def set_text(self, text: str):
self.text = text self.text = text
def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[str]: def scrape_text(html: str, max_words: int = None, add_source=True, count_images: int = 2) -> Iterator[str]:
source = BeautifulSoup(html, "html.parser") source = BeautifulSoup(html, "html.parser")
soup = source soup = source
for selector in [ for selector in [
@@ -88,7 +89,20 @@ def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[s
if select: if select:
select.extract() select.extract()
for paragraph in soup.select("p, table:not(:has(p)), ul:not(:has(p)), h1, h2, h3, h4, h5, h6"): image_select = "img[alt][src^=http]:not([alt=''])"
image_link_select = f"a:has({image_select})"
for paragraph in soup.select(f"h1, h2, h3, h4, h5, h6, p, table:not(:has(p)), ul:not(:has(p)), {image_link_select}"):
image = paragraph.select_one(image_select)
if count_images > 0:
if image:
title = paragraph.get("title") or paragraph.text
if title:
yield f"!{format_link(image['src'], title)}\n"
if max_words is not None:
max_words -= 10
count_images -= 1
continue
for line in paragraph.text.splitlines(): for line in paragraph.text.splitlines():
words = [word for word in line.replace("\t", " ").split(" ") if word] words = [word for word in line.replace("\t", " ").split(" ") if word]
count = len(words) count = len(words)
@@ -112,7 +126,7 @@ async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = No
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape" bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
bucket_dir.mkdir(parents=True, exist_ok=True) bucket_dir.mkdir(parents=True, exist_ok=True)
md5_hash = hashlib.md5(url.encode()).hexdigest() md5_hash = hashlib.md5(url.encode()).hexdigest()
cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')[:16]}.{datetime.date.today()}.{md5_hash}.txt" cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')[:16]}.{datetime.date.today()}.{md5_hash}.cache"
if cache_file.exists(): if cache_file.exists():
return cache_file.read_text() return cache_file.read_text()
async with session.get(url) as response: async with session.get(url) as response:
@@ -179,14 +193,15 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
md5_hash = hashlib.md5(json_bytes).hexdigest() md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search" / f"{datetime.date.today()}" bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search" / f"{datetime.date.today()}"
bucket_dir.mkdir(parents=True, exist_ok=True) bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{quote_plus(query[:20])}.{md5_hash}.txt" cache_file = bucket_dir / f"{quote_plus(query[:20])}.{md5_hash}.cache"
if cache_file.exists(): if cache_file.exists():
with cache_file.open("r") as f: with cache_file.open("r") as f:
search_results = f.read() search_results = f.read()
else: else:
search_results = await search(query, **kwargs) search_results = await search(query, **kwargs)
with cache_file.open("w") as f: if search_results.results:
f.write(str(search_results)) with cache_file.open("w") as f:
f.write(str(search_results))
new_prompt = f""" new_prompt = f"""
{search_results} {search_results}