~ | code styling

This commit is contained in:
abc
2023-08-27 17:37:44 +02:00
parent 5d08c7201f
commit efd75a11b8
33 changed files with 842 additions and 967 deletions

View File

@@ -9,20 +9,19 @@ import math
class BaseProvider(ABC):
url: str
working = False
needs_auth = False
supports_stream = False
working = False
needs_auth = False
supports_stream = False
supports_gpt_35_turbo = False
supports_gpt_4 = False
supports_gpt_4 = False
@staticmethod
@abstractmethod
def create_completion(
model: str,
messages: list[dict[str, str]],
stream: bool,
**kwargs: Any,
) -> CreateResult:
stream: bool, **kwargs: Any) -> CreateResult:
raise NotImplementedError()
@classmethod
@@ -42,8 +41,10 @@ _cookies = {}
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
return _cookies[cookie_domain]
@@ -53,18 +54,15 @@ class AsyncProvider(BaseProvider):
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = False,
**kwargs: Any
) -> CreateResult:
stream: bool = False, **kwargs: Any) -> CreateResult:
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@abstractmethod
async def create_async(
model: str,
messages: list[dict[str, str]],
**kwargs: Any,
) -> str:
messages: list[dict[str, str]], **kwargs: Any) -> str:
raise NotImplementedError()
@@ -74,9 +72,8 @@ class AsyncGeneratorProvider(AsyncProvider):
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = True,
**kwargs: Any
) -> CreateResult:
stream: bool = True, **kwargs: Any) -> CreateResult:
if stream:
yield from run_generator(cls.create_async_generator(model, messages, **kwargs))
else:
@@ -86,9 +83,8 @@ class AsyncGeneratorProvider(AsyncProvider):
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
**kwargs: Any,
) -> str:
messages: list[dict[str, str]], **kwargs: Any) -> str:
chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)]
if chunks:
return "".join(chunks)
@@ -97,14 +93,14 @@ class AsyncGeneratorProvider(AsyncProvider):
@abstractmethod
def create_async_generator(
model: str,
messages: list[dict[str, str]],
) -> AsyncGenerator:
messages: list[dict[str, str]]) -> AsyncGenerator:
raise NotImplementedError()
def run_generator(generator: AsyncGenerator[Union[Any, str], Any]):
loop = asyncio.new_event_loop()
gen = generator.__aiter__()
gen = generator.__aiter__()
while True:
try: