2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,20 @@
"""The primary :mod:`gql` package includes everything you need to
execute GraphQL requests, with the exception of the transports
which are optional:
- the :func:`gql <gql.gql>` method to parse a GraphQL query
- the :class:`Client <gql.Client>` class as the entrypoint to execute requests
and create sessions
"""
from .__version__ import __version__
from .client import Client
from .gql import gql
from .graphql_request import GraphQLRequest
__all__ = [
"__version__",
"gql",
"Client",
"GraphQLRequest",
]
@@ -0,0 +1 @@
__version__ = "3.5.3"
@@ -0,0 +1,541 @@
import asyncio
import json
import logging
import signal as signal_module
import sys
import textwrap
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
from typing import Any, Dict, Optional
from graphql import GraphQLError, print_schema
from yarl import URL
from gql import Client, __version__, gql
from gql.transport import AsyncTransport
from gql.transport.exceptions import TransportQueryError
description = """
Send GraphQL queries from the command line using http(s) or websockets.
If used interactively, write your query, then use Ctrl-D (EOF) to execute it.
"""
examples = """
EXAMPLES
========
# Simple query using https
echo 'query { continent(code:"AF") { name } }' | \
gql-cli https://countries.trevorblades.com
# Simple query using websockets
echo 'query { continent(code:"AF") { name } }' | \
gql-cli wss://countries.trevorblades.com/graphql
# Query with variable
echo 'query getContinent($code:ID!) { continent(code:$code) { name } }' | \
gql-cli https://countries.trevorblades.com --variables code:AF
# Interactive usage (insert your query in the terminal, then press Ctrl-D to execute it)
gql-cli wss://countries.trevorblades.com/graphql --variables code:AF
# Execute query saved in a file
cat query.gql | gql-cli wss://countries.trevorblades.com/graphql
# Print the schema of the backend
gql-cli https://countries.trevorblades.com/graphql --print-schema
"""
def positive_int_or_none(value_str: str) -> Optional[int]:
"""Convert a string argument value into either an int or None.
Raise a ValueError if the argument is negative or a string which is not "none"
"""
try:
value_int = int(value_str)
except ValueError:
if value_str.lower() == "none":
return None
else:
raise
if value_int < 0:
raise ValueError
return value_int
def get_parser(with_examples: bool = False) -> ArgumentParser:
"""Provides an ArgumentParser for the gql-cli script.
This function is also used by sphinx to generate the script documentation.
:param with_examples: set to False by default so that the examples are not
present in the sphinx docs (they are put there with
a different layout)
"""
parser = ArgumentParser(
description=description,
epilog=examples if with_examples else None,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"server", help="the server url starting with http://, https://, ws:// or wss://"
)
parser.add_argument(
"-V",
"--variables",
nargs="*",
help="query variables in the form key:json_value",
)
parser.add_argument(
"-H", "--headers", nargs="*", help="http headers in the form key:value"
)
parser.add_argument("--version", action="version", version=f"v{__version__}")
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-d",
"--debug",
help="print lots of debugging statements (loglevel==DEBUG)",
action="store_const",
dest="loglevel",
const=logging.DEBUG,
)
group.add_argument(
"-v",
"--verbose",
help="show low level messages (loglevel==INFO)",
action="store_const",
dest="loglevel",
const=logging.INFO,
)
parser.add_argument(
"-o",
"--operation-name",
help="set the operation_name value",
dest="operation_name",
)
parser.add_argument(
"--print-schema",
help="get the schema from instrospection and print it",
action="store_true",
dest="print_schema",
)
parser.add_argument(
"--schema-download",
nargs="*",
help=textwrap.dedent(
"""select the introspection query arguments to download the schema.
Only useful if --print-schema is used.
By default, it will:
- request field descriptions
- not request deprecated input fields
Possible options:
- descriptions:false for a compact schema without comments
- input_value_deprecation:true to download deprecated input fields
- specified_by_url:true
- schema_description:true
- directive_is_repeatable:true"""
),
dest="schema_download",
)
parser.add_argument(
"--execute-timeout",
help="set the execute_timeout argument of the Client (default: 10)",
type=positive_int_or_none,
default=10,
dest="execute_timeout",
)
parser.add_argument(
"--transport",
default="auto",
choices=[
"auto",
"aiohttp",
"phoenix",
"websockets",
"appsync_http",
"appsync_websockets",
],
help=(
"select the transport. 'auto' by default: "
"aiohttp or websockets depending on url scheme"
),
dest="transport",
)
appsync_description = """
By default, for an AppSync backend, the IAM authentication is chosen.
If you want API key or JWT authentication, you can provide one of the
following arguments:"""
appsync_group = parser.add_argument_group(
"AWS AppSync options", description=appsync_description
)
appsync_auth_group = appsync_group.add_mutually_exclusive_group()
appsync_auth_group.add_argument(
"--api-key",
help="Provide an API key for authentication",
dest="api_key",
)
appsync_auth_group.add_argument(
"--jwt",
help="Provide an JSON Web token for authentication",
dest="jwt",
)
return parser
def get_transport_args(args: Namespace) -> Dict[str, Any]:
"""Extract extra arguments necessary for the transport
from the parsed command line args
Will create a headers dict by splitting the colon
in the --headers arguments
:param args: parsed command line arguments
"""
transport_args: Dict[str, Any] = {}
# Parse the headers argument
headers = {}
if args.headers is not None:
for header in args.headers:
try:
# Split only the first colon (throw a ValueError if no colon is present)
header_key, header_value = header.split(":", 1)
headers[header_key] = header_value
except ValueError:
raise ValueError(f"Invalid header: {header}")
if args.headers is not None:
transport_args["headers"] = headers
return transport_args
def get_execute_args(args: Namespace) -> Dict[str, Any]:
"""Extract extra arguments necessary for the execute or subscribe
methods from the parsed command line args
Extract the operation_name
Extract the variable_values from the --variables argument
by splitting the first colon, then loads the json value,
We try to add double quotes around the value if it does not work first
in order to simplify the passing of simple string values
(we allow --variables KEY:VALUE instead of KEY:\"VALUE\")
:param args: parsed command line arguments
"""
execute_args: Dict[str, Any] = {}
# Parse the operation_name argument
if args.operation_name is not None:
execute_args["operation_name"] = args.operation_name
# Parse the variables argument
if args.variables is not None:
variables = {}
for var in args.variables:
try:
# Split only the first colon
# (throw a ValueError if no colon is present)
variable_key, variable_json_value = var.split(":", 1)
# Extract the json value,
# trying with double quotes if it does not work
try:
variable_value = json.loads(variable_json_value)
except json.JSONDecodeError:
try:
variable_value = json.loads(f'"{variable_json_value}"')
except json.JSONDecodeError:
raise ValueError
# Save the value in the variables dict
variables[variable_key] = variable_value
except ValueError:
raise ValueError(f"Invalid variable: {var}")
execute_args["variable_values"] = variables
return execute_args
def autodetect_transport(url: URL) -> str:
"""Detects which transport should be used depending on url."""
if url.scheme in ["ws", "wss"]:
transport_name = "websockets"
else:
assert url.scheme in ["http", "https"]
transport_name = "aiohttp"
return transport_name
def get_transport(args: Namespace) -> Optional[AsyncTransport]:
"""Instantiate a transport from the parsed command line arguments
:param args: parsed command line arguments
"""
# Get the url scheme from server parameter
url = URL(args.server)
# Validate scheme
if url.scheme not in ["http", "https", "ws", "wss"]:
raise ValueError("URL protocol should be one of: http, https, ws, wss")
# Get extra transport parameters from command line arguments
# (headers)
transport_args = get_transport_args(args)
# Either use the requested transport or autodetect it
if args.transport == "auto":
transport_name = autodetect_transport(url)
else:
transport_name = args.transport
# Import the correct transport class depending on the transport name
if transport_name == "aiohttp":
from gql.transport.aiohttp import AIOHTTPTransport
return AIOHTTPTransport(url=args.server, **transport_args)
elif transport_name == "phoenix":
from gql.transport.phoenix_channel_websockets import (
PhoenixChannelWebsocketsTransport,
)
return PhoenixChannelWebsocketsTransport(url=args.server, **transport_args)
elif transport_name == "websockets":
from gql.transport.websockets import WebsocketsTransport
transport_args["ssl"] = url.scheme == "wss"
return WebsocketsTransport(url=args.server, **transport_args)
else:
from gql.transport.appsync_auth import AppSyncAuthentication
assert transport_name in ["appsync_http", "appsync_websockets"]
assert url.host is not None
auth: AppSyncAuthentication
if args.api_key:
from gql.transport.appsync_auth import AppSyncApiKeyAuthentication
auth = AppSyncApiKeyAuthentication(host=url.host, api_key=args.api_key)
elif args.jwt:
from gql.transport.appsync_auth import AppSyncJWTAuthentication
auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt)
else:
from gql.transport.appsync_auth import AppSyncIAMAuthentication
from botocore.exceptions import NoRegionError
try:
auth = AppSyncIAMAuthentication(host=url.host)
except NoRegionError:
# A warning message has been printed in the console
return None
transport_args["auth"] = auth
if transport_name == "appsync_http":
from gql.transport.aiohttp import AIOHTTPTransport
return AIOHTTPTransport(url=args.server, **transport_args)
else:
from gql.transport.appsync_websockets import AppSyncWebsocketsTransport
try:
return AppSyncWebsocketsTransport(url=args.server, **transport_args)
except Exception:
# This is for the NoCredentialsError but we cannot import it here
return None
def get_introspection_args(args: Namespace) -> Dict:
"""Get the introspection args depending on the schema_download argument"""
# Parse the headers argument
introspection_args = {}
possible_args = [
"descriptions",
"specified_by_url",
"directive_is_repeatable",
"schema_description",
"input_value_deprecation",
]
if args.schema_download is not None:
for arg in args.schema_download:
try:
# Split only the first colon (throw a ValueError if no colon is present)
arg_key, arg_value = arg.split(":", 1)
if arg_key not in possible_args:
raise ValueError(f"Invalid schema_download: {args.schema_download}")
arg_value = arg_value.lower()
if arg_value not in ["true", "false"]:
raise ValueError(f"Invalid schema_download: {args.schema_download}")
introspection_args[arg_key] = arg_value == "true"
except ValueError:
raise ValueError(f"Invalid schema_download: {args.schema_download}")
return introspection_args
async def main(args: Namespace) -> int:
"""Main entrypoint of the gql-cli script
:param args: The parsed command line arguments
:return: The script exit code (0 = ok, 1 = error)
"""
# Set requested log level
if args.loglevel is not None:
logging.basicConfig(level=args.loglevel)
try:
# Instantiate transport from command line arguments
transport = get_transport(args)
if transport is None:
return 1
# Get extra execute parameters from command line arguments
# (variables, operation_name)
execute_args = get_execute_args(args)
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
return 1
# By default, the exit_code is 0 (everything is ok)
exit_code = 0
# Connect to the backend and provide a session
async with Client(
transport=transport,
fetch_schema_from_transport=args.print_schema,
introspection_args=get_introspection_args(args),
execute_timeout=args.execute_timeout,
) as session:
if args.print_schema:
schema_str = print_schema(session.client.schema)
print(schema_str)
return exit_code
while True:
# Read multiple lines from input and trim whitespaces
# Will read until EOF character is received (Ctrl-D)
query_str = sys.stdin.read().strip()
# Exit if query is empty
if len(query_str) == 0:
break
# Parse query, continue on error
try:
query = gql(query_str)
except GraphQLError as e:
print(e, file=sys.stderr)
exit_code = 1
continue
# Execute or Subscribe the query depending on transport
try:
try:
async for result in session.subscribe(query, **execute_args):
print(json.dumps(result))
except KeyboardInterrupt: # pragma: no cover
pass
except NotImplementedError:
result = await session.execute(query, **execute_args)
print(json.dumps(result))
except (GraphQLError, TransportQueryError) as e:
print(e, file=sys.stderr)
exit_code = 1
return exit_code
def gql_cli() -> None:
"""Synchronously invoke ``main`` with the parsed command line arguments.
Formerly ``scripts/gql-cli``, now registered as an ``entry_point``
"""
# Get arguments from command line
parser = get_parser(with_examples=True)
args = parser.parse_args()
try:
# Create a new asyncio event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a gql-cli task with the supplied arguments
main_task = asyncio.ensure_future(main(args), loop=loop)
# Add signal handlers to close gql-cli cleanly on Control-C
for signal_name in ["SIGINT", "SIGTERM", "CTRL_C_EVENT", "CTRL_BREAK_EVENT"]:
signal = getattr(signal_module, signal_name, None)
if signal is None:
continue
try:
loop.add_signal_handler(signal, main_task.cancel)
except NotImplementedError: # pragma: no cover
# not all signals supported on all platforms
pass
# Run the asyncio loop to execute the task
exit_code = 0
try:
exit_code = loop.run_until_complete(main_task)
finally:
loop.close()
# Return with the correct exit code
sys.exit(exit_code)
except KeyboardInterrupt: # pragma: no cover
pass
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,24 @@
from __future__ import annotations
from graphql import DocumentNode, Source, parse
def gql(request_string: str | Source) -> DocumentNode:
"""Given a string containing a GraphQL request, parse it into a Document.
:param request_string: the GraphQL request as a String
:type request_string: str | Source
:return: a Document which can be later executed or subscribed by a
:class:`Client <gql.client.Client>`, by an
:class:`async session <gql.client.AsyncClientSession>` or by a
:class:`sync session <gql.client.SyncClientSession>`
:raises GraphQLError: if a syntax error is encountered.
"""
if isinstance(request_string, Source):
source = request_string
elif isinstance(request_string, str):
source = Source(request_string, "GraphQL request")
else:
raise TypeError("Request must be passed as a string or Source object.")
return parse(source)
@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from graphql import DocumentNode, GraphQLSchema
from .utilities import serialize_variable_values
@dataclass(frozen=True)
class GraphQLRequest:
"""GraphQL Request to be executed."""
document: DocumentNode
"""GraphQL query as AST Node object."""
variable_values: Optional[Dict[str, Any]] = None
"""Dictionary of input parameters (Default: None)."""
operation_name: Optional[str] = None
"""
Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""
def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
assert self.variable_values
return GraphQLRequest(
document=self.document,
variable_values=serialize_variable_values(
schema=schema,
document=self.document,
variable_values=self.variable_values,
operation_name=self.operation_name,
),
operation_name=self.operation_name,
)
@@ -0,0 +1 @@
# Marker file for PEP 561. The gql package uses inline types.
@@ -0,0 +1,4 @@
from .async_transport import AsyncTransport
from .transport import Transport
__all__ = ["AsyncTransport", "Transport"]
@@ -0,0 +1,386 @@
import asyncio
import functools
import io
import json
import logging
import warnings
from ssl import SSLContext
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
cast,
)
import aiohttp
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.client_reqrep import Fingerprint
from aiohttp.helpers import BasicAuth
from aiohttp.typedefs import LooseCookies, LooseHeaders
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDictProxy
from ..utils import extract_files
from .appsync_auth import AppSyncAuthentication
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class AIOHTTPTransport(AsyncTransport):
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
on remote servers with an HTTP connection.
This transport use the aiohttp library with asyncio.
"""
file_classes: Tuple[Type[Any], ...] = (
io.IOBase,
aiohttp.StreamReader,
AsyncGenerator,
)
def __init__(
self,
url: str,
headers: Optional[LooseHeaders] = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param headers: Dict of HTTP Headers.
:param cookies: Dict of HTTP cookies.
:param auth: BasicAuth object to enable Basic HTTP auth if needed
Or Appsync Authentication class
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
to close properly
:param json_serialize: Json serializer callable.
By default json.dumps() function
:param client_session_args: Dict of extra args passed to
`aiohttp.ClientSession`_
.. _aiohttp.ClientSession:
https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession
"""
self.url: str = url
self.headers: Optional[LooseHeaders] = headers
self.cookies: Optional[LooseCookies] = cookies
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
if ssl == "ssl_warning":
ssl = False
if str(url).startswith("https"):
warnings.warn(
"WARNING: By default, AIOHTTPTransport does not verify"
" ssl certificates. This will be fixed in the next major version."
" You can set ssl=True to force the ssl certificate verification"
" or ssl=False to disable this warning"
)
self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
Union[SSLContext, bool, Fingerprint], ssl
)
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
self.session: Optional[aiohttp.ClientSession] = None
self.response_headers: Optional[CIMultiDictProxy[str]]
self.json_serialize: Callable = json_serialize
async def connect(self) -> None:
"""Coroutine which will create an aiohttp ClientSession() as self.session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
to create the session.
Should be cleaned with a call to the close coroutine.
"""
if self.session is None:
client_session_args: Dict[str, Any] = {
"cookies": self.cookies,
"headers": self.headers,
"auth": None
if isinstance(self.auth, AppSyncAuthentication)
else self.auth,
"json_serialize": self.json_serialize,
}
if self.timeout is not None:
client_session_args["timeout"] = aiohttp.ClientTimeout(
total=self.timeout
)
# Adding custom parameters passed from init
if self.client_session_args:
client_session_args.update(self.client_session_args) # type: ignore
log.debug("Connecting transport")
self.session = aiohttp.ClientSession(**client_session_args)
else:
raise TransportAlreadyConnected("Transport is already connected")
@staticmethod
def create_aiohttp_closed_event(session) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
Returns:
An event that will be set once all transports have been properly closed.
"""
ssl_transports = 0
all_is_lost = asyncio.Event()
def connection_lost(exc, orig_lost):
nonlocal ssl_transports
try:
orig_lost(exc)
finally:
ssl_transports -= 1
if ssl_transports == 0:
all_is_lost.set()
def eof_received(orig_eof_received):
try:
orig_eof_received()
except AttributeError: # pragma: no cover
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass
for conn in session.connector._conns.values():
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue
ssl_transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received
proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)
if ssl_transports == 0:
all_is_lost.set()
return all_is_lost
async def close(self) -> None:
"""Coroutine which will close the aiohttp session.
Don't call this coroutine directly on the transport, instead use
:code:`async with` on the client and this coroutine will be executed
when you exit the async context manager.
"""
if self.session is not None:
log.debug("Closing transport")
if (
self.client_session_args
and self.client_session_args.get("connector_owner") is False
):
log.debug("connector_owner is False -> not closing connector")
else:
closed_event = self.create_aiohttp_closed_event(self.session)
await self.session.close()
try:
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
except asyncio.TimeoutError:
pass
self.session = None
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
using the current session.
This uses the aiohttp library to perform a HTTP POST request asynchronously
to the remote server.
Don't call this coroutine directly on the transport, instead use
:code:`execute` on a client or a session.
:param document: the parsed GraphQL request
:param variable_values: An optional Dict of variable values
:param operation_name: An optional Operation name for the request
:param extra_args: additional arguments to send to the aiohttp post method
:param upload_files: Set to True if you want to put files in the variable values
:returns: an ExecutionResult object.
"""
query_str = print_ast(document)
payload: Dict[str, Any] = {
"query": query_str,
}
if operation_name:
payload["operationName"] = operation_name
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Prepare aiohttp to send multipart-encoded data
data = aiohttp.FormData()
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Enumerate the file streams
# Will generate something like {'0': <_io.BufferedReader ...>}
file_streams = {str(i): files[path] for i, path in enumerate(files)}
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
data.add_field(
"operations", operations_str, content_type="application/json"
)
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
data.add_field("map", file_map_str, content_type="application/json")
# Add the extracted files as remaining fields
for k, f in file_streams.items():
name = getattr(f, "name", k)
content_type = getattr(f, "content_type", None)
data.add_field(k, f, filename=name, content_type=content_type)
post_args: Dict[str, Any] = {"data": data}
else:
if variable_values:
payload["variables"] = variable_values
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", self.json_serialize(payload))
post_args = {"json": payload}
# Pass post_args to aiohttp post method
if extra_args:
post_args.update(extra_args)
# Add headers for AppSync if requested
if isinstance(self.auth, AppSyncAuthentication):
post_args["headers"] = self.auth.get_headers(
self.json_serialize(payload),
{"content-type": "application/json"},
)
if self.session is None:
raise TransportClosed("Transport is not connected")
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
# Saving latest response headers in the transport
self.response_headers = resp.headers
async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a ClientResponseError if response status is 400 or higher
resp.raise_for_status()
except ClientResponseError as e:
raise TransportServerError(str(e), e.status) from e
result_text = await resp.text()
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
result = await resp.json(content_type=None)
if log.isEnabledFor(logging.INFO):
result_text = await resp.text()
log.info("<<< %s", result_text)
except Exception:
await raise_response_error(resp, "Not a JSON answer")
if result is None:
await raise_response_error(resp, "Not a JSON answer")
if "errors" not in result and "data" not in result:
await raise_response_error(resp, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Subscribe is not supported on HTTP.
:meta private:
"""
raise NotImplementedError(" The HTTP transport does not support subscriptions")
@@ -0,0 +1,221 @@
import json
import logging
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from typing import Any, Callable, Dict, Optional
try:
import botocore
except ImportError: # pragma: no cover
# botocore is only needed for the IAM AppSync authentication method
pass
log = logging.getLogger("gql.transport.appsync")
class AppSyncAuthentication(ABC):
"""AWS authentication abstract base class
All AWS authentication class should have a
:meth:`get_headers <gql.transport.appsync.AppSyncAuthentication.get_headers>`
method which defines the headers used in the authentication process."""
def get_auth_url(self, url: str) -> str:
"""
:return: a url with base64 encoded headers used to establish
a websocket connection to the appsync-realtime-api.
"""
headers = self.get_headers()
encoded_headers = b64encode(
json.dumps(headers, separators=(",", ":")).encode()
).decode()
url_base = url.replace("https://", "wss://").replace(
"appsync-api", "appsync-realtime-api"
)
return f"{url_base}?header={encoded_headers}&payload=e30="
@abstractmethod
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
raise NotImplementedError() # pragma: no cover
class AppSyncApiKeyAuthentication(AppSyncAuthentication):
"""AWS authentication class using an API key"""
def __init__(self, host: str, api_key: str) -> None:
"""
:param host: the host, something like:
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
:param api_key: the API key
"""
self._host = host.replace("appsync-realtime-api", "appsync-api")
self.api_key = api_key
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return {"host": self._host, "x-api-key": self.api_key}
class AppSyncJWTAuthentication(AppSyncAuthentication):
"""AWS authentication class using a JWT access token.
It can be used either for:
- Amazon Cognito user pools
- OpenID Connect (OIDC)
"""
def __init__(self, host: str, jwt: str) -> None:
"""
:param host: the host, something like:
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
:param jwt: the JWT Access Token
"""
self._host = host.replace("appsync-realtime-api", "appsync-api")
self.jwt = jwt
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return {"host": self._host, "Authorization": self.jwt}
class AppSyncIAMAuthentication(AppSyncAuthentication):
"""AWS authentication class using IAM.
.. note::
There is no need for you to use this class directly, you could instead
intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport`
without an auth argument.
During initialization, this class will use botocore to attempt to
find your IAM credentials, either from environment variables or
from your AWS credentials file.
"""
def __init__(
self,
host: str,
region_name: Optional[str] = None,
signer: Optional["botocore.auth.BaseSigner"] = None,
request_creator: Optional[
Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"]
] = None,
credentials: Optional["botocore.credentials.Credentials"] = None,
session: Optional["botocore.session.Session"] = None,
) -> None:
"""Initialize itself, saving the found credentials used
to sign the headers later.
if no credentials are found, then a NoCredentialsError is raised.
"""
from botocore.auth import SigV4Auth
from botocore.awsrequest import create_request_object
from botocore.session import get_session
self._host = host.replace("appsync-realtime-api", "appsync-api")
self._session = session if session else get_session()
self._credentials = (
credentials if credentials else self._session.get_credentials()
)
self._service_name = "appsync"
self._region_name = region_name or self._detect_region_name()
self._signer = (
signer
if signer
else SigV4Auth(self._credentials, self._service_name, self._region_name)
)
self._request_creator = (
request_creator if request_creator else create_request_object
)
def _detect_region_name(self):
"""Try to detect the correct region_name.
First try to extract the region_name from the host.
If that does not work, then try to get the region_name from
the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION
environment variable.
If no region_name was found, then raise a NoRegionError exception."""
from botocore.exceptions import NoRegionError
# Regular expression from botocore.utils.validate_region
m = re.search(
r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(?<!-))\.", self._host
)
if m:
region_name = m.groups()[0]
log.debug(f"Region name extracted from host: {region_name}")
else:
log.debug("Region name not found in host, trying default region name")
region_name = self._session._resolve_region_name(
None, self._session.get_default_client_config()
)
if region_name is None:
log.warning(
"Region name not found. "
"It was not possible to detect your region either from the host "
"or from your default AWS configuration."
)
raise NoRegionError
return region_name
def get_headers(
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
from botocore.exceptions import NoCredentialsError
# Default headers for a websocket connection
headers = headers or {
"accept": "application/json, text/javascript",
"content-encoding": "amz-1.0",
"content-type": "application/json; charset=UTF-8",
}
request: "botocore.awsrequest.AWSRequest" = self._request_creator(
{
"method": "POST",
"url": f"https://{self._host}/graphql{'' if data else '/connect'}",
"headers": headers,
"context": {},
"body": data or "{}",
}
)
try:
self._signer.add_auth(request)
except NoCredentialsError:
log.warning(
"Credentials not found for the IAM auth. "
"Do you have default AWS credentials configured?",
)
raise
headers = dict(request.headers)
headers["host"] = self._host
if log.isEnabledFor(logging.DEBUG):
headers_log = []
headers_log.append("\n\nSigned headers:")
for key, value in headers.items():
headers_log.append(f" {key}: {value}")
headers_log.append("\n")
log.debug("\n".join(headers_log))
return headers
@@ -0,0 +1,214 @@
import json
import logging
from ssl import SSLContext
from typing import Any, Dict, Optional, Tuple, Union, cast
from urllib.parse import urlparse
from graphql import DocumentNode, ExecutionResult, print_ast
from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication
from .exceptions import TransportProtocolError, TransportServerError
from .websockets import WebsocketsTransport, WebsocketsTransportBase
log = logging.getLogger("gql.transport.appsync")
try:
import botocore
except ImportError: # pragma: no cover
# botocore is only needed for the IAM AppSync authentication method
pass
class AppSyncWebsocketsTransport(WebsocketsTransportBase):
""":ref:`Async Transport <async_transports>` used to execute GraphQL subscription on
AWS appsync realtime endpoint.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
auth: Optional[AppSyncAuthentication]
def __init__(
self,
url: str,
auth: Optional[AppSyncAuthentication] = None,
session: Optional["botocore.session.Session"] = None,
ssl: Union[SSLContext, bool] = False,
connect_timeout: int = 10,
close_timeout: int = 10,
ack_timeout: int = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
connect_args: Dict[str, Any] = {},
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL endpoint URL. Example:
https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql
:param auth: Optional AWS authentication class which will provide the
necessary headers to be correctly authenticated. If this
argument is not provided, then we will try to authenticate
using IAM.
:param ssl: ssl_context of the connection.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param connect_args: Other parameters forwarded to websockets.connect
"""
if not auth:
# Extract host from url
host = str(urlparse(url).netloc)
# May raise NoRegionError or NoCredentialsError or ImportError
auth = AppSyncIAMAuthentication(host=host, session=session)
self.auth = auth
url = self.auth.get_auth_url(url)
super().__init__(
url,
ssl=ssl,
connect_timeout=connect_timeout,
close_timeout=close_timeout,
ack_timeout=ack_timeout,
keep_alive_timeout=keep_alive_timeout,
connect_args=connect_args,
)
# Using the same 'graphql-ws' protocol as the apollo protocol
self.supported_subprotocols = [
WebsocketsTransport.APOLLO_SUBPROTOCOL,
]
self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server.
Difference between apollo protocol and aws protocol:
- aws protocol can return an error without an id
- aws protocol will send start_ack messages
Returns a list consisting of:
- the answer_type:
- 'connection_ack',
- 'connection_error',
- 'start_ack',
- 'ka',
- 'data',
- 'error',
- 'complete'
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
answer_type: str = ""
try:
json_answer = json.loads(answer)
answer_type = str(json_answer.get("type"))
if answer_type == "start_ack":
return ("start_ack", None, None)
elif answer_type == "error" and "id" not in json_answer:
error_payload = json_answer.get("payload")
raise TransportServerError(f"Server error: '{error_payload!r}'")
else:
return WebsocketsTransport._parse_answer_apollo(
cast(WebsocketsTransport, self), json_answer
)
except ValueError:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {answer}"
)
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
query_id = self.next_query_id
self.next_query_id += 1
data: Dict = {"query": print_ast(document)}
if variable_values:
data["variables"] = variable_values
if operation_name:
data["operationName"] = operation_name
serialized_data = json.dumps(data, separators=(",", ":"))
payload = {"data": serialized_data}
message: Dict = {
"id": str(query_id),
"type": "start",
"payload": payload,
}
assert self.auth is not None
message["payload"]["extensions"] = {
"authorization": self.auth.get_headers(serialized_data)
}
await self._send(
json.dumps(
message,
separators=(",", ":"),
)
)
return query_id
subscribe = WebsocketsTransportBase.subscribe
"""Send a subscription query and receive the results using
a python async generator.
Only subscriptions are supported, queries and mutations are forbidden.
The results are sent as an ExecutionResult object.
"""
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""This method is not available.
Only subscriptions are supported on the AWS realtime endpoint.
:raise: AssertionError"""
raise AssertionError(
"execute method is not allowed for AppSyncWebsocketsTransport "
"because only subscriptions are allowed on the realtime endpoint."
)
_initialize = WebsocketsTransport._initialize
_stop_listener = WebsocketsTransport._send_stop_message # type: ignore
_send_init_message_and_wait_ack = (
WebsocketsTransport._send_init_message_and_wait_ack
)
_wait_ack = WebsocketsTransport._wait_ack
@@ -0,0 +1,50 @@
import abc
from typing import Any, AsyncGenerator, Dict, Optional
from graphql import DocumentNode, ExecutionResult
class AsyncTransport(abc.ABC):
@abc.abstractmethod
async def connect(self):
"""Coroutine used to create a connection to the specified address"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement connect method"
) # pragma: no cover
@abc.abstractmethod
async def close(self):
"""Coroutine used to Close an established connection"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement close method"
) # pragma: no cover
@abc.abstractmethod
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""Execute the provided document AST for either a remote or local GraphQL
Schema."""
raise NotImplementedError(
"Any AsyncTransport subclass must implement execute method"
) # pragma: no cover
@abc.abstractmethod
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a query and receive the results using an async generator
The query can be a graphql query, mutation or subscription
The results are sent as an ExecutionResult object
"""
raise NotImplementedError(
"Any AsyncTransport subclass must implement subscribe method"
) # pragma: no cover
@@ -0,0 +1,69 @@
from typing import Any, List, Optional
class TransportError(Exception):
"""Base class for all the Transport exceptions"""
pass
class TransportProtocolError(TransportError):
"""Transport protocol error.
The answer received from the server does not correspond to the transport protocol.
"""
class TransportServerError(TransportError):
"""The server returned a global error.
This exception will close the transport connection.
"""
code: Optional[int]
def __init__(self, message: str, code: Optional[int] = None):
super(TransportServerError, self).__init__(message)
self.code = code
class TransportQueryError(TransportError):
"""The server returned an error for a specific query.
This exception should not close the transport connection.
"""
query_id: Optional[int]
errors: Optional[List[Any]]
data: Optional[Any]
extensions: Optional[Any]
def __init__(
self,
msg: str,
query_id: Optional[int] = None,
errors: Optional[List[Any]] = None,
data: Optional[Any] = None,
extensions: Optional[Any] = None,
):
super().__init__(msg)
self.query_id = query_id
self.errors = errors
self.data = data
self.extensions = extensions
class TransportClosed(TransportError):
"""Transport is already closed.
This exception is generated when the client is trying to use the transport
while the transport was previously closed.
"""
class TransportAlreadyConnected(TransportError):
"""Transport is already connected.
Exception generated when the client is trying to connect to the transport
while the transport is already connected.
"""
@@ -0,0 +1,311 @@
import io
import json
import logging
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import httpx
from graphql import DocumentNode, ExecutionResult, print_ast
from ..utils import extract_files
from . import AsyncTransport, Transport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class _HTTPXTransport:
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
response_headers: Optional[httpx.Headers] = None
def __init__(
self,
url: Union[str, httpx.URL],
json_serialize: Callable = json.dumps,
**kwargs,
):
"""Initialize the transport with the given httpx parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param json_serialize: Json serializer callable.
By default json.dumps() function.
:param kwargs: Extra args passed to the `httpx` client.
"""
self.url = url
self.json_serialize = json_serialize
self.kwargs = kwargs
def _prepare_request(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> Dict[str, Any]:
query_str = print_ast(document)
payload: Dict[str, Any] = {
"query": query_str,
}
if operation_name:
payload["operationName"] = operation_name
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
post_args = self._prepare_file_uploads(variable_values, payload)
else:
if variable_values:
payload["variables"] = variable_values
post_args = {"json": payload}
# Log the payload
if log.isEnabledFor(logging.DEBUG):
log.debug(">>> %s", self.json_serialize(payload))
# Pass post_args to httpx post method
if extra_args:
post_args.update(extra_args)
return post_args
def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]:
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Prepare to send multipart-encoded data
data: Dict[str, Any] = {}
file_map: Dict[str, List[str]] = {}
file_streams: Dict[str, Tuple[str, ...]] = {}
for i, (path, f) in enumerate(files.items()):
key = str(i)
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map[key] = [path]
# Generate the file streams
# Will generate something like
# {"0": ("variables.file", <_io.BufferedReader ...>)}
name = cast(str, getattr(f, "name", key))
content_type = getattr(f, "content_type", None)
if content_type is None:
file_streams[key] = (name, f)
else:
file_streams[key] = (name, f, content_type)
# Add the payload to the operations field
operations_str = self.json_serialize(payload)
log.debug("operations %s", operations_str)
data["operations"] = operations_str
# Add the file map field
file_map_str = self.json_serialize(file_map)
log.debug("file_map %s", file_map_str)
data["map"] = file_map_str
return {"data": data, "files": file_streams}
def _prepare_result(self, response: httpx.Response) -> ExecutionResult:
# Save latest response headers in transport
self.response_headers = response.headers
if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", response.text)
try:
result: Dict[str, Any] = response.json()
except Exception:
self._raise_response_error(response, "Not a JSON answer")
if "errors" not in result and "data" not in result:
self._raise_response_error(response, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def _raise_response_error(self, response: httpx.Response, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a HTTPError if response status is 400 or higher
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportServerError(str(e), e.response.status_code) from e
raise TransportProtocolError(
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
)
class HTTPXTransport(Transport, _HTTPXTransport):
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the httpx library to send HTTP POST requests.
"""
client: Optional[httpx.Client] = None
def connect(self):
if self.client:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("Connecting transport")
self.client = httpx.Client(**self.kwargs)
def execute( # type: ignore
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the httpx library to perform a HTTP POST request to the remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param extra_args: additional arguments to send to the httpx post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.client:
raise TransportClosed("Transport is not connected")
post_args = self._prepare_request(
document,
variable_values,
operation_name,
extra_args,
upload_files,
)
response = self.client.post(self.url, **post_args)
return self._prepare_result(response)
def close(self):
"""Closing the transport by closing the inner session"""
if self.client:
self.client.close()
self.client = None
class HTTPXAsyncTransport(AsyncTransport, _HTTPXTransport):
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the httpx library with anyio.
"""
client: Optional[httpx.AsyncClient] = None
async def connect(self):
if self.client:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("Connecting transport")
self.client = httpx.AsyncClient(**self.kwargs)
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the httpx library to perform a HTTP POST request asynchronously to the
remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param extra_args: additional arguments to send to the httpx post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.client:
raise TransportClosed("Transport is not connected")
post_args = self._prepare_request(
document,
variable_values,
operation_name,
extra_args,
upload_files,
)
response = await self.client.post(self.url, **post_args)
return self._prepare_result(response)
async def close(self):
"""Closing the transport by closing the inner session"""
if self.client:
await self.client.aclose()
self.client = None
def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Subscribe is not supported on HTTP.
:meta private:
"""
raise NotImplementedError("The HTTP transport does not support subscriptions")
@@ -0,0 +1,78 @@
import asyncio
from inspect import isawaitable
from typing import AsyncGenerator, Awaitable, cast
from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe
from gql.transport import AsyncTransport
class LocalSchemaTransport(AsyncTransport):
"""A transport for executing GraphQL queries against a local schema."""
def __init__(
self,
schema: GraphQLSchema,
):
"""Initialize the transport with the given local schema.
:param schema: Local schema as GraphQLSchema object
"""
self.schema = schema
async def connect(self):
"""No connection needed on local transport"""
pass
async def close(self):
"""No close needed on local transport"""
pass
async def execute(
self,
document: DocumentNode,
*args,
**kwargs,
) -> ExecutionResult:
"""Execute the provided document AST for on a local GraphQL Schema."""
result_or_awaitable = execute(self.schema, document, *args, **kwargs)
execution_result: ExecutionResult
if isawaitable(result_or_awaitable):
result_or_awaitable = cast(Awaitable[ExecutionResult], result_or_awaitable)
execution_result = await result_or_awaitable
else:
result_or_awaitable = cast(ExecutionResult, result_or_awaitable)
execution_result = result_or_awaitable
return execution_result
@staticmethod
async def _await_if_necessary(obj):
"""This method is necessary to work with
graphql-core versions < and >= 3.3.0a3"""
return await obj if asyncio.iscoroutine(obj) else obj
async def subscribe(
self,
document: DocumentNode,
*args,
**kwargs,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a subscription and receive the results using an async generator
The results are sent as an ExecutionResult object
"""
subscribe_result = await self._await_if_necessary(
subscribe(self.schema, document, *args, **kwargs)
)
if isinstance(subscribe_result, ExecutionResult):
yield subscribe_result
else:
async for result in subscribe_result:
yield result
@@ -0,0 +1,414 @@
import asyncio
import json
import logging
from typing import Any, Dict, Optional, Tuple
from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.exceptions import ConnectionClosed
from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets_base import WebsocketsTransportBase
log = logging.getLogger(__name__)
class Subscription:
"""Records listener_id and unsubscribe query_id for a subscription."""
def __init__(self, query_id: int) -> None:
self.listener_id: int = query_id
self.unsubscribe_id: Optional[int] = None
class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
"""The PhoenixChannelWebsocketsTransport is an async transport
which allows you to execute queries and subscriptions against an `Absinthe`_
backend using the `Phoenix`_ framework `channels`_.
.. _Absinthe: http://absinthe-graphql.org
.. _Phoenix: https://www.phoenixframework.org
.. _channels: https://hexdocs.pm/phoenix/Phoenix.Channel.html#content
"""
def __init__(
self,
channel_name: str = "__absinthe__:control",
heartbeat_interval: float = 30,
*args,
**kwargs,
) -> None:
"""Initialize the transport with the given parameters.
:param channel_name: Channel on the server this transport will join.
The default for Absinthe servers is "__absinthe__:control"
:param heartbeat_interval: Interval in second between each heartbeat messages
sent by the client
"""
self.channel_name: str = channel_name
self.heartbeat_interval: float = heartbeat_interval
self.heartbeat_task: Optional[asyncio.Future] = None
self.subscriptions: Dict[str, Subscription] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
async def _initialize(self) -> None:
"""Join the specified channel and wait for the connection ACK.
If the answer is not a connection_ack message, we will return an Exception.
"""
query_id = self.next_query_id
self.next_query_id += 1
init_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_join",
"payload": {},
"ref": query_id,
}
)
await self._send(init_message)
# Wait for the connection_ack message or raise a TimeoutError
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
if answer_type != "reply":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)
async def heartbeat_coro():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
query_id = self.next_query_id
self.next_query_id += 1
await self._send(
json.dumps(
{
"topic": "phoenix",
"event": "heartbeat",
"payload": {},
"ref": query_id,
}
)
)
except ConnectionClosed: # pragma: no cover
return
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())
async def _send_stop_message(self, query_id: int) -> None:
"""Send an 'unsubscribe' message to the Phoenix Channel referencing
the listener's query_id, saving the query_id of the message.
The server should afterwards return a 'phx_reply' message with
the same query_id and subscription_id of the 'unsubscribe' request.
"""
subscription_id = self._find_existing_subscription(query_id)
unsubscribe_query_id = self.next_query_id
self.next_query_id += 1
# Save the ref so it can be matched in the reply
self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id
unsubscribe_message = json.dumps(
{
"topic": self.channel_name,
"event": "unsubscribe",
"payload": {"subscriptionId": subscription_id},
"ref": unsubscribe_query_id,
}
)
await self._send(unsubscribe_message)
async def _stop_listener(self, query_id: int) -> None:
await self._send_stop_message(query_id)
async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel."""
query_id = self.next_query_id
self.next_query_id += 1
connection_terminate_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_leave",
"payload": {},
"ref": query_id,
}
)
await self._send(connection_terminate_message)
async def _connection_terminate(self):
await self._send_connection_terminate_message()
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.
We use an incremented id to reference the query.
Returns the used id for this query.
"""
query_id = self.next_query_id
self.next_query_id += 1
query_str = json.dumps(
{
"topic": self.channel_name,
"event": "doc",
"payload": {
"query": print_ast(document),
"variables": variable_values or {},
},
"ref": query_id,
}
)
await self._send(query_str)
return query_id
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server
Returns a list consisting of:
- the answer_type (between:
'data', 'reply', 'complete', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
event: str = ""
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None
subscription_id: Optional[str] = None
def _get_value(d: Any, key: str, label: str) -> Any:
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")
return d.get(key)
def _required_value(d: Any, key: str, label: str) -> Any:
value = _get_value(d, key, label)
if value is None:
raise ValueError(f"null {key} in {label}")
return value
def _required_subscription_id(
d: Any, label: str, must_exist: bool = False, must_not_exist=False
) -> str:
subscription_id = str(_required_value(d, "subscriptionId", label))
if must_exist and (subscription_id not in self.subscriptions):
raise ValueError("unregistered subscriptionId")
if must_not_exist and (subscription_id in self.subscriptions):
raise ValueError("previously registered subscriptionId")
return subscription_id
def _validate_data_response(d: Any, label: str) -> dict:
"""Make sure query, mutation or subscription answer conforms.
The GraphQL spec says only three keys are permitted.
"""
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")
keys = set(d.keys())
invalid = keys - {"data", "errors", "extensions"}
if len(invalid) > 0:
raise ValueError(
f"{label} contains invalid items: " + ", ".join(invalid)
)
return d
try:
json_answer = json.loads(answer)
event = str(_required_value(json_answer, "event", "answer"))
if event == "subscription:data":
payload = _required_value(json_answer, "payload", "answer")
subscription_id = _required_subscription_id(
payload, "payload", must_exist=True
)
result = _validate_data_response(payload.get("result"), "result")
answer_type = "data"
subscription = self.subscriptions[subscription_id]
answer_id = subscription.listener_id
execution_result = ExecutionResult(
data=result.get("data"),
errors=result.get("errors"),
extensions=result.get("extensions"),
)
elif event == "phx_reply":
# Will generate a ValueError if 'ref' is not there
# or if it is not an integer
answer_id = int(_required_value(json_answer, "ref", "answer"))
payload = _required_value(json_answer, "payload", "answer")
status = _get_value(payload, "status", "payload")
if status == "ok":
answer_type = "reply"
if answer_id in self.listeners:
response = _required_value(payload, "response", "payload")
if isinstance(response, dict) and "subscriptionId" in response:
# Subscription answer
subscription_id = _required_subscription_id(
response, "response", must_not_exist=True
)
self.subscriptions[subscription_id] = Subscription(
answer_id
)
else:
# Query or mutation answer
# GraphQL spec says only three keys are permitted
response = _validate_data_response(response, "response")
answer_type = "data"
execution_result = ExecutionResult(
data=response.get("data"),
errors=response.get("errors"),
extensions=response.get("extensions"),
)
else:
(
registered_subscription_id,
listener_id,
) = self._find_subscription(answer_id)
if registered_subscription_id is not None:
# Unsubscription answer
response = _required_value(payload, "response", "payload")
subscription_id = _required_subscription_id(
response, "response"
)
if subscription_id != registered_subscription_id:
raise ValueError("subscription id does not match")
answer_type = "complete"
answer_id = listener_id
elif status == "error":
response = payload.get("response")
if isinstance(response, dict):
if "errors" in response:
raise TransportQueryError(
str(response.get("errors")), query_id=answer_id
)
elif "reason" in response:
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise TransportQueryError("reply error", query_id=answer_id)
elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)
else:
# missing or unrecognized status, just continue
pass
elif event == "phx_error":
# Sent if the channel has crashed
# answer_id will be the "join_ref" for the channel
# answer_id = int(json_answer.get("ref"))
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError("unrecognized event")
except ValueError as e:
log.error(f"Error parsing answer '{answer}': {e!r}")
raise TransportProtocolError(
f"Server did not return a GraphQL result: {e!s}"
) from e
return answer_type, answer_id, execution_result
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
if answer_type == "close":
await self.close()
else:
await super()._handle_answer(answer_type, answer_id, execution_result)
def _remove_listener(self, query_id: int) -> None:
"""If the listener was a subscription, remove that information."""
try:
subscription_id = self._find_existing_subscription(query_id)
del self.subscriptions[subscription_id]
except Exception:
pass
super()._remove_listener(query_id)
def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
for subscription_id, subscription in self.subscriptions.items():
if query_id == subscription.listener_id:
return subscription_id, query_id
if query_id == subscription.unsubscribe_id:
return subscription_id, subscription.listener_id
return None, query_id
def _find_existing_subscription(self, query_id: int) -> str:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
subscription_id, _listener_id = self._find_subscription(query_id)
if subscription_id is None:
raise TransportProtocolError(
f"No subscription registered for listener {query_id}"
)
return subscription_id
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()
await super()._close_coro(e, clean_close)
@@ -0,0 +1,426 @@
import io
import json
import logging
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
import requests
from graphql import DocumentNode, ExecutionResult, print_ast
from requests.adapters import HTTPAdapter, Retry
from requests.auth import AuthBase
from requests.cookies import RequestsCookieJar
from requests_toolbelt.multipart.encoder import MultipartEncoder
from gql.transport import Transport
from ..graphql_request import GraphQLRequest
from ..utils import extract_files
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)
log = logging.getLogger(__name__)
class RequestsHTTPTransport(Transport):
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
on remote servers.
The transport uses the requests library to send HTTP POST requests.
"""
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
_default_retry_codes = (429, 500, 502, 503, 504)
def __init__(
self,
url: str,
headers: Optional[Dict[str, Any]] = None,
cookies: Optional[Union[Dict[str, Any], RequestsCookieJar]] = None,
auth: Optional[AuthBase] = None,
use_json: bool = True,
timeout: Optional[int] = None,
verify: Union[bool, str] = True,
retries: int = 0,
method: str = "POST",
retry_backoff_factor: float = 0.1,
retry_status_forcelist: Collection[int] = _default_retry_codes,
**kwargs: Any,
):
"""Initialize the transport with the given request parameters.
:param url: The GraphQL server URL.
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`
(Default: None).
:param cookies: Dict or CookieJar object to send with the :class:`Request`
(Default: None).
:param auth: Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth
(Default: None).
:param use_json: Send request body as JSON instead of form-urlencoded
(Default: True).
:param timeout: Specifies a default timeout for requests (Default: None).
:param verify: Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. (Default: True).
:param retries: Pre-setup of the requests' Session for performing retries
:param method: HTTP method used for requests. (Default: POST).
:param retry_backoff_factor: A backoff factor to apply between attempts after
the second try. urllib3 will sleep for:
{backoff factor} * (2 ** ({number of previous retries}))
:param retry_status_forcelist: A set of integer HTTP status codes that we
should force a retry on. A retry is initiated if the request method is
in allowed_methods and the response status code is in status_forcelist.
(Default: [429, 500, 502, 503, 504])
:param kwargs: Optional arguments that ``request`` takes.
These can be seen at the `requests`_ source code or the official `docs`_
.. _requests: https://github.com/psf/requests/blob/master/requests/api.py
.. _docs: https://requests.readthedocs.io/en/master/
"""
self.url = url
self.headers = headers
self.cookies = cookies
self.auth = auth
self.use_json = use_json
self.default_timeout = timeout
self.verify = verify
self.retries = retries
self.method = method
self.retry_backoff_factor = retry_backoff_factor
self.retry_status_forcelist = retry_status_forcelist
self.kwargs = kwargs
self.session = None
self.response_headers = None
def connect(self):
if self.session is None:
# Creating a session that can later be re-use to configure custom mechanisms
self.session = requests.Session()
# If we specified some retries, we provide a predefined retry-logic
if self.retries > 0:
adapter = HTTPAdapter(
max_retries=Retry(
total=self.retries,
backoff_factor=self.retry_backoff_factor,
status_forcelist=self.retry_status_forcelist,
allowed_methods=None,
)
)
for prefix in "http://", "https://":
self.session.mount(prefix, adapter)
else:
raise TransportAlreadyConnected("Transport is already connected")
def execute( # type: ignore
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param document: GraphQL query as AST Node object.
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:param upload_files: Set to True if you want to put files in the variable values
:return: The result of execution.
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""
if not self.session:
raise TransportClosed("Transport is not connected")
query_str = print_ast(document)
payload: Dict[str, Any] = {"query": query_str}
if operation_name:
payload["operationName"] = operation_name
post_args = {
"headers": self.headers,
"auth": self.auth,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
"verify": self.verify,
}
if upload_files:
# If the upload_files flag is set, then we need variable_values
assert variable_values is not None
# If we upload files, we will extract the files present in the
# variable_values dict and replace them by null values
nulled_variable_values, files = extract_files(
variables=variable_values,
file_classes=self.file_classes,
)
# Save the nulled variable values in the payload
payload["variables"] = nulled_variable_values
# Add the payload to the operations field
operations_str = json.dumps(payload)
log.debug("operations %s", operations_str)
# Generate the file map
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't support that.
# Will generate something like {"0": ["variables.file"]}
file_map = {str(i): [path] for i, path in enumerate(files)}
# Enumerate the file streams
# Will generate something like {'0': <_io.BufferedReader ...>}
file_streams = {str(i): files[path] for i, path in enumerate(files)}
# Add the file map field
file_map_str = json.dumps(file_map)
log.debug("file_map %s", file_map_str)
fields = {"operations": operations_str, "map": file_map_str}
# Add the extracted files as remaining fields
for k, f in file_streams.items():
name = getattr(f, "name", k)
content_type = getattr(f, "content_type", None)
if content_type is None:
fields[k] = (name, f)
else:
fields[k] = (name, f, content_type)
# Prepare requests http to send multipart-encoded data
data = MultipartEncoder(fields=fields)
post_args["data"] = data
if post_args["headers"] is None:
post_args["headers"] = {}
else:
post_args["headers"] = {**post_args["headers"]}
post_args["headers"]["Content-Type"] = data.content_type
else:
if variable_values:
payload["variables"] = variable_values
data_key = "json" if self.use_json else "data"
post_args[data_key] = payload
# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(payload))
# Pass kwargs to requests post method
post_args.update(self.kwargs)
# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
# Using the created session to perform requests
response = self.session.request(
self.method, self.url, **post_args # type: ignore
)
self.response_headers = response.headers
def raise_response_error(resp: requests.Response, reason: str):
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases
try:
# Raise a HTTPError if response status is 400 or higher
resp.raise_for_status()
except requests.HTTPError as e:
raise TransportServerError(str(e), e.response.status_code) from e
result_text = resp.text
raise TransportProtocolError(
f"Server did not return a GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
result = response.json()
if log.isEnabledFor(logging.INFO):
log.info("<<< %s", response.text)
except Exception:
raise_response_error(response, "Not a JSON answer")
if "errors" not in result and "data" not in result:
raise_response_error(response, 'No "data" or "errors" keys in answer')
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def execute_batch( # type: ignore
self,
reqs: List[GraphQLRequest],
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch.
Execute the provided requests against the configured remote server. This
uses the requests library to perform a HTTP POST request to the remote server.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:param timeout: Specifies a default timeout for requests (Default: None).
:param extra_args: additional arguments to send to the requests post method
:return: A list of results of execution.
For every result `data` is the result of executing the query,
`errors` is null if no errors occurred, and is a non-empty array
if an error occurred.
"""
if not self.session:
raise TransportClosed("Transport is not connected")
# Using the created session to perform requests
response = self.session.request(
self.method,
self.url,
**self._build_batch_post_args(reqs, timeout, extra_args),
)
self.response_headers = response.headers
answers = self._extract_response(response)
self._validate_answer_is_a_list(answers)
self._validate_num_of_answers_same_as_requests(reqs, answers)
self._validate_every_answer_is_a_dict(answers)
self._validate_data_and_errors_keys_in_answers(answers)
return [self._answer_to_execution_result(answer) for answer in answers]
def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult:
return ExecutionResult(
errors=result.get("errors"),
data=result.get("data"),
extensions=result.get("extensions"),
)
def _validate_answer_is_a_list(self, results: Any) -> None:
if not isinstance(results, list):
self._raise_invalid_result(
str(results),
"Answer is not a list",
)
def _validate_data_and_errors_keys_in_answers(
self, results: List[Dict[str, Any]]
) -> None:
for result in results:
if "errors" not in result and "data" not in result:
self._raise_invalid_result(
str(results),
'No "data" or "errors" keys in answer',
)
def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None:
for result in results:
if not isinstance(result, dict):
self._raise_invalid_result(str(results), "Not every answer is dict")
def _validate_num_of_answers_same_as_requests(
self,
reqs: List[GraphQLRequest],
results: List[Dict[str, Any]],
) -> None:
if len(reqs) != len(results):
self._raise_invalid_result(
str(results),
"Invalid answer length",
)
def _raise_invalid_result(self, result_text: str, reason: str) -> None:
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)
def _extract_response(self, response: requests.Response) -> Any:
try:
response.raise_for_status()
result = response.json()
if log.isEnabledFor(logging.INFO):
log.info("<<< %s", response.text)
except requests.HTTPError as e:
raise TransportServerError(str(e), e.response.status_code) from e
except Exception:
self._raise_invalid_result(str(response.text), "Not a JSON answer")
return result
def _build_batch_post_args(
self,
reqs: List[GraphQLRequest],
timeout: Optional[int] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
post_args: Dict[str, Any] = {
"headers": self.headers,
"auth": self.auth,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
"verify": self.verify,
}
data_key = "json" if self.use_json else "data"
post_args[data_key] = [self._build_data(req) for req in reqs]
# Log the payload
if log.isEnabledFor(logging.INFO):
log.info(">>> %s", json.dumps(post_args[data_key]))
# Pass kwargs to requests post method
post_args.update(self.kwargs)
# Pass post_args to requests post method
if extra_args:
post_args.update(extra_args)
return post_args
def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]:
query_str = print_ast(req.document)
payload: Dict[str, Any] = {"query": query_str}
if req.operation_name:
payload["operationName"] = req.operation_name
if req.variable_values:
payload["variables"] = req.variable_values
return payload
def close(self):
"""Closing the transport by closing the inner session"""
if self.session:
self.session.close()
self.session = None
@@ -0,0 +1,51 @@
import abc
from typing import List
from graphql import DocumentNode, ExecutionResult
from ..graphql_request import GraphQLRequest
class Transport(abc.ABC):
@abc.abstractmethod
def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
"""Execute GraphQL query.
Execute the provided document AST for either a remote or local GraphQL Schema.
:param document: GraphQL query as AST Node or Document object.
:return: ExecutionResult
"""
raise NotImplementedError(
"Any Transport subclass must implement execute method"
) # pragma: no cover
def execute_batch(
self,
reqs: List[GraphQLRequest],
*args,
**kwargs,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch.
Execute the provided requests for either a remote or local GraphQL Schema.
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
:return: a list of ExecutionResult objects
"""
raise NotImplementedError(
"This Transport has not implemented the execute_batch method"
) # pragma: no cover
def connect(self):
"""Establish a session with the transport."""
pass # pragma: no cover
def close(self):
"""Close the transport
This method doesn't have to be implemented unless the transport would benefit
from it. This is currently used by the RequestsHTTPTransport transport to close
the session's connection pool.
"""
pass # pragma: no cover
@@ -0,0 +1,514 @@
import asyncio
import json
import logging
from contextlib import suppress
from ssl import SSLContext
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.datastructures import HeadersLike
from websockets.typing import Subprotocol
from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets_base import WebsocketsTransportBase
log = logging.getLogger(__name__)
class WebsocketsTransport(WebsocketsTransportBase):
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries on
remote servers with websocket connection.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
# This transport supports two subprotocols and will autodetect the
# subprotocol supported on the server
APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws")
GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws")
def __init__(
self,
url: str,
headers: Optional[HeadersLike] = None,
ssl: Union[SSLContext, bool] = False,
init_payload: Dict[str, Any] = {},
connect_timeout: Optional[Union[int, float]] = 10,
close_timeout: Optional[Union[int, float]] = 10,
ack_timeout: Optional[Union[int, float]] = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
ping_interval: Optional[Union[int, float]] = None,
pong_timeout: Optional[Union[int, float]] = None,
answer_pings: bool = True,
connect_args: Dict[str, Any] = {},
subprotocols: Optional[List[Subprotocol]] = None,
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
:param headers: Dict of HTTP Headers.
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param init_payload: Dict of the payload sent in the connection_init message.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param ping_interval: Delay in seconds between pings sent by the client to
the backend for the graphql-ws protocol. None (by default) means that
we don't send pings. Note: there are also pings sent by the underlying
websockets protocol. See the
:ref:`keepalive documentation <websockets_transport_keepalives>`
for more information about this.
:param pong_timeout: Delay in seconds to receive a pong from the backend
after we sent a ping (only for the graphql-ws protocol).
By default equal to half of the ping_interval.
:param answer_pings: Whether the client answers the pings from the backend
(for the graphql-ws protocol).
By default: True
:param connect_args: Other parameters forwarded to
`websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
client.html#opening-a-connection>`_
:param subprotocols: list of subprotocols sent to the
backend in the 'subprotocols' http header.
By default: both apollo and graphql-ws subprotocols.
"""
super().__init__(
url,
headers,
ssl,
init_payload,
connect_timeout,
close_timeout,
ack_timeout,
keep_alive_timeout,
connect_args,
)
self.ping_interval: Optional[Union[int, float]] = ping_interval
self.pong_timeout: Optional[Union[int, float]]
self.answer_pings: bool = answer_pings
if ping_interval is not None:
if pong_timeout is None:
self.pong_timeout = ping_interval / 2
else:
self.pong_timeout = pong_timeout
self.send_ping_task: Optional[asyncio.Future] = None
self.ping_received: asyncio.Event = asyncio.Event()
"""ping_received is an asyncio Event which will fire each time
a ping is received with the graphql-ws protocol"""
self.pong_received: asyncio.Event = asyncio.Event()
"""pong_received is an asyncio Event which will fire each time
a pong is received with the graphql-ws protocol"""
if subprotocols is None:
self.supported_subprotocols = [
self.APOLLO_SUBPROTOCOL,
self.GRAPHQLWS_SUBPROTOCOL,
]
else:
self.supported_subprotocols = subprotocols
async def _wait_ack(self) -> None:
"""Wait for the connection_ack message. Keep alive messages are ignored"""
while True:
init_answer = await self._receive()
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
if answer_type == "connection_ack":
return
if answer_type != "ka":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)
async def _send_init_message_and_wait_ack(self) -> None:
"""Send init message to the provided websocket and wait for the connection ACK.
If the answer is not a connection_ack message, we will return an Exception.
"""
init_message = json.dumps(
{"type": "connection_init", "payload": self.init_payload}
)
await self._send(init_message)
# Wait for the connection_ack message or raise a TimeoutError
await asyncio.wait_for(self._wait_ack(), self.ack_timeout)
async def _initialize(self):
await self._send_init_message_and_wait_ack()
async def send_ping(self, payload: Optional[Any] = None) -> None:
"""Send a ping message for the graphql-ws protocol"""
ping_message = {"type": "ping"}
if payload is not None:
ping_message["payload"] = payload
await self._send(json.dumps(ping_message))
async def send_pong(self, payload: Optional[Any] = None) -> None:
"""Send a pong message for the graphql-ws protocol"""
pong_message = {"type": "pong"}
if payload is not None:
pong_message["payload"] = payload
await self._send(json.dumps(pong_message))
async def _send_stop_message(self, query_id: int) -> None:
"""Send stop message to the provided websocket connection and query_id.
The server should afterwards return a 'complete' message.
"""
stop_message = json.dumps({"id": str(query_id), "type": "stop"})
await self._send(stop_message)
async def _send_complete_message(self, query_id: int) -> None:
"""Send a complete message for the provided query_id.
This is only for the graphql-ws protocol.
"""
complete_message = json.dumps({"id": str(query_id), "type": "complete"})
await self._send(complete_message)
async def _stop_listener(self, query_id: int):
"""Stop the listener corresponding to the query_id depending on the
detected backend protocol.
For apollo: send a "stop" message
(a "complete" message will be sent from the backend)
For graphql-ws: send a "complete" message and simulate the reception
of a "complete" message from the backend
"""
log.debug(f"stop listener {query_id}")
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
await self._send_complete_message(query_id)
await self.listeners[query_id].put(("complete", None))
else:
await self._send_stop_message(query_id)
async def _send_connection_terminate_message(self) -> None:
"""Send a connection_terminate message to the provided websocket connection.
This message indicates that the connection will disconnect.
"""
connection_terminate_message = json.dumps({"type": "connection_terminate"})
await self._send(connection_terminate_message)
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.
We use an incremented id to reference the query.
Returns the used id for this query.
"""
query_id = self.next_query_id
self.next_query_id += 1
payload: Dict[str, Any] = {"query": print_ast(document)}
if variable_values:
payload["variables"] = variable_values
if operation_name:
payload["operationName"] = operation_name
query_type = "start"
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
query_type = "subscribe"
query_str = json.dumps(
{"id": str(query_id), "type": query_type, "payload": payload}
)
await self._send(query_str)
return query_id
async def _connection_terminate(self):
if self.subprotocol == self.APOLLO_SUBPROTOCOL:
await self._send_connection_terminate_message()
def _parse_answer_graphqlws(
self, json_answer: Dict[str, Any]
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server if the server supports the
graphql-ws protocol.
Returns a list consisting of:
- the answer_type (between:
'connection_ack', 'ping', 'pong', 'data', 'error', 'complete')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
Differences with the apollo websockets protocol (superclass):
- the "data" message is now called "next"
- the "stop" message is now called "complete"
- there is no connection_terminate or connection_error messages
- instead of a unidirectional keep-alive (ka) message from server to client,
there is now the possibility to send bidirectional ping/pong messages
- connection_ack has an optional payload
- the 'error' answer type returns a list of errors instead of a single error
"""
answer_type: str = ""
answer_id: Optional[int] = None
execution_result: Optional[ExecutionResult] = None
try:
answer_type = str(json_answer.get("type"))
if answer_type in ["next", "error", "complete"]:
answer_id = int(str(json_answer.get("id")))
if answer_type == "next" or answer_type == "error":
payload = json_answer.get("payload")
if answer_type == "next":
if not isinstance(payload, dict):
raise ValueError("payload is not a dict")
if "errors" not in payload and "data" not in payload:
raise ValueError(
"payload does not contain 'data' or 'errors' fields"
)
execution_result = ExecutionResult(
errors=payload.get("errors"),
data=payload.get("data"),
extensions=payload.get("extensions"),
)
# Saving answer_type as 'data' to be understood with superclass
answer_type = "data"
elif answer_type == "error":
if not isinstance(payload, list):
raise ValueError("payload is not a list")
raise TransportQueryError(
str(payload[0]), query_id=answer_id, errors=payload
)
elif answer_type in ["ping", "pong", "connection_ack"]:
self.payloads[answer_type] = json_answer.get("payload", None)
else:
raise ValueError
if self.check_keep_alive_task is not None:
self._next_keep_alive_message.set()
except ValueError as e:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {json_answer}"
) from e
return answer_type, answer_id, execution_result
def _parse_answer_apollo(
self, json_answer: Dict[str, Any]
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server if the server supports the
apollo websockets protocol.
Returns a list consisting of:
- the answer_type (between:
'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
answer_type: str = ""
answer_id: Optional[int] = None
execution_result: Optional[ExecutionResult] = None
try:
answer_type = str(json_answer.get("type"))
if answer_type in ["data", "error", "complete"]:
answer_id = int(str(json_answer.get("id")))
if answer_type == "data" or answer_type == "error":
payload = json_answer.get("payload")
if not isinstance(payload, dict):
raise ValueError("payload is not a dict")
if answer_type == "data":
if "errors" not in payload and "data" not in payload:
raise ValueError(
"payload does not contain 'data' or 'errors' fields"
)
execution_result = ExecutionResult(
errors=payload.get("errors"),
data=payload.get("data"),
extensions=payload.get("extensions"),
)
elif answer_type == "error":
raise TransportQueryError(
str(payload), query_id=answer_id, errors=[payload]
)
elif answer_type == "ka":
# Keep-alive message
if self.check_keep_alive_task is not None:
self._next_keep_alive_message.set()
elif answer_type == "connection_ack":
pass
elif answer_type == "connection_error":
error_payload = json_answer.get("payload")
raise TransportServerError(f"Server error: '{repr(error_payload)}'")
else:
raise ValueError
except ValueError as e:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {json_answer}"
) from e
return answer_type, answer_id, execution_result
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server depending on
the detected subprotocol.
"""
try:
json_answer = json.loads(answer)
except ValueError:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {answer}"
)
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
return self._parse_answer_graphqlws(json_answer)
return self._parse_answer_apollo(json_answer)
async def _send_ping_coro(self) -> None:
"""Coroutine to periodically send a ping from the client to the backend.
Only used for the graphql-ws protocol.
Send a ping every ping_interval seconds.
Close the connection if a pong is not received within pong_timeout seconds.
"""
assert self.ping_interval is not None
try:
while True:
await asyncio.sleep(self.ping_interval)
await self.send_ping()
await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout)
# Reset for the next iteration
self.pong_received.clear()
except asyncio.TimeoutError:
# No pong received in the appriopriate time, close with error
# If the timeout happens during a close already in progress, do nothing
if self.close_task is None:
await self._fail(
TransportServerError(
f"No pong received after {self.pong_timeout!r} seconds"
),
clean_close=False,
)
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
# Put the answer in the queue
await super()._handle_answer(answer_type, answer_id, execution_result)
# Answer pong to ping for graphql-ws protocol
if answer_type == "ping":
self.ping_received.set()
if self.answer_pings:
await self.send_pong()
elif answer_type == "pong":
self.pong_received.set()
async def _after_connect(self):
# Find the backend subprotocol returned in the response headers
response_headers = self.websocket.response_headers
try:
self.subprotocol = response_headers["Sec-WebSocket-Protocol"]
except KeyError:
# If the server does not send the subprotocol header, using
# the apollo subprotocol by default
self.subprotocol = self.APOLLO_SUBPROTOCOL
log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
async def _after_initialize(self):
# If requested, create a task to send periodic pings to the backend
if (
self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL
and self.ping_interval is not None
):
self.send_ping_task = asyncio.ensure_future(self._send_ping_coro())
async def _close_hook(self):
# Properly shut down the send ping task if enabled
if self.send_ping_task is not None:
self.send_ping_task.cancel()
with suppress(asyncio.CancelledError):
await self.send_ping_task
self.send_ping_task = None
@@ -0,0 +1,669 @@
import asyncio
import logging
import warnings
from abc import abstractmethod
from contextlib import suppress
from ssl import SSLContext
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
import websockets
from graphql import DocumentNode, ExecutionResult
from websockets.client import WebSocketClientProtocol
from websockets.datastructures import Headers, HeadersLike
from websockets.exceptions import ConnectionClosed
from websockets.typing import Data, Subprotocol
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
log = logging.getLogger("gql.transport.websockets")
ParsedAnswer = Tuple[str, Optional[ExecutionResult]]
class ListenerQueue:
"""Special queue used for each query waiting for server answers
If the server is stopped while the listener is still waiting,
Then we send an exception to the queue and this exception will be raised
to the consumer once all the previous messages have been consumed from the queue
"""
def __init__(self, query_id: int, send_stop: bool) -> None:
self.query_id: int = query_id
self.send_stop: bool = send_stop
self._queue: asyncio.Queue = asyncio.Queue()
self._closed: bool = False
async def get(self) -> ParsedAnswer:
try:
item = self._queue.get_nowait()
except asyncio.QueueEmpty:
item = await self._queue.get()
self._queue.task_done()
# If we receive an exception when reading the queue, we raise it
if isinstance(item, Exception):
self._closed = True
raise item
# Don't need to save new answers or
# send the stop message if we already received the complete message
answer_type, execution_result = item
if answer_type == "complete":
self.send_stop = False
self._closed = True
return item
async def put(self, item: ParsedAnswer) -> None:
if not self._closed:
await self._queue.put(item)
async def set_exception(self, exception: Exception) -> None:
# Put the exception in the queue
await self._queue.put(exception)
# Don't need to send stop messages in case of error
self.send_stop = False
self._closed = True
class WebsocketsTransportBase(AsyncTransport):
"""abstract :ref:`Async Transport <async_transports>` used to implement
different websockets protocols.
This transport uses asyncio and the websockets library in order to send requests
on a websocket connection.
"""
def __init__(
self,
url: str,
headers: Optional[HeadersLike] = None,
ssl: Union[SSLContext, bool] = False,
init_payload: Dict[str, Any] = {},
connect_timeout: Optional[Union[int, float]] = 10,
close_timeout: Optional[Union[int, float]] = 10,
ack_timeout: Optional[Union[int, float]] = 10,
keep_alive_timeout: Optional[Union[int, float]] = None,
connect_args: Dict[str, Any] = {},
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
:param headers: Dict of HTTP Headers.
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param init_payload: Dict of the payload sent in the connection_init message.
:param connect_timeout: Timeout in seconds for the establishment
of the websocket connection. If None is provided this will wait forever.
:param close_timeout: Timeout in seconds for the close. If None is provided
this will wait forever.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server. If None is provided this will wait forever.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param connect_args: Other parameters forwarded to websockets.connect
"""
self.url: str = url
self.headers: Optional[HeadersLike] = headers
self.ssl: Union[SSLContext, bool] = ssl
self.init_payload: Dict[str, Any] = init_payload
self.connect_timeout: Optional[Union[int, float]] = connect_timeout
self.close_timeout: Optional[Union[int, float]] = close_timeout
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout
self.connect_args = connect_args
self.websocket: Optional[WebSocketClientProtocol] = None
self.next_query_id: int = 1
self.listeners: Dict[int, ListenerQueue] = {}
self.receive_data_task: Optional[asyncio.Future] = None
self.check_keep_alive_task: Optional[asyncio.Future] = None
self.close_task: Optional[asyncio.Future] = None
# We need to set an event loop here if there is none
# Or else we will not be able to create an asyncio.Event()
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._wait_closed: asyncio.Event = asyncio.Event()
self._wait_closed.set()
self._no_more_listeners: asyncio.Event = asyncio.Event()
self._no_more_listeners.set()
if self.keep_alive_timeout is not None:
self._next_keep_alive_message: asyncio.Event = asyncio.Event()
self._next_keep_alive_message.set()
self.payloads: Dict[str, Any] = {}
"""payloads is a dict which will contain the payloads received
for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'"""
self._connecting: bool = False
self.close_exception: Optional[Exception] = None
# The list of supported subprotocols should be defined in the subclass
self.supported_subprotocols: List[Subprotocol] = []
self.response_headers: Optional[Headers] = None
async def _initialize(self):
"""Hook to send the initialization messages after the connection
and potentially wait for the backend ack.
"""
pass # pragma: no cover
async def _stop_listener(self, query_id: int):
"""Hook to stop to listen to a specific query.
Will send a stop message in some subclasses.
"""
pass # pragma: no cover
async def _after_connect(self):
"""Hook to add custom code for subclasses after the connection
has been established.
"""
pass # pragma: no cover
async def _after_initialize(self):
"""Hook to add custom code for subclasses after the initialization
has been done.
"""
pass # pragma: no cover
async def _close_hook(self):
"""Hook to add custom code for subclasses for the connection close"""
pass # pragma: no cover
async def _connection_terminate(self):
"""Hook to add custom code for subclasses after the initialization
has been done.
"""
pass # pragma: no cover
async def _send(self, message: str) -> None:
"""Send the provided message to the websocket connection and log the message"""
if not self.websocket:
raise TransportClosed(
"Transport is not connected"
) from self.close_exception
try:
await self.websocket.send(message)
log.info(">>> %s", message)
except ConnectionClosed as e:
await self._fail(e, clean_close=False)
raise e
async def _receive(self) -> str:
"""Wait the next message from the websocket connection and log the answer"""
# It is possible that the websocket has been already closed in another task
if self.websocket is None:
raise TransportClosed("Transport is already closed")
# Wait for the next websocket frame. Can raise ConnectionClosed
data: Data = await self.websocket.recv()
# websocket.recv() can return either str or bytes
# In our case, we should receive only str here
if not isinstance(data, str):
raise TransportProtocolError("Binary data received in the websocket")
answer: str = data
log.info("<<< %s", answer)
return answer
@abstractmethod
async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> int:
raise NotImplementedError # pragma: no cover
@abstractmethod
def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
raise NotImplementedError # pragma: no cover
async def _check_ws_liveness(self) -> None:
"""Coroutine which will periodically check the liveness of the connection
through keep-alive messages
"""
try:
while True:
await asyncio.wait_for(
self._next_keep_alive_message.wait(), self.keep_alive_timeout
)
# Reset for the next iteration
self._next_keep_alive_message.clear()
except asyncio.TimeoutError:
# No keep-alive message in the appriopriate interval, close with error
# while trying to notify the server of a proper close (in case
# the keep-alive interval of the client or server was not aligned
# the connection still remains)
# If the timeout happens during a close already in progress, do nothing
if self.close_task is None:
await self._fail(
TransportServerError(
"No keep-alive message has been received within "
"the expected interval ('keep_alive_timeout' parameter)"
),
clean_close=False,
)
except asyncio.CancelledError:
# The client is probably closing, handle it properly
pass
async def _receive_data_loop(self) -> None:
"""Main asyncio task which will listen to the incoming messages and will
call the parse_answer and handle_answer methods of the subclass."""
try:
while True:
# Wait the next answer from the websocket server
try:
answer = await self._receive()
except (ConnectionClosed, TransportProtocolError) as e:
await self._fail(e, clean_close=False)
break
except TransportClosed:
break
# Parse the answer
try:
answer_type, answer_id, execution_result = self._parse_answer(
answer
)
except TransportQueryError as e:
# Received an exception for a specific query
# ==> Add an exception to this query queue
# The exception is raised for this specific query,
# but the transport is not closed.
assert isinstance(
e.query_id, int
), "TransportQueryError should have a query_id defined here"
try:
await self.listeners[e.query_id].set_exception(e)
except KeyError:
# Do nothing if no one is listening to this query_id
pass
continue
except (TransportServerError, TransportProtocolError) as e:
# Received a global exception for this transport
# ==> close the transport
# The exception will be raised for all current queries.
await self._fail(e, clean_close=False)
break
await self._handle_answer(answer_type, answer_id, execution_result)
finally:
log.debug("Exiting _receive_data_loop()")
async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put((answer_type, execution_result))
except KeyError:
# Do nothing if no one is listening to this query_id.
pass
async def subscribe(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
send_stop: Optional[bool] = True,
) -> AsyncGenerator[ExecutionResult, None]:
"""Send a query and receive the results using a python async generator.
The query can be a graphql query, mutation or subscription.
The results are sent as an ExecutionResult object.
"""
# Send the query and receive the id
query_id: int = await self._send_query(
document, variable_values, operation_name
)
# Create a queue to receive the answers for this query_id
listener = ListenerQueue(query_id, send_stop=(send_stop is True))
self.listeners[query_id] = listener
# We will need to wait at close for this query to clean properly
self._no_more_listeners.clear()
try:
# Loop over the received answers
while True:
# Wait for the answer from the queue of this query_id
# This can raise a TransportError or ConnectionClosed exception.
answer_type, execution_result = await listener.get()
# If the received answer contains data,
# Then we will yield the results back as an ExecutionResult object
if execution_result is not None:
yield execution_result
# If we receive a 'complete' answer from the server,
# Then we will end this async generator output without errors
elif answer_type == "complete":
log.debug(
f"Complete received for query {query_id} --> exit without error"
)
break
except (asyncio.CancelledError, GeneratorExit) as e:
log.debug(f"Exception in subscribe: {e!r}")
if listener.send_stop:
await self._stop_listener(query_id)
listener.send_stop = False
finally:
log.debug(f"In subscribe finally for query_id {query_id}")
self._remove_listener(query_id)
async def execute(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
using the current session.
Send a query but close the async generator as soon as we have the first answer.
The result is sent as an ExecutionResult object.
"""
first_result = None
generator = self.subscribe(
document, variable_values, operation_name, send_stop=False
)
async for result in generator:
first_result = result
# Note: we need to run generator.aclose() here or the finally block in
# the subscribe will not be reached in pypy3 (python version 3.6.1)
await generator.aclose()
break
if first_result is None:
raise TransportQueryError(
"Query completed without any answer received from the server"
)
return first_result
async def connect(self) -> None:
"""Coroutine which will:
- connect to the websocket address
- send the init message
- wait for the connection acknowledge from the server
- create an asyncio task which will be used to receive
and parse the websocket answers
Should be cleaned with a call to the close coroutine
"""
log.debug("connect: starting")
if self.websocket is None and not self._connecting:
# Set connecting to True to avoid a race condition if user is trying
# to connect twice using the same client at the same time
self._connecting = True
# If the ssl parameter is not provided,
# generate the ssl value depending on the url
ssl: Optional[Union[SSLContext, bool]]
if self.ssl:
ssl = self.ssl
else:
ssl = True if self.url.startswith("wss") else None
# Set default arguments used in the websockets.connect call
connect_args: Dict[str, Any] = {
"ssl": ssl,
"extra_headers": self.headers,
"subprotocols": self.supported_subprotocols,
}
# Adding custom parameters passed from init
connect_args.update(self.connect_args)
# Connection to the specified url
# Generate a TimeoutError if taking more than connect_timeout seconds
# Set the _connecting flag to False after in all cases
try:
self.websocket = await asyncio.wait_for(
websockets.client.connect(self.url, **connect_args),
self.connect_timeout,
)
finally:
self._connecting = False
self.websocket = cast(WebSocketClientProtocol, self.websocket)
self.response_headers = self.websocket.response_headers
# Run the after_connect hook of the subclass
await self._after_connect()
self.next_query_id = 1
self.close_exception = None
self._wait_closed.clear()
# Send the init message and wait for the ack from the server
# Note: This should generate a TimeoutError
# if no ACKs are received within the ack_timeout
try:
await self._initialize()
except ConnectionClosed as e:
raise e
except (TransportProtocolError, asyncio.TimeoutError) as e:
await self._fail(e, clean_close=False)
raise e
# Run the after_init hook of the subclass
await self._after_initialize()
# If specified, create a task to check liveness of the connection
# through keep-alive messages
if self.keep_alive_timeout is not None:
self.check_keep_alive_task = asyncio.ensure_future(
self._check_ws_liveness()
)
# Create a task to listen to the incoming websocket messages
self.receive_data_task = asyncio.ensure_future(self._receive_data_loop())
else:
raise TransportAlreadyConnected("Transport is already connected")
log.debug("connect: done")
def _remove_listener(self, query_id) -> None:
"""After exiting from a subscription, remove the listener and
signal an event if this was the last listener for the client.
"""
if query_id in self.listeners:
del self.listeners[query_id]
remaining = len(self.listeners)
log.debug(f"listener {query_id} deleted, {remaining} remaining")
if remaining == 0:
self._no_more_listeners.set()
async def _clean_close(self, e: Exception) -> None:
"""Coroutine which will:
- send stop messages for each active subscription to the server
- send the connection terminate message
"""
# Send 'stop' message for all current queries
for query_id, listener in self.listeners.items():
if listener.send_stop:
await self._stop_listener(query_id)
listener.send_stop = False
# Wait that there is no more listeners (we received 'complete' for all queries)
try:
await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout)
except asyncio.TimeoutError: # pragma: no cover
log.debug("Timer close_timeout fired")
# Calling the subclass hook
await self._connection_terminate()
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
"""Coroutine which will:
- do a clean_close if possible:
- send stop messages for each active query to the server
- send the connection terminate message
- close the websocket connection
- send the exception to all the remaining listeners
"""
log.debug("_close_coro: starting")
try:
# We should always have an active websocket connection here
assert self.websocket is not None
# Properly shut down liveness checker if enabled
if self.check_keep_alive_task is not None:
# More info: https://stackoverflow.com/a/43810272/1113207
self.check_keep_alive_task.cancel()
with suppress(asyncio.CancelledError):
await self.check_keep_alive_task
# Calling the subclass close hook
await self._close_hook()
# Saving exception to raise it later if trying to use the transport
# after it has already closed.
self.close_exception = e
if clean_close:
log.debug("_close_coro: starting clean_close")
try:
await self._clean_close(e)
except Exception as exc: # pragma: no cover
log.warning("Ignoring exception in _clean_close: " + repr(exc))
log.debug("_close_coro: sending exception to listeners")
# Send an exception to all remaining listeners
for query_id, listener in self.listeners.items():
await listener.set_exception(e)
log.debug("_close_coro: close websocket connection")
await self.websocket.close()
log.debug("_close_coro: websocket connection closed")
except Exception as exc: # pragma: no cover
log.warning("Exception catched in _close_coro: " + repr(exc))
finally:
log.debug("_close_coro: start cleanup")
self.websocket = None
self.close_task = None
self.check_keep_alive_task = None
self._wait_closed.set()
log.debug("_close_coro: exiting")
async def _fail(self, e: Exception, clean_close: bool = True) -> None:
log.debug("_fail: starting with exception: " + repr(e))
if self.close_task is None:
if self.websocket is None:
log.debug("_fail started with self.websocket == None -> already closed")
else:
self.close_task = asyncio.shield(
asyncio.ensure_future(self._close_coro(e, clean_close=clean_close))
)
else:
log.debug(
"close_task is not None in _fail. Previous exception is: "
+ repr(self.close_exception)
+ " New exception is: "
+ repr(e)
)
async def close(self) -> None:
log.debug("close: starting")
await self._fail(TransportClosed("Websocket GraphQL transport closed by user"))
await self.wait_closed()
log.debug("close: done")
async def wait_closed(self) -> None:
log.debug("wait_close: starting")
await self._wait_closed.wait()
log.debug("wait_close: done")
@@ -0,0 +1,19 @@
from .build_client_schema import build_client_schema
from .get_introspection_query_ast import get_introspection_query_ast
from .node_tree import node_tree
from .parse_result import parse_result
from .serialize_variable_values import serialize_value, serialize_variable_values
from .update_schema_enum import update_schema_enum
from .update_schema_scalars import update_schema_scalar, update_schema_scalars
__all__ = [
"build_client_schema",
"node_tree",
"parse_result",
"get_introspection_query_ast",
"serialize_variable_values",
"serialize_value",
"update_schema_enum",
"update_schema_scalars",
"update_schema_scalar",
]
@@ -0,0 +1,98 @@
from graphql import GraphQLSchema, IntrospectionQuery
from graphql import build_client_schema as build_client_schema_orig
from graphql.pyutils import inspect
from graphql.utilities.get_introspection_query import (
DirectiveLocation,
IntrospectionDirective,
)
__all__ = ["build_client_schema"]
INCLUDE_DIRECTIVE_JSON: IntrospectionDirective = {
"name": "include",
"description": (
"Directs the executor to include this field or fragment "
"only when the `if` argument is true."
),
"locations": [
DirectiveLocation.FIELD,
DirectiveLocation.FRAGMENT_SPREAD,
DirectiveLocation.INLINE_FRAGMENT,
],
"args": [
{
"name": "if",
"description": "Included when true.",
"type": {
"kind": "NON_NULL",
"name": "None",
"ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"},
},
"defaultValue": "None",
}
],
}
SKIP_DIRECTIVE_JSON: IntrospectionDirective = {
"name": "skip",
"description": (
"Directs the executor to skip this field or fragment "
"when the `if` argument is true."
),
"locations": [
DirectiveLocation.FIELD,
DirectiveLocation.FRAGMENT_SPREAD,
DirectiveLocation.INLINE_FRAGMENT,
],
"args": [
{
"name": "if",
"description": "Skipped when true.",
"type": {
"kind": "NON_NULL",
"name": "None",
"ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"},
},
"defaultValue": "None",
}
],
}
def build_client_schema(introspection: IntrospectionQuery) -> GraphQLSchema:
"""This is an alternative to the graphql-core function
:code:`build_client_schema` but with default include and skip directives
added to the schema to fix
`issue #278 <https://github.com/graphql-python/gql/issues/278>`_
.. warning::
This function will be removed once the issue
`graphql-js#3419 <https://github.com/graphql/graphql-js/issues/3419>`_
has been fixed and ported to graphql-core so don't use it
outside gql.
"""
if not isinstance(introspection, dict) or not isinstance(
introspection.get("__schema"), dict
):
raise TypeError(
"Invalid or incomplete introspection result. Ensure that you"
" are passing the 'data' attribute of an introspection response"
f" and no 'errors' were returned alongside: {inspect(introspection)}."
)
schema_introspection = introspection["__schema"]
directives = schema_introspection.get("directives", None)
if directives is None:
schema_introspection["directives"] = directives = []
if not any(directive["name"] == "skip" for directive in directives):
directives.append(SKIP_DIRECTIVE_JSON)
if not any(directive["name"] == "include" for directive in directives):
directives.append(INCLUDE_DIRECTIVE_JSON)
return build_client_schema_orig(introspection, assume_valid=False)
@@ -0,0 +1,142 @@
from itertools import repeat
from graphql import DocumentNode, GraphQLSchema
from gql.dsl import DSLFragment, DSLMetaField, DSLQuery, DSLSchema, dsl_gql
def get_introspection_query_ast(
descriptions: bool = True,
specified_by_url: bool = False,
directive_is_repeatable: bool = False,
schema_description: bool = False,
input_value_deprecation: bool = False,
type_recursion_level: int = 7,
) -> DocumentNode:
"""Get a query for introspection as a document using the DSL module.
Equivalent to the get_introspection_query function from graphql-core
but using the DSL module and allowing to select the recursion level.
Optionally, you can exclude descriptions, include specification URLs,
include repeatability of directives, and specify whether to include
the schema description as well.
"""
ds = DSLSchema(GraphQLSchema())
fragment_FullType = DSLFragment("FullType").on(ds.__Type)
fragment_InputValue = DSLFragment("InputValue").on(ds.__InputValue)
fragment_TypeRef = DSLFragment("TypeRef").on(ds.__Type)
schema = DSLMetaField("__schema")
if descriptions and schema_description:
schema.select(ds.__Schema.description)
schema.select(
ds.__Schema.queryType.select(ds.__Type.name),
ds.__Schema.mutationType.select(ds.__Type.name),
ds.__Schema.subscriptionType.select(ds.__Type.name),
)
schema.select(ds.__Schema.types.select(fragment_FullType))
directives = ds.__Schema.directives.select(ds.__Directive.name)
deprecated_expand = {}
if input_value_deprecation:
deprecated_expand = {
"includeDeprecated": True,
}
if descriptions:
directives.select(ds.__Directive.description)
if directive_is_repeatable:
directives.select(ds.__Directive.isRepeatable)
directives.select(
ds.__Directive.locations,
ds.__Directive.args(**deprecated_expand).select(fragment_InputValue),
)
schema.select(directives)
fragment_FullType.select(
ds.__Type.kind,
ds.__Type.name,
)
if descriptions:
fragment_FullType.select(ds.__Type.description)
if specified_by_url:
fragment_FullType.select(ds.__Type.specifiedByURL)
fields = ds.__Type.fields(includeDeprecated=True).select(ds.__Field.name)
if descriptions:
fields.select(ds.__Field.description)
fields.select(
ds.__Field.args(**deprecated_expand).select(fragment_InputValue),
ds.__Field.type.select(fragment_TypeRef),
ds.__Field.isDeprecated,
ds.__Field.deprecationReason,
)
enum_values = ds.__Type.enumValues(includeDeprecated=True).select(
ds.__EnumValue.name
)
if descriptions:
enum_values.select(ds.__EnumValue.description)
enum_values.select(
ds.__EnumValue.isDeprecated,
ds.__EnumValue.deprecationReason,
)
fragment_FullType.select(
fields,
ds.__Type.inputFields(**deprecated_expand).select(fragment_InputValue),
ds.__Type.interfaces.select(fragment_TypeRef),
enum_values,
ds.__Type.possibleTypes.select(fragment_TypeRef),
)
fragment_InputValue.select(ds.__InputValue.name)
if descriptions:
fragment_InputValue.select(ds.__InputValue.description)
fragment_InputValue.select(
ds.__InputValue.type.select(fragment_TypeRef),
ds.__InputValue.defaultValue,
)
if input_value_deprecation:
fragment_InputValue.select(
ds.__InputValue.isDeprecated,
ds.__InputValue.deprecationReason,
)
fragment_TypeRef.select(
ds.__Type.kind,
ds.__Type.name,
)
if type_recursion_level >= 1:
current_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
fragment_TypeRef.select(current_field)
for _ in repeat(None, type_recursion_level - 1):
new_oftype = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
current_field.select(new_oftype)
current_field = new_oftype
query = DSLQuery(schema)
query.name = "IntrospectionQuery"
dsl_query = dsl_gql(query, fragment_FullType, fragment_InputValue, fragment_TypeRef)
return dsl_query
@@ -0,0 +1,92 @@
from typing import Any, Iterable, List, Optional, Sized
from graphql import Node
def _node_tree_recursive(
obj: Any,
*,
indent: int = 0,
ignored_keys: List,
):
assert ignored_keys is not None
results = []
if hasattr(obj, "__slots__"):
results.append(" " * indent + f"{type(obj).__name__}")
try:
keys = sorted(obj.keys)
except AttributeError:
# If the object has no keys attribute, print its repr and return.
results.append(" " * (indent + 1) + repr(obj))
else:
for key in keys:
if key in ignored_keys:
continue
attr_value = getattr(obj, key, None)
results.append(" " * (indent + 1) + f"{key}:")
if isinstance(attr_value, Iterable) and not isinstance(
attr_value, (str, bytes)
):
if isinstance(attr_value, Sized) and len(attr_value) == 0:
results.append(
" " * (indent + 2) + f"empty {type(attr_value).__name__}"
)
else:
for item in attr_value:
results.append(
_node_tree_recursive(
item,
indent=indent + 2,
ignored_keys=ignored_keys,
)
)
else:
results.append(
_node_tree_recursive(
attr_value,
indent=indent + 2,
ignored_keys=ignored_keys,
)
)
else:
results.append(" " * indent + repr(obj))
return "\n".join(results)
def node_tree(
obj: Node,
*,
ignore_loc: bool = True,
ignore_block: bool = True,
ignored_keys: Optional[List] = None,
):
"""Method which returns a tree of Node elements as a String.
Useful to debug deep DocumentNode instances created by gql or dsl_gql.
NOTE: from gql version 3.6.0b4 the elements of each node are sorted to ignore
small changes in graphql-core
WARNING: the output of this method is not guaranteed and may change without notice.
"""
assert isinstance(obj, Node)
if ignored_keys is None:
ignored_keys = []
if ignore_loc:
# We are ignoring loc attributes by default
ignored_keys.append("loc")
if ignore_block:
# We are ignoring block attributes by default (in StringValueNode)
ignored_keys.append("block")
return _node_tree_recursive(obj, ignored_keys=ignored_keys)
@@ -0,0 +1,446 @@
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from graphql import (
IDLE,
REMOVE,
DocumentNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
GraphQLError,
GraphQLInterfaceType,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLType,
InlineFragmentNode,
NameNode,
Node,
OperationDefinitionNode,
SelectionSetNode,
TypeInfo,
TypeInfoVisitor,
Visitor,
is_leaf_type,
print_ast,
visit,
)
from graphql.language.visitor import VisitorActionEnum
from graphql.pyutils import inspect
log = logging.getLogger(__name__)
# Equivalent to QUERY_DOCUMENT_KEYS but only for fields interesting to
# visit to parse the results
RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
"document": ("definitions",),
"operation_definition": ("selection_set",),
"selection_set": ("selections",),
"field": ("selection_set",),
"inline_fragment": ("selection_set",),
"fragment_definition": ("selection_set",),
}
def _ignore_non_null(type_: GraphQLType):
"""Removes the GraphQLNonNull wrappings around types."""
if isinstance(type_, GraphQLNonNull):
return type_.of_type
else:
return type_
def _get_fragment(document, fragment_name):
"""Returns a fragment from the document."""
for definition in document.definitions:
if isinstance(definition, FragmentDefinitionNode):
if definition.name.value == fragment_name:
return definition
raise GraphQLError(f'Fragment "{fragment_name}" not found in document!')
class ParseResultVisitor(Visitor):
def __init__(
self,
schema: GraphQLSchema,
document: DocumentNode,
node: Node,
result: Dict[str, Any],
type_info: TypeInfo,
visit_fragment: bool = False,
inside_list_level: int = 0,
operation_name: Optional[str] = None,
):
"""Recursive Implementation of a Visitor class to parse results
correspondind to a schema and a document.
Using a TypeInfo class to get the node types during traversal.
If we reach a list in the results, then we parse each
item of the list recursively, traversing the same nodes
of the query again.
During traversal, we keep the current position in the result
in the result_stack field.
Alongside the field type, we calculate the "result type"
which is computed from the field type and the current
recursive level we are for this field
(:code:`inside_list_level` argument).
"""
self.schema: GraphQLSchema = schema
self.document: DocumentNode = document
self.node: Node = node
self.result: Dict[str, Any] = result
self.type_info: TypeInfo = type_info
self.visit_fragment: bool = visit_fragment
self.inside_list_level = inside_list_level
self.operation_name = operation_name
self.result_stack: List[Any] = []
super().__init__()
@property
def current_result(self):
try:
return self.result_stack[-1]
except IndexError:
return self.result
@staticmethod
def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]:
results = cast(List[Dict[str, Any]], node.definitions)
return {k: v for result in results for k, v in result.items()}
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> Union[None, VisitorActionEnum]:
if self.operation_name is not None:
if not hasattr(node.name, "value"):
return REMOVE # pragma: no cover
node.name = cast(NameNode, node.name)
if node.name.value != self.operation_name:
log.debug(f"SKIPPING operation {node.name.value}")
return REMOVE
return IDLE
@staticmethod
def leave_operation_definition(
node: OperationDefinitionNode, *_args: Any
) -> Dict[str, Any]:
selections = cast(List[Dict[str, Any]], node.selection_set)
return {k: v for s in selections for k, v in s.items()}
@staticmethod
def leave_selection_set(node: SelectionSetNode, *_args: Any) -> Dict[str, Any]:
partial_results = cast(Dict[str, Any], node.selections)
return partial_results
@staticmethod
def in_first_field(path):
return path.count("selections") <= 1
def get_current_result_type(self, path):
field_type = self.type_info.get_type()
list_level = self.inside_list_level
result_type = _ignore_non_null(field_type)
if self.in_first_field(path):
while list_level > 0:
assert isinstance(result_type, GraphQLList)
result_type = _ignore_non_null(result_type.of_type)
list_level -= 1
return result_type
def enter_field(
self,
node: FieldNode,
key: str,
parent: Node,
path: List[Node],
ancestors: List[Node],
) -> Union[None, VisitorActionEnum, Dict[str, Any]]:
name = node.alias.value if node.alias else node.name.value
if log.isEnabledFor(logging.DEBUG):
log.debug(f"Enter field {name}")
log.debug(f" path={path!r}")
log.debug(f" current_result={self.current_result!r}")
if self.current_result is None:
# Result was null for this field -> remove
return REMOVE
elif isinstance(self.current_result, Mapping):
try:
result_value = self.current_result[name]
except KeyError:
# Key not found in result.
# Should never happen in theory with a correct GraphQL backend
# Silently ignoring this field
log.debug(f" Key {name} not found in result --> REMOVE")
return REMOVE
log.debug(f" result_value={result_value}")
# We get the field_type from type_info
field_type = self.type_info.get_type()
# We calculate a virtual "result type" depending on our recursion level.
result_type = self.get_current_result_type(path)
# If the result for this field is a list, then we need
# to recursively visit the same node multiple times for each
# item in the list.
if (
not isinstance(result_value, Mapping)
and isinstance(result_value, Iterable)
and not isinstance(result_value, str)
and not is_leaf_type(result_type)
):
# Finding out the inner type of the list
inner_type = _ignore_non_null(result_type.of_type)
if log.isEnabledFor(logging.DEBUG):
log.debug(" List detected:")
log.debug(f" field_type={inspect(field_type)}")
log.debug(f" result_type={inspect(result_type)}")
log.debug(f" inner_type={inspect(inner_type)}\n")
visits: List[Dict[str, Any]] = []
# Get parent type
initial_type = self.type_info.get_parent_type()
assert isinstance(
initial_type, (GraphQLObjectType, GraphQLInterfaceType)
)
# Get parent SelectionSet node
selection_set_node = ancestors[-1]
assert isinstance(selection_set_node, SelectionSetNode)
# Keep only the current node in a new selection set node
new_node = SelectionSetNode(selections=[node])
for item in result_value:
new_result = {name: item}
if log.isEnabledFor(logging.DEBUG):
log.debug(f" recursive new_result={new_result}")
log.debug(f" recursive ast={print_ast(node)}")
log.debug(f" recursive path={path!r}")
log.debug(f" recursive initial_type={initial_type!r}\n")
if self.in_first_field(path):
inside_list_level = self.inside_list_level + 1
else:
inside_list_level = 1
inner_visit = parse_result_recursive(
self.schema,
self.document,
new_node,
new_result,
initial_type=initial_type,
inside_list_level=inside_list_level,
)
log.debug(f" recursive result={inner_visit}\n")
inner_visit = cast(List[Dict[str, Any]], inner_visit)
visits.append(inner_visit[0][name])
result_value = {name: visits}
log.debug(f" recursive visits final result = {result_value}\n")
return result_value
# If the result for this field is not a list, then add it
# to the result stack so that it becomes the current_value
# for the next inner fields
self.result_stack.append(result_value)
return IDLE
raise GraphQLError(
f"Invalid result for container of field {name}: {self.current_result!r}"
)
def leave_field(
self,
node: FieldNode,
key: str,
parent: Node,
path: List[Node],
ancestors: List[Node],
) -> Dict[str, Any]:
name = cast(str, node.alias.value if node.alias else node.name.value)
log.debug(f"Leave field {name}")
if self.current_result is None:
return_value = None
elif node.selection_set is None:
field_type = self.type_info.get_type()
result_type = self.get_current_result_type(path)
if log.isEnabledFor(logging.DEBUG):
log.debug(f" field type of {name} is {inspect(field_type)}")
log.debug(f" result type of {name} is {inspect(result_type)}")
assert is_leaf_type(result_type)
# Finally parsing a single scalar using the parse_value method
return_value = result_type.parse_value(self.current_result)
else:
partial_results = cast(List[Dict[str, Any]], node.selection_set)
return_value = {k: v for pr in partial_results for k, v in pr.items()}
# Go up a level in the result stack
self.result_stack.pop()
log.debug(f"Leave field {name}: returning {return_value}")
return {name: return_value}
# Fragments
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> Union[None, VisitorActionEnum]:
if log.isEnabledFor(logging.DEBUG):
log.debug(f"Enter fragment definition {node.name.value}.")
log.debug(f"visit_fragment={self.visit_fragment!s}")
if self.visit_fragment:
return IDLE
else:
return REMOVE
@staticmethod
def leave_fragment_definition(
node: FragmentDefinitionNode, *_args: Any
) -> Dict[str, Any]:
selections = cast(List[Dict[str, Any]], node.selection_set)
return {k: v for s in selections for k, v in s.items()}
def leave_fragment_spread(
self, node: FragmentSpreadNode, *_args: Any
) -> Dict[str, Any]:
fragment_name = node.name.value
log.debug(f"Start recursive fragment visit {fragment_name}")
fragment_node = _get_fragment(self.document, fragment_name)
fragment_result = parse_result_recursive(
self.schema,
self.document,
fragment_node,
self.current_result,
visit_fragment=True,
)
log.debug(
f"Result of recursive fragment visit {fragment_name}: {fragment_result}"
)
return cast(Dict[str, Any], fragment_result)
@staticmethod
def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, Any]:
selections = cast(List[Dict[str, Any]], node.selection_set)
return {k: v for s in selections for k, v in s.items()}
def parse_result_recursive(
schema: GraphQLSchema,
document: DocumentNode,
node: Node,
result: Optional[Dict[str, Any]],
initial_type: Optional[GraphQLType] = None,
inside_list_level: int = 0,
visit_fragment: bool = False,
operation_name: Optional[str] = None,
) -> Any:
if result is None:
return None
type_info = TypeInfo(schema, initial_type=initial_type)
visited = visit(
node,
TypeInfoVisitor(
type_info,
ParseResultVisitor(
schema,
document,
node,
result,
type_info=type_info,
inside_list_level=inside_list_level,
visit_fragment=visit_fragment,
operation_name=operation_name,
),
),
visitor_keys=RESULT_DOCUMENT_KEYS,
)
return visited
def parse_result(
schema: GraphQLSchema,
document: DocumentNode,
result: Optional[Dict[str, Any]],
operation_name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Unserialize a result received from a GraphQL backend.
:param schema: the GraphQL schema
:param document: the document representing the query sent to the backend
:param result: the serialized result received from the backend
:param operation_name: the optional operation name
:returns: a parsed result with scalars and enums parsed depending on
their definition in the schema.
Given a schema, a query and a serialized result,
provide a new result with parsed values.
If the result contains only built-in GraphQL scalars (String, Int, Float, ...)
then the parsed result should be unchanged.
If the result contains custom scalars or enums, then those values
will be parsed with the parse_value method of the custom scalar or enum
definition in the schema."""
return parse_result_recursive(
schema, document, document, result, operation_name=operation_name
)
@@ -0,0 +1,130 @@
from typing import Any, Dict, Optional
from graphql import (
DocumentNode,
GraphQLEnumType,
GraphQLError,
GraphQLInputObjectType,
GraphQLList,
GraphQLNonNull,
GraphQLScalarType,
GraphQLSchema,
GraphQLType,
GraphQLWrappingType,
OperationDefinitionNode,
type_from_ast,
)
from graphql.pyutils import inspect
def _get_document_operation(
document: DocumentNode, operation_name: Optional[str] = None
) -> OperationDefinitionNode:
"""Returns the operation which should be executed in the document.
Raises a GraphQLError if a single operation cannot be retrieved.
"""
operation: Optional[OperationDefinitionNode] = None
for definition in document.definitions:
if isinstance(definition, OperationDefinitionNode):
if operation_name is None:
if operation:
raise GraphQLError(
"Must provide operation name"
" if query contains multiple operations."
)
operation = definition
elif definition.name and definition.name.value == operation_name:
operation = definition
if not operation:
if operation_name is not None:
raise GraphQLError(f"Unknown operation named '{operation_name}'.")
# The following line should never happen normally as the document is
# already verified before calling this function.
raise GraphQLError("Must provide an operation.") # pragma: no cover
return operation
def serialize_value(type_: GraphQLType, value: Any) -> Any:
"""Given a GraphQL type and a Python value, return the serialized value.
This method will serialize the value recursively, entering into
lists and dicts.
Can be used to serialize Enums and/or Custom Scalars in variable values.
:param type_: the GraphQL type
:param value: the provided value
"""
if value is None:
if isinstance(type_, GraphQLNonNull):
# raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.")
raise GraphQLError(f"Type {inspect(type_)} Cannot be None.")
else:
return None
if isinstance(type_, GraphQLWrappingType):
inner_type = type_.of_type
if isinstance(type_, GraphQLNonNull):
return serialize_value(inner_type, value)
elif isinstance(type_, GraphQLList):
return [serialize_value(inner_type, v) for v in value]
elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)):
return type_.serialize(value)
elif isinstance(type_, GraphQLInputObjectType):
return {
field_name: serialize_value(field.type, value[field_name])
for field_name, field in type_.fields.items()
if field_name in value
}
raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.")
def serialize_variable_values(
schema: GraphQLSchema,
document: DocumentNode,
variable_values: Dict[str, Any],
operation_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Given a GraphQL document and a schema, serialize the Dictionary of
variable values.
Useful to serialize Enums and/or Custom Scalars in variable values.
:param schema: the GraphQL schema
:param document: the document representing the query sent to the backend
:param variable_values: the dictionnary of variable values which needs
to be serialized.
:param operation_name: the optional operation_name for the query.
"""
parsed_variable_values: Dict[str, Any] = {}
# Find the operation in the document
operation = _get_document_operation(document, operation_name=operation_name)
# Serialize every variable value defined for the operation
for var_def_node in operation.variable_definitions:
var_name = var_def_node.variable.name.value
var_type = type_from_ast(schema, var_def_node.type)
if var_name in variable_values:
assert var_type is not None
var_value = variable_values[var_name]
parsed_variable_values[var_name] = serialize_value(var_type, var_value)
return parsed_variable_values
@@ -0,0 +1,69 @@
from enum import Enum
from typing import Any, Dict, Mapping, Type, Union, cast
from graphql import GraphQLEnumType, GraphQLSchema
def update_schema_enum(
schema: GraphQLSchema,
name: str,
values: Union[Dict[str, Any], Type[Enum]],
use_enum_values: bool = False,
):
"""Update in the schema the GraphQLEnumType corresponding to the given name.
Example::
from enum import Enum
class Color(Enum):
RED = 0
GREEN = 1
BLUE = 2
update_schema_enum(schema, 'Color', Color)
:param schema: a GraphQL Schema already containing the GraphQLEnumType type.
:param name: the name of the enum in the GraphQL schema
:param values: Either a Python Enum or a dict of values. The keys of the provided
values should correspond to the keys of the existing enum in the schema.
:param use_enum_values: By default, we configure the GraphQLEnumType to serialize
to enum instances (ie: .parse_value() returns Color.RED).
If use_enum_values is set to True, then .parse_value() returns 0.
use_enum_values=True is the defaut behaviour when passing an Enum
to a GraphQLEnumType.
"""
# Convert Enum values to Dict
if isinstance(values, type):
if issubclass(values, Enum):
values = cast(Type[Enum], values)
if use_enum_values:
values = {enum.name: enum.value for enum in values}
else:
values = {enum.name: enum for enum in values}
if not isinstance(values, Mapping):
raise TypeError(f"Invalid type for enum values: {type(values)}")
# Find enum type in schema
schema_enum = schema.get_type(name)
if schema_enum is None:
raise KeyError(f"Enum {name} not found in schema!")
if not isinstance(schema_enum, GraphQLEnumType):
raise TypeError(
f'The type "{name}" is not a GraphQLEnumType, it is a {type(schema_enum)}'
)
# Replace all enum values
for enum_name, enum_value in schema_enum.values.items():
try:
enum_value.value = values[enum_name]
except KeyError:
raise KeyError(f'Enum key "{enum_name}" not found in provided values!')
# Delete the _value_lookup cached property
if "_value_lookup" in schema_enum.__dict__:
del schema_enum.__dict__["_value_lookup"]
@@ -0,0 +1,60 @@
from typing import Iterable, List
from graphql import GraphQLScalarType, GraphQLSchema
def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType):
"""Update the scalar in a schema with the scalar provided.
:param schema: the GraphQL schema
:param name: the name of the custom scalar type in the schema
:param scalar: a provided scalar type
This can be used to update the default Custom Scalar implementation
when the schema has been provided from a text file or from introspection.
"""
if not isinstance(scalar, GraphQLScalarType):
raise TypeError("Scalars should be instances of GraphQLScalarType.")
schema_scalar = schema.get_type(name)
if schema_scalar is None:
raise KeyError(f"Scalar '{name}' not found in schema.")
if not isinstance(schema_scalar, GraphQLScalarType):
raise TypeError(
f'The type "{name}" is not a GraphQLScalarType,'
f" it is a {type(schema_scalar)}"
)
# Update the conversion methods
# Using setattr because mypy has a false positive
# https://github.com/python/mypy/issues/2427
setattr(schema_scalar, "serialize", scalar.serialize)
setattr(schema_scalar, "parse_value", scalar.parse_value)
setattr(schema_scalar, "parse_literal", scalar.parse_literal)
def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]):
"""Update the scalars in a schema with the scalars provided.
:param schema: the GraphQL schema
:param scalars: a list of provided scalar types
This can be used to update the default Custom Scalar implementation
when the schema has been provided from a text file or from introspection.
If the name of the provided scalar is different than the name of
the custom scalar, then you should use the
:func:`update_schema_scalar <gql.utilities.update_schema_scalar>` method instead.
"""
if not isinstance(scalars, Iterable):
raise TypeError("Scalars argument should be a list of scalars.")
for scalar in scalars:
if not isinstance(scalar, GraphQLScalarType):
raise TypeError("Scalars should be instances of GraphQLScalarType.")
update_schema_scalar(schema, scalar.name, scalar)
@@ -0,0 +1,58 @@
"""Utilities to manipulate several python objects."""
from typing import Any, Dict, List, Tuple, Type
# From this response in Stackoverflow
# http://stackoverflow.com/a/19053800/1072990
def to_camel_case(snake_str):
components = snake_str.split("_")
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + "".join(x.title() if x else "_" for x in components[1:])
def extract_files(
variables: Dict, file_classes: Tuple[Type[Any], ...]
) -> Tuple[Dict, Dict]:
files = {}
def recurse_extract(path, obj):
"""
recursively traverse obj, doing a deepcopy, but
replacing any file-like objects with nulls and
shunting the originals off to the side.
"""
nonlocal files
if isinstance(obj, list):
nulled_obj = []
for key, value in enumerate(obj):
value = recurse_extract(f"{path}.{key}", value)
nulled_obj.append(value)
return nulled_obj
elif isinstance(obj, dict):
nulled_obj = {}
for key, value in obj.items():
value = recurse_extract(f"{path}.{key}", value)
nulled_obj[key] = value
return nulled_obj
elif isinstance(obj, file_classes):
# extract obj from its parent and put it into files instead.
files[path] = obj
return None
else:
# base case: pass through unchanged
return obj
nulled_variables = recurse_extract("variables", variables)
return nulled_variables, files
def str_first_element(errors: List) -> str:
try:
first_error = errors[0]
except (KeyError, TypeError):
first_error = errors
return str(first_error)