Show only free providers by default

This commit is contained in:
hlohaus
2025-02-21 06:52:04 +01:00
parent e53483d85b
commit 470b795418
14 changed files with 84 additions and 59 deletions

View File

@@ -123,9 +123,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
if model not in cls.image_models: if model not in cls.image_models:
raise raise
if not cache and seed is None:
seed = random.randint(1000, 999999)
if model in cls.image_models: if model in cls.image_models:
async for chunk in cls._generate_image( async for chunk in cls._generate_image(
model=model, model=model,
@@ -134,6 +131,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
width=width, width=width,
height=height, height=height,
seed=seed, seed=seed,
cache=cache,
nologo=nologo, nologo=nologo,
private=private, private=private,
enhance=enhance, enhance=enhance,
@@ -165,11 +163,14 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
width: int, width: int,
height: int, height: int,
seed: Optional[int], seed: Optional[int],
cache: bool,
nologo: bool, nologo: bool,
private: bool, private: bool,
enhance: bool, enhance: bool,
safe: bool safe: bool
) -> AsyncResult: ) -> AsyncResult:
if not cache and seed is None:
seed = random.randint(9999, 99999999)
params = { params = {
"seed": str(seed) if seed is not None else None, "seed": str(seed) if seed is not None else None,
"width": str(width), "width": str(width),
@@ -207,6 +208,8 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
seed: Optional[int], seed: Optional[int],
cache: bool cache: bool
) -> AsyncResult: ) -> AsyncResult:
if not cache and seed is None:
seed = random.randint(9999, 99999999)
json_mode = False json_mode = False
if response_format and response_format.get("type") == "json_object": if response_format and response_format.get("type") == "json_object":
json_mode = True json_mode = True

View File

@@ -28,6 +28,7 @@ class PollinationsImage(PollinationsAI):
width: int = 1024, width: int = 1024,
height: int = 1024, height: int = 1024,
seed: Optional[int] = None, seed: Optional[int] = None,
cache: bool = False,
nologo: bool = True, nologo: bool = True,
private: bool = False, private: bool = False,
enhance: bool = False, enhance: bool = False,
@@ -41,6 +42,7 @@ class PollinationsImage(PollinationsAI):
width=width, width=width,
height=height, height=height,
seed=seed, seed=seed,
cache=cache,
nologo=nologo, nologo=nologo,
private=private, private=private,
enhance=enhance, enhance=enhance,

View File

@@ -8,7 +8,8 @@ import base64
from typing import AsyncIterator from typing import AsyncIterator
try: try:
from curl_cffi.requests import Session, CurlMime from curl_cffi.requests import Session
from curl_cffi import CurlMime
has_curl_cffi = True has_curl_cffi = True
except ImportError: except ImportError:
has_curl_cffi = False has_curl_cffi = False

View File

