From a2b5446b2e449f3107d9980d8eaac0a3ae07730f Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 17 Apr 2024 10:33:23 +0200 Subject: Fix DuckDuckGo Provider issues Fix PerplexityLabs, FlowGpt Provider Update Bing, Gemini Provider --- g4f/requests/curl_cffi.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'g4f/requests/curl_cffi.py') diff --git a/g4f/requests/curl_cffi.py b/g4f/requests/curl_cffi.py index e955d640..d0d44ba7 100644 --- a/g4f/requests/curl_cffi.py +++ b/g4f/requests/curl_cffi.py @@ -6,6 +6,11 @@ try: has_curl_mime = True except ImportError: has_curl_mime = False +try: + from curl_cffi.requests import CurlWsFlag + has_curl_ws = True +except ImportError: + has_curl_ws = False from typing import AsyncGenerator, Any from functools import partialmethod import json @@ -73,6 +78,12 @@ class StreamSession(AsyncSession): """Create and return a StreamResponse object for the given HTTP request.""" return StreamResponse(super().request(method, url, stream=True, **kwargs)) + def ws_connect(self, url, *args, **kwargs): + return WebSocket(self, url) + + def _ws_connect(self, url): + return super().ws_connect(url) + # Defining HTTP methods as partial methods of the request method. head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") @@ -88,4 +99,25 @@ if has_curl_mime: else: class FormData(): def __init__(self) -> None: - raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]") \ No newline at end of file + raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]") + +class WebSocket(): + def __init__(self, session, url) -> None: + if not has_curl_ws: + raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]") + self.session: StreamSession = session + self.url: str = url + + async def __aenter__(self): + self.inner = await self.session._ws_connect(self.url) + return self + + async def __aexit__(self, *args): + self.inner.aclose() + + async def receive_str(self) -> str: + bytes, _ = await self.inner.arecv() + return bytes.decode(errors="ignore") + + async def send_str(self, data: str): + await self.inner.asend(data.encode(), CurlWsFlag.TEXT) \ No newline at end of file -- cgit v1.2.3