diff options
author | kqlio67 <kqlio67@users.noreply.github.com> | 2024-11-07 21:05:40 +0100 |
---|---|---|
committer | kqlio67 <kqlio67@users.noreply.github.com> | 2024-11-07 21:05:40 +0100 |
commit | 9fe5ac6780713443356f29e63d0a4d211392bc5b (patch) | |
tree | 026e45c6736409f5e990a9861782e8cb27ad6d76 /g4f/Provider | |
parent | refactor(g4f/Provider/AIUncensored.py): Enhance robustness and add features (diff) | |
download | gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar.gz gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar.bz2 gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar.lz gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar.xz gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.tar.zst gpt4free-9fe5ac6780713443356f29e63d0a4d211392bc5b.zip |
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/AIUncensored.py | 80 |
1 files changed, 42 insertions, 38 deletions
diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py index db3aa6cd..c2f0f4b3 100644 --- a/g4f/Provider/AIUncensored.py +++ b/g4f/Provider/AIUncensored.py @@ -2,9 +2,9 @@ from __future__ import annotations import json import random -import logging from aiohttp import ClientSession, ClientError -from typing import List +import asyncio +from itertools import cycle from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin @@ -38,27 +38,9 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): @staticmethod def generate_cipher() -> str: + """Generate a cipher in format like '3221229284179118'""" return ''.join([str(random.randint(0, 9)) for _ in range(16)]) - @staticmethod - async def try_request(session: ClientSession, endpoints: List[str], data: dict, proxy: str = None): - available_endpoints = endpoints.copy() - random.shuffle(available_endpoints) - - while available_endpoints: - endpoint = available_endpoints.pop() - try: - async with session.post(endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - return response - except ClientError as e: - logging.warning(f"Failed to connect to {endpoint}: {str(e)}") - if not available_endpoints: - raise - continue - - raise Exception("All endpoints are unavailable") - @classmethod def get_model(cls, model: str) -> str: if model in cls.models: @@ -103,26 +85,48 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): "prompt": prompt, "cipher": cls.generate_cipher() } - response = await cls.try_request(session, cls.api_endpoints_image, data, proxy) - response_data = await response.json() - image_url = response_data['image_url'] - image_response = ImageResponse(images=image_url, alt=prompt) - yield image_response + endpoints = cycle(cls.api_endpoints_image) + + while True: + endpoint = next(endpoints) + try: + async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response: + response.raise_for_status() + response_data = await response.json() + image_url = response_data['image_url'] + image_response = ImageResponse(images=image_url, alt=prompt) + yield image_response + return + except (ClientError, asyncio.TimeoutError): + continue + elif model in cls.text_models: data = { "messages": messages, "cipher": cls.generate_cipher() } - response = await cls.try_request(session, cls.api_endpoints_text, data, proxy) - async for line in response.content: - line = line.decode('utf-8') - if line.startswith("data: "): - try: - json_str = line[6:] - if json_str != "[DONE]": - data = json.loads(json_str) - if "data" in data: - yield data["data"] - except json.JSONDecodeError: - continue + + endpoints = cycle(cls.api_endpoints_text) + + while True: + endpoint = next(endpoints) + try: + async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response: + response.raise_for_status() + full_response = "" + async for line in response.content: + line = line.decode('utf-8') + if line.startswith("data: "): + try: + json_str = line[6:] + if json_str != "[DONE]": + data = json.loads(json_str) + if "data" in data: + full_response += data["data"] + yield data["data"] + except json.JSONDecodeError: + continue + return + except (ClientError, asyncio.TimeoutError): + continue |