From c617b18d12c2f9d82ce7c73aae46d353b83f625a Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 1 Jan 2024 17:48:57 +0100 Subject: Add support for all models Add AbstractProvider class Add ProviderType type Add get_last_provider function Add version module and VersionUtils Display used provider in gui Fix error response in api --- g4f/Provider/base_provider.py | 48 +++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 27 deletions(-) (limited to 'g4f/Provider/base_provider.py') diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 62029f5d..6da7f6c6 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,12 +1,14 @@ from __future__ import annotations import sys +import asyncio from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor -from abc import ABC, abstractmethod +from abc import abstractmethod from inspect import signature, Parameter from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages +from ..typing import CreateResult, AsyncResult, Messages, Union +from ..base_provider import BaseProvider if sys.version_info < (3, 10): NoneType = type(None) @@ -20,25 +22,7 @@ if sys.platform == 'win32': ): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -class BaseProvider(ABC): - url: str - working: bool = False - needs_auth: bool = False - supports_stream: bool = False - supports_gpt_35_turbo: bool = False - supports_gpt_4: bool = False - supports_message_history: bool = False - - @staticmethod - @abstractmethod - def create_completion( - model: str, - messages: Messages, - stream: bool, - **kwargs - ) -> CreateResult: - raise NotImplementedError() - +class AbstractProvider(BaseProvider): @classmethod async def create_async( cls, @@ -60,9 +44,12 @@ class BaseProvider(ABC): **kwargs )) - return await loop.run_in_executor( - executor, - create_func + return await asyncio.wait_for( + loop.run_in_executor( + executor, + create_func + ), + timeout=kwargs.get("timeout", 0) ) @classmethod @@ -102,16 +89,19 @@ class BaseProvider(ABC): return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" -class AsyncProvider(BaseProvider): +class AsyncProvider(AbstractProvider): @classmethod def create_completion( cls, model: str, messages: Messages, stream: bool = False, + *, + loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - loop = get_event_loop() + if not loop: + loop = get_event_loop() coro = cls.create_async(model, messages, **kwargs) yield loop.run_until_complete(coro) @@ -134,9 +124,12 @@ class AsyncGeneratorProvider(AsyncProvider): model: str, messages: Messages, stream: bool = True, + *, + loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - loop = get_event_loop() + if not loop: + loop = get_event_loop() generator = cls.create_async_generator( model, messages, @@ -171,6 +164,7 @@ class AsyncGeneratorProvider(AsyncProvider): def create_async_generator( model: str, messages: Messages, + stream: bool = True, **kwargs ) -> AsyncResult: raise NotImplementedError() -- cgit v1.2.3