2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,4 @@
from .async_transport import AsyncTransport
from .transport import Transport
__all__ = ["AsyncTransport", "Transport"]
@@ -0,0 +1,386 @@
import asyncio
import functools
import io
import json
import logging
import warnings
from ssl import SSLContext
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
cast,
)
import aiohttp
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.client_reqrep import Fingerprint
from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import LooseCookies, LooseHeaders
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDictProxy
from ..utils import extract_files
from .appsync_auth import AppSyncAuthentication
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class AIOHTTPTransport(AsyncTransport):
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
on remote servers with an HTTP connection.
This transport use the aiohttp library with asyncio.
"""
file_classes: Tuple[Type[Any], ...] = (
io.IOBase,
aiohttp.StreamReader,
AsyncGenerator,
)
def __init__(
self,
url: str,
headers: Optional[LooseHeaders] = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param headers: Dict of HTTP Headers.
:param cookies: Dict of HTTP cookies.
:param auth: BasicAuth object to enable Basic HTTP auth if needed
Or Appsync Authentication class
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
to close properly
:param json_serialize: Json serializer callable.
By default json.dumps() function
:param client_session_args: Dict of extra args passed to
`aiohttp.ClientSession`_
.. _aiohttp.ClientSession:
https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession
"""
self.url: str = url
self.headers: Optional[LooseHeaders] = headers
self.cookies: Optional[LooseCookies] = cookies
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
if ssl == "ssl_warning":
ssl = False
if str(url).startswith("https"):
warnings.warn(
"WARNING: By default, AIOHTTPTransport does not verify"
" ssl certificates. This will be fixed in the next major version."
" You can set ssl=True to force the ssl certificate verification"
" or ssl=False to disable this warning"
)
self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
Union[SSLContext, bool, Fingerprint], ssl
)
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
self.session: Optional[aiohttp.ClientSession] = None
self.response_headers: Optional[CIMultiDictProxy[str]]
self.json_serialize: Callable = json_serialize
async def connect(self) -> None:
"""Coroutine which will create an aiohttp ClientSession() as self.session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
to create the session.
Should be cleaned with a call to the close coroutine.
"""
if self.session is None:
client_session_args: Dict[str, Any] = {
"cookies": self.cookies,
"headers": self.headers,
"auth": None
if isinstance(self.auth, AppSyncAuthentication)
else self.auth,
"json_serialize": self.json_serialize,
}
if self.timeout is not None:
client_session_args["timeout"] = aiohttp.ClientTimeout(
total=self.timeout
)
# Adding custom parameters passed from init
if self.client_session_args:
client_session_args.update(self.client_session_args) # type: ignore
log.debug("Connecting transport")
self.session = aiohttp.ClientSession(**client_session_args)
else:
raise TransportAlreadyConnected("Transport is already connected")
@staticmethod
def create_aiohttp_closed_event(session) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
Returns:
An event that will be set once all transports have been properly closed.
"""
ssl_transports = 0
all_is_lost = asyncio.Event()
def connection_lost(exc, orig_lost):
nonlocal ssl_transports
try:
orig_lost(exc)
finally:
ssl_transports -= 1
if ssl_transports == 0:
all_is_lost.set()
def eof_received(orig_eof_received):
try:
orig_eof_received()
except AttributeError: # pragma: no cover
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass
for conn in session.connector._conns.values():
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue
ssl_transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received
proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)
if ssl_transports == 0:
all_is_lost.set()
return all_is_lost
async def close(self) -> None:
"""Coroutine which will close the aiohttp session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
when you exit the async context manager.
"""
if self.session is not None:
log.debug("Closing transport")
if (
self.client_session_args
and self.client_session_args.get("connector_owner") is False
):
log.debug("connector_owner is False -> not closing connector")
else:
closed_event = self.create_aiohttp_closed_event(self.session)
await self.session.close()
try:
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
except asyncio.TimeoutError:
pass
self.session = None
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
using the current session.
This uses the aiohttp library to perform a HTTP POST request asynchronously
to the remote server.
Don't call this coroutine directly on the transport, instead use
:code:`execute` on a client or a session.
:param document: the parsed GraphQL request
:param variable_values: An optional Dict of variable values
:param operation_name: An optional Operation name for the request
:param extra_args: additional arguments to send to the aiohttp post method
:param upload_files: Set to True if you want to put files in the variable values
:returns: an ExecutionResult object.
"""
query_str = print_ast(document)
payload: Dict[str, Any] = {
"query": query_str,
}
if operation_name:
payload["operationName"] = operation_name
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Prepare aiohttp to send multipart-encoded data
data = aiohttp.FormData()
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Enumerate the file streams
# Will generate something like {'0': <_io.BufferedReader ...>}
file_streams = {str(i): files[path] for i, path in enumerate(files)}
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
data.add_field(
"operations", operations_str, content_type="application/json"
)
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
data.add_field("map", file_map_str, content_type="application/json")
# Add the extracted files as remaining fields
for k, f in file_streams.items():
name = getattr(f, "name", k)
content_type = getattr(f, "content_type", None)
data.add_field(k, f, filename=name, content_type=content_type)
post_args: Dict[str, Any] = {"data": data}
else:
if variable_values:
payload["variables"] = variable_values
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", self.json_serialize(payload))
post_args = {"json": payload}
# Pass post_args to aiohttp post method
if extra_args:
post_args.update(extra_args)
# Add headers for AppSync if requested
if isinstance(self.auth, AppSyncAuthentication):
post_args["headers"] = self.auth.get_headers(
self.json_serialize(payload),
{"content-type": "application/json"},
)
if self.session is None:
raise TransportClosed("Transport is not connected")
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
# Saving latest response headers in the transport
self.response_headers = resp.headers
async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a ClientResponseError if response status is 400 or higher
resp.raise_for_status()
except ClientResponseError as e:
raise TransportServerError(str(e), e.status) from e
result_text = await resp.text()
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
result = await resp.json(content_type=None)
if log.isEnabledFor(logging.INFO):
result_text = await resp.text()
log.info("<<< %s", result_text)
except Exception:
await raise_response_error(resp, "Not a JSON answer")
if result is None:
await raise_response_error(resp, "Not a JSON answer")
if "errors" not in result and "data" not in result:
await raise_response_error(resp, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Subscribe is not supported on HTTP.
:meta private:
"""
raise NotImplementedError(" The HTTP transport does not support subscriptions")
@@ -0,0 +1,221 @@
import json
import logging
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from typing import Any, Callable, Dict, Optional
try:
import botocore
except ImportError: # pragma: no cover
# botocore is only needed for the IAM AppSync authentication method
pass
log = logging.getLogger("gql.transport.appsync")
class AppSyncAuthentication(ABC):
"""AWS authentication abstract base class
All AWS authentication class should have a
:meth:`get_headers <gql.transport.appsync.AppSyncAuthentication.get_headers>`
method which defines the headers used in the authentication process."""
def get_auth_url(self, url: str) -> str:
"""
:return: a url with base64 encoded headers used to establish
a websocket connection to the appsync-realtime-api.
"""
headers = self.get_headers()
encoded_headers = b64encode(
json.dumps(headers, separators=(",", ":")).encode()
).decode()
url_base = url.replace("https://", "wss://").replace(
"appsync-api", "appsync-realtime-api"
)
return f"{url_base}?header={encoded_headers}&payload=e30="
@abstractmethod
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
raise NotImplementedError() # pragma: no cover
class AppSyncApiKeyAuthentication(AppSyncAuthentication):
"""AWS authentication class using an API key"""
def __init__(self, host: str, api_key: str) -> None:
"""
:param host: the host, something like:
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
:param api_key: the API key
"""
self._host = host.replace("appsync-realtime-api", "appsync-api")
self.api_key = api_key
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return {"host": self._host, "x-api-key": self.api_key}
class AppSyncJWTAuthentication(AppSyncAuthentication):
"""AWS authentication class using a JWT access token.
It can be used either for:
- Amazon Cognito user pools
- OpenID Connect (OIDC)
"""
def __init__(self, host: str, jwt: str) -> None:
"""
:param host: the host, something like:
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
:param jwt: the JWT Access Token
"""
self._host = host.replace("appsync-realtime-api", "appsync-api")
self.jwt = jwt
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return {"host": self._host, "Authorization": self.jwt}
class AppSyncIAMAuthentication(AppSyncAuthentication):
"""AWS authentication class using IAM.
.. note::
There is no need for you to use this class directly, you could instead
intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport`
without an auth argument.
During initialization, this class will use botocore to attempt to
find your IAM credentials, either from environment variables or
from your AWS credentials file.
"""
def __init__(
self,
host: str,
region_name: Optional[str] = None,
signer: Optional["botocore.auth.BaseSigner"] = None,
request_creator: Optional[
Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"]
] = None,
credentials: Optional["botocore.credentials.Credentials"] = None,
session: Optional["botocore.session.Session"] = None,
) -> None:
"""Initialize itself, saving the found credentials used
to sign the headers later.
if no credentials are found, then a NoCredentialsError is raised.
"""
from botocore.auth import SigV4Auth
from botocore.awsrequest import create_request_object
from botocore.session import get_session
self._host = host.replace("appsync-realtime-api", "appsync-api")
self._session = session if session else get_session()
self._credentials = (
credentials if credentials else self._session.get_credentials()
)
self._service_name = "appsync"
self._region_name = region_name or self._detect_region_name()
self._signer = (
signer
if signer
else SigV4Auth(self._credentials, self._service_name, self._region_name)
)
self._request_creator = (
request_creator if request_creator else create_request_object
)
def _detect_region_name(self):
"""Try to detect the correct region_name.
First try to extract the region_name from the host.
If that does not work, then try to get the region_name from
the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION
environment variable.
If no region_name was found, then raise a NoRegionError exception."""
from botocore.exceptions import NoRegionError
# Regular expression from botocore.utils.validate_region
m = re.search(
r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(?<!-))\.", self._host
)
if m:
region_name = m.groups()[0]
log.debug(f"Region name extracted from host: {region_name}")
else:
log.debug("Region name not found in host, trying default region name")
region_name = self._session._resolve_region_name(
None, self._session.get_default_client_config()
)
if region_name is None:
log.warning(
"Region name not found. "
"It was not possible to detect your region either from the host "
"or from your default AWS configuration."
)
raise NoRegionError
return region_name
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
from botocore.exceptions import NoCredentialsError
# Default headers for a websocket connection
headers = headers or {
"accept": "application/json, text/javascript",
"content-encoding": "amz-1.0",
"content-type": "application/json; charset=UTF-8",
}
request: "botocore.awsrequest.AWSRequest" = self._request_creator(
{
"method": "POST",
"url": f"https://{self._host}/graphql{'' if data else '/connect'}",
"headers": headers,
"context": {},
"body": data or "{}",
}
)
try:
self._signer.add_auth(request)
except NoCredentialsError:
log.warning(
"Credentials not found for the IAM auth. "
"Do you have default AWS credentials configured?",
)
raise
headers = dict(request.headers)
headers["host"] = self._host
if log.isEnabledFor(logging.DEBUG):
headers_log = []
headers_log.append("\n\nSigned headers:")
for key, value in headers.items():
headers_log.append(f" {key}: {value}")
headers_log.append("\n")
log.debug("\n".join(headers_log))
return headers
@@ -0,0 +1,214 @@
import json
import logging
from ssl import SSLContext
from typing import Any, Dict, Optional, Tuple, Union, cast
from urllib.parse import urlparse
from graphql import DocumentNode, ExecutionResult, print_ast
from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication
from .exceptions import TransportProtocolError, TransportServerError
from .websockets import WebsocketsTransport, WebsocketsTransportBase
log = logging.getLogger("gql.transport.appsync")
try:
import botocore
except ImportError: # pragma: no cover
# botocore is only needed for the IAM AppSync authentication method
pass
class AppSyncWebsocketsTransport(WebsocketsTransportBase):
""":ref:`Async Transport <async_transports>` used to execute GraphQL subscription on
AWS appsync realtime endpoint.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
auth: Optional[AppSyncAuthentication]
def __init__(
self,
url: str,
auth: Optional[AppSyncAuthentication] = None,
session: Optional["botocore.session.Session"] = None,
ssl: Union[SSLContext, bool] = False,
connect_timeout: int = 10,
close_timeout: int = 10,
ack_timeout: int = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
connect_args: Dict[str, Any] = {},
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL endpoint URL. Example:
https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql
:param auth: Optional AWS authentication class which will provide the
necessary headers to be correctly authenticated. If this
argument is not provided, then we will try to authenticate
using IAM.
:param ssl: ssl_context of the connection.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param connect_args: Other parameters forwarded to websockets.connect
"""
if not auth:
# Extract host from url
host = str(urlparse(url).netloc)
# May raise NoRegionError or NoCredentialsError or ImportError
auth = AppSyncIAMAuthentication(host=host, session=session)
self.auth = auth
url = self.auth.get_auth_url(url)
super().__init__(
url,
ssl=ssl,
connect_timeout=connect_timeout,
close_timeout=close_timeout,
ack_timeout=ack_timeout,
keep_alive_timeout=keep_alive_timeout,
connect_args=connect_args,
)
# Using the same 'graphql-ws' protocol as the apollo protocol
self.supported_subprotocols = [
WebsocketsTransport.APOLLO_SUBPROTOCOL,
]
self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server.
Difference between apollo protocol and aws protocol:
- aws protocol can return an error without an id
- aws protocol will send start_ack messages
Returns a list consisting of:
- the answer_type:
- 'connection_ack',
- 'connection_error',
- 'start_ack',
- 'ka',
- 'data',
- 'error',
- 'complete'
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
answer_type: str = ""
try:
json_answer = json.loads(answer)
answer_type = str(json_answer.get("type"))
if answer_type == "start_ack":
return ("start_ack", None, None)
elif answer_type == "error" and "id" not in json_answer:
error_payload = json_answer.get("payload")
raise TransportServerError(f"Server error: '{error_payload!r}'")
else:
return WebsocketsTransport._parse_answer_apollo(
cast(WebsocketsTransport, self), json_answer
)
except ValueError:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {answer}"
)
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
query_id = self.next_query_id
self.next_query_id += 1
data: Dict = {"query": print_ast(document)}
if variable_values:
data["variables"] = variable_values
if operation_name:
data["operationName"] = operation_name
serialized_data = json.dumps(data, separators=(",", ":"))
payload = {"data": serialized_data}
message: Dict = {
"id": str(query_id),
"type": "start",
"payload": payload,
}
assert self.auth is not None
message["payload"]["extensions"] = {
"authorization": self.auth.get_headers(serialized_data)
}
await self._send(
json.dumps(
message,
separators=(",", ":"),
)
)
return query_id
subscribe = WebsocketsTransportBase.subscribe
"""Send a subscription query and receive the results using
a python async generator.
Only subscriptions are supported, queries and mutations are forbidden.
The results are sent as an ExecutionResult object.
"""
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""This method is not available.
Only subscriptions are supported on the AWS realtime endpoint.
:raise: AssertionError"""
raise AssertionError(
"execute method is not allowed for AppSyncWebsocketsTransport "
"because only subscriptions are allowed on the realtime endpoint."
)
_initialize = WebsocketsTransport._initialize
_stop_listener = WebsocketsTransport._send_stop_message # type: ignore
_send_init_message_and_wait_ack = (
WebsocketsTransport._send_init_message_and_wait_ack
)
_wait_ack = WebsocketsTransport._wait_ack
@@ -0,0 +1,50 @@
import abc
from typing import Any, AsyncGenerator, Dict, Optional
from graphql import DocumentNode, ExecutionResult
class AsyncTransport(abc.ABC):
@abc.abstractmethod
async def connect(self):
"""Coroutine used to create a connection to the specified address"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement connect method"
) # pragma: no cover
@abc.abstractmethod
async def close(self):
"""Coroutine used to Close an established connection"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement close method"
) # pragma: no cover
@abc.abstractmethod
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""Execute the provided document AST for either a remote or local GraphQL
Schema."""
raise NotImplementedError(
"Any AsyncTransport subclass must implement execute method"
) # pragma: no cover
@abc.abstractmethod
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a query and receive the results using an async generator
The query can be a graphql query, mutation or subscription
The results are sent as an ExecutionResult object
"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement subscribe method"
) # pragma: no cover
@@ -0,0 +1,69 @@
from typing import Any, List, Optional
class TransportError(Exception):
"""Base class for all the Transport exceptions"""
pass
class TransportProtocolError(TransportError):
"""Transport protocol error.
The answer received from the server does not correspond to the transport protocol.
"""
class TransportServerError(TransportError):
"""The server returned a global error.
This exception will close the transport connection.
"""
code: Optional[int]
def __init__(self, message: str, code: Optional[int] = None):
super(TransportServerError, self).__init__(message)
self.code = code
class TransportQueryError(TransportError):
"""The server returned an error for a specific query.
This exception should not close the transport connection.
"""
query_id: Optional[int]
errors: Optional[List[Any]]
data: Optional[Any]
extensions: Optional[Any]
def __init__(
self,
msg: str,
query_id: Optional[int] = None,
errors: Optional[List[Any]] = None,
data: Optional[Any] = None,
extensions: Optional[Any] = None,
):
super().__init__(msg)
self.query_id = query_id
self.errors = errors
self.data = data
self.extensions = extensions
class TransportClosed(TransportError):
"""Transport is already closed.
This exception is generated when the client is trying to use the transport
while the transport was previously closed.
"""
class TransportAlreadyConnected(TransportError):
"""Transport is already connected.
Exception generated when the client is trying to connect to the transport
while the transport is already connected.
"""
@@ -0,0 +1,311 @@
import io
import json
import logging
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import httpx
from graphql import DocumentNode, ExecutionResult, print_ast
from ..utils import extract_files
from . import AsyncTransport, Transport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class _HTTPXTransport:
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
response_headers: Optional[httpx.Headers] = None
def __init__(
self,
url: Union[str, httpx.URL],
json_serialize: Callable = json.dumps,
**kwargs,
):
"""Initialize the transport with the given httpx parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param json_serialize: Json serializer callable.
By default json.dumps() function.
:param kwargs: Extra args passed to the `httpx` client.
"""
self.url = url
self.json_serialize = json_serialize
self.kwargs = kwargs
def _prepare_request(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> Dict[str, Any]:
query_str = print_ast(document)
payload: Dict[str, Any] = {
"query": query_str,
}
if operation_name:
payload["operationName"] = operation_name
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
post_args = self._prepare_file_uploads(variable_values, payload)
else:
if variable_values:
payload["variables"] = variable_values
post_args = {"json": payload}
# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(payload))
# Pass post_args to httpx post method
if extra_args:
post_args.update(extra_args)
return post_args
def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]:
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Prepare to send multipart-encoded data
data: Dict[str, Any] = {}
file_map: Dict[str, List[str]] = {}
file_streams: Dict[str, Tuple[str, ...]] = {}
for i, (path, f) in enumerate(files.items()):
key = str(i)
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map[key] = [path]
# Generate the file streams
# Will generate something like
# {"0": ("variables.file", <_io.BufferedReader ...>)}
name = cast(str, getattr(f, "name", key))
content_type = getattr(f, "content_type", None)
if content_type is None:
file_streams[key] = (name, f)
else:
file_streams[key] = (name, f, content_type)
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
data["operations"] = operations_str
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
data["map"] = file_map_str
return {"data": data, "files": file_streams}
def _prepare_result(self, response: httpx.Response) -> ExecutionResult:
# Save latest response headers in transport
self.response_headers = response.headers
if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", response.text)
try:
result: Dict[str, Any] = response.json()
except Exception:
self._raise_response_error(response, "Not a JSON answer")
if "errors" not in result and "data" not in result:
self._raise_response_error(response, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def _raise_response_error(self, response: httpx.Response, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a HTTPError if response status is 400 or higher
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportServerError(str(e), e.response.status_code) from e
raise TransportProtocolError(
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
)
class HTTPXTransport(Transport, _HTTPXTransport):
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the httpx library to send HTTP POST requests.
"""
client: Optional[httpx.Client] = None
def connect(self):
if self.client:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("Connecting transport")
self.client = httpx.Client(**self.kwargs)
def execute( # type: ignore
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the httpx library to perform a HTTP POST request to the remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param extra_args: additional arguments to send to the httpx post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.client:
raise TransportClosed("Transport is not connected")
post_args = self._prepare_request(
document,
variable_values,
operation_name,
extra_args,
upload_files,
)
response = self.client.post(self.url, **post_args)
return self._prepare_result(response)
def close(self):
"""Closing the transport by closing the inner session"""
if self.client:
self.client.close()
self.client = None
class HTTPXAsyncTransport(AsyncTransport, _HTTPXTransport):
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the httpx library with anyio.
"""
client: Optional[httpx.AsyncClient] = None
async def connect(self):
if self.client:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("Connecting transport")
self.client = httpx.AsyncClient(**self.kwargs)
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the httpx library to perform a HTTP POST request asynchronously to the
remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param extra_args: additional arguments to send to the httpx post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.client:
raise TransportClosed("Transport is not connected")
post_args = self._prepare_request(
document,
variable_values,
operation_name,
extra_args,
upload_files,
)
response = await self.client.post(self.url, **post_args)
return self._prepare_result(response)
async def close(self):
"""Closing the transport by closing the inner session"""
if self.client:
await self.client.aclose()
self.client = None
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Subscribe is not supported on HTTP.
:meta private:
"""
raise NotImplementedError("The HTTP transport does not support subscriptions")
@@ -0,0 +1,78 @@
import asyncio
from inspect import isawaitable
from typing import AsyncGenerator, Awaitable, cast
from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe
from gql.transport import AsyncTransport
class LocalSchemaTransport(AsyncTransport):
"""A transport for executing GraphQL queries against a local schema."""
def __init__(
self,
schema: GraphQLSchema,
):
"""Initialize the transport with the given local schema.
:param schema: Local schema as GraphQLSchema object
"""
self.schema = schema
async def connect(self):
"""No connection needed on local transport"""
pass
async def close(self):
"""No close needed on local transport"""
pass
async def execute(
self,
document: DocumentNode,
*args,
**kwargs,
) -> ExecutionResult:
"""Execute the provided document AST for on a local GraphQL Schema."""
result_or_awaitable = execute(self.schema, document, *args, **kwargs)
execution_result: ExecutionResult
if isawaitable(result_or_awaitable):
result_or_awaitable = cast(Awaitable[ExecutionResult], result_or_awaitable)
execution_result = await result_or_awaitable
else:
result_or_awaitable = cast(ExecutionResult, result_or_awaitable)
execution_result = result_or_awaitable
return execution_result
@staticmethod
async def _await_if_necessary(obj):
"""This method is necessary to work with
graphql-core versions < and >= 3.3.0a3"""
return await obj if asyncio.iscoroutine(obj) else obj
async def subscribe(
self,
document: DocumentNode,
*args,
**kwargs,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a subscription and receive the results using an async generator
The results are sent as an ExecutionResult object
"""
subscribe_result = await self._await_if_necessary(
subscribe(self.schema, document, *args, **kwargs)
)
if isinstance(subscribe_result, ExecutionResult):
yield subscribe_result
else:
async for result in subscribe_result:
yield result
@@ -0,0 +1,414 @@
import asyncio
import json
import logging
from typing import Any, Dict, Optional, Tuple
from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.exceptions import ConnectionClosed
from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets_base import WebsocketsTransportBase
log = logging.getLogger(__name__)
class Subscription:
"""Records listener_id and unsubscribe query_id for a subscription."""
def __init__(self, query_id: int) -> None:
self.listener_id: int = query_id
self.unsubscribe_id: Optional[int] = None
class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
"""The PhoenixChannelWebsocketsTransport is an async transport
which allows you to execute queries and subscriptions against an `Absinthe`_
backend using the `Phoenix`_ framework `channels`_.
.. _Absinthe: http://absinthe-graphql.org
.. _Phoenix: https://www.phoenixframework.org
.. _channels: https://hexdocs.pm/phoenix/Phoenix.Channel.html#content
"""
def __init__(
self,
channel_name: str = "__absinthe__:control",
heartbeat_interval: float = 30,
*args,
**kwargs,
) -> None:
"""Initialize the transport with the given parameters.
:param channel_name: Channel on the server this transport will join.
The default for Absinthe servers is "__absinthe__:control"
:param heartbeat_interval: Interval in second between each heartbeat messages
sent by the client
"""
self.channel_name: str = channel_name
self.heartbeat_interval: float = heartbeat_interval
self.heartbeat_task: Optional[asyncio.Future] = None
self.subscriptions: Dict[str, Subscription] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
async def _initialize(self) -> None:
"""Join the specified channel and wait for the connection ACK.
If the answer is not a connection_ack message, we will return an Exception.
"""
query_id = self.next_query_id
self.next_query_id += 1
init_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_join",
"payload": {},
"ref": query_id,
}
)
await self._send(init_message)
# Wait for the connection_ack message or raise a TimeoutError
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
if answer_type != "reply":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)
async def heartbeat_coro():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
query_id = self.next_query_id
self.next_query_id += 1
await self._send(
json.dumps(
{
"topic": "phoenix",
"event": "heartbeat",
"payload": {},
"ref": query_id,
}
)
)
except ConnectionClosed: # pragma: no cover
return
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())
async def _send_stop_message(self, query_id: int) -> None:
"""Send an 'unsubscribe' message to the Phoenix Channel referencing
the listener's query_id, saving the query_id of the message.
The server should afterwards return a 'phx_reply' message with
the same query_id and subscription_id of the 'unsubscribe' request.
"""
subscription_id = self._find_existing_subscription(query_id)
unsubscribe_query_id = self.next_query_id
self.next_query_id += 1
# Save the ref so it can be matched in the reply
self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id
unsubscribe_message = json.dumps(
{
"topic": self.channel_name,
"event": "unsubscribe",
"payload": {"subscriptionId": subscription_id},
"ref": unsubscribe_query_id,
}
)
await self._send(unsubscribe_message)
async def _stop_listener(self, query_id: int) -> None:
await self._send_stop_message(query_id)
async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel."""
query_id = self.next_query_id
self.next_query_id += 1
connection_terminate_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_leave",
"payload": {},
"ref": query_id,
}
)
await self._send(connection_terminate_message)
async def _connection_terminate(self):
await self._send_connection_terminate_message()
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.
We use an incremented id to reference the query.
Returns the used id for this query.
"""
query_id = self.next_query_id
self.next_query_id += 1
query_str = json.dumps(
{
"topic": self.channel_name,
"event": "doc",
"payload": {
"query": print_ast(document),
"variables": variable_values or {},
},
"ref": query_id,
}
)
await self._send(query_str)
return query_id
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server
Returns a list consisting of:
- the answer_type (between:
'data', 'reply', 'complete', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
event: str = ""
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None
subscription_id: Optional[str] = None
def _get_value(d: Any, key: str, label: str) -> Any:
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")
return d.get(key)
def _required_value(d: Any, key: str, label: str) -> Any:
value = _get_value(d, key, label)
if value is None:
raise ValueError(f"null {key} in {label}")
return value
def _required_subscription_id(
d: Any, label: str, must_exist: bool = False, must_not_exist=False
) -> str:
subscription_id = str(_required_value(d, "subscriptionId", label))
if must_exist and (subscription_id not in self.subscriptions):
raise ValueError("unregistered subscriptionId")
if must_not_exist and (subscription_id in self.subscriptions):
raise ValueError("previously registered subscriptionId")
return subscription_id
def _validate_data_response(d: Any, label: str) -> dict:
"""Make sure query, mutation or subscription answer conforms.
The GraphQL spec says only three keys are permitted.
"""
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")
keys = set(d.keys())
invalid = keys - {"data", "errors", "extensions"}
if len(invalid) > 0:
raise ValueError(
f"{label} contains invalid items: " + ", ".join(invalid)
)
return d
try:
json_answer = json.loads(answer)
event = str(_required_value(json_answer, "event", "answer"))
if event == "subscription:data":
payload = _required_value(json_answer, "payload", "answer")
subscription_id = _required_subscription_id(
payload, "payload", must_exist=True
)
result = _validate_data_response(payload.get("result"), "result")
answer_type = "data"
subscription = self.subscriptions[subscription_id]
answer_id = subscription.listener_id
execution_result = ExecutionResult(
data=result.get("data"),
errors=result.get("errors"),
extensions=result.get("extensions"),
)
elif event == "phx_reply":
# Will generate a ValueError if 'ref' is not there
# or if it is not an integer
answer_id = int(_required_value(json_answer, "ref", "answer"))
payload = _required_value(json_answer, "payload", "answer")
status = _get_value(payload, "status", "payload")
if status == "ok":
answer_type = "reply"
if answer_id in self.listeners:
response = _required_value(payload, "response", "payload")
if isinstance(response, dict) and "subscriptionId" in response:
# Subscription answer
subscription_id = _required_subscription_id(
response, "response", must_not_exist=True
)
self.subscriptions[subscription_id] = Subscription(
answer_id
)
else:
# Query or mutation answer
# GraphQL spec says only three keys are permitted
response = _validate_data_response(response, "response")
answer_type = "data"
execution_result = ExecutionResult(
data=response.get("data"),
errors=response.get("errors"),
extensions=response.get("extensions"),
)
else:
(
registered_subscription_id,
listener_id,
) = self._find_subscription(answer_id)
if registered_subscription_id is not None:
# Unsubscription answer
response = _required_value(payload, "response", "payload")
subscription_id = _required_subscription_id(
response, "response"
)
if subscription_id != registered_subscription_id:
raise ValueError("subscription id does not match")
answer_type = "complete"
answer_id = listener_id
elif status == "error":
response = payload.get("response")
if isinstance(response, dict):
if "errors" in response:
raise TransportQueryError(
str(response.get("errors")), query_id=answer_id
)
elif "reason" in response:
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise TransportQueryError("reply error", query_id=answer_id)
elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)
else:
# missing or unrecognized status, just continue
pass
elif event == "phx_error":
# Sent if the channel has crashed
# answer_id will be the "join_ref" for the channel
# answer_id = int(json_answer.get("ref"))
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError("unrecognized event")
except ValueError as e:
log.error(f"Error parsing answer '{answer}': {e!r}")
raise TransportProtocolError(
f"Server did not return a GraphQL result: {e!s}"
) from e
return answer_type, answer_id, execution_result
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
if answer_type == "close":
await self.close()
else:
await super()._handle_answer(answer_type, answer_id, execution_result)
def _remove_listener(self, query_id: int) -> None:
"""If the listener was a subscription, remove that information."""
try:
subscription_id = self._find_existing_subscription(query_id)
del self.subscriptions[subscription_id]
except Exception:
pass
super()._remove_listener(query_id)
def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
for subscription_id, subscription in self.subscriptions.items():
if query_id == subscription.listener_id:
return subscription_id, query_id
if query_id == subscription.unsubscribe_id:
return subscription_id, subscription.listener_id
return None, query_id
def _find_existing_subscription(self, query_id: int) -> str:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
subscription_id, _listener_id = self._find_subscription(query_id)
if subscription_id is None:
raise TransportProtocolError(
f"No subscription registered for listener {query_id}"
)
return subscription_id
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()
await super()._close_coro(e, clean_close)
@@ -0,0 +1,426 @@
import io
import json
import logging
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
import requests
from graphql import DocumentNode, ExecutionResult, print_ast
from requests.adapters import HTTPAdapter, Retry
from requests.auth import AuthBase
from requests.cookies import RequestsCookieJar
from requests_toolbelt.multipart.encoder import MultipartEncoder
from gql.transport import Transport
from ..graphql_request import GraphQLRequest
from ..utils import extract_files
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class RequestsHTTPTransport(Transport):
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the requests library to send HTTP POST requests.
"""
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
_default_retry_codes = (429, 500, 502, 503, 504)
def __init__(
self,
url: str,
headers: Optional[Dict[str, Any]] = None,
cookies: Optional[Union[Dict[str, Any], RequestsCookieJar]] = None,
auth: Optional[AuthBase] = None,
use_json: bool = True,
timeout: Optional[int] = None,
verify: Union[bool, str] = True,
retries: int = 0,
method: str = "POST",
retry_backoff_factor: float = 0.1,
retry_status_forcelist: Collection[int] = _default_retry_codes,
**kwargs: Any,
):
"""Initialize the transport with the given request parameters.
:param url: The GraphQL server URL.
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`
(Default: None).
:param cookies: Dict or CookieJar object to send with the :class:`Request`
(Default: None).
:param auth: Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth
(Default: None).
:param use_json: Send request body as JSON instead of form-urlencoded
(Default: True).
:param timeout: Specifies a default timeout for requests (Default: None).
:param verify: Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. (Default: True).
:param retries: Pre-setup of the requests' Session for performing retries
:param method: HTTP method used for requests. (Default: POST).
:param retry_backoff_factor: A backoff factor to apply between attempts after
the second try. urllib3 will sleep for:
{backoff factor} * (2 ** ({number of previous retries}))
:param retry_status_forcelist: A set of integer HTTP status codes that we
should force a retry on. A retry is initiated if the request method is
in allowed_methods and the response status code is in status_forcelist.
(Default: [429, 500, 502, 503, 504])
:param kwargs: Optional arguments that ``request`` takes.
These can be seen at the `requests`_ source code or the official `docs`_
.. _requests: https://github.com/psf/requests/blob/master/requests/api.py
.. _docs: https://requests.readthedocs.io/en/master/
"""
self.url = url
self.headers = headers
self.cookies = cookies
self.auth = auth
self.use_json = use_json
self.default_timeout = timeout
self.verify = verify
self.retries = retries
self.method = method
self.retry_backoff_factor = retry_backoff_factor
self.retry_status_forcelist = retry_status_forcelist
self.kwargs = kwargs
self.session = None
self.response_headers = None
def connect(self):
if self.session is None:
# Creating a session that can later be re-use to configure custom mechanisms
self.session = requests.Session()
# If we specified some retries, we provide a predefined retry-logic
if self.retries > 0:
adapter = HTTPAdapter(
max_retries=Retry(
total=self.retries,
backoff_factor=self.retry_backoff_factor,
status_forcelist=self.retry_status_forcelist,
allowed_methods=None,
)
)
for prefix in "http://", "https://":
self.session.mount(prefix, adapter)
else:
raise TransportAlreadyConnected("Transport is already connected")
def execute( # type: ignore
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.session:
raise TransportClosed("Transport is not connected")
query_str = print_ast(document)
payload: Dict[str, Any] = {"query": query_str}
if operation_name:
payload["operationName"] = operation_name
post_args = {
"headers": self.headers,
"auth": self.auth,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
"verify": self.verify,
}
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Add the payload to the operations field
operations_str = json.dumps(payload)
log.debug("operations %s", operations_str)
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Enumerate the file streams
# Will generate something like {'0': <_io.BufferedReader ...>}
file_streams = {str(i): files[path] for i, path in enumerate(files)}
# Add the file map field
file_map_str = json.dumps(file_map)
log.debug("file_map %s", file_map_str)
fields = {"operations": operations_str, "map": file_map_str}
# Add the extracted files as remaining fields
for k, f in file_streams.items():
name = getattr(f, "name", k)
content_type = getattr(f, "content_type", None)
if content_type is None:
fields[k] = (name, f)
else:
fields[k] = (name, f, content_type)
# Prepare requests http to send multipart-encoded data
data = MultipartEncoder(fields=fields)
post_args["data"] = data
if post_args["headers"] is None:
post_args["headers"] = {}
else:
post_args["headers"] = {**post_args["headers"]}
post_args["headers"]["Content-Type"] = data.content_type
else:
if variable_values:
payload["variables"] = variable_values
data_key = "json" if self.use_json else "data"
post_args[data_key] = payload
# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(payload))
# Pass kwargs to requests post method
post_args.update(self.kwargs)
# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
# Using the created session to perform requests
response = self.session.request(
self.method, self.url, **post_args # type: ignore
)
self.response_headers = response.headers
def raise_response_error(resp: requests.Response, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a HTTPError if response status is 400 or higher
resp.raise_for_status()
except requests.HTTPError as e:
raise TransportServerError(str(e), e.response.status_code) from e
result_text = resp.text
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
result = response.json()
if log.isEnabledFor(logging.INFO):
log.info("<<< %s", response.text)
except Exception:
raise_response_error(response, "Not a JSON answer")
if "errors" not in result and "data" not in result:
raise_response_error(response, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def execute_batch( # type: ignore
self,
reqs: List[GraphQLRequest],
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch.
Execute the provided requests against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:return: A list of results of execution.
For every result `data` is the result of executing the query,
`errors` is null if no errors occurred, and is a non-empty array
if an error occurred.
"""
if not self.session:
raise TransportClosed("Transport is not connected")
# Using the created session to perform requests
response = self.session.request(
self.method,
self.url,
**self._build_batch_post_args(reqs, timeout, extra_args),
)
self.response_headers = response.headers
answers = self._extract_response(response)
self._validate_answer_is_a_list(answers)
self._validate_num_of_answers_same_as_requests(reqs, answers)
self._validate_every_answer_is_a_dict(answers)
self._validate_data_and_errors_keys_in_answers(answers)
return [self._answer_to_execution_result(answer) for answer in answers]
def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult:
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def _validate_answer_is_a_list(self, results: Any) -> None:
if not isinstance(results, list):
self._raise_invalid_result(
str(results),
"Answer is not a list",
)
def _validate_data_and_errors_keys_in_answers(
self, results: List[Dict[str, Any]]
) -> None:
for result in results:
if "errors" not in result and "data" not in result:
self._raise_invalid_result(
str(results),
'No "data" or "errors" keys in answer',
)
def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None:
for result in results:
if not isinstance(result, dict):
self._raise_invalid_result(str(results), "Not every answer is dict")
def _validate_num_of_answers_same_as_requests(
self,
reqs: List[GraphQLRequest],
results: List[Dict[str, Any]],
) -> None:
if len(reqs) != len(results):
self._raise_invalid_result(
str(results),
"Invalid answer length",
)
def _raise_invalid_result(self, result_text: str, reason: str) -> None:
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)
def _extract_response(self, response: requests.Response) -> Any:
try:
response.raise_for_status()
result = response.json()
if log.isEnabledFor(logging.INFO):
log.info("<<< %s", response.text)
except requests.HTTPError as e:
raise TransportServerError(str(e), e.response.status_code) from e
except Exception:
self._raise_invalid_result(str(response.text), "Not a JSON answer")
return result
def _build_batch_post_args(
self,
reqs: List[GraphQLRequest],
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
post_args: Dict[str, Any] = {
"headers": self.headers,
"auth": self.auth,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
"verify": self.verify,
}
data_key = "json" if self.use_json else "data"
post_args[data_key] = [self._build_data(req) for req in reqs]
# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(post_args[data_key]))
# Pass kwargs to requests post method
post_args.update(self.kwargs)
# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
return post_args
def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]:
query_str = print_ast(req.document)
payload: Dict[str, Any] = {"query": query_str}
if req.operation_name:
payload["operationName"] = req.operation_name
if req.variable_values:
payload["variables"] = req.variable_values
return payload
def close(self):
"""Closing the transport by closing the inner session"""
if self.session:
self.session.close()
self.session = None
@@ -0,0 +1,51 @@
import abc
from typing import List
from graphql import DocumentNode, ExecutionResult
from ..graphql_request import GraphQLRequest
class Transport(abc.ABC):
@abc.abstractmethod
def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST for either a remote or local GraphQL Schema.
:param document: GraphQL query as AST Node or Document object.
:return: ExecutionResult
"""
raise NotImplementedError(
"Any Transport subclass must implement execute method"
) # pragma: no cover
def execute_batch(
self,
reqs: List[GraphQLRequest],
*args,
**kwargs,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch.
Execute the provided requests for either a remote or local GraphQL Schema.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:return: a list of ExecutionResult objects
"""
raise NotImplementedError(
"This Transport has not implemented the execute_batch method"
) # pragma: no cover
def connect(self):
"""Establish a session with the transport."""
pass # pragma: no cover
def close(self):
"""Close the transport
This method doesn't have to be implemented unless the transport would benefit
from it. This is currently used by the RequestsHTTPTransport transport to close
the session's connection pool.
"""
pass # pragma: no cover
@@ -0,0 +1,514 @@
import asyncio
import json
import logging
from contextlib import suppress
from ssl import SSLContext
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.datastructures import HeadersLike
from websockets.typing import Subprotocol
from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets_base import WebsocketsTransportBase
log = logging.getLogger(__name__)
class WebsocketsTransport(WebsocketsTransportBase):
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries on
remote servers with websocket connection.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
# This transport supports two subprotocols and will autodetect the
# subprotocol supported on the server
APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws")
GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws")
def __init__(
self,
url: str,
headers: Optional[HeadersLike] = None,
ssl: Union[SSLContext, bool] = False,
init_payload: Dict[str, Any] = {},
connect_timeout: Optional[Union[int, float]] = 10,
close_timeout: Optional[Union[int, float]] = 10,
ack_timeout: Optional[Union[int, float]] = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
ping_interval: Optional[Union[int, float]] = None,
pong_timeout: Optional[Union[int, float]] = None,
answer_pings: bool = True,
connect_args: Dict[str, Any] = {},
subprotocols: Optional[List[Subprotocol]] = None,
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
:param headers: Dict of HTTP Headers.
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param init_payload: Dict of the payload sent in the connection_init message.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param ping_interval: Delay in seconds between pings sent by the client to
the backend for the graphql-ws protocol. None (by default) means that
we don't send pings. Note: there are also pings sent by the underlying
websockets protocol. See the
:ref:`keepalive documentation <websockets_transport_keepalives>`
for more information about this.
:param pong_timeout: Delay in seconds to receive a pong from the backend
after we sent a ping (only for the graphql-ws protocol).
By default equal to half of the ping_interval.
:param answer_pings: Whether the client answers the pings from the backend
(for the graphql-ws protocol).
By default: True
:param connect_args: Other parameters forwarded to
`websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
client.html#opening-a-connection>`_
:param subprotocols: list of subprotocols sent to the
backend in the 'subprotocols' http header.
By default: both apollo and graphql-ws subprotocols.
"""
super().__init__(
url,
headers,
ssl,
init_payload,
connect_timeout,
close_timeout,
ack_timeout,
keep_alive_timeout,
connect_args,
)
self.ping_interval: Optional[Union[int, float]] = ping_interval
self.pong_timeout: Optional[Union[int, float]]
self.answer_pings: bool = answer_pings
if ping_interval is not None:
if pong_timeout is None:
self.pong_timeout = ping_interval / 2
else:
self.pong_timeout = pong_timeout
self.send_ping_task: Optional[asyncio.Future] = None
self.ping_received: asyncio.Event = asyncio.Event()
"""ping_received is an asyncio Event which will fire each time
a ping is received with the graphql-ws protocol"""
self.pong_received: asyncio.Event = asyncio.Event()
"""pong_received is an asyncio Event which will fire each time
a pong is received with the graphql-ws protocol"""
if subprotocols is None:
self.supported_subprotocols = [
self.APOLLO_SUBPROTOCOL,
self.GRAPHQLWS_SUBPROTOCOL,
]
else:
self.supported_subprotocols = subprotocols
async def _wait_ack(self) -> None:
"""Wait for the connection_ack message. Keep alive messages are ignored"""
while True:
init_answer = await self._receive()
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
if answer_type == "connection_ack":
return
if answer_type != "ka":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)
async def _send_init_message_and_wait_ack(self) -> None:
"""Send init message to the provided websocket and wait for the connection ACK.
If the answer is not a connection_ack message, we will return an Exception.
"""
init_message = json.dumps(
{"type": "connection_init", "payload": self.init_payload}
)
await self._send(init_message)
# Wait for the connection_ack message or raise a TimeoutError
await asyncio.wait_for(self._wait_ack(), self.ack_timeout)
async def _initialize(self):
await self._send_init_message_and_wait_ack()
async def send_ping(self, payload: Optional[Any] = None) -> None:
"""Send a ping message for the graphql-ws protocol"""
ping_message = {"type": "ping"}
if payload is not None:
ping_message["payload"] = payload
await self._send(json.dumps(ping_message))
async def send_pong(self, payload: Optional[Any] = None) -> None:
"""Send a pong message for the graphql-ws protocol"""
pong_message = {"type": "pong"}
if payload is not None:
pong_message["payload"] = payload
await self._send(json.dumps(pong_message))
async def _send_stop_message(self, query_id: int) -> None:
"""Send stop message to the provided websocket connection and query_id.
The server should afterwards return a 'complete' message.
"""
stop_message = json.dumps({"id": str(query_id), "type": "stop"})
await self._send(stop_message)
async def _send_complete_message(self, query_id: int) -> None:
"""Send a complete message for the provided query_id.
This is only for the graphql-ws protocol.
"""
complete_message = json.dumps({"id": str(query_id), "type": "complete"})
await self._send(complete_message)
async def _stop_listener(self, query_id: int):
"""Stop the listener corresponding to the query_id depending on the
detected backend protocol.
For apollo: send a "stop" message
(a "complete" message will be sent from the backend)
For graphql-ws: send a "complete" message and simulate the reception
of a "complete" message from the backend
"""
log.debug(f"stop listener {query_id}")
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
await self._send_complete_message(query_id)
await self.listeners[query_id].put(("complete", None))
else:
await self._send_stop_message(query_id)
async def _send_connection_terminate_message(self) -> None:
"""Send a connection_terminate message to the provided websocket connection.
This message indicates that the connection will disconnect.
"""
connection_terminate_message = json.dumps({"type": "connection_terminate"})
await self._send(connection_terminate_message)
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.
We use an incremented id to reference the query.
Returns the used id for this query.
"""
query_id = self.next_query_id
self.next_query_id += 1
payload: Dict[str, Any] = {"query": print_ast(document)}
if variable_values:
payload["variables"] = variable_values
if operation_name:
payload["operationName"] = operation_name
query_type = "start"
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
query_type = "subscribe"
query_str = json.dumps(
{"id": str(query_id), "type": query_type, "payload": payload}
)
await self._send(query_str)
return query_id
async def _connection_terminate(self):
if self.subprotocol == self.APOLLO_SUBPROTOCOL:
await self._send_connection_terminate_message()
def _parse_answer_graphqlws(
self, json_answer: Dict[str, Any]
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server if the server supports the
graphql-ws protocol.
Returns a list consisting of:
- the answer_type (between:
'connection_ack', 'ping', 'pong', 'data', 'error', 'complete')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
Differences with the apollo websockets protocol (superclass):
- the "data" message is now called "next"
- the "stop" message is now called "complete"
- there is no connection_terminate or connection_error messages
- instead of a unidirectional keep-alive (ka) message from server to client,
there is now the possibility to send bidirectional ping/pong messages
- connection_ack has an optional payload
- the 'error' answer type returns a list of errors instead of a single error
"""
answer_type: str = ""
answer_id: Optional[int] = None
execution_result: Optional[ExecutionResult] = None
try:
answer_type = str(json_answer.get("type"))
if answer_type in ["next", "error", "complete"]:
answer_id = int(str(json_answer.get("id")))
if answer_type == "next" or answer_type == "error":
payload = json_answer.get("payload")
if answer_type == "next":
if not isinstance(payload, dict):
raise ValueError("payload is not a dict")
if "errors" not in payload and "data" not in payload:
raise ValueError(
"payload does not contain 'data' or 'errors' fields"
)
execution_result = ExecutionResult(
errors=payload.get("errors"),
data=payload.get("data"),
extensions=payload.get("extensions"),
)
# Saving answer_type as 'data' to be understood with superclass
answer_type = "data"
elif answer_type == "error":
if not isinstance(payload, list):
raise ValueError("payload is not a list")
raise TransportQueryError(
str(payload[0]), query_id=answer_id, errors=payload
)
elif answer_type in ["ping", "pong", "connection_ack"]:
self.payloads[answer_type] = json_answer.get("payload", None)
else:
raise ValueError
if self.check_keep_alive_task is not None:
self._next_keep_alive_message.set()
except ValueError as e:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {json_answer}"
) from e
return answer_type, answer_id, execution_result
def _parse_answer_apollo(
self, json_answer: Dict[str, Any]
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server if the server supports the
apollo websockets protocol.
Returns a list consisting of:
- the answer_type (between:
'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
answer_type: str = ""
answer_id: Optional[int] = None
execution_result: Optional[ExecutionResult] = None
try:
answer_type = str(json_answer.get("type"))
if answer_type in ["data", "error", "complete"]:
answer_id = int(str(json_answer.get("id")))
if answer_type == "data" or answer_type == "error":
payload = json_answer.get("payload")
if not isinstance(payload, dict):
raise ValueError("payload is not a dict")
if answer_type == "data":
if "errors" not in payload and "data" not in payload:
raise ValueError(
"payload does not contain 'data' or 'errors' fields"
)
execution_result = ExecutionResult(
errors=payload.get("errors"),
data=payload.get("data"),
extensions=payload.get("extensions"),
)
elif answer_type == "error":
raise TransportQueryError(
str(payload), query_id=answer_id, errors=[payload]
)
elif answer_type == "ka":
# Keep-alive message
if self.check_keep_alive_task is not None:
self._next_keep_alive_message.set()
elif answer_type == "connection_ack":
pass
elif answer_type == "connection_error":
error_payload = json_answer.get("payload")
raise TransportServerError(f"Server error: '{repr(error_payload)}'")
else:
raise ValueError
except ValueError as e:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {json_answer}"
) from e
return answer_type, answer_id, execution_result
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server depending on
the detected subprotocol.
"""
try:
json_answer = json.loads(answer)
except ValueError:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {answer}"
)
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
return self._parse_answer_graphqlws(json_answer)
return self._parse_answer_apollo(json_answer)
async def _send_ping_coro(self) -> None:
"""Coroutine to periodically send a ping from the client to the backend.
Only used for the graphql-ws protocol.
Send a ping every ping_interval seconds.
Close the connection if a pong is not received within pong_timeout seconds.
"""
assert self.ping_interval is not None
try:
while True:
await asyncio.sleep(self.ping_interval)
await self.send_ping()
await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout)
# Reset for the next iteration
self.pong_received.clear()
except asyncio.TimeoutError:
# No pong received in the appriopriate time, close with error
# If the timeout happens during a close already in progress, do nothing
if self.close_task is None:
await self._fail(
TransportServerError(
f"No pong received after {self.pong_timeout!r} seconds"
),
clean_close=False,
)
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
# Put the answer in the queue
await super()._handle_answer(answer_type, answer_id, execution_result)
# Answer pong to ping for graphql-ws protocol
if answer_type == "ping":
self.ping_received.set()
if self.answer_pings:
await self.send_pong()
elif answer_type == "pong":
self.pong_received.set()
async def _after_connect(self):
# Find the backend subprotocol returned in the response headers
response_headers = self.websocket.response_headers
try:
self.subprotocol = response_headers["Sec-WebSocket-Protocol"]
except KeyError:
# If the server does not send the subprotocol header, using
# the apollo subprotocol by default
self.subprotocol = self.APOLLO_SUBPROTOCOL
log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
async def _after_initialize(self):
# If requested, create a task to send periodic pings to the backend
if (
self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL
and self.ping_interval is not None
):
self.send_ping_task = asyncio.ensure_future(self._send_ping_coro())
async def _close_hook(self):
# Properly shut down the send ping task if enabled
if self.send_ping_task is not None:
self.send_ping_task.cancel()
with suppress(asyncio.CancelledError):
await self.send_ping_task
self.send_ping_task = None
@@ -0,0 +1,669 @@
import asyncio
import logging
import warnings
from abc import abstractmethod
from contextlib import suppress
from ssl import SSLContext
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
import websockets
from graphql import DocumentNode, ExecutionResult
from websockets.client import WebSocketClientProtocol
from websockets.datastructures import Headers, HeadersLike
from websockets.exceptions import ConnectionClosed
from websockets.typing import Data, Subprotocol
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
log = logging.getLogger("gql.transport.websockets")
ParsedAnswer = Tuple[str, Optional[ExecutionResult]]
class ListenerQueue:
"""Special queue used for each query waiting for server answers
If the server is stopped while the listener is still waiting,
Then we send an exception to the queue and this exception will be raised
to the consumer once all the previous messages have been consumed from the queue
"""
def __init__(self, query_id: int, send_stop: bool) -> None:
self.query_id: int = query_id
self.send_stop: bool = send_stop
self._queue: asyncio.Queue = asyncio.Queue()
self._closed: bool = False
async def get(self) -> ParsedAnswer:
try:
item = self._queue.get_nowait()
except asyncio.QueueEmpty:
item = await self._queue.get()
self._queue.task_done()
# If we receive an exception when reading the queue, we raise it
if isinstance(item, Exception):
self._closed = True
raise item
# Don't need to save new answers or
# send the stop message if we already received the complete message
answer_type, execution_result = item
if answer_type == "complete":
self.send_stop = False
self._closed = True
return item
async def put(self, item: ParsedAnswer) -> None:
if not self._closed:
await self._queue.put(item)
async def set_exception(self, exception: Exception) -> None:
# Put the exception in the queue
await self._queue.put(exception)
# Don't need to send stop messages in case of error
self.send_stop = False
self._closed = True
class WebsocketsTransportBase(AsyncTransport):
"""abstract :ref:`Async Transport <async_transports>` used to implement
different websockets protocols.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
def __init__(
self,
url: str,
headers: Optional[HeadersLike] = None,
ssl: Union[SSLContext, bool] = False,
init_payload: Dict[str, Any] = {},
connect_timeout: Optional[Union[int, float]] = 10,
close_timeout: Optional[Union[int, float]] = 10,
ack_timeout: Optional[Union[int, float]] = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
connect_args: Dict[str, Any] = {},
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
:param headers: Dict of HTTP Headers.
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param init_payload: Dict of the payload sent in the connection_init message.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param connect_args: Other parameters forwarded to websockets.connect
"""
self.url: str = url
self.headers: Optional[HeadersLike] = headers
self.ssl: Union[SSLContext, bool] = ssl
self.init_payload: Dict[str, Any] = init_payload
self.connect_timeout: Optional[Union[int, float]] = connect_timeout
self.close_timeout: Optional[Union[int, float]] = close_timeout
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout
self.connect_args = connect_args
self.websocket: Optional[WebSocketClientProtocol] = None
self.next_query_id: int = 1
self.listeners: Dict[int, ListenerQueue] = {}
self.receive_data_task: Optional[asyncio.Future] = None
self.check_keep_alive_task: Optional[asyncio.Future] = None
self.close_task: Optional[asyncio.Future] = None
# We need to set an event loop here if there is none
# Or else we will not be able to create an asyncio.Event()
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._wait_closed: asyncio.Event = asyncio.Event()
self._wait_closed.set()
self._no_more_listeners: asyncio.Event = asyncio.Event()
self._no_more_listeners.set()
if self.keep_alive_timeout is not None:
self._next_keep_alive_message: asyncio.Event = asyncio.Event()
self._next_keep_alive_message.set()
self.payloads: Dict[str, Any] = {}
"""payloads is a dict which will contain the payloads received
for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'"""
self._connecting: bool = False
self.close_exception: Optional[Exception] = None
# The list of supported subprotocols should be defined in the subclass
self.supported_subprotocols: List[Subprotocol] = []
self.response_headers: Optional[Headers] = None
async def _initialize(self):
"""Hook to send the initialization messages after the connection
and potentially wait for the backend ack.
"""
pass # pragma: no cover
async def _stop_listener(self, query_id: int):
"""Hook to stop to listen to a specific query.
Will send a stop message in some subclasses.
"""
pass # pragma: no cover
async def _after_connect(self):
"""Hook to add custom code for subclasses after the connection
has been established.
"""
pass # pragma: no cover
async def _after_initialize(self):
"""Hook to add custom code for subclasses after the initialization
has been done.
"""
pass # pragma: no cover
async def _close_hook(self):
"""Hook to add custom code for subclasses for the connection close"""
pass # pragma: no cover
async def _connection_terminate(self):
"""Hook to add custom code for subclasses after the initialization
has been done.
"""
pass # pragma: no cover
async def _send(self, message: str) -> None:
"""Send the provided message to the websocket connection and log the message"""
if not self.websocket:
raise TransportClosed(
"Transport is not connected"
) from self.close_exception
try:
await self.websocket.send(message)
log.info(">>> %s", message)
except ConnectionClosed as e:
await self._fail(e, clean_close=False)
raise e
async def _receive(self) -> str:
"""Wait the next message from the websocket connection and log the answer"""
# It is possible that the websocket has been already closed in another task
if self.websocket is None:
raise TransportClosed("Transport is already closed")
# Wait for the next websocket frame. Can raise ConnectionClosed
data: Data = await self.websocket.recv()
# websocket.recv() can return either str or bytes
# In our case, we should receive only str here
if not isinstance(data, str):
raise TransportProtocolError("Binary data received in the websocket")
answer: str = data
log.info("<<< %s", answer)
return answer
@abstractmethod
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
raise NotImplementedError # pragma: no cover
@abstractmethod
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
raise NotImplementedError # pragma: no cover
async def _check_ws_liveness(self) -> None:
"""Coroutine which will periodically check the liveness of the connection
through keep-alive messages
"""
try:
while True:
await asyncio.wait_for(
self._next_keep_alive_message.wait(), self.keep_alive_timeout
)
# Reset for the next iteration
self._next_keep_alive_message.clear()
except asyncio.TimeoutError:
# No keep-alive message in the appriopriate interval, close with error
# while trying to notify the server of a proper close (in case
# the keep-alive interval of the client or server was not aligned
# the connection still remains)
# If the timeout happens during a close already in progress, do nothing
if self.close_task is None:
await self._fail(
TransportServerError(
"No keep-alive message has been received within "
"the expected interval ('keep_alive_timeout' parameter)"
),
clean_close=False,
)
except asyncio.CancelledError:
# The client is probably closing, handle it properly
pass
async def _receive_data_loop(self) -> None:
"""Main asyncio task which will listen to the incoming messages and will
call the parse_answer and handle_answer methods of the subclass."""
try:
while True:
# Wait the next answer from the websocket server
try:
answer = await self._receive()
except (ConnectionClosed, TransportProtocolError) as e:
await self._fail(e, clean_close=False)
break
except TransportClosed:
break
# Parse the answer
try:
answer_type, answer_id, execution_result = self._parse_answer(
answer
)
except TransportQueryError as e:
# Received an exception for a specific query
# ==> Add an exception to this query queue
# The exception is raised for this specific query,
# but the transport is not closed.
assert isinstance(
e.query_id, int
), "TransportQueryError should have a query_id defined here"
try:
await self.listeners[e.query_id].set_exception(e)
except KeyError:
# Do nothing if no one is listening to this query_id
pass
continue
except (TransportServerError, TransportProtocolError) as e:
# Received a global exception for this transport
# ==> close the transport
# The exception will be raised for all current queries.
await self._fail(e, clean_close=False)
break
await self._handle_answer(answer_type, answer_id, execution_result)
finally:
log.debug("Exiting _receive_data_loop()")
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put((answer_type, execution_result))
except KeyError:
# Do nothing if no one is listening to this query_id.
pass
async def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
send_stop: Optional[bool] = True,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a query and receive the results using a python async generator.
The query can be a graphql query, mutation or subscription.
The results are sent as an ExecutionResult object.
"""
# Send the query and receive the id
query_id: int = await self._send_query(
document, variable_values, operation_name
)
# Create a queue to receive the answers for this query_id
listener = ListenerQueue(query_id, send_stop=(send_stop is True))
self.listeners[query_id] = listener
# We will need to wait at close for this query to clean properly
self._no_more_listeners.clear()
try:
# Loop over the received answers
while True:
# Wait for the answer from the queue of this query_id
# This can raise a TransportError or ConnectionClosed exception.
answer_type, execution_result = await listener.get()
# If the received answer contains data,
# Then we will yield the results back as an ExecutionResult object
if execution_result is not None:
yield execution_result
# If we receive a 'complete' answer from the server,
# Then we will end this async generator output without errors
elif answer_type == "complete":
log.debug(
f"Complete received for query {query_id} --> exit without error"
)
break
except (asyncio.CancelledError, GeneratorExit) as e:
log.debug(f"Exception in subscribe: {e!r}")
if listener.send_stop:
await self._stop_listener(query_id)
listener.send_stop = False
finally:
log.debug(f"In subscribe finally for query_id {query_id}")
self._remove_listener(query_id)
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
using the current session.
Send a query but close the async generator as soon as we have the first answer.
The result is sent as an ExecutionResult object.
"""
first_result = None
generator = self.subscribe(
document, variable_values, operation_name, send_stop=False
)
async for result in generator:
first_result = result
# Note: we need to run generator.aclose() here or the finally block in
# the subscribe will not be reached in pypy3 (python version 3.6.1)
await generator.aclose()
break
if first_result is None:
raise TransportQueryError(
"Query completed without any answer received from the server"
)
return first_result
async def connect(self) -> None:
"""Coroutine which will:
- connect to the websocket address
- send the init message
- wait for the connection acknowledge from the server
- create an asyncio task which will be used to receive
and parse the websocket answers
Should be cleaned with a call to the close coroutine
"""
log.debug("connect: starting")
if self.websocket is None and not self._connecting:
# Set connecting to True to avoid a race condition if user is trying
# to connect twice using the same client at the same time
self._connecting = True
# If the ssl parameter is not provided,
# generate the ssl value depending on the url
ssl: Optional[Union[SSLContext, bool]]
if self.ssl:
ssl = self.ssl
else:
ssl = True if self.url.startswith("wss") else None
# Set default arguments used in the websockets.connect call
connect_args: Dict[str, Any] = {
"ssl": ssl,
"extra_headers": self.headers,
"subprotocols": self.supported_subprotocols,
}
# Adding custom parameters passed from init
connect_args.update(self.connect_args)
# Connection to the specified url
# Generate a TimeoutError if taking more than connect_timeout seconds
# Set the _connecting flag to False after in all cases
try:
self.websocket = await asyncio.wait_for(
websockets.client.connect(self.url, **connect_args),
self.connect_timeout,
)
finally:
self._connecting = False
self.websocket = cast(WebSocketClientProtocol, self.websocket)
self.response_headers = self.websocket.response_headers
# Run the after_connect hook of the subclass
await self._after_connect()
self.next_query_id = 1
self.close_exception = None
self._wait_closed.clear()
# Send the init message and wait for the ack from the server
# Note: This should generate a TimeoutError
# if no ACKs are received within the ack_timeout
try:
await self._initialize()
except ConnectionClosed as e:
raise e
except (TransportProtocolError, asyncio.TimeoutError) as e:
await self._fail(e, clean_close=False)
raise e
# Run the after_init hook of the subclass
await self._after_initialize()
# If specified, create a task to check liveness of the connection
# through keep-alive messages
if self.keep_alive_timeout is not None:
self.check_keep_alive_task = asyncio.ensure_future(
self._check_ws_liveness()
)
# Create a task to listen to the incoming websocket messages
self.receive_data_task = asyncio.ensure_future(self._receive_data_loop())
else:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("connect: done")
def _remove_listener(self, query_id) -> None:
"""After exiting from a subscription, remove the listener and
signal an event if this was the last listener for the client.
"""
if query_id in self.listeners:
del self.listeners[query_id]
remaining = len(self.listeners)
log.debug(f"listener {query_id} deleted, {remaining} remaining")
if remaining == 0:
self._no_more_listeners.set()
async def _clean_close(self, e: Exception) -> None:
"""Coroutine which will:
- send stop messages for each active subscription to the server
- send the connection terminate message
"""
# Send 'stop' message for all current queries
for query_id, listener in self.listeners.items():
if listener.send_stop:
await self._stop_listener(query_id)
listener.send_stop = False
# Wait that there is no more listeners (we received 'complete' for all queries)
try:
await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout)
except asyncio.TimeoutError: # pragma: no cover
log.debug("Timer close_timeout fired")
# Calling the subclass hook
await self._connection_terminate()
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
"""Coroutine which will:
- do a clean_close if possible:
- send stop messages for each active query to the server
- send the connection terminate message
- close the websocket connection
- send the exception to all the remaining listeners
"""
log.debug("_close_coro: starting")
try:
# We should always have an active websocket connection here
assert self.websocket is not None
# Properly shut down liveness checker if enabled
if self.check_keep_alive_task is not None:
# More info: https://stackoverflow.com/a/43810272/1113207
self.check_keep_alive_task.cancel()
with suppress(asyncio.CancelledError):
await self.check_keep_alive_task
# Calling the subclass close hook
await self._close_hook()
# Saving exception to raise it later if trying to use the transport
# after it has already closed.
self.close_exception = e
if clean_close:
log.debug("_close_coro: starting clean_close")
try:
await self._clean_close(e)
except Exception as exc: # pragma: no cover
log.warning("Ignoring exception in _clean_close: " + repr(exc))
log.debug("_close_coro: sending exception to listeners")
# Send an exception to all remaining listeners
for query_id, listener in self.listeners.items():
await listener.set_exception(e)
log.debug("_close_coro: close websocket connection")
await self.websocket.close()
log.debug("_close_coro: websocket connection closed")
except Exception as exc: # pragma: no cover
log.warning("Exception catched in _close_coro: " + repr(exc))
finally:
log.debug("_close_coro: start cleanup")
self.websocket = None
self.close_task = None
self.check_keep_alive_task = None
self._wait_closed.set()
log.debug("_close_coro: exiting")
async def _fail(self, e: Exception, clean_close: bool = True) -> None:
log.debug("_fail: starting with exception: " + repr(e))
if self.close_task is None:
if self.websocket is None:
log.debug("_fail started with self.websocket == None -> already closed")
else:
self.close_task = asyncio.shield(
asyncio.ensure_future(self._close_coro(e, clean_close=clean_close))
)
else:
log.debug(
"close_task is not None in _fail. Previous exception is: "
+ repr(self.close_exception)
+ " New exception is: "
+ repr(e)
)
async def close(self) -> None:
log.debug("close: starting")
await self._fail(TransportClosed("Websocket GraphQL transport closed by user"))
await self.wait_closed()
log.debug("close: done")
async def wait_closed(self) -> None:
log.debug("wait_close: starting")
await self._wait_closed.wait()
log.debug("wait_close: done")