summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/nexra/NexraGeminiPro.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/nexra/NexraGeminiPro.py')
-rw-r--r--g4f/Provider/nexra/NexraGeminiPro.py70
1 files changed, 43 insertions, 27 deletions
diff --git a/g4f/Provider/nexra/NexraGeminiPro.py b/g4f/Provider/nexra/NexraGeminiPro.py
index a57daed4..651f7cb4 100644
--- a/g4f/Provider/nexra/NexraGeminiPro.py
+++ b/g4f/Provider/nexra/NexraGeminiPro.py
@@ -1,17 +1,25 @@
from __future__ import annotations
-import json
from aiohttp import ClientSession
-
-from ...typing import AsyncResult, Messages
+import json
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt
+from ...typing import AsyncResult, Messages
class NexraGeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Nexra Gemini PRO"
+ url = "https://nexra.aryahcr.cc/documentation/gemini-pro/en"
api_endpoint = "https://nexra.aryahcr.cc/api/chat/complements"
- models = ['gemini-pro']
+ working = True
+ supports_stream = True
+
+ default_model = 'gemini-pro'
+ models = [default_model]
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ return cls.default_model
@classmethod
async def create_async_generator(
@@ -19,34 +27,42 @@ class NexraGeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
+ stream: bool = False,
+ markdown: bool = False,
**kwargs
) -> AsyncResult:
+ model = cls.get_model(model)
+
headers = {
"Content-Type": "application/json"
}
+
+ data = {
+ "messages": [
+ {
+ "role": "user",
+ "content": format_prompt(messages)
+ }
+ ],
+ "markdown": markdown,
+ "stream": stream,
+ "model": model
+ }
+
async with ClientSession(headers=headers) as session:
- data = {
- "messages": [
- {'role': 'assistant', 'content': ''},
- {'role': 'user', 'content': format_prompt(messages)}
- ],
- "markdown": False,
- "stream": True,
- "model": model
- }
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
- full_response = ''
- async for line in response.content:
- if line:
- messages = line.decode('utf-8').split('\x1e')
- for message_str in messages:
- try:
- message = json.loads(message_str)
- if message.get('message'):
- full_response = message['message']
- if message.get('finish'):
- yield full_response.strip()
- return
- except json.JSONDecodeError:
- pass
+ buffer = ""
+ async for chunk in response.content.iter_any():
+ if chunk.strip(): # Check if chunk is not empty
+ buffer += chunk.decode()
+ while '\x1e' in buffer:
+ part, buffer = buffer.split('\x1e', 1)
+ if part.strip():
+ try:
+ response_json = json.loads(part)
+ message = response_json.get("message", "")
+ if message:
+ yield message
+ except json.JSONDecodeError as e:
+ print(f"JSONDecodeError: {e}")