@@ -4,6 +4,7 @@ from ...providers.types import Messages
from ...typing import ImagesType from ...typing import ImagesType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...errors import ModelNotSupportedError from ...errors import ModelNotSupportedError
from ...providers.helper import get_last_user_message
from ..template.OpenaiTemplate import OpenaiTemplate from ..template.OpenaiTemplate import OpenaiTemplate
from .models import model_aliases, vision_models, default_vision_model from .models import model_aliases, vision_models, default_vision_model
from .HuggingChat import HuggingChat from .HuggingChat import HuggingChat
@@ -22,7 +23,7 @@ class HuggingFaceAPI(OpenaiTemplate):
vision_models = vision_models vision_models = vision_models
model_aliases = model_aliases model_aliases = model_aliases
pipeline_tag: dict[str, str] = {} pipeline_tags: dict[str, str] = {}
@classmethod @classmethod
def get_models(cls, **kwargs): def get_models(cls, **kwargs):
@@ -36,8 +37,8 @@ class HuggingFaceAPI(OpenaiTemplate):
@classmethod @classmethod
async def get_pipline_tag(cls, model: str, api_key: str = None): async def get_pipline_tag(cls, model: str, api_key: str = None):
if model in cls.pipeline_tag: if model in cls.pipeline_tags:
return cls.pipeline_tag[model] return cls.pipeline_tags[model]
async with StreamSession( async with StreamSession(
timeout=30, timeout=30,
headers=cls.get_headers(False, api_key), headers=cls.get_headers(False, api_key),
@@ -45,8 +46,8 @@ class HuggingFaceAPI(OpenaiTemplate):
async with session.get(f"https://huggingface.co/api/models/{model}") as response: async with session.get(f"https://huggingface.co/api/models/{model}") as response:
await raise_for_status(response) await raise_for_status(response)
model_data = await response.json() model_data = await response.json()
cls.pipeline_tag[model] = model_data.get("pipeline_tag") cls.pipeline_tags[model] = model_data.get("pipeline_tag")
return cls.pipeline_tag[model] return cls.pipeline_tags[model]
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@@ -73,10 +74,11 @@ class HuggingFaceAPI(OpenaiTemplate):
if len(messages) > 6: if len(messages) > 6:
messages = messages[:3] + messages[-3:] messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght: if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2: if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + messages[-1:] messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght: if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = [messages[-1]] messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}") debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs): async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
yield chunk yield chunk

View File

@@ -7,7 +7,7 @@ import requests
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError from ...errors import ModelNotSupportedError, ResponseError
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ImageResponse from ...providers.response import FinishReason, ImageResponse
from ..helper import format_image_prompt, get_last_user_message from ..helper import format_image_prompt, get_last_user_message
@@ -24,6 +24,8 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
model_aliases = model_aliases model_aliases = model_aliases
image_models = image_models image_models = image_models
model_data: dict[str, dict] = {}
@classmethod @classmethod
def get_models(cls) -> list[str]: def get_models(cls) -> list[str]:
if not cls.models: if not cls.models:
@@ -43,6 +45,17 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
cls.models = models cls.models = models
return cls.models return cls.models
@classmethod
async def get_model_data(cls, session: StreamSession, model: str) -> str:
if model in cls.model_data:
return cls.model_data[model]
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
await raise_for_status(response)
cls.model_data[model] = await response.json()
return cls.model_data[model]
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, cls,
@@ -96,11 +109,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
timeout=600 timeout=600
) as session: ) as session:
if payload is None: if payload is None:
async with session.get(f"https://huggingface.co/api/models/{model}") as response: model_data = await cls.get_model_data(session, model)
if response.status == 404:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
await raise_for_status(response)
model_data = await response.json()
pipeline_tag = model_data.get("pipeline_tag") pipeline_tag = model_data.get("pipeline_tag")
if pipeline_tag == "text-to-image": if pipeline_tag == "text-to-image":
stream = False stream = False
@@ -117,7 +126,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
if len(messages) > 6: if len(messages) > 6:
messages = messages[:3] + messages[-3:] messages = messages[:3] + messages[-3:]
else: else:
messages = [m for m in messages if m["role"] == "system"] + [get_last_user_message(messages)] messages = [m for m in messages if m["role"] == "system"] + [{"role": "user", "content": get_last_user_message(messages)}]
inputs = get_inputs(messages, model_data, model_type, do_continue) inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}") debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_tokens >= 1024: if model_type == "gpt2" and max_tokens >= 1024:
@@ -130,7 +139,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
if response.status == 404: if response.status == 404:
raise ModelNotFoundError(f"Model is not supported: {model}") raise ModelNotSupportedError(f"Model is not supported: {model}")
await raise_for_status(response) await raise_for_status(response)
if stream: if stream:
first = True first = True

View File

@@ -36,7 +36,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if "api_key" not in kwargs and "images" not in kwargs and random.random() >= 0.5: if "images" not in kwargs and "deepseek" in model or random.random() >= 0.5:
try: try:
is_started = False is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs): async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):

View File

