diff options
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 074c9161..37bdf074 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -111,7 +111,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): # Post the image data to the service and get the image data async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response: cls._update_request_args(session) - await raise_for_status(response) + await raise_for_status(response, "Create file failed") image_data = { **data, **await response.json(), @@ -129,7 +129,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "x-ms-blob-type": "BlockBlob" } ) as response: - await raise_for_status(response) + await raise_for_status(response, "Send file failed") # Post the file ID to the service and get the download URL async with session.post( f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded", @@ -137,12 +137,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): headers=headers ) as response: cls._update_request_args(session) - await raise_for_status(response) + await raise_for_status(response, "Get download url failed") image_data["download_url"] = (await response.json())["download_url"] return ImageRequest(image_data) @classmethod - def create_messages(cls, messages: Messages, image_request: ImageRequest = None): + def create_messages(cls, messages: Messages, image_request: ImageRequest = None, system_hints: list = None): """ Create a list of messages for the user input @@ -160,7 +160,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "id": str(uuid.uuid4()), "create_time": int(time.time()), "id": str(uuid.uuid4()), - "metadata": {"serialization_metadata": {"custom_symbol_offsets": []}} + "metadata": {"serialization_metadata": {"custom_symbol_offsets": []}, "system_hints": system_hints}, } for message in messages] # Check if there is an image response @@ -189,7 +189,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return messages @classmethod - async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict) -> ImageResponse: + async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse: """ Retrieves the image response based on the message content. @@ -211,6 +211,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): try: prompt = element["metadata"]["dalle"]["prompt"] file_id = element["asset_pointer"].split("file-service://", 1)[1] + except TypeError: + return except Exception as e: raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}") try: @@ -240,6 +242,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): image_name: str = None, return_conversation: bool = False, max_retries: int = 3, + web_search: bool = False, **kwargs ) -> AsyncResult: """ @@ -331,14 +334,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "conversation_mode": {"kind":"primary_assistant"}, "websocket_request_id": str(uuid.uuid4()), "supported_encodings": ["v1"], - "supports_buffering": True + "supports_buffering": True, + "system_hints": ["search"] if web_search else None } if conversation.conversation_id is not None: data["conversation_id"] = conversation.conversation_id debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}") if action != "continue": messages = messages if conversation_id is None else [messages[-1]] - data["messages"] = cls.create_messages(messages, image_request) + data["messages"] = cls.create_messages(messages, image_request, ["search"] if web_search else None) headers = { **cls._headers, "accept": "text/event-stream", @@ -419,9 +423,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): generated_images = [] for element in c.get("parts"): if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer": - generated_images.append( - cls.get_generated_image(session, cls._headers, element) - ) + image = cls.get_generated_image(session, cls._headers, element) + if image is not None: + generated_images.append(image) for image_response in await asyncio.gather(*generated_images): yield image_response if m.get("author", {}).get("role") == "assistant": |