Files
blender-portable-repo/scripts/addons/Rokoko Libraries/python311/gql/transport/aiohttp.py
T
2026-03-17 14:58:51 -06:00

387 lines
14 KiB
Python

import asyncio
import functools
import io
import json
import logging
import warnings
from ssl import SSLContext
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
cast,
)
import aiohttp
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.client_reqrep import Fingerprint
from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import LooseCookies, LooseHeaders
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDictProxy
from ..utils import extract_files
from .appsync_auth import AppSyncAuthentication
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class AIOHTTPTransport(AsyncTransport):
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
on remote servers with an HTTP connection.
This transport use the aiohttp library with asyncio.
"""
file_classes: Tuple[Type[Any], ...] = (
io.IOBase,
aiohttp.StreamReader,
AsyncGenerator,
)
def __init__(
self,
url: str,
headers: Optional[LooseHeaders] = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param headers: Dict of HTTP Headers.
:param cookies: Dict of HTTP cookies.
:param auth: BasicAuth object to enable Basic HTTP auth if needed
Or Appsync Authentication class
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
to close properly
:param json_serialize: Json serializer callable.
By default json.dumps() function
:param client_session_args: Dict of extra args passed to
`aiohttp.ClientSession`_
.. _aiohttp.ClientSession:
https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession
"""
self.url: str = url
self.headers: Optional[LooseHeaders] = headers
self.cookies: Optional[LooseCookies] = cookies
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
if ssl == "ssl_warning":
ssl = False
if str(url).startswith("https"):
warnings.warn(
"WARNING: By default, AIOHTTPTransport does not verify"
" ssl certificates. This will be fixed in the next major version."
" You can set ssl=True to force the ssl certificate verification"
" or ssl=False to disable this warning"
)
self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
Union[SSLContext, bool, Fingerprint], ssl
)
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
self.session: Optional[aiohttp.ClientSession] = None
self.response_headers: Optional[CIMultiDictProxy[str]]
self.json_serialize: Callable = json_serialize
async def connect(self) -> None:
"""Coroutine which will create an aiohttp ClientSession() as self.session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
to create the session.
Should be cleaned with a call to the close coroutine.
"""
if self.session is None:
client_session_args: Dict[str, Any] = {
"cookies": self.cookies,
"headers": self.headers,
"auth": None
if isinstance(self.auth, AppSyncAuthentication)
else self.auth,
"json_serialize": self.json_serialize,
}
if self.timeout is not None:
client_session_args["timeout"] = aiohttp.ClientTimeout(
total=self.timeout
)
# Adding custom parameters passed from init
if self.client_session_args:
client_session_args.update(self.client_session_args) # type: ignore
log.debug("Connecting transport")
self.session = aiohttp.ClientSession(**client_session_args)
else:
raise TransportAlreadyConnected("Transport is already connected")
@staticmethod
def create_aiohttp_closed_event(session) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
Returns:
An event that will be set once all transports have been properly closed.
"""
ssl_transports = 0
all_is_lost = asyncio.Event()
def connection_lost(exc, orig_lost):
nonlocal ssl_transports
try:
orig_lost(exc)
finally:
ssl_transports -= 1
if ssl_transports == 0:
all_is_lost.set()
def eof_received(orig_eof_received):
try:
orig_eof_received()
except AttributeError: # pragma: no cover
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass
for conn in session.connector._conns.values():
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue
ssl_transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received
proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)
if ssl_transports == 0:
all_is_lost.set()
return all_is_lost
async def close(self) -> None:
"""Coroutine which will close the aiohttp session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
when you exit the async context manager.
"""
if self.session is not None:
log.debug("Closing transport")
if (
self.client_session_args
and self.client_session_args.get("connector_owner") is False
):
log.debug("connector_owner is False -> not closing connector")
else:
closed_event = self.create_aiohttp_closed_event(self.session)
await self.session.close()
try:
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
except asyncio.TimeoutError:
pass
self.session = None
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
using the current session.
This uses the aiohttp library to perform a HTTP POST request asynchronously
to the remote server.
Don't call this coroutine directly on the transport, instead use
:code:`execute` on a client or a session.
:param document: the parsed GraphQL request
:param variable_values: An optional Dict of variable values
:param operation_name: An optional Operation name for the request
:param extra_args: additional arguments to send to the aiohttp post method
:param upload_files: Set to True if you want to put files in the variable values
:returns: an ExecutionResult object.
"""
query_str = print_ast(document)
payload: Dict[str, Any] = {
"query": query_str,
}
if operation_name:
payload["operationName"] = operation_name
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Prepare aiohttp to send multipart-encoded data
data = aiohttp.FormData()
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Enumerate the file streams
# Will generate something like {'0': <_io.BufferedReader ...>}
file_streams = {str(i): files[path] for i, path in enumerate(files)}
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
data.add_field(
"operations", operations_str, content_type="application/json"
)
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
data.add_field("map", file_map_str, content_type="application/json")
# Add the extracted files as remaining fields
for k, f in file_streams.items():
name = getattr(f, "name", k)
content_type = getattr(f, "content_type", None)
data.add_field(k, f, filename=name, content_type=content_type)
post_args: Dict[str, Any] = {"data": data}
else:
if variable_values:
payload["variables"] = variable_values
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", self.json_serialize(payload))
post_args = {"json": payload}
# Pass post_args to aiohttp post method
if extra_args:
post_args.update(extra_args)
# Add headers for AppSync if requested
if isinstance(self.auth, AppSyncAuthentication):
post_args["headers"] = self.auth.get_headers(
self.json_serialize(payload),
{"content-type": "application/json"},
)
if self.session is None:
raise TransportClosed("Transport is not connected")
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
# Saving latest response headers in the transport
self.response_headers = resp.headers
async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a ClientResponseError if response status is 400 or higher
resp.raise_for_status()
except ClientResponseError as e:
raise TransportServerError(str(e), e.status) from e
result_text = await resp.text()
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
result = await resp.json(content_type=None)
if log.isEnabledFor(logging.INFO):
result_text = await resp.text()
log.info("<<< %s", result_text)
except Exception:
await raise_response_error(resp, "Not a JSON answer")
if result is None:
await raise_response_error(resp, "Not a JSON answer")
if "errors" not in result and "data" not in result:
await raise_response_error(resp, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Subscribe is not supported on HTTP.
:meta private:
"""
raise NotImplementedError(" The HTTP transport does not support subscriptions")