@@ -13,7 +13,6 @@ from ...errors import MissingAuthError
from ...requests import get_args_from_nodriver, get_nodriver from ...requests import get_args_from_nodriver, get_nodriver
from ...providers.response import AuthResult, RequestLogin, Reasoning, JsonConversation, FinishReason from ...providers.response import AuthResult, RequestLogin, Reasoning, JsonConversation, FinishReason
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ... import debug
try: try:
from curl_cffi import requests from curl_cffi import requests
from dsk.api import DeepSeekAPI, AuthenticationError, DeepSeekPOW from dsk.api import DeepSeekAPI, AuthenticationError, DeepSeekPOW

View File

@@ -5,3 +5,4 @@ from .OpenaiChat import OpenaiChat
class OpenaiAccount(OpenaiChat): class OpenaiAccount(OpenaiChat):
needs_auth = True needs_auth = True
parent = "OpenaiChat" parent = "OpenaiChat"
use_nodriver = False # Show (Auth) in the model name

View File

@@ -65,7 +65,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
prompt: str = None, prompt: str = None,
headers: dict = None, headers: dict = None,
impersonate: str = None, impersonate: str = None,
extra_parameters: list[str] = ["tools", "parallel_tool_calls", "", "reasoning_effort", "logit_bias"], extra_parameters: list[str] = ["tools", "parallel_tool_calls", "tool_choice", "reasoning_effort", "logit_bias"],
extra_data: dict = {}, extra_data: dict = {},
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:

View File

@@ -365,7 +365,7 @@ class Images:
break break
except Exception as e: except Exception as e:
error = e error = e
debug.error(e, name=f"{provider.__name__} {type(e).__name__}") debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
@@ -460,7 +460,7 @@ class Images:
break break
except Exception as e: except Exception as e:
error = e error = e
debug.error(e, name=f"{provider.__name__} {type(e).__name__}") debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
else: else:
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)

View File

