2025-12-01
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
"""The primary :mod:`gql` package includes everything you need to
|
||||
execute GraphQL requests, with the exception of the transports
|
||||
which are optional:
|
||||
|
||||
- the :func:`gql <gql.gql>` method to parse a GraphQL query
|
||||
- the :class:`Client <gql.Client>` class as the entrypoint to execute requests
|
||||
and create sessions
|
||||
"""
|
||||
|
||||
from .__version__ import __version__
|
||||
from .client import Client
|
||||
from .gql import gql
|
||||
from .graphql_request import GraphQLRequest
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"gql",
|
||||
"Client",
|
||||
"GraphQLRequest",
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "3.5.3"
|
||||
@@ -0,0 +1,541 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal as signal_module
|
||||
import sys
|
||||
import textwrap
|
||||
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from graphql import GraphQLError, print_schema
|
||||
from yarl import URL
|
||||
|
||||
from gql import Client, __version__, gql
|
||||
from gql.transport import AsyncTransport
|
||||
from gql.transport.exceptions import TransportQueryError
|
||||
|
||||
description = """
|
||||
Send GraphQL queries from the command line using http(s) or websockets.
|
||||
If used interactively, write your query, then use Ctrl-D (EOF) to execute it.
|
||||
"""
|
||||
|
||||
examples = """
|
||||
EXAMPLES
|
||||
========
|
||||
|
||||
# Simple query using https
|
||||
echo 'query { continent(code:"AF") { name } }' | \
|
||||
gql-cli https://countries.trevorblades.com
|
||||
|
||||
# Simple query using websockets
|
||||
echo 'query { continent(code:"AF") { name } }' | \
|
||||
gql-cli wss://countries.trevorblades.com/graphql
|
||||
|
||||
# Query with variable
|
||||
echo 'query getContinent($code:ID!) { continent(code:$code) { name } }' | \
|
||||
gql-cli https://countries.trevorblades.com --variables code:AF
|
||||
|
||||
# Interactive usage (insert your query in the terminal, then press Ctrl-D to execute it)
|
||||
gql-cli wss://countries.trevorblades.com/graphql --variables code:AF
|
||||
|
||||
# Execute query saved in a file
|
||||
cat query.gql | gql-cli wss://countries.trevorblades.com/graphql
|
||||
|
||||
# Print the schema of the backend
|
||||
gql-cli https://countries.trevorblades.com/graphql --print-schema
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def positive_int_or_none(value_str: str) -> Optional[int]:
|
||||
"""Convert a string argument value into either an int or None.
|
||||
|
||||
Raise a ValueError if the argument is negative or a string which is not "none"
|
||||
"""
|
||||
try:
|
||||
value_int = int(value_str)
|
||||
except ValueError:
|
||||
if value_str.lower() == "none":
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
if value_int < 0:
|
||||
raise ValueError
|
||||
|
||||
return value_int
|
||||
|
||||
|
||||
def get_parser(with_examples: bool = False) -> ArgumentParser:
|
||||
"""Provides an ArgumentParser for the gql-cli script.
|
||||
|
||||
This function is also used by sphinx to generate the script documentation.
|
||||
|
||||
:param with_examples: set to False by default so that the examples are not
|
||||
present in the sphinx docs (they are put there with
|
||||
a different layout)
|
||||
"""
|
||||
|
||||
parser = ArgumentParser(
|
||||
description=description,
|
||||
epilog=examples if with_examples else None,
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"server", help="the server url starting with http://, https://, ws:// or wss://"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-V",
|
||||
"--variables",
|
||||
nargs="*",
|
||||
help="query variables in the form key:json_value",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-H", "--headers", nargs="*", help="http headers in the form key:value"
|
||||
)
|
||||
parser.add_argument("--version", action="version", version=f"v{__version__}")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
help="print lots of debugging statements (loglevel==DEBUG)",
|
||||
action="store_const",
|
||||
dest="loglevel",
|
||||
const=logging.DEBUG,
|
||||
)
|
||||
group.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="show low level messages (loglevel==INFO)",
|
||||
action="store_const",
|
||||
dest="loglevel",
|
||||
const=logging.INFO,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--operation-name",
|
||||
help="set the operation_name value",
|
||||
dest="operation_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-schema",
|
||||
help="get the schema from instrospection and print it",
|
||||
action="store_true",
|
||||
dest="print_schema",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schema-download",
|
||||
nargs="*",
|
||||
help=textwrap.dedent(
|
||||
"""select the introspection query arguments to download the schema.
|
||||
Only useful if --print-schema is used.
|
||||
By default, it will:
|
||||
|
||||
- request field descriptions
|
||||
- not request deprecated input fields
|
||||
|
||||
Possible options:
|
||||
|
||||
- descriptions:false for a compact schema without comments
|
||||
- input_value_deprecation:true to download deprecated input fields
|
||||
- specified_by_url:true
|
||||
- schema_description:true
|
||||
- directive_is_repeatable:true"""
|
||||
),
|
||||
dest="schema_download",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--execute-timeout",
|
||||
help="set the execute_timeout argument of the Client (default: 10)",
|
||||
type=positive_int_or_none,
|
||||
default=10,
|
||||
dest="execute_timeout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
default="auto",
|
||||
choices=[
|
||||
"auto",
|
||||
"aiohttp",
|
||||
"phoenix",
|
||||
"websockets",
|
||||
"appsync_http",
|
||||
"appsync_websockets",
|
||||
],
|
||||
help=(
|
||||
"select the transport. 'auto' by default: "
|
||||
"aiohttp or websockets depending on url scheme"
|
||||
),
|
||||
dest="transport",
|
||||
)
|
||||
|
||||
appsync_description = """
|
||||
By default, for an AppSync backend, the IAM authentication is chosen.
|
||||
|
||||
If you want API key or JWT authentication, you can provide one of the
|
||||
following arguments:"""
|
||||
|
||||
appsync_group = parser.add_argument_group(
|
||||
"AWS AppSync options", description=appsync_description
|
||||
)
|
||||
|
||||
appsync_auth_group = appsync_group.add_mutually_exclusive_group()
|
||||
|
||||
appsync_auth_group.add_argument(
|
||||
"--api-key",
|
||||
help="Provide an API key for authentication",
|
||||
dest="api_key",
|
||||
)
|
||||
|
||||
appsync_auth_group.add_argument(
|
||||
"--jwt",
|
||||
help="Provide an JSON Web token for authentication",
|
||||
dest="jwt",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_transport_args(args: Namespace) -> Dict[str, Any]:
|
||||
"""Extract extra arguments necessary for the transport
|
||||
from the parsed command line args
|
||||
|
||||
Will create a headers dict by splitting the colon
|
||||
in the --headers arguments
|
||||
|
||||
:param args: parsed command line arguments
|
||||
"""
|
||||
|
||||
transport_args: Dict[str, Any] = {}
|
||||
|
||||
# Parse the headers argument
|
||||
headers = {}
|
||||
if args.headers is not None:
|
||||
for header in args.headers:
|
||||
|
||||
try:
|
||||
# Split only the first colon (throw a ValueError if no colon is present)
|
||||
header_key, header_value = header.split(":", 1)
|
||||
|
||||
headers[header_key] = header_value
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid header: {header}")
|
||||
|
||||
if args.headers is not None:
|
||||
transport_args["headers"] = headers
|
||||
|
||||
return transport_args
|
||||
|
||||
|
||||
def get_execute_args(args: Namespace) -> Dict[str, Any]:
|
||||
"""Extract extra arguments necessary for the execute or subscribe
|
||||
methods from the parsed command line args
|
||||
|
||||
Extract the operation_name
|
||||
|
||||
Extract the variable_values from the --variables argument
|
||||
by splitting the first colon, then loads the json value,
|
||||
We try to add double quotes around the value if it does not work first
|
||||
in order to simplify the passing of simple string values
|
||||
(we allow --variables KEY:VALUE instead of KEY:\"VALUE\")
|
||||
|
||||
:param args: parsed command line arguments
|
||||
"""
|
||||
|
||||
execute_args: Dict[str, Any] = {}
|
||||
|
||||
# Parse the operation_name argument
|
||||
if args.operation_name is not None:
|
||||
execute_args["operation_name"] = args.operation_name
|
||||
|
||||
# Parse the variables argument
|
||||
if args.variables is not None:
|
||||
|
||||
variables = {}
|
||||
|
||||
for var in args.variables:
|
||||
|
||||
try:
|
||||
# Split only the first colon
|
||||
# (throw a ValueError if no colon is present)
|
||||
variable_key, variable_json_value = var.split(":", 1)
|
||||
|
||||
# Extract the json value,
|
||||
# trying with double quotes if it does not work
|
||||
try:
|
||||
variable_value = json.loads(variable_json_value)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
variable_value = json.loads(f'"{variable_json_value}"')
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError
|
||||
|
||||
# Save the value in the variables dict
|
||||
variables[variable_key] = variable_value
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid variable: {var}")
|
||||
|
||||
execute_args["variable_values"] = variables
|
||||
|
||||
return execute_args
|
||||
|
||||
|
||||
def autodetect_transport(url: URL) -> str:
|
||||
"""Detects which transport should be used depending on url."""
|
||||
|
||||
if url.scheme in ["ws", "wss"]:
|
||||
transport_name = "websockets"
|
||||
|
||||
else:
|
||||
assert url.scheme in ["http", "https"]
|
||||
transport_name = "aiohttp"
|
||||
|
||||
return transport_name
|
||||
|
||||
|
||||
def get_transport(args: Namespace) -> Optional[AsyncTransport]:
|
||||
"""Instantiate a transport from the parsed command line arguments
|
||||
|
||||
:param args: parsed command line arguments
|
||||
"""
|
||||
|
||||
# Get the url scheme from server parameter
|
||||
url = URL(args.server)
|
||||
|
||||
# Validate scheme
|
||||
if url.scheme not in ["http", "https", "ws", "wss"]:
|
||||
raise ValueError("URL protocol should be one of: http, https, ws, wss")
|
||||
|
||||
# Get extra transport parameters from command line arguments
|
||||
# (headers)
|
||||
transport_args = get_transport_args(args)
|
||||
|
||||
# Either use the requested transport or autodetect it
|
||||
if args.transport == "auto":
|
||||
transport_name = autodetect_transport(url)
|
||||
else:
|
||||
transport_name = args.transport
|
||||
|
||||
# Import the correct transport class depending on the transport name
|
||||
if transport_name == "aiohttp":
|
||||
from gql.transport.aiohttp import AIOHTTPTransport
|
||||
|
||||
return AIOHTTPTransport(url=args.server, **transport_args)
|
||||
|
||||
elif transport_name == "phoenix":
|
||||
from gql.transport.phoenix_channel_websockets import (
|
||||
PhoenixChannelWebsocketsTransport,
|
||||
)
|
||||
|
||||
return PhoenixChannelWebsocketsTransport(url=args.server, **transport_args)
|
||||
|
||||
elif transport_name == "websockets":
|
||||
from gql.transport.websockets import WebsocketsTransport
|
||||
|
||||
transport_args["ssl"] = url.scheme == "wss"
|
||||
|
||||
return WebsocketsTransport(url=args.server, **transport_args)
|
||||
|
||||
else:
|
||||
|
||||
from gql.transport.appsync_auth import AppSyncAuthentication
|
||||
|
||||
assert transport_name in ["appsync_http", "appsync_websockets"]
|
||||
assert url.host is not None
|
||||
|
||||
auth: AppSyncAuthentication
|
||||
|
||||
if args.api_key:
|
||||
from gql.transport.appsync_auth import AppSyncApiKeyAuthentication
|
||||
|
||||
auth = AppSyncApiKeyAuthentication(host=url.host, api_key=args.api_key)
|
||||
|
||||
elif args.jwt:
|
||||
from gql.transport.appsync_auth import AppSyncJWTAuthentication
|
||||
|
||||
auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt)
|
||||
|
||||
else:
|
||||
from gql.transport.appsync_auth import AppSyncIAMAuthentication
|
||||
from botocore.exceptions import NoRegionError
|
||||
|
||||
try:
|
||||
auth = AppSyncIAMAuthentication(host=url.host)
|
||||
except NoRegionError:
|
||||
# A warning message has been printed in the console
|
||||
return None
|
||||
|
||||
transport_args["auth"] = auth
|
||||
|
||||
if transport_name == "appsync_http":
|
||||
from gql.transport.aiohttp import AIOHTTPTransport
|
||||
|
||||
return AIOHTTPTransport(url=args.server, **transport_args)
|
||||
|
||||
else:
|
||||
from gql.transport.appsync_websockets import AppSyncWebsocketsTransport
|
||||
|
||||
try:
|
||||
return AppSyncWebsocketsTransport(url=args.server, **transport_args)
|
||||
except Exception:
|
||||
# This is for the NoCredentialsError but we cannot import it here
|
||||
return None
|
||||
|
||||
|
||||
def get_introspection_args(args: Namespace) -> Dict:
|
||||
"""Get the introspection args depending on the schema_download argument"""
|
||||
|
||||
# Parse the headers argument
|
||||
introspection_args = {}
|
||||
|
||||
possible_args = [
|
||||
"descriptions",
|
||||
"specified_by_url",
|
||||
"directive_is_repeatable",
|
||||
"schema_description",
|
||||
"input_value_deprecation",
|
||||
]
|
||||
|
||||
if args.schema_download is not None:
|
||||
for arg in args.schema_download:
|
||||
|
||||
try:
|
||||
# Split only the first colon (throw a ValueError if no colon is present)
|
||||
arg_key, arg_value = arg.split(":", 1)
|
||||
|
||||
if arg_key not in possible_args:
|
||||
raise ValueError(f"Invalid schema_download: {args.schema_download}")
|
||||
|
||||
arg_value = arg_value.lower()
|
||||
if arg_value not in ["true", "false"]:
|
||||
raise ValueError(f"Invalid schema_download: {args.schema_download}")
|
||||
|
||||
introspection_args[arg_key] = arg_value == "true"
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid schema_download: {args.schema_download}")
|
||||
|
||||
return introspection_args
|
||||
|
||||
|
||||
async def main(args: Namespace) -> int:
|
||||
"""Main entrypoint of the gql-cli script
|
||||
|
||||
:param args: The parsed command line arguments
|
||||
:return: The script exit code (0 = ok, 1 = error)
|
||||
"""
|
||||
|
||||
# Set requested log level
|
||||
if args.loglevel is not None:
|
||||
logging.basicConfig(level=args.loglevel)
|
||||
|
||||
try:
|
||||
# Instantiate transport from command line arguments
|
||||
transport = get_transport(args)
|
||||
|
||||
if transport is None:
|
||||
return 1
|
||||
|
||||
# Get extra execute parameters from command line arguments
|
||||
# (variables, operation_name)
|
||||
execute_args = get_execute_args(args)
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# By default, the exit_code is 0 (everything is ok)
|
||||
exit_code = 0
|
||||
|
||||
# Connect to the backend and provide a session
|
||||
async with Client(
|
||||
transport=transport,
|
||||
fetch_schema_from_transport=args.print_schema,
|
||||
introspection_args=get_introspection_args(args),
|
||||
execute_timeout=args.execute_timeout,
|
||||
) as session:
|
||||
|
||||
if args.print_schema:
|
||||
schema_str = print_schema(session.client.schema)
|
||||
print(schema_str)
|
||||
|
||||
return exit_code
|
||||
|
||||
while True:
|
||||
|
||||
# Read multiple lines from input and trim whitespaces
|
||||
# Will read until EOF character is received (Ctrl-D)
|
||||
query_str = sys.stdin.read().strip()
|
||||
|
||||
# Exit if query is empty
|
||||
if len(query_str) == 0:
|
||||
break
|
||||
|
||||
# Parse query, continue on error
|
||||
try:
|
||||
query = gql(query_str)
|
||||
except GraphQLError as e:
|
||||
print(e, file=sys.stderr)
|
||||
exit_code = 1
|
||||
continue
|
||||
|
||||
# Execute or Subscribe the query depending on transport
|
||||
try:
|
||||
try:
|
||||
async for result in session.subscribe(query, **execute_args):
|
||||
print(json.dumps(result))
|
||||
except KeyboardInterrupt: # pragma: no cover
|
||||
pass
|
||||
except NotImplementedError:
|
||||
result = await session.execute(query, **execute_args)
|
||||
print(json.dumps(result))
|
||||
except (GraphQLError, TransportQueryError) as e:
|
||||
print(e, file=sys.stderr)
|
||||
exit_code = 1
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
def gql_cli() -> None:
|
||||
"""Synchronously invoke ``main`` with the parsed command line arguments.
|
||||
|
||||
Formerly ``scripts/gql-cli``, now registered as an ``entry_point``
|
||||
"""
|
||||
# Get arguments from command line
|
||||
parser = get_parser(with_examples=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Create a new asyncio event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Create a gql-cli task with the supplied arguments
|
||||
main_task = asyncio.ensure_future(main(args), loop=loop)
|
||||
|
||||
# Add signal handlers to close gql-cli cleanly on Control-C
|
||||
for signal_name in ["SIGINT", "SIGTERM", "CTRL_C_EVENT", "CTRL_BREAK_EVENT"]:
|
||||
signal = getattr(signal_module, signal_name, None)
|
||||
|
||||
if signal is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
loop.add_signal_handler(signal, main_task.cancel)
|
||||
except NotImplementedError: # pragma: no cover
|
||||
# not all signals supported on all platforms
|
||||
pass
|
||||
|
||||
# Run the asyncio loop to execute the task
|
||||
exit_code = 0
|
||||
try:
|
||||
exit_code = loop.run_until_complete(main_task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Return with the correct exit code
|
||||
sys.exit(exit_code)
|
||||
except KeyboardInterrupt: # pragma: no cover
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from graphql import DocumentNode, Source, parse
|
||||
|
||||
|
||||
def gql(request_string: str | Source) -> DocumentNode:
|
||||
"""Given a string containing a GraphQL request, parse it into a Document.
|
||||
|
||||
:param request_string: the GraphQL request as a String
|
||||
:type request_string: str | Source
|
||||
:return: a Document which can be later executed or subscribed by a
|
||||
:class:`Client <gql.client.Client>`, by an
|
||||
:class:`async session <gql.client.AsyncClientSession>` or by a
|
||||
:class:`sync session <gql.client.SyncClientSession>`
|
||||
|
||||
:raises GraphQLError: if a syntax error is encountered.
|
||||
"""
|
||||
if isinstance(request_string, Source):
|
||||
source = request_string
|
||||
elif isinstance(request_string, str):
|
||||
source = Source(request_string, "GraphQL request")
|
||||
else:
|
||||
raise TypeError("Request must be passed as a string or Source object.")
|
||||
return parse(source)
|
||||
@@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from graphql import DocumentNode, GraphQLSchema
|
||||
|
||||
from .utilities import serialize_variable_values
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphQLRequest:
|
||||
"""GraphQL Request to be executed."""
|
||||
|
||||
document: DocumentNode
|
||||
"""GraphQL query as AST Node object."""
|
||||
|
||||
variable_values: Optional[Dict[str, Any]] = None
|
||||
"""Dictionary of input parameters (Default: None)."""
|
||||
|
||||
operation_name: Optional[str] = None
|
||||
"""
|
||||
Name of the operation that shall be executed.
|
||||
Only required in multi-operation documents (Default: None).
|
||||
"""
|
||||
|
||||
def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
|
||||
assert self.variable_values
|
||||
|
||||
return GraphQLRequest(
|
||||
document=self.document,
|
||||
variable_values=serialize_variable_values(
|
||||
schema=schema,
|
||||
document=self.document,
|
||||
variable_values=self.variable_values,
|
||||
operation_name=self.operation_name,
|
||||
),
|
||||
operation_name=self.operation_name,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# Marker file for PEP 561. The gql package uses inline types.
|
||||
@@ -0,0 +1,4 @@
|
||||
from .async_transport import AsyncTransport
|
||||
from .transport import Transport
|
||||
|
||||
__all__ = ["AsyncTransport", "Transport"]
|
||||
@@ -0,0 +1,386 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from ssl import SSLContext
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client_exceptions import ClientResponseError
|
||||
from aiohttp.client_reqrep import Fingerprint
|
||||
from aiohttp.helpers import BasicAuth
|
||||
from aiohttp.typedefs import LooseCookies, LooseHeaders
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
from multidict import CIMultiDictProxy
|
||||
|
||||
from ..utils import extract_files
|
||||
from .appsync_auth import AppSyncAuthentication
|
||||
from .async_transport import AsyncTransport
|
||||
from .exceptions import (
|
||||
TransportAlreadyConnected,
|
||||
TransportClosed,
|
||||
TransportProtocolError,
|
||||
TransportServerError,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIOHTTPTransport(AsyncTransport):
|
||||
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
|
||||
on remote servers with an HTTP connection.
|
||||
|
||||
This transport use the aiohttp library with asyncio.
|
||||
"""
|
||||
|
||||
file_classes: Tuple[Type[Any], ...] = (
|
||||
io.IOBase,
|
||||
aiohttp.StreamReader,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[LooseHeaders] = None,
|
||||
cookies: Optional[LooseCookies] = None,
|
||||
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
|
||||
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
|
||||
timeout: Optional[int] = None,
|
||||
ssl_close_timeout: Optional[Union[int, float]] = 10,
|
||||
json_serialize: Callable = json.dumps,
|
||||
client_session_args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the transport with the given aiohttp parameters.
|
||||
|
||||
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
|
||||
:param headers: Dict of HTTP Headers.
|
||||
:param cookies: Dict of HTTP cookies.
|
||||
:param auth: BasicAuth object to enable Basic HTTP auth if needed
|
||||
Or Appsync Authentication class
|
||||
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
|
||||
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
|
||||
to close properly
|
||||
:param json_serialize: Json serializer callable.
|
||||
By default json.dumps() function
|
||||
:param client_session_args: Dict of extra args passed to
|
||||
`aiohttp.ClientSession`_
|
||||
|
||||
.. _aiohttp.ClientSession:
|
||||
https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession
|
||||
"""
|
||||
self.url: str = url
|
||||
self.headers: Optional[LooseHeaders] = headers
|
||||
self.cookies: Optional[LooseCookies] = cookies
|
||||
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
|
||||
|
||||
if ssl == "ssl_warning":
|
||||
ssl = False
|
||||
if str(url).startswith("https"):
|
||||
warnings.warn(
|
||||
"WARNING: By default, AIOHTTPTransport does not verify"
|
||||
" ssl certificates. This will be fixed in the next major version."
|
||||
" You can set ssl=True to force the ssl certificate verification"
|
||||
" or ssl=False to disable this warning"
|
||||
)
|
||||
|
||||
self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
|
||||
Union[SSLContext, bool, Fingerprint], ssl
|
||||
)
|
||||
self.timeout: Optional[int] = timeout
|
||||
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
|
||||
self.client_session_args = client_session_args
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.response_headers: Optional[CIMultiDictProxy[str]]
|
||||
self.json_serialize: Callable = json_serialize
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Coroutine which will create an aiohttp ClientSession() as self.session.
|
||||
|
||||
Don't call this coroutine directly on the transport, instead use
|
||||
:code:`async with` on the client and this coroutine will be executed
|
||||
to create the session.
|
||||
|
||||
Should be cleaned with a call to the close coroutine.
|
||||
"""
|
||||
|
||||
if self.session is None:
|
||||
|
||||
client_session_args: Dict[str, Any] = {
|
||||
"cookies": self.cookies,
|
||||
"headers": self.headers,
|
||||
"auth": None
|
||||
if isinstance(self.auth, AppSyncAuthentication)
|
||||
else self.auth,
|
||||
"json_serialize": self.json_serialize,
|
||||
}
|
||||
|
||||
if self.timeout is not None:
|
||||
client_session_args["timeout"] = aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
)
|
||||
|
||||
# Adding custom parameters passed from init
|
||||
if self.client_session_args:
|
||||
client_session_args.update(self.client_session_args) # type: ignore
|
||||
|
||||
log.debug("Connecting transport")
|
||||
|
||||
self.session = aiohttp.ClientSession(**client_session_args)
|
||||
|
||||
else:
|
||||
raise TransportAlreadyConnected("Transport is already connected")
|
||||
|
||||
@staticmethod
|
||||
def create_aiohttp_closed_event(session) -> asyncio.Event:
|
||||
"""Work around aiohttp issue that doesn't properly close transports on exit.
|
||||
|
||||
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
|
||||
|
||||
Returns:
|
||||
An event that will be set once all transports have been properly closed.
|
||||
"""
|
||||
|
||||
ssl_transports = 0
|
||||
all_is_lost = asyncio.Event()
|
||||
|
||||
def connection_lost(exc, orig_lost):
|
||||
nonlocal ssl_transports
|
||||
|
||||
try:
|
||||
orig_lost(exc)
|
||||
finally:
|
||||
ssl_transports -= 1
|
||||
if ssl_transports == 0:
|
||||
all_is_lost.set()
|
||||
|
||||
def eof_received(orig_eof_received):
|
||||
try:
|
||||
orig_eof_received()
|
||||
except AttributeError: # pragma: no cover
|
||||
# It may happen that eof_received() is called after
|
||||
# _app_protocol and _transport are set to None.
|
||||
pass
|
||||
|
||||
for conn in session.connector._conns.values():
|
||||
for handler, _ in conn:
|
||||
proto = getattr(handler.transport, "_ssl_protocol", None)
|
||||
if proto is None:
|
||||
continue
|
||||
|
||||
ssl_transports += 1
|
||||
orig_lost = proto.connection_lost
|
||||
orig_eof_received = proto.eof_received
|
||||
|
||||
proto.connection_lost = functools.partial(
|
||||
connection_lost, orig_lost=orig_lost
|
||||
)
|
||||
proto.eof_received = functools.partial(
|
||||
eof_received, orig_eof_received=orig_eof_received
|
||||
)
|
||||
|
||||
if ssl_transports == 0:
|
||||
all_is_lost.set()
|
||||
|
||||
return all_is_lost
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Coroutine which will close the aiohttp session.
|
||||
|
||||
Don't call this coroutine directly on the transport, instead use
|
||||
:code:`async with` on the client and this coroutine will be executed
|
||||
when you exit the async context manager.
|
||||
"""
|
||||
if self.session is not None:
|
||||
|
||||
log.debug("Closing transport")
|
||||
|
||||
if (
|
||||
self.client_session_args
|
||||
and self.client_session_args.get("connector_owner") is False
|
||||
):
|
||||
|
||||
log.debug("connector_owner is False -> not closing connector")
|
||||
|
||||
else:
|
||||
closed_event = self.create_aiohttp_closed_event(self.session)
|
||||
await self.session.close()
|
||||
try:
|
||||
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
self.session = None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
upload_files: bool = False,
|
||||
) -> ExecutionResult:
|
||||
"""Execute the provided document AST against the configured remote server
|
||||
using the current session.
|
||||
This uses the aiohttp library to perform a HTTP POST request asynchronously
|
||||
to the remote server.
|
||||
|
||||
Don't call this coroutine directly on the transport, instead use
|
||||
:code:`execute` on a client or a session.
|
||||
|
||||
:param document: the parsed GraphQL request
|
||||
:param variable_values: An optional Dict of variable values
|
||||
:param operation_name: An optional Operation name for the request
|
||||
:param extra_args: additional arguments to send to the aiohttp post method
|
||||
:param upload_files: Set to True if you want to put files in the variable values
|
||||
:returns: an ExecutionResult object.
|
||||
"""
|
||||
|
||||
query_str = print_ast(document)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"query": query_str,
|
||||
}
|
||||
|
||||
if operation_name:
|
||||
payload["operationName"] = operation_name
|
||||
|
||||
if upload_files:
|
||||
|
||||
# If the upload_files flag is set, then we need variable_values
|
||||
assert variable_values is not None
|
||||
|
||||
# If we upload files, we will extract the files present in the
|
||||
# variable_values dict and replace them by null values
|
||||
nulled_variable_values, files = extract_files(
|
||||
variables=variable_values,
|
||||
file_classes=self.file_classes,
|
||||
)
|
||||
|
||||
# Save the nulled variable values in the payload
|
||||
payload["variables"] = nulled_variable_values
|
||||
|
||||
# Prepare aiohttp to send multipart-encoded data
|
||||
data = aiohttp.FormData()
|
||||
|
||||
# Generate the file map
|
||||
# path is nested in a list because the spec allows multiple pointers
|
||||
# to the same file. But we don't support that.
|
||||
# Will generate something like {"0": ["variables.file"]}
|
||||
file_map = {str(i): [path] for i, path in enumerate(files)}
|
||||
|
||||
# Enumerate the file streams
|
||||
# Will generate something like {'0': <_io.BufferedReader ...>}
|
||||
file_streams = {str(i): files[path] for i, path in enumerate(files)}
|
||||
|
||||
# Add the payload to the operations field
|
||||
operations_str = self.json_serialize(payload)
|
||||
log.debug("operations %s", operations_str)
|
||||
data.add_field(
|
||||
"operations", operations_str, content_type="application/json"
|
||||
)
|
||||
|
||||
# Add the file map field
|
||||
file_map_str = self.json_serialize(file_map)
|
||||
log.debug("file_map %s", file_map_str)
|
||||
data.add_field("map", file_map_str, content_type="application/json")
|
||||
|
||||
# Add the extracted files as remaining fields
|
||||
for k, f in file_streams.items():
|
||||
name = getattr(f, "name", k)
|
||||
content_type = getattr(f, "content_type", None)
|
||||
|
||||
data.add_field(k, f, filename=name, content_type=content_type)
|
||||
|
||||
post_args: Dict[str, Any] = {"data": data}
|
||||
|
||||
else:
|
||||
if variable_values:
|
||||
payload["variables"] = variable_values
|
||||
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
log.info(">>> %s", self.json_serialize(payload))
|
||||
|
||||
post_args = {"json": payload}
|
||||
|
||||
# Pass post_args to aiohttp post method
|
||||
if extra_args:
|
||||
post_args.update(extra_args)
|
||||
|
||||
# Add headers for AppSync if requested
|
||||
if isinstance(self.auth, AppSyncAuthentication):
|
||||
post_args["headers"] = self.auth.get_headers(
|
||||
self.json_serialize(payload),
|
||||
{"content-type": "application/json"},
|
||||
)
|
||||
|
||||
if self.session is None:
|
||||
raise TransportClosed("Transport is not connected")
|
||||
|
||||
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
|
||||
|
||||
# Saving latest response headers in the transport
|
||||
self.response_headers = resp.headers
|
||||
|
||||
async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
|
||||
# We raise a TransportServerError if the status code is 400 or higher
|
||||
# We raise a TransportProtocolError in the other cases
|
||||
|
||||
try:
|
||||
# Raise a ClientResponseError if response status is 400 or higher
|
||||
resp.raise_for_status()
|
||||
except ClientResponseError as e:
|
||||
raise TransportServerError(str(e), e.status) from e
|
||||
|
||||
result_text = await resp.text()
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: "
|
||||
f"{reason}: "
|
||||
f"{result_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await resp.json(content_type=None)
|
||||
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
result_text = await resp.text()
|
||||
log.info("<<< %s", result_text)
|
||||
|
||||
except Exception:
|
||||
await raise_response_error(resp, "Not a JSON answer")
|
||||
|
||||
if result is None:
|
||||
await raise_response_error(resp, "Not a JSON answer")
|
||||
|
||||
if "errors" not in result and "data" not in result:
|
||||
await raise_response_error(resp, 'No "data" or "errors" keys in answer')
|
||||
|
||||
return ExecutionResult(
|
||||
errors=result.get("errors"),
|
||||
data=result.get("data"),
|
||||
extensions=result.get("extensions"),
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
"""Subscribe is not supported on HTTP.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
raise NotImplementedError(" The HTTP transport does not support subscriptions")
|
||||
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
try:
|
||||
import botocore
|
||||
except ImportError: # pragma: no cover
|
||||
# botocore is only needed for the IAM AppSync authentication method
|
||||
pass
|
||||
|
||||
log = logging.getLogger("gql.transport.appsync")
|
||||
|
||||
|
||||
class AppSyncAuthentication(ABC):
|
||||
"""AWS authentication abstract base class
|
||||
|
||||
All AWS authentication class should have a
|
||||
:meth:`get_headers <gql.transport.appsync.AppSyncAuthentication.get_headers>`
|
||||
method which defines the headers used in the authentication process."""
|
||||
|
||||
def get_auth_url(self, url: str) -> str:
|
||||
"""
|
||||
:return: a url with base64 encoded headers used to establish
|
||||
a websocket connection to the appsync-realtime-api.
|
||||
"""
|
||||
headers = self.get_headers()
|
||||
|
||||
encoded_headers = b64encode(
|
||||
json.dumps(headers, separators=(",", ":")).encode()
|
||||
).decode()
|
||||
|
||||
url_base = url.replace("https://", "wss://").replace(
|
||||
"appsync-api", "appsync-realtime-api"
|
||||
)
|
||||
|
||||
return f"{url_base}?header={encoded_headers}&payload=e30="
|
||||
|
||||
@abstractmethod
|
||||
def get_headers(
|
||||
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class AppSyncApiKeyAuthentication(AppSyncAuthentication):
|
||||
"""AWS authentication class using an API key"""
|
||||
|
||||
def __init__(self, host: str, api_key: str) -> None:
|
||||
"""
|
||||
:param host: the host, something like:
|
||||
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
|
||||
:param api_key: the API key
|
||||
"""
|
||||
self._host = host.replace("appsync-realtime-api", "appsync-api")
|
||||
self.api_key = api_key
|
||||
|
||||
def get_headers(
|
||||
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
return {"host": self._host, "x-api-key": self.api_key}
|
||||
|
||||
|
||||
class AppSyncJWTAuthentication(AppSyncAuthentication):
|
||||
"""AWS authentication class using a JWT access token.
|
||||
|
||||
It can be used either for:
|
||||
- Amazon Cognito user pools
|
||||
- OpenID Connect (OIDC)
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, jwt: str) -> None:
|
||||
"""
|
||||
:param host: the host, something like:
|
||||
XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com
|
||||
:param jwt: the JWT Access Token
|
||||
"""
|
||||
self._host = host.replace("appsync-realtime-api", "appsync-api")
|
||||
self.jwt = jwt
|
||||
|
||||
def get_headers(
|
||||
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
return {"host": self._host, "Authorization": self.jwt}
|
||||
|
||||
|
||||
class AppSyncIAMAuthentication(AppSyncAuthentication):
|
||||
"""AWS authentication class using IAM.
|
||||
|
||||
.. note::
|
||||
There is no need for you to use this class directly, you could instead
|
||||
intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport`
|
||||
without an auth argument.
|
||||
|
||||
During initialization, this class will use botocore to attempt to
|
||||
find your IAM credentials, either from environment variables or
|
||||
from your AWS credentials file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
region_name: Optional[str] = None,
|
||||
signer: Optional["botocore.auth.BaseSigner"] = None,
|
||||
request_creator: Optional[
|
||||
Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"]
|
||||
] = None,
|
||||
credentials: Optional["botocore.credentials.Credentials"] = None,
|
||||
session: Optional["botocore.session.Session"] = None,
|
||||
) -> None:
|
||||
"""Initialize itself, saving the found credentials used
|
||||
to sign the headers later.
|
||||
|
||||
if no credentials are found, then a NoCredentialsError is raised.
|
||||
"""
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import create_request_object
|
||||
from botocore.session import get_session
|
||||
|
||||
self._host = host.replace("appsync-realtime-api", "appsync-api")
|
||||
self._session = session if session else get_session()
|
||||
self._credentials = (
|
||||
credentials if credentials else self._session.get_credentials()
|
||||
)
|
||||
self._service_name = "appsync"
|
||||
self._region_name = region_name or self._detect_region_name()
|
||||
self._signer = (
|
||||
signer
|
||||
if signer
|
||||
else SigV4Auth(self._credentials, self._service_name, self._region_name)
|
||||
)
|
||||
self._request_creator = (
|
||||
request_creator if request_creator else create_request_object
|
||||
)
|
||||
|
||||
def _detect_region_name(self):
|
||||
"""Try to detect the correct region_name.
|
||||
|
||||
First try to extract the region_name from the host.
|
||||
|
||||
If that does not work, then try to get the region_name from
|
||||
the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION
|
||||
environment variable.
|
||||
|
||||
If no region_name was found, then raise a NoRegionError exception."""
|
||||
|
||||
from botocore.exceptions import NoRegionError
|
||||
|
||||
# Regular expression from botocore.utils.validate_region
|
||||
m = re.search(
|
||||
r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(?<!-))\.", self._host
|
||||
)
|
||||
|
||||
if m:
|
||||
region_name = m.groups()[0]
|
||||
log.debug(f"Region name extracted from host: {region_name}")
|
||||
|
||||
else:
|
||||
log.debug("Region name not found in host, trying default region name")
|
||||
region_name = self._session._resolve_region_name(
|
||||
None, self._session.get_default_client_config()
|
||||
)
|
||||
|
||||
if region_name is None:
|
||||
log.warning(
|
||||
"Region name not found. "
|
||||
"It was not possible to detect your region either from the host "
|
||||
"or from your default AWS configuration."
|
||||
)
|
||||
raise NoRegionError
|
||||
|
||||
return region_name
|
||||
|
||||
def get_headers(
|
||||
self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
|
||||
# Default headers for a websocket connection
|
||||
headers = headers or {
|
||||
"accept": "application/json, text/javascript",
|
||||
"content-encoding": "amz-1.0",
|
||||
"content-type": "application/json; charset=UTF-8",
|
||||
}
|
||||
|
||||
request: "botocore.awsrequest.AWSRequest" = self._request_creator(
|
||||
{
|
||||
"method": "POST",
|
||||
"url": f"https://{self._host}/graphql{'' if data else '/connect'}",
|
||||
"headers": headers,
|
||||
"context": {},
|
||||
"body": data or "{}",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
self._signer.add_auth(request)
|
||||
except NoCredentialsError:
|
||||
log.warning(
|
||||
"Credentials not found for the IAM auth. "
|
||||
"Do you have default AWS credentials configured?",
|
||||
)
|
||||
raise
|
||||
|
||||
headers = dict(request.headers)
|
||||
|
||||
headers["host"] = self._host
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
headers_log = []
|
||||
headers_log.append("\n\nSigned headers:")
|
||||
for key, value in headers.items():
|
||||
headers_log.append(f" {key}: {value}")
|
||||
headers_log.append("\n")
|
||||
log.debug("\n".join(headers_log))
|
||||
|
||||
return headers
|
||||
@@ -0,0 +1,214 @@
|
||||
import json
|
||||
import logging
|
||||
from ssl import SSLContext
|
||||
from typing import Any, Dict, Optional, Tuple, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
|
||||
from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication
|
||||
from .exceptions import TransportProtocolError, TransportServerError
|
||||
from .websockets import WebsocketsTransport, WebsocketsTransportBase
|
||||
|
||||
log = logging.getLogger("gql.transport.appsync")
|
||||
|
||||
try:
|
||||
import botocore
|
||||
except ImportError: # pragma: no cover
|
||||
# botocore is only needed for the IAM AppSync authentication method
|
||||
pass
|
||||
|
||||
|
||||
class AppSyncWebsocketsTransport(WebsocketsTransportBase):
|
||||
""":ref:`Async Transport <async_transports>` used to execute GraphQL subscription on
|
||||
AWS appsync realtime endpoint.
|
||||
|
||||
This transport uses asyncio and the websockets library in order to send requests
|
||||
on a websocket connection.
|
||||
"""
|
||||
|
||||
auth: Optional[AppSyncAuthentication]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
auth: Optional[AppSyncAuthentication] = None,
|
||||
session: Optional["botocore.session.Session"] = None,
|
||||
ssl: Union[SSLContext, bool] = False,
|
||||
connect_timeout: int = 10,
|
||||
close_timeout: int = 10,
|
||||
ack_timeout: int = 10,
|
||||
keep_alive_timeout: Optional[Union[int, float]] = None,
|
||||
connect_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Initialize the transport with the given parameters.
|
||||
|
||||
:param url: The GraphQL endpoint URL. Example:
|
||||
https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql
|
||||
:param auth: Optional AWS authentication class which will provide the
|
||||
necessary headers to be correctly authenticated. If this
|
||||
argument is not provided, then we will try to authenticate
|
||||
using IAM.
|
||||
:param ssl: ssl_context of the connection.
|
||||
:param connect_timeout: Timeout in seconds for the establishment
|
||||
of the websocket connection. If None is provided this will wait forever.
|
||||
:param close_timeout: Timeout in seconds for the close. If None is provided
|
||||
this will wait forever.
|
||||
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
|
||||
from the server. If None is provided this will wait forever.
|
||||
:param keep_alive_timeout: Optional Timeout in seconds to receive
|
||||
a sign of liveness from the server.
|
||||
:param connect_args: Other parameters forwarded to websockets.connect
|
||||
"""
|
||||
|
||||
if not auth:
|
||||
|
||||
# Extract host from url
|
||||
host = str(urlparse(url).netloc)
|
||||
|
||||
# May raise NoRegionError or NoCredentialsError or ImportError
|
||||
auth = AppSyncIAMAuthentication(host=host, session=session)
|
||||
|
||||
self.auth = auth
|
||||
|
||||
url = self.auth.get_auth_url(url)
|
||||
|
||||
super().__init__(
|
||||
url,
|
||||
ssl=ssl,
|
||||
connect_timeout=connect_timeout,
|
||||
close_timeout=close_timeout,
|
||||
ack_timeout=ack_timeout,
|
||||
keep_alive_timeout=keep_alive_timeout,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
# Using the same 'graphql-ws' protocol as the apollo protocol
|
||||
self.supported_subprotocols = [
|
||||
WebsocketsTransport.APOLLO_SUBPROTOCOL,
|
||||
]
|
||||
self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL
|
||||
|
||||
def _parse_answer(
|
||||
self, answer: str
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server.
|
||||
|
||||
Difference between apollo protocol and aws protocol:
|
||||
|
||||
- aws protocol can return an error without an id
|
||||
- aws protocol will send start_ack messages
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type:
|
||||
- 'connection_ack',
|
||||
- 'connection_error',
|
||||
- 'start_ack',
|
||||
- 'ka',
|
||||
- 'data',
|
||||
- 'error',
|
||||
- 'complete'
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
"""
|
||||
|
||||
answer_type: str = ""
|
||||
|
||||
try:
|
||||
json_answer = json.loads(answer)
|
||||
|
||||
answer_type = str(json_answer.get("type"))
|
||||
|
||||
if answer_type == "start_ack":
|
||||
return ("start_ack", None, None)
|
||||
|
||||
elif answer_type == "error" and "id" not in json_answer:
|
||||
error_payload = json_answer.get("payload")
|
||||
raise TransportServerError(f"Server error: '{error_payload!r}'")
|
||||
|
||||
else:
|
||||
|
||||
return WebsocketsTransport._parse_answer_apollo(
|
||||
cast(WebsocketsTransport, self), json_answer
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {answer}"
|
||||
)
|
||||
|
||||
async def _send_query(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> int:
|
||||
|
||||
query_id = self.next_query_id
|
||||
|
||||
self.next_query_id += 1
|
||||
|
||||
data: Dict = {"query": print_ast(document)}
|
||||
|
||||
if variable_values:
|
||||
data["variables"] = variable_values
|
||||
|
||||
if operation_name:
|
||||
data["operationName"] = operation_name
|
||||
|
||||
serialized_data = json.dumps(data, separators=(",", ":"))
|
||||
|
||||
payload = {"data": serialized_data}
|
||||
|
||||
message: Dict = {
|
||||
"id": str(query_id),
|
||||
"type": "start",
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
assert self.auth is not None
|
||||
|
||||
message["payload"]["extensions"] = {
|
||||
"authorization": self.auth.get_headers(serialized_data)
|
||||
}
|
||||
|
||||
await self._send(
|
||||
json.dumps(
|
||||
message,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
)
|
||||
|
||||
return query_id
|
||||
|
||||
subscribe = WebsocketsTransportBase.subscribe
|
||||
"""Send a subscription query and receive the results using
|
||||
a python async generator.
|
||||
|
||||
Only subscriptions are supported, queries and mutations are forbidden.
|
||||
|
||||
The results are sent as an ExecutionResult object.
|
||||
"""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> ExecutionResult:
|
||||
"""This method is not available.
|
||||
|
||||
Only subscriptions are supported on the AWS realtime endpoint.
|
||||
|
||||
:raise: AssertionError"""
|
||||
raise AssertionError(
|
||||
"execute method is not allowed for AppSyncWebsocketsTransport "
|
||||
"because only subscriptions are allowed on the realtime endpoint."
|
||||
)
|
||||
|
||||
_initialize = WebsocketsTransport._initialize
|
||||
_stop_listener = WebsocketsTransport._send_stop_message # type: ignore
|
||||
_send_init_message_and_wait_ack = (
|
||||
WebsocketsTransport._send_init_message_and_wait_ack
|
||||
)
|
||||
_wait_ack = WebsocketsTransport._wait_ack
|
||||
@@ -0,0 +1,50 @@
|
||||
import abc
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult
|
||||
|
||||
|
||||
class AsyncTransport(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def connect(self):
|
||||
"""Coroutine used to create a connection to the specified address"""
|
||||
raise NotImplementedError(
|
||||
"Any AsyncTransport subclass must implement connect method"
|
||||
) # pragma: no cover
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self):
|
||||
"""Coroutine used to Close an established connection"""
|
||||
raise NotImplementedError(
|
||||
"Any AsyncTransport subclass must implement close method"
|
||||
) # pragma: no cover
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> ExecutionResult:
|
||||
"""Execute the provided document AST for either a remote or local GraphQL
|
||||
Schema."""
|
||||
raise NotImplementedError(
|
||||
"Any AsyncTransport subclass must implement execute method"
|
||||
) # pragma: no cover
|
||||
|
||||
@abc.abstractmethod
|
||||
def subscribe(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
"""Send a query and receive the results using an async generator
|
||||
|
||||
The query can be a graphql query, mutation or subscription
|
||||
|
||||
The results are sent as an ExecutionResult object
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Any AsyncTransport subclass must implement subscribe method"
|
||||
) # pragma: no cover
|
||||
@@ -0,0 +1,69 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
class TransportError(Exception):
|
||||
"""Base class for all the Transport exceptions"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TransportProtocolError(TransportError):
|
||||
"""Transport protocol error.
|
||||
|
||||
The answer received from the server does not correspond to the transport protocol.
|
||||
"""
|
||||
|
||||
|
||||
class TransportServerError(TransportError):
|
||||
"""The server returned a global error.
|
||||
|
||||
This exception will close the transport connection.
|
||||
"""
|
||||
|
||||
code: Optional[int]
|
||||
|
||||
def __init__(self, message: str, code: Optional[int] = None):
|
||||
super(TransportServerError, self).__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
class TransportQueryError(TransportError):
|
||||
"""The server returned an error for a specific query.
|
||||
|
||||
This exception should not close the transport connection.
|
||||
"""
|
||||
|
||||
query_id: Optional[int]
|
||||
errors: Optional[List[Any]]
|
||||
data: Optional[Any]
|
||||
extensions: Optional[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
msg: str,
|
||||
query_id: Optional[int] = None,
|
||||
errors: Optional[List[Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
extensions: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(msg)
|
||||
self.query_id = query_id
|
||||
self.errors = errors
|
||||
self.data = data
|
||||
self.extensions = extensions
|
||||
|
||||
|
||||
class TransportClosed(TransportError):
|
||||
"""Transport is already closed.
|
||||
|
||||
This exception is generated when the client is trying to use the transport
|
||||
while the transport was previously closed.
|
||||
"""
|
||||
|
||||
|
||||
class TransportAlreadyConnected(TransportError):
|
||||
"""Transport is already connected.
|
||||
|
||||
Exception generated when the client is trying to connect to the transport
|
||||
while the transport is already connected.
|
||||
"""
|
||||
@@ -0,0 +1,311 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
|
||||
from ..utils import extract_files
|
||||
from . import AsyncTransport, Transport
|
||||
from .exceptions import (
|
||||
TransportAlreadyConnected,
|
||||
TransportClosed,
|
||||
TransportProtocolError,
|
||||
TransportServerError,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HTTPXTransport:
|
||||
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
|
||||
|
||||
response_headers: Optional[httpx.Headers] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Union[str, httpx.URL],
|
||||
json_serialize: Callable = json.dumps,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the transport with the given httpx parameters.
|
||||
|
||||
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
|
||||
:param json_serialize: Json serializer callable.
|
||||
By default json.dumps() function.
|
||||
:param kwargs: Extra args passed to the `httpx` client.
|
||||
"""
|
||||
self.url = url
|
||||
self.json_serialize = json_serialize
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
upload_files: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
query_str = print_ast(document)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"query": query_str,
|
||||
}
|
||||
|
||||
if operation_name:
|
||||
payload["operationName"] = operation_name
|
||||
|
||||
if upload_files:
|
||||
# If the upload_files flag is set, then we need variable_values
|
||||
assert variable_values is not None
|
||||
|
||||
post_args = self._prepare_file_uploads(variable_values, payload)
|
||||
else:
|
||||
if variable_values:
|
||||
payload["variables"] = variable_values
|
||||
|
||||
post_args = {"json": payload}
|
||||
|
||||
# Log the payload
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(">>> %s", self.json_serialize(payload))
|
||||
|
||||
# Pass post_args to httpx post method
|
||||
if extra_args:
|
||||
post_args.update(extra_args)
|
||||
|
||||
return post_args
|
||||
|
||||
def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]:
|
||||
# If we upload files, we will extract the files present in the
|
||||
# variable_values dict and replace them by null values
|
||||
nulled_variable_values, files = extract_files(
|
||||
variables=variable_values,
|
||||
file_classes=self.file_classes,
|
||||
)
|
||||
|
||||
# Save the nulled variable values in the payload
|
||||
payload["variables"] = nulled_variable_values
|
||||
|
||||
# Prepare to send multipart-encoded data
|
||||
data: Dict[str, Any] = {}
|
||||
file_map: Dict[str, List[str]] = {}
|
||||
file_streams: Dict[str, Tuple[str, ...]] = {}
|
||||
|
||||
for i, (path, f) in enumerate(files.items()):
|
||||
key = str(i)
|
||||
|
||||
# Generate the file map
|
||||
# path is nested in a list because the spec allows multiple pointers
|
||||
# to the same file. But we don't support that.
|
||||
# Will generate something like {"0": ["variables.file"]}
|
||||
file_map[key] = [path]
|
||||
|
||||
# Generate the file streams
|
||||
# Will generate something like
|
||||
# {"0": ("variables.file", <_io.BufferedReader ...>)}
|
||||
name = cast(str, getattr(f, "name", key))
|
||||
content_type = getattr(f, "content_type", None)
|
||||
|
||||
if content_type is None:
|
||||
file_streams[key] = (name, f)
|
||||
else:
|
||||
file_streams[key] = (name, f, content_type)
|
||||
|
||||
# Add the payload to the operations field
|
||||
operations_str = self.json_serialize(payload)
|
||||
log.debug("operations %s", operations_str)
|
||||
data["operations"] = operations_str
|
||||
|
||||
# Add the file map field
|
||||
file_map_str = self.json_serialize(file_map)
|
||||
log.debug("file_map %s", file_map_str)
|
||||
data["map"] = file_map_str
|
||||
|
||||
return {"data": data, "files": file_streams}
|
||||
|
||||
def _prepare_result(self, response: httpx.Response) -> ExecutionResult:
|
||||
# Save latest response headers in transport
|
||||
self.response_headers = response.headers
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("<<< %s", response.text)
|
||||
|
||||
try:
|
||||
result: Dict[str, Any] = response.json()
|
||||
|
||||
except Exception:
|
||||
self._raise_response_error(response, "Not a JSON answer")
|
||||
|
||||
if "errors" not in result and "data" not in result:
|
||||
self._raise_response_error(response, 'No "data" or "errors" keys in answer')
|
||||
|
||||
return ExecutionResult(
|
||||
errors=result.get("errors"),
|
||||
data=result.get("data"),
|
||||
extensions=result.get("extensions"),
|
||||
)
|
||||
|
||||
def _raise_response_error(self, response: httpx.Response, reason: str):
|
||||
# We raise a TransportServerError if the status code is 400 or higher
|
||||
# We raise a TransportProtocolError in the other cases
|
||||
|
||||
try:
|
||||
# Raise a HTTPError if response status is 400 or higher
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TransportServerError(str(e), e.response.status_code) from e
|
||||
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
|
||||
)
|
||||
|
||||
|
||||
class HTTPXTransport(Transport, _HTTPXTransport):
|
||||
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
|
||||
on remote servers.
|
||||
|
||||
The transport uses the httpx library to send HTTP POST requests.
|
||||
"""
|
||||
|
||||
client: Optional[httpx.Client] = None
|
||||
|
||||
def connect(self):
|
||||
if self.client:
|
||||
raise TransportAlreadyConnected("Transport is already connected")
|
||||
|
||||
log.debug("Connecting transport")
|
||||
|
||||
self.client = httpx.Client(**self.kwargs)
|
||||
|
||||
def execute( # type: ignore
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
upload_files: bool = False,
|
||||
) -> ExecutionResult:
|
||||
"""Execute GraphQL query.
|
||||
|
||||
Execute the provided document AST against the configured remote server. This
|
||||
uses the httpx library to perform a HTTP POST request to the remote server.
|
||||
|
||||
:param document: GraphQL query as AST Node object.
|
||||
:param variable_values: Dictionary of input parameters (Default: None).
|
||||
:param operation_name: Name of the operation that shall be executed.
|
||||
Only required in multi-operation documents (Default: None).
|
||||
:param extra_args: additional arguments to send to the httpx post method
|
||||
:param upload_files: Set to True if you want to put files in the variable values
|
||||
:return: The result of execution.
|
||||
`data` is the result of executing the query, `errors` is null
|
||||
if no errors occurred, and is a non-empty array if an error occurred.
|
||||
"""
|
||||
if not self.client:
|
||||
raise TransportClosed("Transport is not connected")
|
||||
|
||||
post_args = self._prepare_request(
|
||||
document,
|
||||
variable_values,
|
||||
operation_name,
|
||||
extra_args,
|
||||
upload_files,
|
||||
)
|
||||
|
||||
response = self.client.post(self.url, **post_args)
|
||||
|
||||
return self._prepare_result(response)
|
||||
|
||||
def close(self):
|
||||
"""Closing the transport by closing the inner session"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
|
||||
|
||||
class HTTPXAsyncTransport(AsyncTransport, _HTTPXTransport):
|
||||
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries
|
||||
on remote servers.
|
||||
|
||||
The transport uses the httpx library with anyio.
|
||||
"""
|
||||
|
||||
client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def connect(self):
|
||||
if self.client:
|
||||
raise TransportAlreadyConnected("Transport is already connected")
|
||||
|
||||
log.debug("Connecting transport")
|
||||
|
||||
self.client = httpx.AsyncClient(**self.kwargs)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
upload_files: bool = False,
|
||||
) -> ExecutionResult:
|
||||
"""Execute GraphQL query.
|
||||
|
||||
Execute the provided document AST against the configured remote server. This
|
||||
uses the httpx library to perform a HTTP POST request asynchronously to the
|
||||
remote server.
|
||||
|
||||
:param document: GraphQL query as AST Node object.
|
||||
:param variable_values: Dictionary of input parameters (Default: None).
|
||||
:param operation_name: Name of the operation that shall be executed.
|
||||
Only required in multi-operation documents (Default: None).
|
||||
:param extra_args: additional arguments to send to the httpx post method
|
||||
:param upload_files: Set to True if you want to put files in the variable values
|
||||
:return: The result of execution.
|
||||
`data` is the result of executing the query, `errors` is null
|
||||
if no errors occurred, and is a non-empty array if an error occurred.
|
||||
"""
|
||||
if not self.client:
|
||||
raise TransportClosed("Transport is not connected")
|
||||
|
||||
post_args = self._prepare_request(
|
||||
document,
|
||||
variable_values,
|
||||
operation_name,
|
||||
extra_args,
|
||||
upload_files,
|
||||
)
|
||||
|
||||
response = await self.client.post(self.url, **post_args)
|
||||
|
||||
return self._prepare_result(response)
|
||||
|
||||
async def close(self):
|
||||
"""Closing the transport by closing the inner session"""
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
"""Subscribe is not supported on HTTP.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
raise NotImplementedError("The HTTP transport does not support subscriptions")
|
||||
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
from inspect import isawaitable
|
||||
from typing import AsyncGenerator, Awaitable, cast
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe
|
||||
|
||||
from gql.transport import AsyncTransport
|
||||
|
||||
|
||||
class LocalSchemaTransport(AsyncTransport):
|
||||
"""A transport for executing GraphQL queries against a local schema."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: GraphQLSchema,
|
||||
):
|
||||
"""Initialize the transport with the given local schema.
|
||||
|
||||
:param schema: Local schema as GraphQLSchema object
|
||||
"""
|
||||
self.schema = schema
|
||||
|
||||
async def connect(self):
|
||||
"""No connection needed on local transport"""
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""No close needed on local transport"""
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ExecutionResult:
|
||||
"""Execute the provided document AST for on a local GraphQL Schema."""
|
||||
|
||||
result_or_awaitable = execute(self.schema, document, *args, **kwargs)
|
||||
|
||||
execution_result: ExecutionResult
|
||||
|
||||
if isawaitable(result_or_awaitable):
|
||||
result_or_awaitable = cast(Awaitable[ExecutionResult], result_or_awaitable)
|
||||
execution_result = await result_or_awaitable
|
||||
else:
|
||||
result_or_awaitable = cast(ExecutionResult, result_or_awaitable)
|
||||
execution_result = result_or_awaitable
|
||||
|
||||
return execution_result
|
||||
|
||||
@staticmethod
|
||||
async def _await_if_necessary(obj):
|
||||
"""This method is necessary to work with
|
||||
graphql-core versions < and >= 3.3.0a3"""
|
||||
return await obj if asyncio.iscoroutine(obj) else obj
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
"""Send a subscription and receive the results using an async generator
|
||||
|
||||
The results are sent as an ExecutionResult object
|
||||
"""
|
||||
|
||||
subscribe_result = await self._await_if_necessary(
|
||||
subscribe(self.schema, document, *args, **kwargs)
|
||||
)
|
||||
|
||||
if isinstance(subscribe_result, ExecutionResult):
|
||||
yield subscribe_result
|
||||
|
||||
else:
|
||||
async for result in subscribe_result:
|
||||
yield result
|
||||
+414
@@ -0,0 +1,414 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from .exceptions import (
|
||||
TransportProtocolError,
|
||||
TransportQueryError,
|
||||
TransportServerError,
|
||||
)
|
||||
from .websockets_base import WebsocketsTransportBase
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Subscription:
|
||||
"""Records listener_id and unsubscribe query_id for a subscription."""
|
||||
|
||||
def __init__(self, query_id: int) -> None:
|
||||
self.listener_id: int = query_id
|
||||
self.unsubscribe_id: Optional[int] = None
|
||||
|
||||
|
||||
class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
|
||||
"""The PhoenixChannelWebsocketsTransport is an async transport
|
||||
which allows you to execute queries and subscriptions against an `Absinthe`_
|
||||
backend using the `Phoenix`_ framework `channels`_.
|
||||
|
||||
.. _Absinthe: http://absinthe-graphql.org
|
||||
.. _Phoenix: https://www.phoenixframework.org
|
||||
.. _channels: https://hexdocs.pm/phoenix/Phoenix.Channel.html#content
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_name: str = "__absinthe__:control",
|
||||
heartbeat_interval: float = 30,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize the transport with the given parameters.
|
||||
|
||||
:param channel_name: Channel on the server this transport will join.
|
||||
The default for Absinthe servers is "__absinthe__:control"
|
||||
:param heartbeat_interval: Interval in second between each heartbeat messages
|
||||
sent by the client
|
||||
"""
|
||||
self.channel_name: str = channel_name
|
||||
self.heartbeat_interval: float = heartbeat_interval
|
||||
self.heartbeat_task: Optional[asyncio.Future] = None
|
||||
self.subscriptions: Dict[str, Subscription] = {}
|
||||
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Join the specified channel and wait for the connection ACK.
|
||||
|
||||
If the answer is not a connection_ack message, we will return an Exception.
|
||||
"""
|
||||
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
init_message = json.dumps(
|
||||
{
|
||||
"topic": self.channel_name,
|
||||
"event": "phx_join",
|
||||
"payload": {},
|
||||
"ref": query_id,
|
||||
}
|
||||
)
|
||||
|
||||
await self._send(init_message)
|
||||
|
||||
# Wait for the connection_ack message or raise a TimeoutError
|
||||
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)
|
||||
|
||||
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
|
||||
|
||||
if answer_type != "reply":
|
||||
raise TransportProtocolError(
|
||||
"Websocket server did not return a connection ack"
|
||||
)
|
||||
|
||||
async def heartbeat_coro():
|
||||
while True:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
try:
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
await self._send(
|
||||
json.dumps(
|
||||
{
|
||||
"topic": "phoenix",
|
||||
"event": "heartbeat",
|
||||
"payload": {},
|
||||
"ref": query_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
except ConnectionClosed: # pragma: no cover
|
||||
return
|
||||
|
||||
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())
|
||||
|
||||
async def _send_stop_message(self, query_id: int) -> None:
|
||||
"""Send an 'unsubscribe' message to the Phoenix Channel referencing
|
||||
the listener's query_id, saving the query_id of the message.
|
||||
|
||||
The server should afterwards return a 'phx_reply' message with
|
||||
the same query_id and subscription_id of the 'unsubscribe' request.
|
||||
"""
|
||||
subscription_id = self._find_existing_subscription(query_id)
|
||||
|
||||
unsubscribe_query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
# Save the ref so it can be matched in the reply
|
||||
self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id
|
||||
unsubscribe_message = json.dumps(
|
||||
{
|
||||
"topic": self.channel_name,
|
||||
"event": "unsubscribe",
|
||||
"payload": {"subscriptionId": subscription_id},
|
||||
"ref": unsubscribe_query_id,
|
||||
}
|
||||
)
|
||||
|
||||
await self._send(unsubscribe_message)
|
||||
|
||||
async def _stop_listener(self, query_id: int) -> None:
|
||||
await self._send_stop_message(query_id)
|
||||
|
||||
async def _send_connection_terminate_message(self) -> None:
|
||||
"""Send a phx_leave message to disconnect from the provided channel."""
|
||||
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
connection_terminate_message = json.dumps(
|
||||
{
|
||||
"topic": self.channel_name,
|
||||
"event": "phx_leave",
|
||||
"payload": {},
|
||||
"ref": query_id,
|
||||
}
|
||||
)
|
||||
|
||||
await self._send(connection_terminate_message)
|
||||
|
||||
async def _connection_terminate(self):
|
||||
await self._send_connection_terminate_message()
|
||||
|
||||
async def _send_query(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Send a query to the provided websocket connection.
|
||||
|
||||
We use an incremented id to reference the query.
|
||||
|
||||
Returns the used id for this query.
|
||||
"""
|
||||
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
query_str = json.dumps(
|
||||
{
|
||||
"topic": self.channel_name,
|
||||
"event": "doc",
|
||||
"payload": {
|
||||
"query": print_ast(document),
|
||||
"variables": variable_values or {},
|
||||
},
|
||||
"ref": query_id,
|
||||
}
|
||||
)
|
||||
|
||||
await self._send(query_str)
|
||||
|
||||
return query_id
|
||||
|
||||
def _parse_answer(
|
||||
self, answer: str
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type (between:
|
||||
'data', 'reply', 'complete', 'close')
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
"""
|
||||
|
||||
event: str = ""
|
||||
answer_id: Optional[int] = None
|
||||
answer_type: str = ""
|
||||
execution_result: Optional[ExecutionResult] = None
|
||||
subscription_id: Optional[str] = None
|
||||
|
||||
def _get_value(d: Any, key: str, label: str) -> Any:
|
||||
if not isinstance(d, dict):
|
||||
raise ValueError(f"{label} is not a dict")
|
||||
|
||||
return d.get(key)
|
||||
|
||||
def _required_value(d: Any, key: str, label: str) -> Any:
|
||||
value = _get_value(d, key, label)
|
||||
if value is None:
|
||||
raise ValueError(f"null {key} in {label}")
|
||||
|
||||
return value
|
||||
|
||||
def _required_subscription_id(
|
||||
d: Any, label: str, must_exist: bool = False, must_not_exist=False
|
||||
) -> str:
|
||||
subscription_id = str(_required_value(d, "subscriptionId", label))
|
||||
if must_exist and (subscription_id not in self.subscriptions):
|
||||
raise ValueError("unregistered subscriptionId")
|
||||
if must_not_exist and (subscription_id in self.subscriptions):
|
||||
raise ValueError("previously registered subscriptionId")
|
||||
|
||||
return subscription_id
|
||||
|
||||
def _validate_data_response(d: Any, label: str) -> dict:
|
||||
"""Make sure query, mutation or subscription answer conforms.
|
||||
The GraphQL spec says only three keys are permitted.
|
||||
"""
|
||||
if not isinstance(d, dict):
|
||||
raise ValueError(f"{label} is not a dict")
|
||||
|
||||
keys = set(d.keys())
|
||||
invalid = keys - {"data", "errors", "extensions"}
|
||||
if len(invalid) > 0:
|
||||
raise ValueError(
|
||||
f"{label} contains invalid items: " + ", ".join(invalid)
|
||||
)
|
||||
return d
|
||||
|
||||
try:
|
||||
json_answer = json.loads(answer)
|
||||
|
||||
event = str(_required_value(json_answer, "event", "answer"))
|
||||
|
||||
if event == "subscription:data":
|
||||
payload = _required_value(json_answer, "payload", "answer")
|
||||
|
||||
subscription_id = _required_subscription_id(
|
||||
payload, "payload", must_exist=True
|
||||
)
|
||||
|
||||
result = _validate_data_response(payload.get("result"), "result")
|
||||
|
||||
answer_type = "data"
|
||||
|
||||
subscription = self.subscriptions[subscription_id]
|
||||
answer_id = subscription.listener_id
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
data=result.get("data"),
|
||||
errors=result.get("errors"),
|
||||
extensions=result.get("extensions"),
|
||||
)
|
||||
|
||||
elif event == "phx_reply":
|
||||
|
||||
# Will generate a ValueError if 'ref' is not there
|
||||
# or if it is not an integer
|
||||
answer_id = int(_required_value(json_answer, "ref", "answer"))
|
||||
|
||||
payload = _required_value(json_answer, "payload", "answer")
|
||||
|
||||
status = _get_value(payload, "status", "payload")
|
||||
|
||||
if status == "ok":
|
||||
answer_type = "reply"
|
||||
|
||||
if answer_id in self.listeners:
|
||||
response = _required_value(payload, "response", "payload")
|
||||
|
||||
if isinstance(response, dict) and "subscriptionId" in response:
|
||||
|
||||
# Subscription answer
|
||||
subscription_id = _required_subscription_id(
|
||||
response, "response", must_not_exist=True
|
||||
)
|
||||
|
||||
self.subscriptions[subscription_id] = Subscription(
|
||||
answer_id
|
||||
)
|
||||
|
||||
else:
|
||||
# Query or mutation answer
|
||||
# GraphQL spec says only three keys are permitted
|
||||
response = _validate_data_response(response, "response")
|
||||
|
||||
answer_type = "data"
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
data=response.get("data"),
|
||||
errors=response.get("errors"),
|
||||
extensions=response.get("extensions"),
|
||||
)
|
||||
else:
|
||||
(
|
||||
registered_subscription_id,
|
||||
listener_id,
|
||||
) = self._find_subscription(answer_id)
|
||||
if registered_subscription_id is not None:
|
||||
# Unsubscription answer
|
||||
response = _required_value(payload, "response", "payload")
|
||||
subscription_id = _required_subscription_id(
|
||||
response, "response"
|
||||
)
|
||||
|
||||
if subscription_id != registered_subscription_id:
|
||||
raise ValueError("subscription id does not match")
|
||||
|
||||
answer_type = "complete"
|
||||
|
||||
answer_id = listener_id
|
||||
|
||||
elif status == "error":
|
||||
response = payload.get("response")
|
||||
|
||||
if isinstance(response, dict):
|
||||
if "errors" in response:
|
||||
raise TransportQueryError(
|
||||
str(response.get("errors")), query_id=answer_id
|
||||
)
|
||||
elif "reason" in response:
|
||||
raise TransportQueryError(
|
||||
str(response.get("reason")), query_id=answer_id
|
||||
)
|
||||
raise TransportQueryError("reply error", query_id=answer_id)
|
||||
|
||||
elif status == "timeout":
|
||||
raise TransportQueryError("reply timeout", query_id=answer_id)
|
||||
else:
|
||||
# missing or unrecognized status, just continue
|
||||
pass
|
||||
|
||||
elif event == "phx_error":
|
||||
# Sent if the channel has crashed
|
||||
# answer_id will be the "join_ref" for the channel
|
||||
# answer_id = int(json_answer.get("ref"))
|
||||
raise TransportServerError("Server error")
|
||||
elif event == "phx_close":
|
||||
answer_type = "close"
|
||||
else:
|
||||
raise ValueError("unrecognized event")
|
||||
|
||||
except ValueError as e:
|
||||
log.error(f"Error parsing answer '{answer}': {e!r}")
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {e!s}"
|
||||
) from e
|
||||
|
||||
return answer_type, answer_id, execution_result
|
||||
|
||||
async def _handle_answer(
|
||||
self,
|
||||
answer_type: str,
|
||||
answer_id: Optional[int],
|
||||
execution_result: Optional[ExecutionResult],
|
||||
) -> None:
|
||||
if answer_type == "close":
|
||||
await self.close()
|
||||
else:
|
||||
await super()._handle_answer(answer_type, answer_id, execution_result)
|
||||
|
||||
def _remove_listener(self, query_id: int) -> None:
|
||||
"""If the listener was a subscription, remove that information."""
|
||||
try:
|
||||
subscription_id = self._find_existing_subscription(query_id)
|
||||
del self.subscriptions[subscription_id]
|
||||
except Exception:
|
||||
pass
|
||||
super()._remove_listener(query_id)
|
||||
|
||||
def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]:
|
||||
"""Perform a reverse lookup to find the subscription id matching
|
||||
a listener's query_id.
|
||||
"""
|
||||
for subscription_id, subscription in self.subscriptions.items():
|
||||
if query_id == subscription.listener_id:
|
||||
return subscription_id, query_id
|
||||
if query_id == subscription.unsubscribe_id:
|
||||
return subscription_id, subscription.listener_id
|
||||
return None, query_id
|
||||
|
||||
def _find_existing_subscription(self, query_id: int) -> str:
|
||||
"""Perform a reverse lookup to find the subscription id matching
|
||||
a listener's query_id.
|
||||
"""
|
||||
subscription_id, _listener_id = self._find_subscription(query_id)
|
||||
|
||||
if subscription_id is None:
|
||||
raise TransportProtocolError(
|
||||
f"No subscription registered for listener {query_id}"
|
||||
)
|
||||
return subscription_id
|
||||
|
||||
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
|
||||
if self.heartbeat_task is not None:
|
||||
self.heartbeat_task.cancel()
|
||||
|
||||
await super()._close_coro(e, clean_close)
|
||||
@@ -0,0 +1,426 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from requests.auth import AuthBase
|
||||
from requests.cookies import RequestsCookieJar
|
||||
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
||||
|
||||
from gql.transport import Transport
|
||||
|
||||
from ..graphql_request import GraphQLRequest
|
||||
from ..utils import extract_files
|
||||
from .exceptions import (
|
||||
TransportAlreadyConnected,
|
||||
TransportClosed,
|
||||
TransportProtocolError,
|
||||
TransportServerError,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestsHTTPTransport(Transport):
|
||||
""":ref:`Sync Transport <sync_transports>` used to execute GraphQL queries
|
||||
on remote servers.
|
||||
|
||||
The transport uses the requests library to send HTTP POST requests.
|
||||
"""
|
||||
|
||||
file_classes: Tuple[Type[Any], ...] = (io.IOBase,)
|
||||
_default_retry_codes = (429, 500, 502, 503, 504)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
cookies: Optional[Union[Dict[str, Any], RequestsCookieJar]] = None,
|
||||
auth: Optional[AuthBase] = None,
|
||||
use_json: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
retries: int = 0,
|
||||
method: str = "POST",
|
||||
retry_backoff_factor: float = 0.1,
|
||||
retry_status_forcelist: Collection[int] = _default_retry_codes,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the transport with the given request parameters.
|
||||
|
||||
:param url: The GraphQL server URL.
|
||||
:param headers: Dictionary of HTTP Headers to send with the :class:`Request`
|
||||
(Default: None).
|
||||
:param cookies: Dict or CookieJar object to send with the :class:`Request`
|
||||
(Default: None).
|
||||
:param auth: Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth
|
||||
(Default: None).
|
||||
:param use_json: Send request body as JSON instead of form-urlencoded
|
||||
(Default: True).
|
||||
:param timeout: Specifies a default timeout for requests (Default: None).
|
||||
:param verify: Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a string, in which case it must be a path
|
||||
to a CA bundle to use. (Default: True).
|
||||
:param retries: Pre-setup of the requests' Session for performing retries
|
||||
:param method: HTTP method used for requests. (Default: POST).
|
||||
:param retry_backoff_factor: A backoff factor to apply between attempts after
|
||||
the second try. urllib3 will sleep for:
|
||||
{backoff factor} * (2 ** ({number of previous retries}))
|
||||
:param retry_status_forcelist: A set of integer HTTP status codes that we
|
||||
should force a retry on. A retry is initiated if the request method is
|
||||
in allowed_methods and the response status code is in status_forcelist.
|
||||
(Default: [429, 500, 502, 503, 504])
|
||||
:param kwargs: Optional arguments that ``request`` takes.
|
||||
These can be seen at the `requests`_ source code or the official `docs`_
|
||||
|
||||
.. _requests: https://github.com/psf/requests/blob/master/requests/api.py
|
||||
.. _docs: https://requests.readthedocs.io/en/master/
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self.cookies = cookies
|
||||
self.auth = auth
|
||||
self.use_json = use_json
|
||||
self.default_timeout = timeout
|
||||
self.verify = verify
|
||||
self.retries = retries
|
||||
self.method = method
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.retry_status_forcelist = retry_status_forcelist
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.session = None
|
||||
|
||||
self.response_headers = None
|
||||
|
||||
def connect(self):
|
||||
if self.session is None:
|
||||
# Creating a session that can later be re-use to configure custom mechanisms
|
||||
self.session = requests.Session()
|
||||
|
||||
# If we specified some retries, we provide a predefined retry-logic
|
||||
if self.retries > 0:
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=Retry(
|
||||
total=self.retries,
|
||||
backoff_factor=self.retry_backoff_factor,
|
||||
status_forcelist=self.retry_status_forcelist,
|
||||
allowed_methods=None,
|
||||
)
|
||||
)
|
||||
for prefix in "http://", "https://":
|
||||
self.session.mount(prefix, adapter)
|
||||
else:
|
||||
raise TransportAlreadyConnected("Transport is already connected")
|
||||
|
||||
def execute( # type: ignore
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
upload_files: bool = False,
|
||||
) -> ExecutionResult:
|
||||
"""Execute GraphQL query.
|
||||
|
||||
Execute the provided document AST against the configured remote server. This
|
||||
uses the requests library to perform a HTTP POST request to the remote server.
|
||||
|
||||
:param document: GraphQL query as AST Node object.
|
||||
:param variable_values: Dictionary of input parameters (Default: None).
|
||||
:param operation_name: Name of the operation that shall be executed.
|
||||
Only required in multi-operation documents (Default: None).
|
||||
:param timeout: Specifies a default timeout for requests (Default: None).
|
||||
:param extra_args: additional arguments to send to the requests post method
|
||||
:param upload_files: Set to True if you want to put files in the variable values
|
||||
:return: The result of execution.
|
||||
`data` is the result of executing the query, `errors` is null
|
||||
if no errors occurred, and is a non-empty array if an error occurred.
|
||||
"""
|
||||
|
||||
if not self.session:
|
||||
raise TransportClosed("Transport is not connected")
|
||||
|
||||
query_str = print_ast(document)
|
||||
payload: Dict[str, Any] = {"query": query_str}
|
||||
|
||||
if operation_name:
|
||||
payload["operationName"] = operation_name
|
||||
|
||||
post_args = {
|
||||
"headers": self.headers,
|
||||
"auth": self.auth,
|
||||
"cookies": self.cookies,
|
||||
"timeout": timeout or self.default_timeout,
|
||||
"verify": self.verify,
|
||||
}
|
||||
|
||||
if upload_files:
|
||||
# If the upload_files flag is set, then we need variable_values
|
||||
assert variable_values is not None
|
||||
|
||||
# If we upload files, we will extract the files present in the
|
||||
# variable_values dict and replace them by null values
|
||||
nulled_variable_values, files = extract_files(
|
||||
variables=variable_values,
|
||||
file_classes=self.file_classes,
|
||||
)
|
||||
|
||||
# Save the nulled variable values in the payload
|
||||
payload["variables"] = nulled_variable_values
|
||||
|
||||
# Add the payload to the operations field
|
||||
operations_str = json.dumps(payload)
|
||||
log.debug("operations %s", operations_str)
|
||||
|
||||
# Generate the file map
|
||||
# path is nested in a list because the spec allows multiple pointers
|
||||
# to the same file. But we don't support that.
|
||||
# Will generate something like {"0": ["variables.file"]}
|
||||
file_map = {str(i): [path] for i, path in enumerate(files)}
|
||||
|
||||
# Enumerate the file streams
|
||||
# Will generate something like {'0': <_io.BufferedReader ...>}
|
||||
file_streams = {str(i): files[path] for i, path in enumerate(files)}
|
||||
|
||||
# Add the file map field
|
||||
file_map_str = json.dumps(file_map)
|
||||
log.debug("file_map %s", file_map_str)
|
||||
|
||||
fields = {"operations": operations_str, "map": file_map_str}
|
||||
|
||||
# Add the extracted files as remaining fields
|
||||
for k, f in file_streams.items():
|
||||
name = getattr(f, "name", k)
|
||||
content_type = getattr(f, "content_type", None)
|
||||
|
||||
if content_type is None:
|
||||
fields[k] = (name, f)
|
||||
else:
|
||||
fields[k] = (name, f, content_type)
|
||||
|
||||
# Prepare requests http to send multipart-encoded data
|
||||
data = MultipartEncoder(fields=fields)
|
||||
|
||||
post_args["data"] = data
|
||||
|
||||
if post_args["headers"] is None:
|
||||
post_args["headers"] = {}
|
||||
else:
|
||||
post_args["headers"] = {**post_args["headers"]}
|
||||
|
||||
post_args["headers"]["Content-Type"] = data.content_type
|
||||
|
||||
else:
|
||||
if variable_values:
|
||||
payload["variables"] = variable_values
|
||||
|
||||
data_key = "json" if self.use_json else "data"
|
||||
post_args[data_key] = payload
|
||||
|
||||
# Log the payload
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
log.info(">>> %s", json.dumps(payload))
|
||||
|
||||
# Pass kwargs to requests post method
|
||||
post_args.update(self.kwargs)
|
||||
|
||||
# Pass post_args to requests post method
|
||||
if extra_args:
|
||||
post_args.update(extra_args)
|
||||
|
||||
# Using the created session to perform requests
|
||||
response = self.session.request(
|
||||
self.method, self.url, **post_args # type: ignore
|
||||
)
|
||||
self.response_headers = response.headers
|
||||
|
||||
def raise_response_error(resp: requests.Response, reason: str):
|
||||
# We raise a TransportServerError if the status code is 400 or higher
|
||||
# We raise a TransportProtocolError in the other cases
|
||||
|
||||
try:
|
||||
# Raise a HTTPError if response status is 400 or higher
|
||||
resp.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
raise TransportServerError(str(e), e.response.status_code) from e
|
||||
|
||||
result_text = resp.text
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: "
|
||||
f"{reason}: "
|
||||
f"{result_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
log.info("<<< %s", response.text)
|
||||
|
||||
except Exception:
|
||||
raise_response_error(response, "Not a JSON answer")
|
||||
|
||||
if "errors" not in result and "data" not in result:
|
||||
raise_response_error(response, 'No "data" or "errors" keys in answer')
|
||||
|
||||
return ExecutionResult(
|
||||
errors=result.get("errors"),
|
||||
data=result.get("data"),
|
||||
extensions=result.get("extensions"),
|
||||
)
|
||||
|
||||
def execute_batch( # type: ignore
|
||||
self,
|
||||
reqs: List[GraphQLRequest],
|
||||
timeout: Optional[int] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute multiple GraphQL requests in a batch.
|
||||
|
||||
Execute the provided requests against the configured remote server. This
|
||||
uses the requests library to perform a HTTP POST request to the remote server.
|
||||
|
||||
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
|
||||
:param timeout: Specifies a default timeout for requests (Default: None).
|
||||
:param extra_args: additional arguments to send to the requests post method
|
||||
:return: A list of results of execution.
|
||||
For every result `data` is the result of executing the query,
|
||||
`errors` is null if no errors occurred, and is a non-empty array
|
||||
if an error occurred.
|
||||
"""
|
||||
|
||||
if not self.session:
|
||||
raise TransportClosed("Transport is not connected")
|
||||
|
||||
# Using the created session to perform requests
|
||||
response = self.session.request(
|
||||
self.method,
|
||||
self.url,
|
||||
**self._build_batch_post_args(reqs, timeout, extra_args),
|
||||
)
|
||||
self.response_headers = response.headers
|
||||
|
||||
answers = self._extract_response(response)
|
||||
|
||||
self._validate_answer_is_a_list(answers)
|
||||
self._validate_num_of_answers_same_as_requests(reqs, answers)
|
||||
self._validate_every_answer_is_a_dict(answers)
|
||||
self._validate_data_and_errors_keys_in_answers(answers)
|
||||
|
||||
return [self._answer_to_execution_result(answer) for answer in answers]
|
||||
|
||||
def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult:
|
||||
return ExecutionResult(
|
||||
errors=result.get("errors"),
|
||||
data=result.get("data"),
|
||||
extensions=result.get("extensions"),
|
||||
)
|
||||
|
||||
def _validate_answer_is_a_list(self, results: Any) -> None:
|
||||
if not isinstance(results, list):
|
||||
self._raise_invalid_result(
|
||||
str(results),
|
||||
"Answer is not a list",
|
||||
)
|
||||
|
||||
def _validate_data_and_errors_keys_in_answers(
|
||||
self, results: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
for result in results:
|
||||
if "errors" not in result and "data" not in result:
|
||||
self._raise_invalid_result(
|
||||
str(results),
|
||||
'No "data" or "errors" keys in answer',
|
||||
)
|
||||
|
||||
def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None:
|
||||
for result in results:
|
||||
if not isinstance(result, dict):
|
||||
self._raise_invalid_result(str(results), "Not every answer is dict")
|
||||
|
||||
def _validate_num_of_answers_same_as_requests(
|
||||
self,
|
||||
reqs: List[GraphQLRequest],
|
||||
results: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
if len(reqs) != len(results):
|
||||
self._raise_invalid_result(
|
||||
str(results),
|
||||
"Invalid answer length",
|
||||
)
|
||||
|
||||
def _raise_invalid_result(self, result_text: str, reason: str) -> None:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a valid GraphQL result: "
|
||||
f"{reason}: "
|
||||
f"{result_text}"
|
||||
)
|
||||
|
||||
def _extract_response(self, response: requests.Response) -> Any:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
log.info("<<< %s", response.text)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
raise TransportServerError(str(e), e.response.status_code) from e
|
||||
|
||||
except Exception:
|
||||
self._raise_invalid_result(str(response.text), "Not a JSON answer")
|
||||
|
||||
return result
|
||||
|
||||
def _build_batch_post_args(
|
||||
self,
|
||||
reqs: List[GraphQLRequest],
|
||||
timeout: Optional[int] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
post_args: Dict[str, Any] = {
|
||||
"headers": self.headers,
|
||||
"auth": self.auth,
|
||||
"cookies": self.cookies,
|
||||
"timeout": timeout or self.default_timeout,
|
||||
"verify": self.verify,
|
||||
}
|
||||
|
||||
data_key = "json" if self.use_json else "data"
|
||||
post_args[data_key] = [self._build_data(req) for req in reqs]
|
||||
|
||||
# Log the payload
|
||||
if log.isEnabledFor(logging.INFO):
|
||||
log.info(">>> %s", json.dumps(post_args[data_key]))
|
||||
|
||||
# Pass kwargs to requests post method
|
||||
post_args.update(self.kwargs)
|
||||
|
||||
# Pass post_args to requests post method
|
||||
if extra_args:
|
||||
post_args.update(extra_args)
|
||||
|
||||
return post_args
|
||||
|
||||
def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]:
|
||||
query_str = print_ast(req.document)
|
||||
payload: Dict[str, Any] = {"query": query_str}
|
||||
|
||||
if req.operation_name:
|
||||
payload["operationName"] = req.operation_name
|
||||
|
||||
if req.variable_values:
|
||||
payload["variables"] = req.variable_values
|
||||
|
||||
return payload
|
||||
|
||||
def close(self):
|
||||
"""Closing the transport by closing the inner session"""
|
||||
if self.session:
|
||||
self.session.close()
|
||||
self.session = None
|
||||
@@ -0,0 +1,51 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult
|
||||
|
||||
from ..graphql_request import GraphQLRequest
|
||||
|
||||
|
||||
class Transport(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
|
||||
"""Execute GraphQL query.
|
||||
|
||||
Execute the provided document AST for either a remote or local GraphQL Schema.
|
||||
|
||||
:param document: GraphQL query as AST Node or Document object.
|
||||
:return: ExecutionResult
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Any Transport subclass must implement execute method"
|
||||
) # pragma: no cover
|
||||
|
||||
def execute_batch(
|
||||
self,
|
||||
reqs: List[GraphQLRequest],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute multiple GraphQL requests in a batch.
|
||||
|
||||
Execute the provided requests for either a remote or local GraphQL Schema.
|
||||
|
||||
:param reqs: GraphQL requests as a list of GraphQLRequest objects.
|
||||
:return: a list of ExecutionResult objects
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"This Transport has not implemented the execute_batch method"
|
||||
) # pragma: no cover
|
||||
|
||||
def connect(self):
|
||||
"""Establish a session with the transport."""
|
||||
pass # pragma: no cover
|
||||
|
||||
def close(self):
|
||||
"""Close the transport
|
||||
|
||||
This method doesn't have to be implemented unless the transport would benefit
|
||||
from it. This is currently used by the RequestsHTTPTransport transport to close
|
||||
the session's connection pool.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
@@ -0,0 +1,514 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from ssl import SSLContext
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
from websockets.datastructures import HeadersLike
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from .exceptions import (
|
||||
TransportProtocolError,
|
||||
TransportQueryError,
|
||||
TransportServerError,
|
||||
)
|
||||
from .websockets_base import WebsocketsTransportBase
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketsTransport(WebsocketsTransportBase):
|
||||
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries on
|
||||
remote servers with websocket connection.
|
||||
|
||||
This transport uses asyncio and the websockets library in order to send requests
|
||||
on a websocket connection.
|
||||
"""
|
||||
|
||||
# This transport supports two subprotocols and will autodetect the
|
||||
# subprotocol supported on the server
|
||||
APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws")
|
||||
GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[HeadersLike] = None,
|
||||
ssl: Union[SSLContext, bool] = False,
|
||||
init_payload: Dict[str, Any] = {},
|
||||
connect_timeout: Optional[Union[int, float]] = 10,
|
||||
close_timeout: Optional[Union[int, float]] = 10,
|
||||
ack_timeout: Optional[Union[int, float]] = 10,
|
||||
keep_alive_timeout: Optional[Union[int, float]] = None,
|
||||
ping_interval: Optional[Union[int, float]] = None,
|
||||
pong_timeout: Optional[Union[int, float]] = None,
|
||||
answer_pings: bool = True,
|
||||
connect_args: Dict[str, Any] = {},
|
||||
subprotocols: Optional[List[Subprotocol]] = None,
|
||||
) -> None:
|
||||
"""Initialize the transport with the given parameters.
|
||||
|
||||
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
|
||||
:param headers: Dict of HTTP Headers.
|
||||
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
|
||||
:param init_payload: Dict of the payload sent in the connection_init message.
|
||||
:param connect_timeout: Timeout in seconds for the establishment
|
||||
of the websocket connection. If None is provided this will wait forever.
|
||||
:param close_timeout: Timeout in seconds for the close. If None is provided
|
||||
this will wait forever.
|
||||
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
|
||||
from the server. If None is provided this will wait forever.
|
||||
:param keep_alive_timeout: Optional Timeout in seconds to receive
|
||||
a sign of liveness from the server.
|
||||
:param ping_interval: Delay in seconds between pings sent by the client to
|
||||
the backend for the graphql-ws protocol. None (by default) means that
|
||||
we don't send pings. Note: there are also pings sent by the underlying
|
||||
websockets protocol. See the
|
||||
:ref:`keepalive documentation <websockets_transport_keepalives>`
|
||||
for more information about this.
|
||||
:param pong_timeout: Delay in seconds to receive a pong from the backend
|
||||
after we sent a ping (only for the graphql-ws protocol).
|
||||
By default equal to half of the ping_interval.
|
||||
:param answer_pings: Whether the client answers the pings from the backend
|
||||
(for the graphql-ws protocol).
|
||||
By default: True
|
||||
:param connect_args: Other parameters forwarded to
|
||||
`websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
|
||||
client.html#opening-a-connection>`_
|
||||
:param subprotocols: list of subprotocols sent to the
|
||||
backend in the 'subprotocols' http header.
|
||||
By default: both apollo and graphql-ws subprotocols.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
url,
|
||||
headers,
|
||||
ssl,
|
||||
init_payload,
|
||||
connect_timeout,
|
||||
close_timeout,
|
||||
ack_timeout,
|
||||
keep_alive_timeout,
|
||||
connect_args,
|
||||
)
|
||||
|
||||
self.ping_interval: Optional[Union[int, float]] = ping_interval
|
||||
self.pong_timeout: Optional[Union[int, float]]
|
||||
self.answer_pings: bool = answer_pings
|
||||
|
||||
if ping_interval is not None:
|
||||
if pong_timeout is None:
|
||||
self.pong_timeout = ping_interval / 2
|
||||
else:
|
||||
self.pong_timeout = pong_timeout
|
||||
|
||||
self.send_ping_task: Optional[asyncio.Future] = None
|
||||
|
||||
self.ping_received: asyncio.Event = asyncio.Event()
|
||||
"""ping_received is an asyncio Event which will fire each time
|
||||
a ping is received with the graphql-ws protocol"""
|
||||
|
||||
self.pong_received: asyncio.Event = asyncio.Event()
|
||||
"""pong_received is an asyncio Event which will fire each time
|
||||
a pong is received with the graphql-ws protocol"""
|
||||
|
||||
if subprotocols is None:
|
||||
self.supported_subprotocols = [
|
||||
self.APOLLO_SUBPROTOCOL,
|
||||
self.GRAPHQLWS_SUBPROTOCOL,
|
||||
]
|
||||
else:
|
||||
self.supported_subprotocols = subprotocols
|
||||
|
||||
async def _wait_ack(self) -> None:
|
||||
"""Wait for the connection_ack message. Keep alive messages are ignored"""
|
||||
|
||||
while True:
|
||||
init_answer = await self._receive()
|
||||
|
||||
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
|
||||
|
||||
if answer_type == "connection_ack":
|
||||
return
|
||||
|
||||
if answer_type != "ka":
|
||||
raise TransportProtocolError(
|
||||
"Websocket server did not return a connection ack"
|
||||
)
|
||||
|
||||
async def _send_init_message_and_wait_ack(self) -> None:
|
||||
"""Send init message to the provided websocket and wait for the connection ACK.
|
||||
|
||||
If the answer is not a connection_ack message, we will return an Exception.
|
||||
"""
|
||||
|
||||
init_message = json.dumps(
|
||||
{"type": "connection_init", "payload": self.init_payload}
|
||||
)
|
||||
|
||||
await self._send(init_message)
|
||||
|
||||
# Wait for the connection_ack message or raise a TimeoutError
|
||||
await asyncio.wait_for(self._wait_ack(), self.ack_timeout)
|
||||
|
||||
async def _initialize(self):
|
||||
await self._send_init_message_and_wait_ack()
|
||||
|
||||
async def send_ping(self, payload: Optional[Any] = None) -> None:
|
||||
"""Send a ping message for the graphql-ws protocol"""
|
||||
|
||||
ping_message = {"type": "ping"}
|
||||
|
||||
if payload is not None:
|
||||
ping_message["payload"] = payload
|
||||
|
||||
await self._send(json.dumps(ping_message))
|
||||
|
||||
async def send_pong(self, payload: Optional[Any] = None) -> None:
|
||||
"""Send a pong message for the graphql-ws protocol"""
|
||||
|
||||
pong_message = {"type": "pong"}
|
||||
|
||||
if payload is not None:
|
||||
pong_message["payload"] = payload
|
||||
|
||||
await self._send(json.dumps(pong_message))
|
||||
|
||||
async def _send_stop_message(self, query_id: int) -> None:
|
||||
"""Send stop message to the provided websocket connection and query_id.
|
||||
|
||||
The server should afterwards return a 'complete' message.
|
||||
"""
|
||||
|
||||
stop_message = json.dumps({"id": str(query_id), "type": "stop"})
|
||||
|
||||
await self._send(stop_message)
|
||||
|
||||
async def _send_complete_message(self, query_id: int) -> None:
|
||||
"""Send a complete message for the provided query_id.
|
||||
|
||||
This is only for the graphql-ws protocol.
|
||||
"""
|
||||
|
||||
complete_message = json.dumps({"id": str(query_id), "type": "complete"})
|
||||
|
||||
await self._send(complete_message)
|
||||
|
||||
async def _stop_listener(self, query_id: int):
|
||||
"""Stop the listener corresponding to the query_id depending on the
|
||||
detected backend protocol.
|
||||
|
||||
For apollo: send a "stop" message
|
||||
(a "complete" message will be sent from the backend)
|
||||
|
||||
For graphql-ws: send a "complete" message and simulate the reception
|
||||
of a "complete" message from the backend
|
||||
"""
|
||||
log.debug(f"stop listener {query_id}")
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
await self._send_complete_message(query_id)
|
||||
await self.listeners[query_id].put(("complete", None))
|
||||
else:
|
||||
await self._send_stop_message(query_id)
|
||||
|
||||
async def _send_connection_terminate_message(self) -> None:
|
||||
"""Send a connection_terminate message to the provided websocket connection.
|
||||
|
||||
This message indicates that the connection will disconnect.
|
||||
"""
|
||||
|
||||
connection_terminate_message = json.dumps({"type": "connection_terminate"})
|
||||
|
||||
await self._send(connection_terminate_message)
|
||||
|
||||
async def _send_query(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Send a query to the provided websocket connection.
|
||||
|
||||
We use an incremented id to reference the query.
|
||||
|
||||
Returns the used id for this query.
|
||||
"""
|
||||
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
payload: Dict[str, Any] = {"query": print_ast(document)}
|
||||
if variable_values:
|
||||
payload["variables"] = variable_values
|
||||
if operation_name:
|
||||
payload["operationName"] = operation_name
|
||||
|
||||
query_type = "start"
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
query_type = "subscribe"
|
||||
|
||||
query_str = json.dumps(
|
||||
{"id": str(query_id), "type": query_type, "payload": payload}
|
||||
)
|
||||
|
||||
await self._send(query_str)
|
||||
|
||||
return query_id
|
||||
|
||||
async def _connection_terminate(self):
|
||||
if self.subprotocol == self.APOLLO_SUBPROTOCOL:
|
||||
await self._send_connection_terminate_message()
|
||||
|
||||
def _parse_answer_graphqlws(
|
||||
self, json_answer: Dict[str, Any]
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server if the server supports the
|
||||
graphql-ws protocol.
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type (between:
|
||||
'connection_ack', 'ping', 'pong', 'data', 'error', 'complete')
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
|
||||
Differences with the apollo websockets protocol (superclass):
|
||||
- the "data" message is now called "next"
|
||||
- the "stop" message is now called "complete"
|
||||
- there is no connection_terminate or connection_error messages
|
||||
- instead of a unidirectional keep-alive (ka) message from server to client,
|
||||
there is now the possibility to send bidirectional ping/pong messages
|
||||
- connection_ack has an optional payload
|
||||
- the 'error' answer type returns a list of errors instead of a single error
|
||||
"""
|
||||
|
||||
answer_type: str = ""
|
||||
answer_id: Optional[int] = None
|
||||
execution_result: Optional[ExecutionResult] = None
|
||||
|
||||
try:
|
||||
answer_type = str(json_answer.get("type"))
|
||||
|
||||
if answer_type in ["next", "error", "complete"]:
|
||||
answer_id = int(str(json_answer.get("id")))
|
||||
|
||||
if answer_type == "next" or answer_type == "error":
|
||||
|
||||
payload = json_answer.get("payload")
|
||||
|
||||
if answer_type == "next":
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("payload is not a dict")
|
||||
|
||||
if "errors" not in payload and "data" not in payload:
|
||||
raise ValueError(
|
||||
"payload does not contain 'data' or 'errors' fields"
|
||||
)
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
errors=payload.get("errors"),
|
||||
data=payload.get("data"),
|
||||
extensions=payload.get("extensions"),
|
||||
)
|
||||
|
||||
# Saving answer_type as 'data' to be understood with superclass
|
||||
answer_type = "data"
|
||||
|
||||
elif answer_type == "error":
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise ValueError("payload is not a list")
|
||||
|
||||
raise TransportQueryError(
|
||||
str(payload[0]), query_id=answer_id, errors=payload
|
||||
)
|
||||
|
||||
elif answer_type in ["ping", "pong", "connection_ack"]:
|
||||
self.payloads[answer_type] = json_answer.get("payload", None)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if self.check_keep_alive_task is not None:
|
||||
self._next_keep_alive_message.set()
|
||||
|
||||
except ValueError as e:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {json_answer}"
|
||||
) from e
|
||||
|
||||
return answer_type, answer_id, execution_result
|
||||
|
||||
def _parse_answer_apollo(
|
||||
self, json_answer: Dict[str, Any]
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server if the server supports the
|
||||
apollo websockets protocol.
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type (between:
|
||||
'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete')
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
"""
|
||||
|
||||
answer_type: str = ""
|
||||
answer_id: Optional[int] = None
|
||||
execution_result: Optional[ExecutionResult] = None
|
||||
|
||||
try:
|
||||
answer_type = str(json_answer.get("type"))
|
||||
|
||||
if answer_type in ["data", "error", "complete"]:
|
||||
answer_id = int(str(json_answer.get("id")))
|
||||
|
||||
if answer_type == "data" or answer_type == "error":
|
||||
|
||||
payload = json_answer.get("payload")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("payload is not a dict")
|
||||
|
||||
if answer_type == "data":
|
||||
|
||||
if "errors" not in payload and "data" not in payload:
|
||||
raise ValueError(
|
||||
"payload does not contain 'data' or 'errors' fields"
|
||||
)
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
errors=payload.get("errors"),
|
||||
data=payload.get("data"),
|
||||
extensions=payload.get("extensions"),
|
||||
)
|
||||
|
||||
elif answer_type == "error":
|
||||
|
||||
raise TransportQueryError(
|
||||
str(payload), query_id=answer_id, errors=[payload]
|
||||
)
|
||||
|
||||
elif answer_type == "ka":
|
||||
# Keep-alive message
|
||||
if self.check_keep_alive_task is not None:
|
||||
self._next_keep_alive_message.set()
|
||||
elif answer_type == "connection_ack":
|
||||
pass
|
||||
elif answer_type == "connection_error":
|
||||
error_payload = json_answer.get("payload")
|
||||
raise TransportServerError(f"Server error: '{repr(error_payload)}'")
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
except ValueError as e:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {json_answer}"
|
||||
) from e
|
||||
|
||||
return answer_type, answer_id, execution_result
|
||||
|
||||
def _parse_answer(
|
||||
self, answer: str
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server depending on
|
||||
the detected subprotocol.
|
||||
"""
|
||||
try:
|
||||
json_answer = json.loads(answer)
|
||||
except ValueError:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {answer}"
|
||||
)
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
return self._parse_answer_graphqlws(json_answer)
|
||||
|
||||
return self._parse_answer_apollo(json_answer)
|
||||
|
||||
async def _send_ping_coro(self) -> None:
|
||||
"""Coroutine to periodically send a ping from the client to the backend.
|
||||
|
||||
Only used for the graphql-ws protocol.
|
||||
|
||||
Send a ping every ping_interval seconds.
|
||||
Close the connection if a pong is not received within pong_timeout seconds.
|
||||
"""
|
||||
|
||||
assert self.ping_interval is not None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
await self.send_ping()
|
||||
|
||||
await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout)
|
||||
|
||||
# Reset for the next iteration
|
||||
self.pong_received.clear()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No pong received in the appriopriate time, close with error
|
||||
# If the timeout happens during a close already in progress, do nothing
|
||||
if self.close_task is None:
|
||||
await self._fail(
|
||||
TransportServerError(
|
||||
f"No pong received after {self.pong_timeout!r} seconds"
|
||||
),
|
||||
clean_close=False,
|
||||
)
|
||||
|
||||
async def _handle_answer(
|
||||
self,
|
||||
answer_type: str,
|
||||
answer_id: Optional[int],
|
||||
execution_result: Optional[ExecutionResult],
|
||||
) -> None:
|
||||
|
||||
# Put the answer in the queue
|
||||
await super()._handle_answer(answer_type, answer_id, execution_result)
|
||||
|
||||
# Answer pong to ping for graphql-ws protocol
|
||||
if answer_type == "ping":
|
||||
self.ping_received.set()
|
||||
if self.answer_pings:
|
||||
await self.send_pong()
|
||||
|
||||
elif answer_type == "pong":
|
||||
self.pong_received.set()
|
||||
|
||||
async def _after_connect(self):
|
||||
|
||||
# Find the backend subprotocol returned in the response headers
|
||||
response_headers = self.websocket.response_headers
|
||||
try:
|
||||
self.subprotocol = response_headers["Sec-WebSocket-Protocol"]
|
||||
except KeyError:
|
||||
# If the server does not send the subprotocol header, using
|
||||
# the apollo subprotocol by default
|
||||
self.subprotocol = self.APOLLO_SUBPROTOCOL
|
||||
|
||||
log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
|
||||
|
||||
async def _after_initialize(self):
|
||||
|
||||
# If requested, create a task to send periodic pings to the backend
|
||||
if (
|
||||
self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL
|
||||
and self.ping_interval is not None
|
||||
):
|
||||
|
||||
self.send_ping_task = asyncio.ensure_future(self._send_ping_coro())
|
||||
|
||||
async def _close_hook(self):
|
||||
|
||||
# Properly shut down the send ping task if enabled
|
||||
if self.send_ping_task is not None:
|
||||
self.send_ping_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self.send_ping_task
|
||||
self.send_ping_task = None
|
||||
@@ -0,0 +1,669 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from contextlib import suppress
|
||||
from ssl import SSLContext
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import websockets
|
||||
from graphql import DocumentNode, ExecutionResult
|
||||
from websockets.client import WebSocketClientProtocol
|
||||
from websockets.datastructures import Headers, HeadersLike
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.typing import Data, Subprotocol
|
||||
|
||||
from .async_transport import AsyncTransport
|
||||
from .exceptions import (
|
||||
TransportAlreadyConnected,
|
||||
TransportClosed,
|
||||
TransportProtocolError,
|
||||
TransportQueryError,
|
||||
TransportServerError,
|
||||
)
|
||||
|
||||
log = logging.getLogger("gql.transport.websockets")
|
||||
|
||||
ParsedAnswer = Tuple[str, Optional[ExecutionResult]]
|
||||
|
||||
|
||||
class ListenerQueue:
|
||||
"""Special queue used for each query waiting for server answers
|
||||
|
||||
If the server is stopped while the listener is still waiting,
|
||||
Then we send an exception to the queue and this exception will be raised
|
||||
to the consumer once all the previous messages have been consumed from the queue
|
||||
"""
|
||||
|
||||
def __init__(self, query_id: int, send_stop: bool) -> None:
|
||||
self.query_id: int = query_id
|
||||
self.send_stop: bool = send_stop
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._closed: bool = False
|
||||
|
||||
async def get(self) -> ParsedAnswer:
|
||||
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
item = await self._queue.get()
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
# If we receive an exception when reading the queue, we raise it
|
||||
if isinstance(item, Exception):
|
||||
self._closed = True
|
||||
raise item
|
||||
|
||||
# Don't need to save new answers or
|
||||
# send the stop message if we already received the complete message
|
||||
answer_type, execution_result = item
|
||||
if answer_type == "complete":
|
||||
self.send_stop = False
|
||||
self._closed = True
|
||||
|
||||
return item
|
||||
|
||||
async def put(self, item: ParsedAnswer) -> None:
|
||||
|
||||
if not self._closed:
|
||||
await self._queue.put(item)
|
||||
|
||||
async def set_exception(self, exception: Exception) -> None:
|
||||
|
||||
# Put the exception in the queue
|
||||
await self._queue.put(exception)
|
||||
|
||||
# Don't need to send stop messages in case of error
|
||||
self.send_stop = False
|
||||
self._closed = True
|
||||
|
||||
|
||||
class WebsocketsTransportBase(AsyncTransport):
|
||||
"""abstract :ref:`Async Transport <async_transports>` used to implement
|
||||
different websockets protocols.
|
||||
|
||||
This transport uses asyncio and the websockets library in order to send requests
|
||||
on a websocket connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[HeadersLike] = None,
|
||||
ssl: Union[SSLContext, bool] = False,
|
||||
init_payload: Dict[str, Any] = {},
|
||||
connect_timeout: Optional[Union[int, float]] = 10,
|
||||
close_timeout: Optional[Union[int, float]] = 10,
|
||||
ack_timeout: Optional[Union[int, float]] = 10,
|
||||
keep_alive_timeout: Optional[Union[int, float]] = None,
|
||||
connect_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Initialize the transport with the given parameters.
|
||||
|
||||
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
|
||||
:param headers: Dict of HTTP Headers.
|
||||
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
|
||||
:param init_payload: Dict of the payload sent in the connection_init message.
|
||||
:param connect_timeout: Timeout in seconds for the establishment
|
||||
of the websocket connection. If None is provided this will wait forever.
|
||||
:param close_timeout: Timeout in seconds for the close. If None is provided
|
||||
this will wait forever.
|
||||
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
|
||||
from the server. If None is provided this will wait forever.
|
||||
:param keep_alive_timeout: Optional Timeout in seconds to receive
|
||||
a sign of liveness from the server.
|
||||
:param connect_args: Other parameters forwarded to websockets.connect
|
||||
"""
|
||||
|
||||
self.url: str = url
|
||||
self.headers: Optional[HeadersLike] = headers
|
||||
self.ssl: Union[SSLContext, bool] = ssl
|
||||
self.init_payload: Dict[str, Any] = init_payload
|
||||
|
||||
self.connect_timeout: Optional[Union[int, float]] = connect_timeout
|
||||
self.close_timeout: Optional[Union[int, float]] = close_timeout
|
||||
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
|
||||
self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout
|
||||
|
||||
self.connect_args = connect_args
|
||||
|
||||
self.websocket: Optional[WebSocketClientProtocol] = None
|
||||
self.next_query_id: int = 1
|
||||
self.listeners: Dict[int, ListenerQueue] = {}
|
||||
|
||||
self.receive_data_task: Optional[asyncio.Future] = None
|
||||
self.check_keep_alive_task: Optional[asyncio.Future] = None
|
||||
self.close_task: Optional[asyncio.Future] = None
|
||||
|
||||
# We need to set an event loop here if there is none
|
||||
# Or else we will not be able to create an asyncio.Event()
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="There is no current event loop"
|
||||
)
|
||||
self._loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
self._wait_closed: asyncio.Event = asyncio.Event()
|
||||
self._wait_closed.set()
|
||||
|
||||
self._no_more_listeners: asyncio.Event = asyncio.Event()
|
||||
self._no_more_listeners.set()
|
||||
|
||||
if self.keep_alive_timeout is not None:
|
||||
self._next_keep_alive_message: asyncio.Event = asyncio.Event()
|
||||
self._next_keep_alive_message.set()
|
||||
|
||||
self.payloads: Dict[str, Any] = {}
|
||||
"""payloads is a dict which will contain the payloads received
|
||||
for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'"""
|
||||
|
||||
self._connecting: bool = False
|
||||
|
||||
self.close_exception: Optional[Exception] = None
|
||||
|
||||
# The list of supported subprotocols should be defined in the subclass
|
||||
self.supported_subprotocols: List[Subprotocol] = []
|
||||
|
||||
self.response_headers: Optional[Headers] = None
|
||||
|
||||
async def _initialize(self):
|
||||
"""Hook to send the initialization messages after the connection
|
||||
and potentially wait for the backend ack.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _stop_listener(self, query_id: int):
|
||||
"""Hook to stop to listen to a specific query.
|
||||
Will send a stop message in some subclasses.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _after_connect(self):
|
||||
"""Hook to add custom code for subclasses after the connection
|
||||
has been established.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _after_initialize(self):
|
||||
"""Hook to add custom code for subclasses after the initialization
|
||||
has been done.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _close_hook(self):
|
||||
"""Hook to add custom code for subclasses for the connection close"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _connection_terminate(self):
|
||||
"""Hook to add custom code for subclasses after the initialization
|
||||
has been done.
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
async def _send(self, message: str) -> None:
|
||||
"""Send the provided message to the websocket connection and log the message"""
|
||||
|
||||
if not self.websocket:
|
||||
raise TransportClosed(
|
||||
"Transport is not connected"
|
||||
) from self.close_exception
|
||||
|
||||
try:
|
||||
await self.websocket.send(message)
|
||||
log.info(">>> %s", message)
|
||||
except ConnectionClosed as e:
|
||||
await self._fail(e, clean_close=False)
|
||||
raise e
|
||||
|
||||
async def _receive(self) -> str:
|
||||
"""Wait the next message from the websocket connection and log the answer"""
|
||||
|
||||
# It is possible that the websocket has been already closed in another task
|
||||
if self.websocket is None:
|
||||
raise TransportClosed("Transport is already closed")
|
||||
|
||||
# Wait for the next websocket frame. Can raise ConnectionClosed
|
||||
data: Data = await self.websocket.recv()
|
||||
|
||||
# websocket.recv() can return either str or bytes
|
||||
# In our case, we should receive only str here
|
||||
if not isinstance(data, str):
|
||||
raise TransportProtocolError("Binary data received in the websocket")
|
||||
|
||||
answer: str = data
|
||||
|
||||
log.info("<<< %s", answer)
|
||||
|
||||
return answer
|
||||
|
||||
@abstractmethod
|
||||
async def _send_query(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> int:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
def _parse_answer(
|
||||
self, answer: str
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def _check_ws_liveness(self) -> None:
|
||||
"""Coroutine which will periodically check the liveness of the connection
|
||||
through keep-alive messages
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.wait_for(
|
||||
self._next_keep_alive_message.wait(), self.keep_alive_timeout
|
||||
)
|
||||
|
||||
# Reset for the next iteration
|
||||
self._next_keep_alive_message.clear()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No keep-alive message in the appriopriate interval, close with error
|
||||
# while trying to notify the server of a proper close (in case
|
||||
# the keep-alive interval of the client or server was not aligned
|
||||
# the connection still remains)
|
||||
|
||||
# If the timeout happens during a close already in progress, do nothing
|
||||
if self.close_task is None:
|
||||
await self._fail(
|
||||
TransportServerError(
|
||||
"No keep-alive message has been received within "
|
||||
"the expected interval ('keep_alive_timeout' parameter)"
|
||||
),
|
||||
clean_close=False,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# The client is probably closing, handle it properly
|
||||
pass
|
||||
|
||||
async def _receive_data_loop(self) -> None:
|
||||
"""Main asyncio task which will listen to the incoming messages and will
|
||||
call the parse_answer and handle_answer methods of the subclass."""
|
||||
try:
|
||||
while True:
|
||||
|
||||
# Wait the next answer from the websocket server
|
||||
try:
|
||||
answer = await self._receive()
|
||||
except (ConnectionClosed, TransportProtocolError) as e:
|
||||
await self._fail(e, clean_close=False)
|
||||
break
|
||||
except TransportClosed:
|
||||
break
|
||||
|
||||
# Parse the answer
|
||||
try:
|
||||
answer_type, answer_id, execution_result = self._parse_answer(
|
||||
answer
|
||||
)
|
||||
except TransportQueryError as e:
|
||||
# Received an exception for a specific query
|
||||
# ==> Add an exception to this query queue
|
||||
# The exception is raised for this specific query,
|
||||
# but the transport is not closed.
|
||||
assert isinstance(
|
||||
e.query_id, int
|
||||
), "TransportQueryError should have a query_id defined here"
|
||||
try:
|
||||
await self.listeners[e.query_id].set_exception(e)
|
||||
except KeyError:
|
||||
# Do nothing if no one is listening to this query_id
|
||||
pass
|
||||
|
||||
continue
|
||||
|
||||
except (TransportServerError, TransportProtocolError) as e:
|
||||
# Received a global exception for this transport
|
||||
# ==> close the transport
|
||||
# The exception will be raised for all current queries.
|
||||
await self._fail(e, clean_close=False)
|
||||
break
|
||||
|
||||
await self._handle_answer(answer_type, answer_id, execution_result)
|
||||
|
||||
finally:
|
||||
log.debug("Exiting _receive_data_loop()")
|
||||
|
||||
async def _handle_answer(
|
||||
self,
|
||||
answer_type: str,
|
||||
answer_id: Optional[int],
|
||||
execution_result: Optional[ExecutionResult],
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
# Put the answer in the queue
|
||||
if answer_id is not None:
|
||||
await self.listeners[answer_id].put((answer_type, execution_result))
|
||||
except KeyError:
|
||||
# Do nothing if no one is listening to this query_id.
|
||||
pass
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
send_stop: Optional[bool] = True,
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
"""Send a query and receive the results using a python async generator.
|
||||
|
||||
The query can be a graphql query, mutation or subscription.
|
||||
|
||||
The results are sent as an ExecutionResult object.
|
||||
"""
|
||||
|
||||
# Send the query and receive the id
|
||||
query_id: int = await self._send_query(
|
||||
document, variable_values, operation_name
|
||||
)
|
||||
|
||||
# Create a queue to receive the answers for this query_id
|
||||
listener = ListenerQueue(query_id, send_stop=(send_stop is True))
|
||||
self.listeners[query_id] = listener
|
||||
|
||||
# We will need to wait at close for this query to clean properly
|
||||
self._no_more_listeners.clear()
|
||||
|
||||
try:
|
||||
# Loop over the received answers
|
||||
while True:
|
||||
|
||||
# Wait for the answer from the queue of this query_id
|
||||
# This can raise a TransportError or ConnectionClosed exception.
|
||||
answer_type, execution_result = await listener.get()
|
||||
|
||||
# If the received answer contains data,
|
||||
# Then we will yield the results back as an ExecutionResult object
|
||||
if execution_result is not None:
|
||||
yield execution_result
|
||||
|
||||
# If we receive a 'complete' answer from the server,
|
||||
# Then we will end this async generator output without errors
|
||||
elif answer_type == "complete":
|
||||
log.debug(
|
||||
f"Complete received for query {query_id} --> exit without error"
|
||||
)
|
||||
break
|
||||
|
||||
except (asyncio.CancelledError, GeneratorExit) as e:
|
||||
log.debug(f"Exception in subscribe: {e!r}")
|
||||
if listener.send_stop:
|
||||
await self._stop_listener(query_id)
|
||||
listener.send_stop = False
|
||||
|
||||
finally:
|
||||
log.debug(f"In subscribe finally for query_id {query_id}")
|
||||
self._remove_listener(query_id)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> ExecutionResult:
|
||||
"""Execute the provided document AST against the configured remote server
|
||||
using the current session.
|
||||
|
||||
Send a query but close the async generator as soon as we have the first answer.
|
||||
|
||||
The result is sent as an ExecutionResult object.
|
||||
"""
|
||||
first_result = None
|
||||
|
||||
generator = self.subscribe(
|
||||
document, variable_values, operation_name, send_stop=False
|
||||
)
|
||||
|
||||
async for result in generator:
|
||||
first_result = result
|
||||
|
||||
# Note: we need to run generator.aclose() here or the finally block in
|
||||
# the subscribe will not be reached in pypy3 (python version 3.6.1)
|
||||
await generator.aclose()
|
||||
|
||||
break
|
||||
|
||||
if first_result is None:
|
||||
raise TransportQueryError(
|
||||
"Query completed without any answer received from the server"
|
||||
)
|
||||
|
||||
return first_result
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Coroutine which will:
|
||||
|
||||
- connect to the websocket address
|
||||
- send the init message
|
||||
- wait for the connection acknowledge from the server
|
||||
- create an asyncio task which will be used to receive
|
||||
and parse the websocket answers
|
||||
|
||||
Should be cleaned with a call to the close coroutine
|
||||
"""
|
||||
|
||||
log.debug("connect: starting")
|
||||
|
||||
if self.websocket is None and not self._connecting:
|
||||
|
||||
# Set connecting to True to avoid a race condition if user is trying
|
||||
# to connect twice using the same client at the same time
|
||||
self._connecting = True
|
||||
|
||||
# If the ssl parameter is not provided,
|
||||
# generate the ssl value depending on the url
|
||||
ssl: Optional[Union[SSLContext, bool]]
|
||||
if self.ssl:
|
||||
ssl = self.ssl
|
||||
else:
|
||||
ssl = True if self.url.startswith("wss") else None
|
||||
|
||||
# Set default arguments used in the websockets.connect call
|
||||
connect_args: Dict[str, Any] = {
|
||||
"ssl": ssl,
|
||||
"extra_headers": self.headers,
|
||||
"subprotocols": self.supported_subprotocols,
|
||||
}
|
||||
|
||||
# Adding custom parameters passed from init
|
||||
connect_args.update(self.connect_args)
|
||||
|
||||
# Connection to the specified url
|
||||
# Generate a TimeoutError if taking more than connect_timeout seconds
|
||||
# Set the _connecting flag to False after in all cases
|
||||
try:
|
||||
self.websocket = await asyncio.wait_for(
|
||||
websockets.client.connect(self.url, **connect_args),
|
||||
self.connect_timeout,
|
||||
)
|
||||
finally:
|
||||
self._connecting = False
|
||||
|
||||
self.websocket = cast(WebSocketClientProtocol, self.websocket)
|
||||
|
||||
self.response_headers = self.websocket.response_headers
|
||||
|
||||
# Run the after_connect hook of the subclass
|
||||
await self._after_connect()
|
||||
|
||||
self.next_query_id = 1
|
||||
self.close_exception = None
|
||||
self._wait_closed.clear()
|
||||
|
||||
# Send the init message and wait for the ack from the server
|
||||
# Note: This should generate a TimeoutError
|
||||
# if no ACKs are received within the ack_timeout
|
||||
try:
|
||||
await self._initialize()
|
||||
except ConnectionClosed as e:
|
||||
raise e
|
||||
except (TransportProtocolError, asyncio.TimeoutError) as e:
|
||||
await self._fail(e, clean_close=False)
|
||||
raise e
|
||||
|
||||
# Run the after_init hook of the subclass
|
||||
await self._after_initialize()
|
||||
|
||||
# If specified, create a task to check liveness of the connection
|
||||
# through keep-alive messages
|
||||
if self.keep_alive_timeout is not None:
|
||||
self.check_keep_alive_task = asyncio.ensure_future(
|
||||
self._check_ws_liveness()
|
||||
)
|
||||
|
||||
# Create a task to listen to the incoming websocket messages
|
||||
self.receive_data_task = asyncio.ensure_future(self._receive_data_loop())
|
||||
|
||||
else:
|
||||
raise TransportAlreadyConnected("Transport is already connected")
|
||||
|
||||
log.debug("connect: done")
|
||||
|
||||
def _remove_listener(self, query_id) -> None:
|
||||
"""After exiting from a subscription, remove the listener and
|
||||
signal an event if this was the last listener for the client.
|
||||
"""
|
||||
if query_id in self.listeners:
|
||||
del self.listeners[query_id]
|
||||
|
||||
remaining = len(self.listeners)
|
||||
log.debug(f"listener {query_id} deleted, {remaining} remaining")
|
||||
|
||||
if remaining == 0:
|
||||
self._no_more_listeners.set()
|
||||
|
||||
async def _clean_close(self, e: Exception) -> None:
|
||||
"""Coroutine which will:
|
||||
|
||||
- send stop messages for each active subscription to the server
|
||||
- send the connection terminate message
|
||||
"""
|
||||
|
||||
# Send 'stop' message for all current queries
|
||||
for query_id, listener in self.listeners.items():
|
||||
|
||||
if listener.send_stop:
|
||||
await self._stop_listener(query_id)
|
||||
listener.send_stop = False
|
||||
|
||||
# Wait that there is no more listeners (we received 'complete' for all queries)
|
||||
try:
|
||||
await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout)
|
||||
except asyncio.TimeoutError: # pragma: no cover
|
||||
log.debug("Timer close_timeout fired")
|
||||
|
||||
# Calling the subclass hook
|
||||
await self._connection_terminate()
|
||||
|
||||
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
|
||||
"""Coroutine which will:
|
||||
|
||||
- do a clean_close if possible:
|
||||
- send stop messages for each active query to the server
|
||||
- send the connection terminate message
|
||||
- close the websocket connection
|
||||
- send the exception to all the remaining listeners
|
||||
"""
|
||||
|
||||
log.debug("_close_coro: starting")
|
||||
|
||||
try:
|
||||
|
||||
# We should always have an active websocket connection here
|
||||
assert self.websocket is not None
|
||||
|
||||
# Properly shut down liveness checker if enabled
|
||||
if self.check_keep_alive_task is not None:
|
||||
# More info: https://stackoverflow.com/a/43810272/1113207
|
||||
self.check_keep_alive_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self.check_keep_alive_task
|
||||
|
||||
# Calling the subclass close hook
|
||||
await self._close_hook()
|
||||
|
||||
# Saving exception to raise it later if trying to use the transport
|
||||
# after it has already closed.
|
||||
self.close_exception = e
|
||||
|
||||
if clean_close:
|
||||
log.debug("_close_coro: starting clean_close")
|
||||
try:
|
||||
await self._clean_close(e)
|
||||
except Exception as exc: # pragma: no cover
|
||||
log.warning("Ignoring exception in _clean_close: " + repr(exc))
|
||||
|
||||
log.debug("_close_coro: sending exception to listeners")
|
||||
|
||||
# Send an exception to all remaining listeners
|
||||
for query_id, listener in self.listeners.items():
|
||||
await listener.set_exception(e)
|
||||
|
||||
log.debug("_close_coro: close websocket connection")
|
||||
|
||||
await self.websocket.close()
|
||||
|
||||
log.debug("_close_coro: websocket connection closed")
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
log.warning("Exception catched in _close_coro: " + repr(exc))
|
||||
|
||||
finally:
|
||||
|
||||
log.debug("_close_coro: start cleanup")
|
||||
|
||||
self.websocket = None
|
||||
self.close_task = None
|
||||
self.check_keep_alive_task = None
|
||||
self._wait_closed.set()
|
||||
|
||||
log.debug("_close_coro: exiting")
|
||||
|
||||
async def _fail(self, e: Exception, clean_close: bool = True) -> None:
|
||||
log.debug("_fail: starting with exception: " + repr(e))
|
||||
|
||||
if self.close_task is None:
|
||||
|
||||
if self.websocket is None:
|
||||
log.debug("_fail started with self.websocket == None -> already closed")
|
||||
else:
|
||||
self.close_task = asyncio.shield(
|
||||
asyncio.ensure_future(self._close_coro(e, clean_close=clean_close))
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
"close_task is not None in _fail. Previous exception is: "
|
||||
+ repr(self.close_exception)
|
||||
+ " New exception is: "
|
||||
+ repr(e)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
log.debug("close: starting")
|
||||
|
||||
await self._fail(TransportClosed("Websocket GraphQL transport closed by user"))
|
||||
await self.wait_closed()
|
||||
|
||||
log.debug("close: done")
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
log.debug("wait_close: starting")
|
||||
|
||||
await self._wait_closed.wait()
|
||||
|
||||
log.debug("wait_close: done")
|
||||
@@ -0,0 +1,19 @@
|
||||
from .build_client_schema import build_client_schema
|
||||
from .get_introspection_query_ast import get_introspection_query_ast
|
||||
from .node_tree import node_tree
|
||||
from .parse_result import parse_result
|
||||
from .serialize_variable_values import serialize_value, serialize_variable_values
|
||||
from .update_schema_enum import update_schema_enum
|
||||
from .update_schema_scalars import update_schema_scalar, update_schema_scalars
|
||||
|
||||
__all__ = [
|
||||
"build_client_schema",
|
||||
"node_tree",
|
||||
"parse_result",
|
||||
"get_introspection_query_ast",
|
||||
"serialize_variable_values",
|
||||
"serialize_value",
|
||||
"update_schema_enum",
|
||||
"update_schema_scalars",
|
||||
"update_schema_scalar",
|
||||
]
|
||||
@@ -0,0 +1,98 @@
|
||||
from graphql import GraphQLSchema, IntrospectionQuery
|
||||
from graphql import build_client_schema as build_client_schema_orig
|
||||
from graphql.pyutils import inspect
|
||||
from graphql.utilities.get_introspection_query import (
|
||||
DirectiveLocation,
|
||||
IntrospectionDirective,
|
||||
)
|
||||
|
||||
__all__ = ["build_client_schema"]
|
||||
|
||||
|
||||
INCLUDE_DIRECTIVE_JSON: IntrospectionDirective = {
|
||||
"name": "include",
|
||||
"description": (
|
||||
"Directs the executor to include this field or fragment "
|
||||
"only when the `if` argument is true."
|
||||
),
|
||||
"locations": [
|
||||
DirectiveLocation.FIELD,
|
||||
DirectiveLocation.FRAGMENT_SPREAD,
|
||||
DirectiveLocation.INLINE_FRAGMENT,
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "if",
|
||||
"description": "Included when true.",
|
||||
"type": {
|
||||
"kind": "NON_NULL",
|
||||
"name": "None",
|
||||
"ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"},
|
||||
},
|
||||
"defaultValue": "None",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
SKIP_DIRECTIVE_JSON: IntrospectionDirective = {
|
||||
"name": "skip",
|
||||
"description": (
|
||||
"Directs the executor to skip this field or fragment "
|
||||
"when the `if` argument is true."
|
||||
),
|
||||
"locations": [
|
||||
DirectiveLocation.FIELD,
|
||||
DirectiveLocation.FRAGMENT_SPREAD,
|
||||
DirectiveLocation.INLINE_FRAGMENT,
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "if",
|
||||
"description": "Skipped when true.",
|
||||
"type": {
|
||||
"kind": "NON_NULL",
|
||||
"name": "None",
|
||||
"ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"},
|
||||
},
|
||||
"defaultValue": "None",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def build_client_schema(introspection: IntrospectionQuery) -> GraphQLSchema:
|
||||
"""This is an alternative to the graphql-core function
|
||||
:code:`build_client_schema` but with default include and skip directives
|
||||
added to the schema to fix
|
||||
`issue #278 <https://github.com/graphql-python/gql/issues/278>`_
|
||||
|
||||
.. warning::
|
||||
This function will be removed once the issue
|
||||
`graphql-js#3419 <https://github.com/graphql/graphql-js/issues/3419>`_
|
||||
has been fixed and ported to graphql-core so don't use it
|
||||
outside gql.
|
||||
"""
|
||||
|
||||
if not isinstance(introspection, dict) or not isinstance(
|
||||
introspection.get("__schema"), dict
|
||||
):
|
||||
raise TypeError(
|
||||
"Invalid or incomplete introspection result. Ensure that you"
|
||||
" are passing the 'data' attribute of an introspection response"
|
||||
f" and no 'errors' were returned alongside: {inspect(introspection)}."
|
||||
)
|
||||
|
||||
schema_introspection = introspection["__schema"]
|
||||
|
||||
directives = schema_introspection.get("directives", None)
|
||||
|
||||
if directives is None:
|
||||
schema_introspection["directives"] = directives = []
|
||||
|
||||
if not any(directive["name"] == "skip" for directive in directives):
|
||||
directives.append(SKIP_DIRECTIVE_JSON)
|
||||
|
||||
if not any(directive["name"] == "include" for directive in directives):
|
||||
directives.append(INCLUDE_DIRECTIVE_JSON)
|
||||
|
||||
return build_client_schema_orig(introspection, assume_valid=False)
|
||||
+142
@@ -0,0 +1,142 @@
|
||||
from itertools import repeat
|
||||
|
||||
from graphql import DocumentNode, GraphQLSchema
|
||||
|
||||
from gql.dsl import DSLFragment, DSLMetaField, DSLQuery, DSLSchema, dsl_gql
|
||||
|
||||
|
||||
def get_introspection_query_ast(
|
||||
descriptions: bool = True,
|
||||
specified_by_url: bool = False,
|
||||
directive_is_repeatable: bool = False,
|
||||
schema_description: bool = False,
|
||||
input_value_deprecation: bool = False,
|
||||
type_recursion_level: int = 7,
|
||||
) -> DocumentNode:
|
||||
"""Get a query for introspection as a document using the DSL module.
|
||||
|
||||
Equivalent to the get_introspection_query function from graphql-core
|
||||
but using the DSL module and allowing to select the recursion level.
|
||||
|
||||
Optionally, you can exclude descriptions, include specification URLs,
|
||||
include repeatability of directives, and specify whether to include
|
||||
the schema description as well.
|
||||
"""
|
||||
|
||||
ds = DSLSchema(GraphQLSchema())
|
||||
|
||||
fragment_FullType = DSLFragment("FullType").on(ds.__Type)
|
||||
fragment_InputValue = DSLFragment("InputValue").on(ds.__InputValue)
|
||||
fragment_TypeRef = DSLFragment("TypeRef").on(ds.__Type)
|
||||
|
||||
schema = DSLMetaField("__schema")
|
||||
|
||||
if descriptions and schema_description:
|
||||
schema.select(ds.__Schema.description)
|
||||
|
||||
schema.select(
|
||||
ds.__Schema.queryType.select(ds.__Type.name),
|
||||
ds.__Schema.mutationType.select(ds.__Type.name),
|
||||
ds.__Schema.subscriptionType.select(ds.__Type.name),
|
||||
)
|
||||
|
||||
schema.select(ds.__Schema.types.select(fragment_FullType))
|
||||
|
||||
directives = ds.__Schema.directives.select(ds.__Directive.name)
|
||||
|
||||
deprecated_expand = {}
|
||||
|
||||
if input_value_deprecation:
|
||||
deprecated_expand = {
|
||||
"includeDeprecated": True,
|
||||
}
|
||||
|
||||
if descriptions:
|
||||
directives.select(ds.__Directive.description)
|
||||
if directive_is_repeatable:
|
||||
directives.select(ds.__Directive.isRepeatable)
|
||||
directives.select(
|
||||
ds.__Directive.locations,
|
||||
ds.__Directive.args(**deprecated_expand).select(fragment_InputValue),
|
||||
)
|
||||
|
||||
schema.select(directives)
|
||||
|
||||
fragment_FullType.select(
|
||||
ds.__Type.kind,
|
||||
ds.__Type.name,
|
||||
)
|
||||
if descriptions:
|
||||
fragment_FullType.select(ds.__Type.description)
|
||||
if specified_by_url:
|
||||
fragment_FullType.select(ds.__Type.specifiedByURL)
|
||||
|
||||
fields = ds.__Type.fields(includeDeprecated=True).select(ds.__Field.name)
|
||||
|
||||
if descriptions:
|
||||
fields.select(ds.__Field.description)
|
||||
|
||||
fields.select(
|
||||
ds.__Field.args(**deprecated_expand).select(fragment_InputValue),
|
||||
ds.__Field.type.select(fragment_TypeRef),
|
||||
ds.__Field.isDeprecated,
|
||||
ds.__Field.deprecationReason,
|
||||
)
|
||||
|
||||
enum_values = ds.__Type.enumValues(includeDeprecated=True).select(
|
||||
ds.__EnumValue.name
|
||||
)
|
||||
|
||||
if descriptions:
|
||||
enum_values.select(ds.__EnumValue.description)
|
||||
|
||||
enum_values.select(
|
||||
ds.__EnumValue.isDeprecated,
|
||||
ds.__EnumValue.deprecationReason,
|
||||
)
|
||||
|
||||
fragment_FullType.select(
|
||||
fields,
|
||||
ds.__Type.inputFields(**deprecated_expand).select(fragment_InputValue),
|
||||
ds.__Type.interfaces.select(fragment_TypeRef),
|
||||
enum_values,
|
||||
ds.__Type.possibleTypes.select(fragment_TypeRef),
|
||||
)
|
||||
|
||||
fragment_InputValue.select(ds.__InputValue.name)
|
||||
|
||||
if descriptions:
|
||||
fragment_InputValue.select(ds.__InputValue.description)
|
||||
|
||||
fragment_InputValue.select(
|
||||
ds.__InputValue.type.select(fragment_TypeRef),
|
||||
ds.__InputValue.defaultValue,
|
||||
)
|
||||
|
||||
if input_value_deprecation:
|
||||
fragment_InputValue.select(
|
||||
ds.__InputValue.isDeprecated,
|
||||
ds.__InputValue.deprecationReason,
|
||||
)
|
||||
|
||||
fragment_TypeRef.select(
|
||||
ds.__Type.kind,
|
||||
ds.__Type.name,
|
||||
)
|
||||
|
||||
if type_recursion_level >= 1:
|
||||
current_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
|
||||
fragment_TypeRef.select(current_field)
|
||||
|
||||
for _ in repeat(None, type_recursion_level - 1):
|
||||
new_oftype = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
|
||||
current_field.select(new_oftype)
|
||||
current_field = new_oftype
|
||||
|
||||
query = DSLQuery(schema)
|
||||
|
||||
query.name = "IntrospectionQuery"
|
||||
|
||||
dsl_query = dsl_gql(query, fragment_FullType, fragment_InputValue, fragment_TypeRef)
|
||||
|
||||
return dsl_query
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import Any, Iterable, List, Optional, Sized
|
||||
|
||||
from graphql import Node
|
||||
|
||||
|
||||
def _node_tree_recursive(
|
||||
obj: Any,
|
||||
*,
|
||||
indent: int = 0,
|
||||
ignored_keys: List,
|
||||
):
|
||||
|
||||
assert ignored_keys is not None
|
||||
|
||||
results = []
|
||||
|
||||
if hasattr(obj, "__slots__"):
|
||||
|
||||
results.append(" " * indent + f"{type(obj).__name__}")
|
||||
|
||||
try:
|
||||
keys = sorted(obj.keys)
|
||||
except AttributeError:
|
||||
# If the object has no keys attribute, print its repr and return.
|
||||
results.append(" " * (indent + 1) + repr(obj))
|
||||
else:
|
||||
for key in keys:
|
||||
if key in ignored_keys:
|
||||
continue
|
||||
attr_value = getattr(obj, key, None)
|
||||
results.append(" " * (indent + 1) + f"{key}:")
|
||||
if isinstance(attr_value, Iterable) and not isinstance(
|
||||
attr_value, (str, bytes)
|
||||
):
|
||||
if isinstance(attr_value, Sized) and len(attr_value) == 0:
|
||||
results.append(
|
||||
" " * (indent + 2) + f"empty {type(attr_value).__name__}"
|
||||
)
|
||||
else:
|
||||
for item in attr_value:
|
||||
results.append(
|
||||
_node_tree_recursive(
|
||||
item,
|
||||
indent=indent + 2,
|
||||
ignored_keys=ignored_keys,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
_node_tree_recursive(
|
||||
attr_value,
|
||||
indent=indent + 2,
|
||||
ignored_keys=ignored_keys,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(" " * indent + repr(obj))
|
||||
|
||||
return "\n".join(results)
|
||||
|
||||
|
||||
def node_tree(
|
||||
obj: Node,
|
||||
*,
|
||||
ignore_loc: bool = True,
|
||||
ignore_block: bool = True,
|
||||
ignored_keys: Optional[List] = None,
|
||||
):
|
||||
"""Method which returns a tree of Node elements as a String.
|
||||
|
||||
Useful to debug deep DocumentNode instances created by gql or dsl_gql.
|
||||
|
||||
NOTE: from gql version 3.6.0b4 the elements of each node are sorted to ignore
|
||||
small changes in graphql-core
|
||||
|
||||
WARNING: the output of this method is not guaranteed and may change without notice.
|
||||
"""
|
||||
|
||||
assert isinstance(obj, Node)
|
||||
|
||||
if ignored_keys is None:
|
||||
ignored_keys = []
|
||||
|
||||
if ignore_loc:
|
||||
# We are ignoring loc attributes by default
|
||||
ignored_keys.append("loc")
|
||||
|
||||
if ignore_block:
|
||||
# We are ignoring block attributes by default (in StringValueNode)
|
||||
ignored_keys.append("block")
|
||||
|
||||
return _node_tree_recursive(obj, ignored_keys=ignored_keys)
|
||||
@@ -0,0 +1,446 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
from graphql import (
|
||||
IDLE,
|
||||
REMOVE,
|
||||
DocumentNode,
|
||||
FieldNode,
|
||||
FragmentDefinitionNode,
|
||||
FragmentSpreadNode,
|
||||
GraphQLError,
|
||||
GraphQLInterfaceType,
|
||||
GraphQLList,
|
||||
GraphQLNonNull,
|
||||
GraphQLObjectType,
|
||||
GraphQLSchema,
|
||||
GraphQLType,
|
||||
InlineFragmentNode,
|
||||
NameNode,
|
||||
Node,
|
||||
OperationDefinitionNode,
|
||||
SelectionSetNode,
|
||||
TypeInfo,
|
||||
TypeInfoVisitor,
|
||||
Visitor,
|
||||
is_leaf_type,
|
||||
print_ast,
|
||||
visit,
|
||||
)
|
||||
from graphql.language.visitor import VisitorActionEnum
|
||||
from graphql.pyutils import inspect
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Equivalent to QUERY_DOCUMENT_KEYS but only for fields interesting to
|
||||
# visit to parse the results
|
||||
RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
|
||||
"document": ("definitions",),
|
||||
"operation_definition": ("selection_set",),
|
||||
"selection_set": ("selections",),
|
||||
"field": ("selection_set",),
|
||||
"inline_fragment": ("selection_set",),
|
||||
"fragment_definition": ("selection_set",),
|
||||
}
|
||||
|
||||
|
||||
def _ignore_non_null(type_: GraphQLType):
|
||||
"""Removes the GraphQLNonNull wrappings around types."""
|
||||
if isinstance(type_, GraphQLNonNull):
|
||||
return type_.of_type
|
||||
else:
|
||||
return type_
|
||||
|
||||
|
||||
def _get_fragment(document, fragment_name):
|
||||
"""Returns a fragment from the document."""
|
||||
for definition in document.definitions:
|
||||
if isinstance(definition, FragmentDefinitionNode):
|
||||
if definition.name.value == fragment_name:
|
||||
return definition
|
||||
|
||||
raise GraphQLError(f'Fragment "{fragment_name}" not found in document!')
|
||||
|
||||
|
||||
class ParseResultVisitor(Visitor):
|
||||
def __init__(
|
||||
self,
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
node: Node,
|
||||
result: Dict[str, Any],
|
||||
type_info: TypeInfo,
|
||||
visit_fragment: bool = False,
|
||||
inside_list_level: int = 0,
|
||||
operation_name: Optional[str] = None,
|
||||
):
|
||||
"""Recursive Implementation of a Visitor class to parse results
|
||||
correspondind to a schema and a document.
|
||||
|
||||
Using a TypeInfo class to get the node types during traversal.
|
||||
|
||||
If we reach a list in the results, then we parse each
|
||||
item of the list recursively, traversing the same nodes
|
||||
of the query again.
|
||||
|
||||
During traversal, we keep the current position in the result
|
||||
in the result_stack field.
|
||||
|
||||
Alongside the field type, we calculate the "result type"
|
||||
which is computed from the field type and the current
|
||||
recursive level we are for this field
|
||||
(:code:`inside_list_level` argument).
|
||||
"""
|
||||
self.schema: GraphQLSchema = schema
|
||||
self.document: DocumentNode = document
|
||||
self.node: Node = node
|
||||
self.result: Dict[str, Any] = result
|
||||
self.type_info: TypeInfo = type_info
|
||||
self.visit_fragment: bool = visit_fragment
|
||||
self.inside_list_level = inside_list_level
|
||||
self.operation_name = operation_name
|
||||
|
||||
self.result_stack: List[Any] = []
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def current_result(self):
|
||||
try:
|
||||
return self.result_stack[-1]
|
||||
except IndexError:
|
||||
return self.result
|
||||
|
||||
@staticmethod
|
||||
def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]:
|
||||
results = cast(List[Dict[str, Any]], node.definitions)
|
||||
return {k: v for result in results for k, v in result.items()}
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> Union[None, VisitorActionEnum]:
|
||||
|
||||
if self.operation_name is not None:
|
||||
if not hasattr(node.name, "value"):
|
||||
return REMOVE # pragma: no cover
|
||||
|
||||
node.name = cast(NameNode, node.name)
|
||||
|
||||
if node.name.value != self.operation_name:
|
||||
log.debug(f"SKIPPING operation {node.name.value}")
|
||||
return REMOVE
|
||||
|
||||
return IDLE
|
||||
|
||||
@staticmethod
|
||||
def leave_operation_definition(
|
||||
node: OperationDefinitionNode, *_args: Any
|
||||
) -> Dict[str, Any]:
|
||||
selections = cast(List[Dict[str, Any]], node.selection_set)
|
||||
return {k: v for s in selections for k, v in s.items()}
|
||||
|
||||
@staticmethod
|
||||
def leave_selection_set(node: SelectionSetNode, *_args: Any) -> Dict[str, Any]:
|
||||
partial_results = cast(Dict[str, Any], node.selections)
|
||||
return partial_results
|
||||
|
||||
@staticmethod
|
||||
def in_first_field(path):
|
||||
return path.count("selections") <= 1
|
||||
|
||||
def get_current_result_type(self, path):
|
||||
field_type = self.type_info.get_type()
|
||||
|
||||
list_level = self.inside_list_level
|
||||
|
||||
result_type = _ignore_non_null(field_type)
|
||||
|
||||
if self.in_first_field(path):
|
||||
|
||||
while list_level > 0:
|
||||
assert isinstance(result_type, GraphQLList)
|
||||
result_type = _ignore_non_null(result_type.of_type)
|
||||
|
||||
list_level -= 1
|
||||
|
||||
return result_type
|
||||
|
||||
def enter_field(
|
||||
self,
|
||||
node: FieldNode,
|
||||
key: str,
|
||||
parent: Node,
|
||||
path: List[Node],
|
||||
ancestors: List[Node],
|
||||
) -> Union[None, VisitorActionEnum, Dict[str, Any]]:
|
||||
|
||||
name = node.alias.value if node.alias else node.name.value
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(f"Enter field {name}")
|
||||
log.debug(f" path={path!r}")
|
||||
log.debug(f" current_result={self.current_result!r}")
|
||||
|
||||
if self.current_result is None:
|
||||
# Result was null for this field -> remove
|
||||
return REMOVE
|
||||
|
||||
elif isinstance(self.current_result, Mapping):
|
||||
|
||||
try:
|
||||
result_value = self.current_result[name]
|
||||
except KeyError:
|
||||
# Key not found in result.
|
||||
# Should never happen in theory with a correct GraphQL backend
|
||||
# Silently ignoring this field
|
||||
log.debug(f" Key {name} not found in result --> REMOVE")
|
||||
return REMOVE
|
||||
|
||||
log.debug(f" result_value={result_value}")
|
||||
|
||||
# We get the field_type from type_info
|
||||
field_type = self.type_info.get_type()
|
||||
|
||||
# We calculate a virtual "result type" depending on our recursion level.
|
||||
result_type = self.get_current_result_type(path)
|
||||
|
||||
# If the result for this field is a list, then we need
|
||||
# to recursively visit the same node multiple times for each
|
||||
# item in the list.
|
||||
if (
|
||||
not isinstance(result_value, Mapping)
|
||||
and isinstance(result_value, Iterable)
|
||||
and not isinstance(result_value, str)
|
||||
and not is_leaf_type(result_type)
|
||||
):
|
||||
|
||||
# Finding out the inner type of the list
|
||||
inner_type = _ignore_non_null(result_type.of_type)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(" List detected:")
|
||||
log.debug(f" field_type={inspect(field_type)}")
|
||||
log.debug(f" result_type={inspect(result_type)}")
|
||||
log.debug(f" inner_type={inspect(inner_type)}\n")
|
||||
|
||||
visits: List[Dict[str, Any]] = []
|
||||
|
||||
# Get parent type
|
||||
initial_type = self.type_info.get_parent_type()
|
||||
assert isinstance(
|
||||
initial_type, (GraphQLObjectType, GraphQLInterfaceType)
|
||||
)
|
||||
|
||||
# Get parent SelectionSet node
|
||||
selection_set_node = ancestors[-1]
|
||||
assert isinstance(selection_set_node, SelectionSetNode)
|
||||
|
||||
# Keep only the current node in a new selection set node
|
||||
new_node = SelectionSetNode(selections=[node])
|
||||
|
||||
for item in result_value:
|
||||
|
||||
new_result = {name: item}
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(f" recursive new_result={new_result}")
|
||||
log.debug(f" recursive ast={print_ast(node)}")
|
||||
log.debug(f" recursive path={path!r}")
|
||||
log.debug(f" recursive initial_type={initial_type!r}\n")
|
||||
|
||||
if self.in_first_field(path):
|
||||
inside_list_level = self.inside_list_level + 1
|
||||
else:
|
||||
inside_list_level = 1
|
||||
|
||||
inner_visit = parse_result_recursive(
|
||||
self.schema,
|
||||
self.document,
|
||||
new_node,
|
||||
new_result,
|
||||
initial_type=initial_type,
|
||||
inside_list_level=inside_list_level,
|
||||
)
|
||||
log.debug(f" recursive result={inner_visit}\n")
|
||||
|
||||
inner_visit = cast(List[Dict[str, Any]], inner_visit)
|
||||
visits.append(inner_visit[0][name])
|
||||
|
||||
result_value = {name: visits}
|
||||
log.debug(f" recursive visits final result = {result_value}\n")
|
||||
return result_value
|
||||
|
||||
# If the result for this field is not a list, then add it
|
||||
# to the result stack so that it becomes the current_value
|
||||
# for the next inner fields
|
||||
self.result_stack.append(result_value)
|
||||
|
||||
return IDLE
|
||||
|
||||
raise GraphQLError(
|
||||
f"Invalid result for container of field {name}: {self.current_result!r}"
|
||||
)
|
||||
|
||||
def leave_field(
|
||||
self,
|
||||
node: FieldNode,
|
||||
key: str,
|
||||
parent: Node,
|
||||
path: List[Node],
|
||||
ancestors: List[Node],
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
name = cast(str, node.alias.value if node.alias else node.name.value)
|
||||
|
||||
log.debug(f"Leave field {name}")
|
||||
|
||||
if self.current_result is None:
|
||||
|
||||
return_value = None
|
||||
|
||||
elif node.selection_set is None:
|
||||
|
||||
field_type = self.type_info.get_type()
|
||||
result_type = self.get_current_result_type(path)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(f" field type of {name} is {inspect(field_type)}")
|
||||
log.debug(f" result type of {name} is {inspect(result_type)}")
|
||||
|
||||
assert is_leaf_type(result_type)
|
||||
|
||||
# Finally parsing a single scalar using the parse_value method
|
||||
return_value = result_type.parse_value(self.current_result)
|
||||
else:
|
||||
|
||||
partial_results = cast(List[Dict[str, Any]], node.selection_set)
|
||||
|
||||
return_value = {k: v for pr in partial_results for k, v in pr.items()}
|
||||
|
||||
# Go up a level in the result stack
|
||||
self.result_stack.pop()
|
||||
|
||||
log.debug(f"Leave field {name}: returning {return_value}")
|
||||
|
||||
return {name: return_value}
|
||||
|
||||
# Fragments
|
||||
|
||||
def enter_fragment_definition(
|
||||
self, node: FragmentDefinitionNode, *_args: Any
|
||||
) -> Union[None, VisitorActionEnum]:
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug(f"Enter fragment definition {node.name.value}.")
|
||||
log.debug(f"visit_fragment={self.visit_fragment!s}")
|
||||
|
||||
if self.visit_fragment:
|
||||
return IDLE
|
||||
else:
|
||||
return REMOVE
|
||||
|
||||
@staticmethod
|
||||
def leave_fragment_definition(
|
||||
node: FragmentDefinitionNode, *_args: Any
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
selections = cast(List[Dict[str, Any]], node.selection_set)
|
||||
return {k: v for s in selections for k, v in s.items()}
|
||||
|
||||
def leave_fragment_spread(
|
||||
self, node: FragmentSpreadNode, *_args: Any
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
fragment_name = node.name.value
|
||||
|
||||
log.debug(f"Start recursive fragment visit {fragment_name}")
|
||||
|
||||
fragment_node = _get_fragment(self.document, fragment_name)
|
||||
|
||||
fragment_result = parse_result_recursive(
|
||||
self.schema,
|
||||
self.document,
|
||||
fragment_node,
|
||||
self.current_result,
|
||||
visit_fragment=True,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"Result of recursive fragment visit {fragment_name}: {fragment_result}"
|
||||
)
|
||||
|
||||
return cast(Dict[str, Any], fragment_result)
|
||||
|
||||
@staticmethod
|
||||
def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, Any]:
|
||||
|
||||
selections = cast(List[Dict[str, Any]], node.selection_set)
|
||||
return {k: v for s in selections for k, v in s.items()}
|
||||
|
||||
|
||||
def parse_result_recursive(
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
node: Node,
|
||||
result: Optional[Dict[str, Any]],
|
||||
initial_type: Optional[GraphQLType] = None,
|
||||
inside_list_level: int = 0,
|
||||
visit_fragment: bool = False,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> Any:
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
type_info = TypeInfo(schema, initial_type=initial_type)
|
||||
|
||||
visited = visit(
|
||||
node,
|
||||
TypeInfoVisitor(
|
||||
type_info,
|
||||
ParseResultVisitor(
|
||||
schema,
|
||||
document,
|
||||
node,
|
||||
result,
|
||||
type_info=type_info,
|
||||
inside_list_level=inside_list_level,
|
||||
visit_fragment=visit_fragment,
|
||||
operation_name=operation_name,
|
||||
),
|
||||
),
|
||||
visitor_keys=RESULT_DOCUMENT_KEYS,
|
||||
)
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
def parse_result(
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
result: Optional[Dict[str, Any]],
|
||||
operation_name: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Unserialize a result received from a GraphQL backend.
|
||||
|
||||
:param schema: the GraphQL schema
|
||||
:param document: the document representing the query sent to the backend
|
||||
:param result: the serialized result received from the backend
|
||||
:param operation_name: the optional operation name
|
||||
|
||||
:returns: a parsed result with scalars and enums parsed depending on
|
||||
their definition in the schema.
|
||||
|
||||
Given a schema, a query and a serialized result,
|
||||
provide a new result with parsed values.
|
||||
|
||||
If the result contains only built-in GraphQL scalars (String, Int, Float, ...)
|
||||
then the parsed result should be unchanged.
|
||||
|
||||
If the result contains custom scalars or enums, then those values
|
||||
will be parsed with the parse_value method of the custom scalar or enum
|
||||
definition in the schema."""
|
||||
|
||||
return parse_result_recursive(
|
||||
schema, document, document, result, operation_name=operation_name
|
||||
)
|
||||
@@ -0,0 +1,130 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from graphql import (
|
||||
DocumentNode,
|
||||
GraphQLEnumType,
|
||||
GraphQLError,
|
||||
GraphQLInputObjectType,
|
||||
GraphQLList,
|
||||
GraphQLNonNull,
|
||||
GraphQLScalarType,
|
||||
GraphQLSchema,
|
||||
GraphQLType,
|
||||
GraphQLWrappingType,
|
||||
OperationDefinitionNode,
|
||||
type_from_ast,
|
||||
)
|
||||
from graphql.pyutils import inspect
|
||||
|
||||
|
||||
def _get_document_operation(
|
||||
document: DocumentNode, operation_name: Optional[str] = None
|
||||
) -> OperationDefinitionNode:
|
||||
"""Returns the operation which should be executed in the document.
|
||||
|
||||
Raises a GraphQLError if a single operation cannot be retrieved.
|
||||
"""
|
||||
|
||||
operation: Optional[OperationDefinitionNode] = None
|
||||
|
||||
for definition in document.definitions:
|
||||
if isinstance(definition, OperationDefinitionNode):
|
||||
if operation_name is None:
|
||||
if operation:
|
||||
raise GraphQLError(
|
||||
"Must provide operation name"
|
||||
" if query contains multiple operations."
|
||||
)
|
||||
operation = definition
|
||||
elif definition.name and definition.name.value == operation_name:
|
||||
operation = definition
|
||||
|
||||
if not operation:
|
||||
if operation_name is not None:
|
||||
raise GraphQLError(f"Unknown operation named '{operation_name}'.")
|
||||
|
||||
# The following line should never happen normally as the document is
|
||||
# already verified before calling this function.
|
||||
raise GraphQLError("Must provide an operation.") # pragma: no cover
|
||||
|
||||
return operation
|
||||
|
||||
|
||||
def serialize_value(type_: GraphQLType, value: Any) -> Any:
|
||||
"""Given a GraphQL type and a Python value, return the serialized value.
|
||||
|
||||
This method will serialize the value recursively, entering into
|
||||
lists and dicts.
|
||||
|
||||
Can be used to serialize Enums and/or Custom Scalars in variable values.
|
||||
|
||||
:param type_: the GraphQL type
|
||||
:param value: the provided value
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
if isinstance(type_, GraphQLNonNull):
|
||||
# raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.")
|
||||
raise GraphQLError(f"Type {inspect(type_)} Cannot be None.")
|
||||
else:
|
||||
return None
|
||||
|
||||
if isinstance(type_, GraphQLWrappingType):
|
||||
inner_type = type_.of_type
|
||||
|
||||
if isinstance(type_, GraphQLNonNull):
|
||||
return serialize_value(inner_type, value)
|
||||
|
||||
elif isinstance(type_, GraphQLList):
|
||||
return [serialize_value(inner_type, v) for v in value]
|
||||
|
||||
elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)):
|
||||
return type_.serialize(value)
|
||||
|
||||
elif isinstance(type_, GraphQLInputObjectType):
|
||||
return {
|
||||
field_name: serialize_value(field.type, value[field_name])
|
||||
for field_name, field in type_.fields.items()
|
||||
if field_name in value
|
||||
}
|
||||
|
||||
raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.")
|
||||
|
||||
|
||||
def serialize_variable_values(
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
variable_values: Dict[str, Any],
|
||||
operation_name: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Given a GraphQL document and a schema, serialize the Dictionary of
|
||||
variable values.
|
||||
|
||||
Useful to serialize Enums and/or Custom Scalars in variable values.
|
||||
|
||||
:param schema: the GraphQL schema
|
||||
:param document: the document representing the query sent to the backend
|
||||
:param variable_values: the dictionnary of variable values which needs
|
||||
to be serialized.
|
||||
:param operation_name: the optional operation_name for the query.
|
||||
"""
|
||||
|
||||
parsed_variable_values: Dict[str, Any] = {}
|
||||
|
||||
# Find the operation in the document
|
||||
operation = _get_document_operation(document, operation_name=operation_name)
|
||||
|
||||
# Serialize every variable value defined for the operation
|
||||
for var_def_node in operation.variable_definitions:
|
||||
var_name = var_def_node.variable.name.value
|
||||
var_type = type_from_ast(schema, var_def_node.type)
|
||||
|
||||
if var_name in variable_values:
|
||||
|
||||
assert var_type is not None
|
||||
|
||||
var_value = variable_values[var_name]
|
||||
|
||||
parsed_variable_values[var_name] = serialize_value(var_type, var_value)
|
||||
|
||||
return parsed_variable_values
|
||||
@@ -0,0 +1,69 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Mapping, Type, Union, cast
|
||||
|
||||
from graphql import GraphQLEnumType, GraphQLSchema
|
||||
|
||||
|
||||
def update_schema_enum(
|
||||
schema: GraphQLSchema,
|
||||
name: str,
|
||||
values: Union[Dict[str, Any], Type[Enum]],
|
||||
use_enum_values: bool = False,
|
||||
):
|
||||
"""Update in the schema the GraphQLEnumType corresponding to the given name.
|
||||
|
||||
Example::
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class Color(Enum):
|
||||
RED = 0
|
||||
GREEN = 1
|
||||
BLUE = 2
|
||||
|
||||
update_schema_enum(schema, 'Color', Color)
|
||||
|
||||
:param schema: a GraphQL Schema already containing the GraphQLEnumType type.
|
||||
:param name: the name of the enum in the GraphQL schema
|
||||
:param values: Either a Python Enum or a dict of values. The keys of the provided
|
||||
values should correspond to the keys of the existing enum in the schema.
|
||||
:param use_enum_values: By default, we configure the GraphQLEnumType to serialize
|
||||
to enum instances (ie: .parse_value() returns Color.RED).
|
||||
If use_enum_values is set to True, then .parse_value() returns 0.
|
||||
use_enum_values=True is the defaut behaviour when passing an Enum
|
||||
to a GraphQLEnumType.
|
||||
"""
|
||||
|
||||
# Convert Enum values to Dict
|
||||
if isinstance(values, type):
|
||||
if issubclass(values, Enum):
|
||||
values = cast(Type[Enum], values)
|
||||
if use_enum_values:
|
||||
values = {enum.name: enum.value for enum in values}
|
||||
else:
|
||||
values = {enum.name: enum for enum in values}
|
||||
|
||||
if not isinstance(values, Mapping):
|
||||
raise TypeError(f"Invalid type for enum values: {type(values)}")
|
||||
|
||||
# Find enum type in schema
|
||||
schema_enum = schema.get_type(name)
|
||||
|
||||
if schema_enum is None:
|
||||
raise KeyError(f"Enum {name} not found in schema!")
|
||||
|
||||
if not isinstance(schema_enum, GraphQLEnumType):
|
||||
raise TypeError(
|
||||
f'The type "{name}" is not a GraphQLEnumType, it is a {type(schema_enum)}'
|
||||
)
|
||||
|
||||
# Replace all enum values
|
||||
for enum_name, enum_value in schema_enum.values.items():
|
||||
try:
|
||||
enum_value.value = values[enum_name]
|
||||
except KeyError:
|
||||
raise KeyError(f'Enum key "{enum_name}" not found in provided values!')
|
||||
|
||||
# Delete the _value_lookup cached property
|
||||
if "_value_lookup" in schema_enum.__dict__:
|
||||
del schema_enum.__dict__["_value_lookup"]
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import Iterable, List
|
||||
|
||||
from graphql import GraphQLScalarType, GraphQLSchema
|
||||
|
||||
|
||||
def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType):
|
||||
"""Update the scalar in a schema with the scalar provided.
|
||||
|
||||
:param schema: the GraphQL schema
|
||||
:param name: the name of the custom scalar type in the schema
|
||||
:param scalar: a provided scalar type
|
||||
|
||||
This can be used to update the default Custom Scalar implementation
|
||||
when the schema has been provided from a text file or from introspection.
|
||||
"""
|
||||
|
||||
if not isinstance(scalar, GraphQLScalarType):
|
||||
raise TypeError("Scalars should be instances of GraphQLScalarType.")
|
||||
|
||||
schema_scalar = schema.get_type(name)
|
||||
|
||||
if schema_scalar is None:
|
||||
raise KeyError(f"Scalar '{name}' not found in schema.")
|
||||
|
||||
if not isinstance(schema_scalar, GraphQLScalarType):
|
||||
raise TypeError(
|
||||
f'The type "{name}" is not a GraphQLScalarType,'
|
||||
f" it is a {type(schema_scalar)}"
|
||||
)
|
||||
|
||||
# Update the conversion methods
|
||||
# Using setattr because mypy has a false positive
|
||||
# https://github.com/python/mypy/issues/2427
|
||||
setattr(schema_scalar, "serialize", scalar.serialize)
|
||||
setattr(schema_scalar, "parse_value", scalar.parse_value)
|
||||
setattr(schema_scalar, "parse_literal", scalar.parse_literal)
|
||||
|
||||
|
||||
def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]):
|
||||
"""Update the scalars in a schema with the scalars provided.
|
||||
|
||||
:param schema: the GraphQL schema
|
||||
:param scalars: a list of provided scalar types
|
||||
|
||||
This can be used to update the default Custom Scalar implementation
|
||||
when the schema has been provided from a text file or from introspection.
|
||||
|
||||
If the name of the provided scalar is different than the name of
|
||||
the custom scalar, then you should use the
|
||||
:func:`update_schema_scalar <gql.utilities.update_schema_scalar>` method instead.
|
||||
"""
|
||||
|
||||
if not isinstance(scalars, Iterable):
|
||||
raise TypeError("Scalars argument should be a list of scalars.")
|
||||
|
||||
for scalar in scalars:
|
||||
if not isinstance(scalar, GraphQLScalarType):
|
||||
raise TypeError("Scalars should be instances of GraphQLScalarType.")
|
||||
|
||||
update_schema_scalar(schema, scalar.name, scalar)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Utilities to manipulate several python objects."""
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
|
||||
# From this response in Stackoverflow
|
||||
# http://stackoverflow.com/a/19053800/1072990
|
||||
def to_camel_case(snake_str):
|
||||
components = snake_str.split("_")
|
||||
# We capitalize the first letter of each component except the first one
|
||||
# with the 'title' method and join them together.
|
||||
return components[0] + "".join(x.title() if x else "_" for x in components[1:])
|
||||
|
||||
|
||||
def extract_files(
|
||||
variables: Dict, file_classes: Tuple[Type[Any], ...]
|
||||
) -> Tuple[Dict, Dict]:
|
||||
files = {}
|
||||
|
||||
def recurse_extract(path, obj):
|
||||
"""
|
||||
recursively traverse obj, doing a deepcopy, but
|
||||
replacing any file-like objects with nulls and
|
||||
shunting the originals off to the side.
|
||||
"""
|
||||
nonlocal files
|
||||
if isinstance(obj, list):
|
||||
nulled_obj = []
|
||||
for key, value in enumerate(obj):
|
||||
value = recurse_extract(f"{path}.{key}", value)
|
||||
nulled_obj.append(value)
|
||||
return nulled_obj
|
||||
elif isinstance(obj, dict):
|
||||
nulled_obj = {}
|
||||
for key, value in obj.items():
|
||||
value = recurse_extract(f"{path}.{key}", value)
|
||||
nulled_obj[key] = value
|
||||
return nulled_obj
|
||||
elif isinstance(obj, file_classes):
|
||||
# extract obj from its parent and put it into files instead.
|
||||
files[path] = obj
|
||||
return None
|
||||
else:
|
||||
# base case: pass through unchanged
|
||||
return obj
|
||||
|
||||
nulled_variables = recurse_extract("variables", variables)
|
||||
|
||||
return nulled_variables, files
|
||||
|
||||
|
||||
def str_first_element(errors: List) -> str:
|
||||
try:
|
||||
first_error = errors[0]
|
||||
except (KeyError, TypeError):
|
||||
first_error = errors
|
||||
|
||||
return str(first_error)
|
||||
Reference in New Issue
Block a user