From dba41cda5647dc912e581d4bf81c09bb25257aab Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Wed, 20 Nov 2024 09:52:38 +0100 Subject: Fix image generation in OpenaiChat (#2390) * Fix image generation in OpenaiChat * Add PollinationsAI provider with image and text generation --- g4f/Provider/PollinationsAI.py | 69 +++++++++++++++++++++++++++++++++++ g4f/Provider/__init__.py | 1 + g4f/Provider/needs_auth/OpenaiChat.py | 65 +++++++-------------------------- g4f/providers/base_provider.py | 1 + 4 files changed, 84 insertions(+), 52 deletions(-) create mode 100644 g4f/Provider/PollinationsAI.py diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py new file mode 100644 index 00000000..57597bf1 --- /dev/null +++ b/g4f/Provider/PollinationsAI.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from urllib.parse import quote +import random +import requests +from sys import maxsize +from aiohttp import ClientSession + +from ..typing import AsyncResult, Messages +from ..image import ImageResponse +from ..requests.raise_for_status import raise_for_status +from ..requests.aiohttp import get_connector +from .needs_auth.OpenaiAPI import OpenaiAPI +from .helper import format_prompt + +class PollinationsAI(OpenaiAPI): + label = "Pollinations.AI" + url = "https://pollinations.ai" + working = True + supports_stream = True + default_model = "openai" + + @classmethod + def get_models(cls): + if not cls.image_models: + url = "https://image.pollinations.ai/models" + response = requests.get(url) + raise_for_status(response) + cls.image_models = response.json() + if not cls.models: + url = "https://text.pollinations.ai/models" + response = requests.get(url) + raise_for_status(response) + cls.models = [model.get("name") for model in response.json()] + cls.models.extend(cls.image_models) + return cls.models + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + api_base: str = "https://text.pollinations.ai/openai", + api_key: str = None, + proxy: str = None, + seed: str = None, + **kwargs + ) -> AsyncResult: + if model: + model = cls.get_model(model) + if model in cls.image_models: + prompt = messages[-1]["content"] + if seed is None: + seed = random.randint(0, maxsize) + image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}" + yield ImageResponse(image, prompt) + return + if api_key is None: + async with ClientSession(connector=get_connector(proxy=proxy)) as session: + prompt = format_prompt(messages) + async with session.get(f"https://text.pollinations.ai/{quote(prompt)}?model={quote(model)}") as response: + await raise_for_status(response) + async for line in response.content.iter_any(): + yield line.decode(errors="ignore") + else: + async for chunk in super().create_async_generator( + model, messages, api_base=api_base, proxy=proxy, **kwargs + ): + yield chunk \ No newline at end of file diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 2083c0ff..378f09c8 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -32,6 +32,7 @@ from .MagickPen import MagickPen from .PerplexityLabs import PerplexityLabs from .Pi import Pi from .Pizzagpt import Pizzagpt +from .PollinationsAI import PollinationsAI from .Prodia import Prodia from .Reka import Reka from .ReplicateHome import ReplicateHome diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 15a87f38..97515ec4 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -65,6 +65,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): default_vision_model = "gpt-4o" fallback_models = ["auto", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"] vision_models = fallback_models + image_models = fallback_models _api_key: str = None _headers: dict = None @@ -330,7 +331,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): api_key: str = None, cookies: Cookies = None, auto_continue: bool = False, - history_disabled: bool = True, + history_disabled: bool = False, action: str = "next", conversation_id: str = None, conversation: Conversation = None, @@ -425,12 +426,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}", f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}", )] - ws = None - if need_arkose: - async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response: - wss_url = (await response.json()).get("wss_url") - if wss_url: - ws = await session.ws_connect(wss_url) data = { "action": action, "messages": None, @@ -474,7 +469,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await asyncio.sleep(5) continue await raise_for_status(response) - async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation, ws): + async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation): if return_conversation: history_disabled = False return_conversation = False @@ -489,44 +484,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if history_disabled and auto_continue: await cls.delete_conversation(session, cls._headers, conversation.conversation_id) - @staticmethod - async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str, is_curl: bool) -> AsyncIterator: - while True: - if is_curl: - message = json.loads(ws.recv()[0]) - else: - message = await ws.receive_json() - if message["conversation_id"] == conversation_id: - yield base64.b64decode(message["body"]) - @classmethod async def iter_messages_chunk( cls, messages: AsyncIterator, session: StreamSession, fields: Conversation, - ws = None ) -> AsyncIterator: async for message in messages: - if message.startswith(b'{"wss_url":'): - message = json.loads(message) - ws = await session.ws_connect(message["wss_url"]) if ws is None else ws - try: - async for chunk in cls.iter_messages_chunk( - cls.iter_messages_ws(ws, message["conversation_id"], hasattr(ws, "recv")), - session, fields - ): - yield chunk - finally: - await ws.aclose() if hasattr(ws, "aclose") else await ws.close() - break async for chunk in cls.iter_messages_line(session, message, fields): - if fields.finish_reason is not None: - break - else: - yield chunk - if fields.finish_reason is not None: - break + yield chunk @classmethod async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator: @@ -542,9 +509,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return if isinstance(line, dict) and "v" in line: v = line.get("v") - if isinstance(v, str): + if isinstance(v, str) and fields.is_recipient: yield v - elif isinstance(v, list): + elif isinstance(v, list) and fields.is_recipient: for m in v: if m.get("p") == "/message/content/parts/0": yield m.get("v") @@ -556,25 +523,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): fields.conversation_id = v.get("conversation_id") debug.log(f"OpenaiChat: New conversation: {fields.conversation_id}") m = v.get("message", {}) - if m.get("author", {}).get("role") == "assistant": - fields.message_id = v.get("message", {}).get("id") + fields.is_recipient = m.get("recipient") == "all" + if fields.is_recipient: c = m.get("content", {}) if c.get("content_type") == "multimodal_text": generated_images = [] for element in c.get("parts"): - if isinstance(element, str): - debug.log(f"No image or text: {line}") - elif element.get("content_type") == "image_asset_pointer": + if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer": generated_images.append( cls.get_generated_image(session, cls._headers, element) ) - elif element.get("content_type") == "text": - for part in element.get("parts", []): - yield part for image_response in await asyncio.gather(*generated_images): yield image_response - else: - debug.log(f"OpenaiChat: {line}") + if m.get("author", {}).get("role") == "assistant": + fields.message_id = v.get("message", {}).get("id") return if "error" in line and line.get("error"): raise RuntimeError(line.get("error")) @@ -652,7 +614,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): cls._headers = cls.get_default_headers() if headers is None else headers if user_agent is not None: cls._headers["user-agent"] = user_agent - cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"} + cls._cookies = {} if cookies is None else cookies cls._update_cookie_header() @classmethod @@ -671,8 +633,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def _update_cookie_header(cls): cls._headers["cookie"] = format_cookies(cls._cookies) - if "oai-did" in cls._cookies: - cls._headers["oai-device-id"] = cls._cookies["oai-did"] class Conversation(BaseConversation): """ @@ -682,6 +642,7 @@ class Conversation(BaseConversation): self.conversation_id = conversation_id self.message_id = message_id self.finish_reason = finish_reason + self.is_recipient = False class Response(): """ diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 128fb5a0..9fa17fc3 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -290,6 +290,7 @@ class ProviderModelMixin: default_model: str = None models: list[str] = [] model_aliases: dict[str, str] = {} + image_models: list = None @classmethod def get_models(cls) -> list[str]: -- cgit v1.2.3