From c84ff591457647bdd0f5a5620505580d3b615540 Mon Sep 17 00:00:00 2001 From: hs_junxiang Date: Fri, 13 Oct 2023 13:45:29 +0800 Subject: feat: ignore providers(#1014) --- g4f/__init__.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'g4f/__init__.py') diff --git a/g4f/__init__.py b/g4f/__init__.py index 1a696c6c..6f777e4c 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -1,13 +1,14 @@ from __future__ import annotations from requests import get from g4f.models import Model, ModelUtils -from .Provider import BaseProvider -from .typing import Messages, CreateResult, Union +from .Provider import BaseProvider, RetryProvider +from .typing import Messages, CreateResult, Union, List from .debug import logging version = '0.1.6.2' version_check = True + def check_pypi_version() -> None: try: response = get("https://pypi.org/pypi/g4f/json").json() @@ -19,9 +20,11 @@ def check_pypi_version() -> None: except Exception as e: print(f'Failed to check g4f pypi version: {e}') + def get_model_and_provider(model : Union[Model, str], provider : Union[type[BaseProvider], None], - stream : bool) -> tuple[Model, type[BaseProvider]]: + stream : bool, + ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]: if isinstance(model, str): if model in ModelUtils.convert: @@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str], if not provider: provider = model.best_provider + if isinstance(provider, RetryProvider) and ignored: + provider.providers = [p for p in provider.providers if p.__name__ not in ignored] + if not provider: raise RuntimeError(f'No provider found for model: {model}') @@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str], return model, provider + class ChatCompletion: @staticmethod def create(model: Union[Model, str], messages : Messages, provider : Union[type[BaseProvider], None] = None, stream : bool = False, - auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]: + auth : Union[str, None] = None, + ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]: - model, provider = get_model_and_provider(model, provider, stream) + model, provider = get_model_and_provider(model, provider, stream, ignored) if provider.needs_auth and not auth: raise ValueError( @@ -71,15 +79,17 @@ class ChatCompletion: model : Union[Model, str], messages: Messages, provider: Union[type[BaseProvider], None] = None, - stream : bool = False, **kwargs) -> str: + stream : bool = False, + ignored : List[str] = None, **kwargs) -> str: if stream: raise ValueError(f'"create_async" does not support "stream" argument') - model, provider = get_model_and_provider(model, provider, False) + model, provider = get_model_and_provider(model, provider, False, ignored) return await provider.create_async(model.name, messages, **kwargs) + class Completion: @staticmethod def create( @@ -87,6 +97,7 @@ class Completion: prompt: str, provider: Union[type[BaseProvider], None] = None, stream: bool = False, + ignored : List[str] = None, **kwargs ) -> Union[CreateResult, str]: @@ -102,7 +113,7 @@ class Completion: if model not in allowed_models: raise Exception(f'ValueError: Can\'t use {model} with Completion.create()') - model, provider = get_model_and_provider(model, provider, stream) + model, provider = get_model_and_provider(model, provider, stream, ignored) result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs) -- cgit v1.2.3