Add Feature provider in demo

Support default provider in DDG
Read api_key from config file
This commit is contained in:
hlohaus
2025-01-27 23:33:21 +01:00
parent 17bd3b3ac6
commit 16e5d9ee86
7 changed files with 48 additions and 12 deletions

View File

@@ -239,7 +239,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
yield ImageResponse(images=[image_url], alt=prompt) yield ImageResponse(images=[image_url], alt=prompt)
return return
if conversation is None: if conversation is None or not hasattr(conversation, "chat_id"):
conversation = Conversation(model) conversation = Conversation(model)
conversation.validated_value = await cls.fetch_validated() conversation.validated_value = await cls.fetch_validated()
conversation.chat_id = cls.generate_chat_id() conversation.chat_id = cls.generate_chat_id()

View File

@@ -50,6 +50,8 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
"""Validates and returns the correct model name""" """Validates and returns the correct model name"""
if not model:
return cls.default_model
if model in cls.model_aliases: if model in cls.model_aliases:
model = cls.model_aliases[model] model = cls.model_aliases[model]
if model not in cls.models: if model not in cls.models:

View File

@@ -7,4 +7,8 @@ class Custom(OpenaiTemplate):
working = True working = True
needs_auth = False needs_auth = False
api_base = "http://localhost:8080/v1" api_base = "http://localhost:8080/v1"
sort_models = False sort_models = False
class Feature(Custom):
label = "Feature Provider"
working = False

View File

@@ -3,6 +3,7 @@ from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount from .CopilotAccount import CopilotAccount
from .Custom import Custom from .Custom import Custom
from .Custom import Feature
from .DeepInfra import DeepInfra from .DeepInfra import DeepInfra
from .DeepSeek import DeepSeek from .DeepSeek import DeepSeek
from .Gemini import Gemini from .Gemini import Gemini

View File

@@ -937,15 +937,17 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} }
try { try {
let api_key; let api_key;
if (is_demo && provider != "Custom") { if (is_demo && provider == "Feature") {
api_key = localStorage.getItem("user");
} else if (is_demo && provider != "Custom") {
api_key = localStorage.getItem("HuggingFace-api_key"); api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) {
location.href = "/";
return;
}
} else { } else {
api_key = get_api_key_by_provider(provider); api_key = get_api_key_by_provider(provider);
} }
if (is_demo && !api_key && provider != "Custom") {
location.href = "/";
return;
}
const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput; const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput;
const files = input && input.files.length > 0 ? input.files : null; const files = input && input.files.length > 0 ? input.files : null;
const download_images = document.getElementById("download_images")?.checked; const download_images = document.getElementById("download_images")?.checked;
@@ -1897,7 +1899,10 @@ async function on_api() {
location.href = "/"; location.href = "/";
return; return;
} }
providerSelect.innerHTML = '<option value="">Demo Mode</option><option value="Custom">Custom Provider</option>'; providerSelect.innerHTML = `
<option value="">Demo Mode</option>
<option value="Feature">Feature Provider</option>
<option value="Custom">Custom Provider</option>`;
providerSelect.selectedIndex = 0; providerSelect.selectedIndex = 0;
document.getElementById("pin").disabled = true; document.getElementById("pin").disabled = true;
document.getElementById("refine")?.parentElement.classList.add("hidden") document.getElementById("refine")?.parentElement.classList.add("hidden")

View File

@@ -134,7 +134,7 @@ class Backend_Api(Api):
else: else:
json_data = request.json json_data = request.json
if app.demo and json_data.get("provider") != "Custom": if app.demo and json_data.get("provider") not in ["Custom", "Feature"]:
model = json_data.get("model") model = json_data.get("model")
if model != "default" and model in models.demo_models: if model != "default" and model in models.demo_models:
json_data["provider"] = random.choice(models.demo_models[model][1]) json_data["provider"] = random.choice(models.demo_models[model][1])

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import re import re
import json import json
import asyncio import asyncio
from pathlib import Path
from typing import Optional, Callable, AsyncIterator from typing import Optional, Callable, AsyncIterator
from ..typing import Messages from ..typing import Messages
from ..providers.helper import filter_none from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator from ..providers.asyncio import to_async_iterator
from ..providers.types import ProviderType
from ..cookies import get_cookies_dir
from .web_search import do_search, get_search_message from .web_search import do_search, get_search_message
from .files import read_bucket, get_bucket_dir from .files import read_bucket, get_bucket_dir
from .. import debug from .. import debug
@@ -27,7 +30,10 @@ def validate_arguments(data: dict) -> dict:
else: else:
return {} return {}
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): def get_api_key_file(cls) -> Path:
return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
async def async_iter_run_tools(provider: ProviderType, model: str, messages, tool_calls: Optional[list] = None, **kwargs):
# Handle web_search from kwargs # Handle web_search from kwargs
web_search = kwargs.get('web_search') web_search = kwargs.get('web_search')
if web_search: if web_search:
@@ -40,6 +46,15 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
# Keep web_search in kwargs for provider native support # Keep web_search in kwargs for provider native support
pass pass
# Read api_key from config file
if provider.needs_auth and "api_key" not in kwargs:
auth_file = get_api_key_file(provider)
if auth_file.exists():
with auth_file.open("r") as f:
auth_result = json.load(f)
if "api_key" in auth_result:
kwargs["api_key"] = auth_result["api_key"]
if tool_calls is not None: if tool_calls is not None:
for tool in tool_calls: for tool in tool_calls:
if tool.get("type") == "function": if tool.get("type") == "function":
@@ -66,8 +81,8 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
message["content"] = new_message_content message["content"] = new_message_content
if has_bucket and isinstance(messages[-1]["content"], str): if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS messages[-1]["content"] += BUCKET_INSTRUCTIONS
create_function = provider.get_async_create_function()
response = to_async_iterator(async_iter_callback(model=model, messages=messages, **kwargs)) response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
async for chunk in response: async for chunk in response:
yield chunk yield chunk
@@ -91,6 +106,15 @@ def iter_run_tools(
# Keep web_search in kwargs for provider native support # Keep web_search in kwargs for provider native support
pass pass
# Read api_key from config file
if provider is not None and provider.needs_auth and "api_key" not in kwargs:
auth_file = get_api_key_file(provider)
if auth_file.exists():
with auth_file.open("r") as f:
auth_result = json.load(f)
if "api_key" in auth_result:
kwargs["api_key"] = auth_result["api_key"]
if tool_calls is not None: if tool_calls is not None:
for tool in tool_calls: for tool in tool_calls:
if tool.get("type") == "function": if tool.get("type") == "function":