From eae317a1665f94ada0c34875a0aec09df89b138b Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Thu, 21 Nov 2024 05:00:08 +0100 Subject: Support synthesize in Openai generator (#2394) * Improve download of generated images, serve images in the api * Add support for conversation handling in the api * Add orginal prompt to image response * Add download images option in gui, fix loading model list in Airforce * Support speech synthesize in Openai generator --- g4f/Provider/needs_auth/OpenaiChat.py | 64 +++++++++---- g4f/client/__init__.py | 6 +- g4f/gui/client/static/css/style.css | 24 ++++- g4f/gui/client/static/js/chat.v1.js | 169 ++++++++++++++++++++++++---------- g4f/gui/server/api.py | 4 +- g4f/gui/server/backend.py | 51 +++++++++- g4f/providers/base_provider.py | 4 +- g4f/providers/response.py | 13 +++ 8 files changed, 259 insertions(+), 76 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 587c0a23..797455fe 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -7,7 +7,6 @@ import json import base64 import time import requests -from aiohttp import ClientWebSocketResponse from copy import copy try: @@ -28,7 +27,7 @@ from ...requests.raise_for_status import raise_for_status from ...requests.aiohttp import StreamSession from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...errors import MissingAuthError, ResponseError -from ...providers.response import BaseConversation +from ...providers.response import BaseConversation, FinishReason, SynthesizeData from ..helper import format_cookies from ..openai.har_file import get_request_config, NoValidHarFileError from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url @@ -367,19 +366,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ + await cls.login(proxy) + async with StreamSession( proxy=proxy, impersonate="chrome", timeout=timeout ) as session: - if cls._expires is not None and cls._expires < time.time(): - cls._headers = cls._api_key = None - try: - await get_request_config(proxy) - cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) - cls._set_api_key(RequestConfig.access_token) - except NoValidHarFileError as e: - await cls.nodriver_auth(proxy) try: image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: @@ -469,18 +462,25 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await asyncio.sleep(5) continue await raise_for_status(response) - async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation): - if return_conversation: - history_disabled = False - return_conversation = False - yield conversation - yield chunk + if return_conversation: + history_disabled = False + yield conversation + async for line in response.iter_lines(): + async for chunk in cls.iter_messages_line(session, line, conversation): + yield chunk + if not history_disabled: + yield SynthesizeData(cls.__name__, { + "conversation_id": conversation.conversation_id, + "message_id": conversation.message_id, + "voice": "maple", + }) if auto_continue and conversation.finish_reason == "max_tokens": conversation.finish_reason = None action = "continue" await asyncio.sleep(5) else: break + yield FinishReason(conversation.finish_reason) if history_disabled and auto_continue: await cls.delete_conversation(session, cls._headers, conversation.conversation_id) @@ -541,10 +541,38 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if "error" in line and line.get("error"): raise RuntimeError(line.get("error")) + @classmethod + async def synthesize(cls, params: dict) -> AsyncIterator[bytes]: + await cls.login() + async with StreamSession( + impersonate="chrome", + timeout=900 + ) as session: + async with session.get( + f"{cls.url}/backend-api/synthesize", + params=params, + headers=cls._headers + ) as response: + await raise_for_status(response) + async for chunk in response.iter_content(): + yield chunk + + @classmethod + async def login(cls, proxy: str = None): + if cls._expires is not None and cls._expires < time.time(): + cls._headers = cls._api_key = None + try: + await get_request_config(proxy) + cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) + cls._set_api_key(RequestConfig.access_token) + except NoValidHarFileError: + if has_nodriver: + await cls.nodriver_auth(proxy) + else: + raise + @classmethod async def nodriver_auth(cls, proxy: str = None): - if not has_nodriver: - return if has_platformdirs: user_data_dir = user_config_dir("g4f-nodriver") else: diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index dea19a60..38269cf6 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -13,7 +13,7 @@ from ..providers.base_provider import AsyncGeneratorProvider from ..image import ImageResponse, copy_images, images_dir from ..typing import Messages, Image, ImageType from ..providers.types import ProviderType -from ..providers.response import ResponseType, FinishReason, BaseConversation +from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData from ..errors import NoImageResponseError, ModelNotFoundError from ..providers.retry_provider import IterListProvider from ..providers.base_provider import get_running_loop @@ -60,6 +60,8 @@ def iter_response( elif isinstance(chunk, BaseConversation): yield chunk continue + elif isinstance(chunk, SynthesizeData): + continue chunk = str(chunk) content += chunk @@ -121,6 +123,8 @@ async def async_iter_response( elif isinstance(chunk, BaseConversation): yield chunk continue + elif isinstance(chunk, SynthesizeData): + continue chunk = str(chunk) content += chunk diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index c4b61d87..b7ec00b9 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -259,7 +259,6 @@ body { flex-direction: column; gap: var(--section-gap); padding: var(--inner-gap) var(--section-gap); - padding-bottom: 0; } .message.print { @@ -271,7 +270,11 @@ body { } .message.regenerate { - opacity: 0.75; + background-color: var(--colour-6); +} + +.white .message.regenerate { + background-color: var(--colour-4); } .message:last-child { @@ -407,6 +410,7 @@ body { .message .count .fa-clipboard.clicked, .message .count .fa-print.clicked, +.message .count .fa-rotate.clicked, .message .count .fa-volume-high.active { color: var(--accent); } @@ -430,6 +434,15 @@ body { font-size: 12px; } +.message audio { + display: none; + max-width: 400px; +} + +.message audio.show { + display: block; +} + .count_total { font-size: 12px; padding-left: 25px; @@ -1159,7 +1172,10 @@ a:-webkit-any-link { .message .user { display: none; } - .message.regenerate { - opacity: 1; + body { + height: auto; + } + .box { + backdrop-filter: none; } } diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 0136f9c4..73c0de0f 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -28,6 +28,7 @@ let message_storage = {}; let controller_storage = {}; let content_storage = {}; let error_storage = {}; +let synthesize_storage = {}; messageInput.addEventListener("blur", () => { window.scrollTo(0, 0); @@ -134,6 +135,13 @@ const register_message_buttons = async () => { if (!("click" in el.dataset)) { el.dataset.click = "true"; el.addEventListener("click", async () => { + const content_el = el.parentElement.parentElement; + const audio = content_el.querySelector("audio"); + if (audio) { + audio.classList.add("show"); + audio.play(); + return; + } let playlist = []; function play_next() { const next = playlist.shift(); @@ -155,7 +163,6 @@ const register_message_buttons = async () => { el.dataset.running = true; el.classList.add("blink") el.classList.add("active") - const content_el = el.parentElement.parentElement; const message_el = content_el.parentElement; let speechText = await get_message(window.conversation_id, message_el.dataset.index); @@ -215,8 +222,8 @@ const register_message_buttons = async () => { const message_el = el.parentElement.parentElement.parentElement; el.classList.add("clicked"); setTimeout(() => el.classList.remove("clicked"), 1000); - await ask_gpt(message_el.dataset.index, get_message_id()); - }) + await ask_gpt(get_message_id(), message_el.dataset.index); + }); } }); document.querySelectorAll(".message .fa-whatsapp").forEach(async (el) => { @@ -301,25 +308,29 @@ const handle_ask = async () => { + `; highlight(message_box); - stop_generating.classList.remove("stop_generating-hidden"); - await ask_gpt(-1, message_id); + await ask_gpt(message_id); }; -async function remove_cancel_button() { +async function safe_remove_cancel_button() { + for (let key in controller_storage) { + if (!controller_storage[key].signal.aborted) { + return; + } + } stop_generating.classList.add("stop_generating-hidden"); } regenerate.addEventListener("click", async () => { regenerate.classList.add("regenerate-hidden"); setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000); - stop_generating.classList.remove("stop_generating-hidden"); await hide_message(window.conversation_id); - await ask_gpt(-1, get_message_id()); + await ask_gpt(get_message_id()); }); stop_generating.addEventListener("click", async () => { @@ -337,21 +348,21 @@ stop_generating.addEventListener("click", async () => { } } } - await load_conversation(window.conversation_id); + await load_conversation(window.conversation_id, false); }); const prepare_messages = (messages, message_index = -1) => { + if (message_index >= 0) { + messages = messages.filter((_, index) => message_index >= index); + } + // Removes none user messages at end - if (message_index == -1) { - let last_message; - while (last_message = messages.pop()) { - if (last_message["role"] == "user") { - messages.push(last_message); - break; - } + let last_message; + while (last_message = messages.pop()) { + if (last_message["role"] == "user") { + messages.push(last_message); + break; } - } else if (message_index >= 0) { - messages = messages.filter((_, index) => message_index >= index); } let new_messages = []; @@ -377,9 +388,11 @@ const prepare_messages = (messages, message_index = -1) => { // Remove generated images from history new_message.content = filter_message(new_message.content); delete new_message.provider; + delete new_message.synthesize; new_messages.push(new_message) } }); + return new_messages; } @@ -427,6 +440,8 @@ async function add_message_chunk(message, message_id) { let p = document.createElement("p"); p.innerText = message.log; log_storage.appendChild(p); + } else if (message.type == "synthesize") { + synthesize_storage[message_id] = message.synthesize; } let scroll_down = ()=>{ if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) { @@ -434,8 +449,10 @@ async function add_message_chunk(message, message_id) { message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); } } - setTimeout(scroll_down, 200); - setTimeout(scroll_down, 1000); + if (!content_map.container.classList.contains("regenerate")) { + scroll_down(); + setTimeout(scroll_down, 200); + } } cameraInput?.addEventListener("click", (e) => { @@ -452,45 +469,58 @@ imageInput?.addEventListener("click", (e) => { } }); -const ask_gpt = async (message_index = -1, message_id) => { +const ask_gpt = async (message_id, message_index = -1) => { let messages = await get_messages(window.conversation_id); - let total_messages = messages.length; messages = prepare_messages(messages, message_index); - message_index = total_messages message_storage[message_id] = ""; - stop_generating.classList.remove(".stop_generating-hidden"); + stop_generating.classList.remove("stop_generating-hidden"); - message_box.scrollTop = message_box.scrollHeight; - window.scrollTo(0, 0); + if (message_index == -1) { + await scroll_to_bottom(); + } let count_total = message_box.querySelector('.count_total'); count_total ? count_total.parentElement.removeChild(count_total) : null; - message_box.innerHTML += ` -
-
- ${gpt_image} - - -
-
-
-
-
-
+ const message_el = document.createElement("div"); + message_el.classList.add("message"); + if (message_index != -1) { + message_el.classList.add("regenerate"); + } + message_el.innerHTML += ` +
+ ${gpt_image} + + +
+
+
+
+
`; + if (message_index == -1) { + message_box.appendChild(message_el); + } else { + parent_message = message_box.querySelector(`.message[data-index="${message_index}"]`); + if (!parent_message) { + return; + } + parent_message.after(message_el); + } controller_storage[message_id] = new AbortController(); let content_el = document.getElementById(`gpt_${message_id}`) let content_map = content_storage[message_id] = { + container: message_el, content: content_el, inner: content_el.querySelector('.content_inner'), count: content_el.querySelector('.count'), } - - await scroll_to_bottom(); + if (message_index == -1) { + await scroll_to_bottom(); + } try { const input = imageInput && imageInput.files.length > 0 ? imageInput : cameraInput; const file = input && input.files.length > 0 ? input.files[0] : null; @@ -527,14 +557,23 @@ const ask_gpt = async (message_index = -1, message_id) => { delete controller_storage[message_id]; if (!error_storage[message_id] && message_storage[message_id]) { const message_provider = message_id in provider_storage ? provider_storage[message_id] : null; - await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider); - await safe_load_conversation(window.conversation_id); + await add_message( + window.conversation_id, + "assistant", + message_storage[message_id], + message_provider, + message_index, + synthesize_storage[message_id] + ); + await safe_load_conversation(window.conversation_id, message_index == -1); } else { - let cursorDiv = message_box.querySelector(".cursor"); + let cursorDiv = message_el.querySelector(".cursor"); if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv); } - await scroll_to_bottom(); - await remove_cancel_button(); + if (message_index == -1) { + await scroll_to_bottom(); + } + await safe_remove_cancel_button(); await register_message_buttons(); await load_conversations(); regenerate.classList.remove("regenerate-hidden"); @@ -687,6 +726,15 @@ const load_conversation = async (conversation_id, scroll=true) => { ${item.provider.model ? ' with ' + item.provider.model : ''}
` : ""; + let audio = ""; + if (item.synthesize) { + const synthesize_params = (new URLSearchParams(item.synthesize.data)).toString(); + audio = ` + + `; + } elements += `
@@ -700,12 +748,14 @@ const load_conversation = async (conversation_id, scroll=true) => {
${provider}
${markdown_render(item.content)}
+ ${audio}
${count_words_and_tokens(item.content, next_provider?.model)} +
@@ -830,14 +880,35 @@ const get_message = async (conversation_id, index) => { return messages[index]["content"]; }; -const add_message = async (conversation_id, role, content, provider) => { +const add_message = async ( + conversation_id, role, content, + provider = null, + message_index = -1, + synthesize_data = null +) => { const conversation = await get_conversation(conversation_id); if (!conversation) return; - conversation.items.push({ + const new_message = { role: role, content: content, - provider: provider - }); + provider: provider, + }; + if (synthesize_data) { + new_message.synthesize = synthesize_data; + } + if (message_index == -1) { + conversation.items.push(new_message); + } else { + const new_messages = []; + conversation.items.forEach((item, index)=>{ + new_messages.push(item); + if (index == message_index) { + new_message.regenerate = true; + new_messages.push(new_message); + } + }); + conversation.items = new_messages; + } await save_conversation(conversation_id, conversation); return conversation.items.length - 1; }; diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 00eb7182..0c32bea5 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.Provider import ProviderType, __providers__, __map__ from g4f.providers.base_provider import ProviderModelMixin -from g4f.providers.response import BaseConversation, FinishReason +from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData from g4f.client.service import convert_to_provider from g4f import debug @@ -177,6 +177,8 @@ class Api: images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) images = ImageResponse(images, chunk.alt) yield self._format_json("content", str(images)) + elif isinstance(chunk, SynthesizeData): + yield self._format_json("synthesize", chunk.to_json()) elif not isinstance(chunk, FinishReason): yield self._format_json("content", str(chunk)) if debug.logs: diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 917d779e..87da49e1 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -1,8 +1,36 @@ import json +import asyncio +import flask from flask import request, Flask +from typing import AsyncGenerator, Generator + from g4f.image import is_allowed_extension, to_image +from g4f.client.service import convert_to_provider +from g4f.errors import ProviderNotFoundError from .api import Api +def safe_iter_generator(generator: Generator) -> Generator: + start = next(generator) + def iter_generator(): + yield start + yield from generator + return iter_generator() + +def to_sync_generator(gen: AsyncGenerator) -> Generator: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + gen = gen.__aiter__() + async def get_next(): + try: + obj = await gen.__anext__() + return False, obj + except StopAsyncIteration: return True, None + while True: + done, obj = loop.run_until_complete(get_next()) + if done: + break + yield obj + class Backend_Api(Api): """ Handles various endpoints in a Flask application for backend operations. @@ -47,6 +75,10 @@ class Backend_Api(Api): 'function': self.handle_conversation, 'methods': ['POST'] }, + '/backend-api/v2/synthesize/': { + 'function': self.handle_synthesize, + 'methods': ['GET'] + }, '/backend-api/v2/error': { 'function': self.handle_error, 'methods': ['POST'] @@ -98,11 +130,28 @@ class Backend_Api(Api): mimetype='text/event-stream' ) + def handle_synthesize(self, provider: str): + try: + provider_handler = convert_to_provider(provider) + except ProviderNotFoundError: + return "Provider not found", 404 + if not hasattr(provider_handler, "synthesize"): + return "Provider doesn't support synthesize", 500 + try: + response_generator = provider_handler.synthesize({**request.args}) + if hasattr(response_generator, "__aiter__"): + response_generator = to_sync_generator(response_generator) + response = flask.Response(safe_iter_generator(response_generator), content_type="audio/mpeg") + response.headers['Cache-Control'] = "max-age=604800" + return response + except Exception as e: + return f"{e.__class__.__name__}: {e}", 500 + def get_provider_models(self, provider: str): api_key = None if request.authorization is None else request.authorization.token models = super().get_provider_models(provider, api_key) if models is None: - return 404, "Provider not found" + return "Provider not found", 404 return models def _format_json(self, response_type: str, content) -> str: diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index b6df48e8..c6d0d950 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -11,7 +11,7 @@ from typing import Callable, Union from ..typing import CreateResult, AsyncResult, Messages from .types import BaseProvider -from .response import FinishReason, BaseConversation +from .response import FinishReason, BaseConversation, SynthesizeData from ..errors import NestAsyncioError, ModelNotSupportedError from .. import debug @@ -259,7 +259,7 @@ class AsyncGeneratorProvider(AsyncProvider): """ return "".join([ str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, (Exception, FinishReason, BaseConversation)) + if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData)) ]) @staticmethod diff --git a/g4f/providers/response.py b/g4f/providers/response.py index a4d1467a..3fddbf4f 100644 --- a/g4f/providers/response.py +++ b/g4f/providers/response.py @@ -22,5 +22,18 @@ class Sources(ResponseType): return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)])) class BaseConversation(ResponseType): + def __str__(self) -> str: + return "" + +class SynthesizeData(ResponseType): + def __init__(self, provider: str, data: dict): + self.provider = provider + self.data = data + + def to_json(self) -> dict: + return { + **self.__dict__ + } + def __str__(self) -> str: return "" \ No newline at end of file -- cgit v1.2.3