@@ -1932,7 +1932,7 @@ const load_provider_option = (input, provider_name) => {
providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach( providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach(
(el) => el.removeAttribute("disabled") (el) => el.removeAttribute("disabled")
); );
settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.remove("hidden"); //settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.remove("hidden");
} else { } else {
modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach( modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach(
(el) => { (el) => {
@@ -1947,7 +1947,7 @@ const load_provider_option = (input, provider_name) => {
providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach( providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach(
(el) => el.setAttribute("disabled", "disabled") (el) => el.setAttribute("disabled", "disabled")
); );
settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.add("hidden"); //settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.add("hidden");
} }
}; };
@@ -2039,13 +2039,13 @@ async function on_api() {
if (provider.parent) { if (provider.parent) {
if (!login_urls[provider.parent]) { if (!login_urls[provider.parent]) {
login_urls[provider.parent] = [provider.label, provider.login_url, [provider.name]]; login_urls[provider.parent] = [provider.label, provider.login_url, [provider.name], provider.auth];
} else { } else {
login_urls[provider.parent][2].push(provider.name); login_urls[provider.parent][2].push(provider.name);
} }
} else if (provider.login_url) { } else if (provider.login_url) {
if (!login_urls[provider.name]) { if (!login_urls[provider.name]) {
login_urls[provider.name] = [provider.label, provider.login_url, []]; login_urls[provider.name] = [provider.label, provider.login_url, [], provider.auth];
} else { } else {
login_urls[provider.name][0] = provider.label; login_urls[provider.name][0] = provider.label;
login_urls[provider.name][1] = provider.login_url; login_urls[provider.name][1] = provider.login_url;
@@ -2068,9 +2068,10 @@ async function on_api() {
if (!provider.parent) { if (!provider.parent) {
let option = document.createElement("div"); let option = document.createElement("div");
option.classList.add("provider-item"); option.classList.add("provider-item");
let api_key = appStorage.getItem(`${provider.name}-api_key`);
option.innerHTML = ` option.innerHTML = `
<span class="label">Enable ${provider.label}</span> <span class="label">Enable ${provider.label}</span>
<input id="Provider${provider.name}" type="checkbox" name="Provider${provider.name}" value="${provider.name}" class="provider" checked=""> <input id="Provider${provider.name}" type="checkbox" name="Provider${provider.name}" value="${provider.name}" class="provider" ${'checked="checked"' ? !provider.auth || api_key : ''}/>
<label for="Provider${provider.name}" class="toogle" title="Remove provider from dropdown"></label> <label for="Provider${provider.name}" class="toogle" title="Remove provider from dropdown"></label>
`; `;
option.querySelector("input").addEventListener("change", (event) => load_provider_option(event.target, provider.name)); option.querySelector("input").addEventListener("change", (event) => load_provider_option(event.target, provider.name));
@@ -2102,7 +2103,7 @@ async function on_api() {
`; `;
settings.querySelector(".paper").appendChild(providersListContainer); settings.querySelector(".paper").appendChild(providersListContainer);
for (let [name, [label, login_url, childs]] of Object.entries(login_urls)) { for (let [name, [label, login_url, childs, auth]] of Object.entries(login_urls)) {
if (!login_url && !is_demo) { if (!login_url && !is_demo) {
continue; continue;
} }
@@ -2113,6 +2114,13 @@ async function on_api() {
<label for="${name}-api_key" class="label" title="">${label}:</label> <label for="${name}-api_key" class="label" title="">${label}:</label>
<input type="text" id="${name}-api_key" name="${name}[api_key]" class="${childs}" placeholder="api_key" autocomplete="off"/> <input type="text" id="${name}-api_key" name="${name}[api_key]" class="${childs}" placeholder="api_key" autocomplete="off"/>
` + (login_url ? `<a href="${login_url}" target="_blank" title="Login to ${label}">Get API key</a>` : ""); ` + (login_url ? `<a href="${login_url}" target="_blank" title="Login to ${label}">Get API key</a>` : "");
if (auth) {
providerBox.querySelector("input").addEventListener("input", (event) => {
const input = document.getElementById(`Provider${name}`);
input.checked = !!event.target.value;
load_provider_option(input, name);
});
}
providersListContainer.querySelector(".collapsible-content").appendChild(providerBox); providersListContainer.querySelector(".collapsible-content").appendChild(providerBox);
} }

View File

@@ -143,7 +143,7 @@ class Api:
def decorated_log(text: str, file = None): def decorated_log(text: str, file = None):
debug.logs.append(text) debug.logs.append(text)
if debug.logging: if debug.logging:
debug.log_handler(text, file) debug.log_handler(text, file=file)
debug.log = decorated_log debug.log = decorated_log
proxy = os.environ.get("G4F_PROXY") proxy = os.environ.get("G4F_PROXY")
provider = kwargs.get("provider") provider = kwargs.get("provider")
@@ -187,7 +187,7 @@ class Api:
yield self._format_json("conversation_id", conversation_id) yield self._format_json("conversation_id", conversation_id)
elif isinstance(chunk, Exception): elif isinstance(chunk, Exception):
logger.exception(chunk) logger.exception(chunk)
debug.error(e) debug.error(chunk)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__) yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, PreviewResponse): elif isinstance(chunk, PreviewResponse):
yield self._format_json("preview", chunk.to_string()) yield self._format_json("preview", chunk.to_string())

View File

@@ -123,7 +123,7 @@ async def copy_images(
return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}" return f"/images/{url_filename}{'?url=' + quote(image) if add_url and not image.startswith('data:') else ''}"
except (ClientError, IOError, OSError) as e: except (ClientError, IOError, OSError) as e:
debug.error(f"Image processing failed: {type(e).__name__}: {e}") debug.error(f"Image copying failed: {type(e).__name__}: {e}")
if target_path and os.path.exists(target_path): if target_path and os.path.exists(target_path):
os.unlink(target_path) os.unlink(target_path)
return get_source_url(image, image) return get_source_url(image, image)

View File

@@ -105,7 +105,7 @@ class IterListProvider(BaseRetryProvider):
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
debug.error(name=f"{provider.__name__} {type(e).__name__}: {e}") debug.error(f"{provider.__name__} {type(e).__name__}: {e}")
if started: if started:
raise e raise e
yield e yield e