2025-12-01
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
"""GraphQL Execution
|
||||
|
||||
The :mod:`graphql.execution` package is responsible for the execution phase of
|
||||
fulfilling a GraphQL request.
|
||||
"""
|
||||
|
||||
from .execute import (
|
||||
execute,
|
||||
execute_sync,
|
||||
default_field_resolver,
|
||||
default_type_resolver,
|
||||
ExecutionContext,
|
||||
ExecutionResult,
|
||||
FormattedExecutionResult,
|
||||
Middleware,
|
||||
)
|
||||
from .map_async_iterator import MapAsyncIterator
|
||||
from .subscribe import subscribe, create_source_event_stream
|
||||
from .middleware import MiddlewareManager
|
||||
from .values import get_argument_values, get_directive_values, get_variable_values
|
||||
|
||||
__all__ = [
|
||||
"create_source_event_stream",
|
||||
"execute",
|
||||
"execute_sync",
|
||||
"default_field_resolver",
|
||||
"default_type_resolver",
|
||||
"subscribe",
|
||||
"ExecutionContext",
|
||||
"ExecutionResult",
|
||||
"FormattedExecutionResult",
|
||||
"MapAsyncIterator",
|
||||
"Middleware",
|
||||
"MiddlewareManager",
|
||||
"get_argument_values",
|
||||
"get_directive_values",
|
||||
"get_variable_values",
|
||||
]
|
||||
@@ -0,0 +1,174 @@
|
||||
from typing import Any, Dict, List, Set, Union, cast
|
||||
|
||||
from ..language import (
|
||||
FieldNode,
|
||||
FragmentDefinitionNode,
|
||||
FragmentSpreadNode,
|
||||
InlineFragmentNode,
|
||||
SelectionSetNode,
|
||||
)
|
||||
from ..type import (
|
||||
GraphQLAbstractType,
|
||||
GraphQLIncludeDirective,
|
||||
GraphQLObjectType,
|
||||
GraphQLSchema,
|
||||
GraphQLSkipDirective,
|
||||
is_abstract_type,
|
||||
)
|
||||
from ..utilities.type_from_ast import type_from_ast
|
||||
from .values import get_directive_values
|
||||
|
||||
__all__ = ["collect_fields", "collect_sub_fields"]
|
||||
|
||||
|
||||
def collect_fields(
|
||||
schema: GraphQLSchema,
|
||||
fragments: Dict[str, FragmentDefinitionNode],
|
||||
variable_values: Dict[str, Any],
|
||||
runtime_type: GraphQLObjectType,
|
||||
selection_set: SelectionSetNode,
|
||||
) -> Dict[str, List[FieldNode]]:
|
||||
"""Collect fields.
|
||||
|
||||
Given a selection_set, collects all the fields and returns them.
|
||||
|
||||
collect_fields requires the "runtime type" of an object. For a field that
|
||||
returns an Interface or Union type, the "runtime type" will be the actual
|
||||
object type returned by that field.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
fields: Dict[str, List[FieldNode]] = {}
|
||||
collect_fields_impl(
|
||||
schema, fragments, variable_values, runtime_type, selection_set, fields, set()
|
||||
)
|
||||
return fields
|
||||
|
||||
|
||||
def collect_sub_fields(
|
||||
schema: GraphQLSchema,
|
||||
fragments: Dict[str, FragmentDefinitionNode],
|
||||
variable_values: Dict[str, Any],
|
||||
return_type: GraphQLObjectType,
|
||||
field_nodes: List[FieldNode],
|
||||
) -> Dict[str, List[FieldNode]]:
|
||||
"""Collect sub fields.
|
||||
|
||||
Given a list of field nodes, collects all the subfields of the passed in fields,
|
||||
and returns them at the end.
|
||||
|
||||
collect_sub_fields requires the "return type" of an object. For a field that
|
||||
returns an Interface or Union type, the "return type" will be the actual
|
||||
object type returned by that field.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
sub_field_nodes: Dict[str, List[FieldNode]] = {}
|
||||
visited_fragment_names: Set[str] = set()
|
||||
for node in field_nodes:
|
||||
if node.selection_set:
|
||||
collect_fields_impl(
|
||||
schema,
|
||||
fragments,
|
||||
variable_values,
|
||||
return_type,
|
||||
node.selection_set,
|
||||
sub_field_nodes,
|
||||
visited_fragment_names,
|
||||
)
|
||||
return sub_field_nodes
|
||||
|
||||
|
||||
def collect_fields_impl(
|
||||
schema: GraphQLSchema,
|
||||
fragments: Dict[str, FragmentDefinitionNode],
|
||||
variable_values: Dict[str, Any],
|
||||
runtime_type: GraphQLObjectType,
|
||||
selection_set: SelectionSetNode,
|
||||
fields: Dict[str, List[FieldNode]],
|
||||
visited_fragment_names: Set[str],
|
||||
) -> None:
|
||||
"""Collect fields (internal implementation)."""
|
||||
for selection in selection_set.selections:
|
||||
if isinstance(selection, FieldNode):
|
||||
if not should_include_node(variable_values, selection):
|
||||
continue
|
||||
name = get_field_entry_key(selection)
|
||||
fields.setdefault(name, []).append(selection)
|
||||
elif isinstance(selection, InlineFragmentNode):
|
||||
if not should_include_node(
|
||||
variable_values, selection
|
||||
) or not does_fragment_condition_match(schema, selection, runtime_type):
|
||||
continue
|
||||
collect_fields_impl(
|
||||
schema,
|
||||
fragments,
|
||||
variable_values,
|
||||
runtime_type,
|
||||
selection.selection_set,
|
||||
fields,
|
||||
visited_fragment_names,
|
||||
)
|
||||
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
|
||||
frag_name = selection.name.value
|
||||
if frag_name in visited_fragment_names or not should_include_node(
|
||||
variable_values, selection
|
||||
):
|
||||
continue
|
||||
visited_fragment_names.add(frag_name)
|
||||
fragment = fragments.get(frag_name)
|
||||
if not fragment or not does_fragment_condition_match(
|
||||
schema, fragment, runtime_type
|
||||
):
|
||||
continue
|
||||
collect_fields_impl(
|
||||
schema,
|
||||
fragments,
|
||||
variable_values,
|
||||
runtime_type,
|
||||
fragment.selection_set,
|
||||
fields,
|
||||
visited_fragment_names,
|
||||
)
|
||||
|
||||
|
||||
def should_include_node(
|
||||
variable_values: Dict[str, Any],
|
||||
node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode],
|
||||
) -> bool:
|
||||
"""Check if node should be included
|
||||
|
||||
Determines if a field should be included based on the @include and @skip
|
||||
directives, where @skip has higher precedence than @include.
|
||||
"""
|
||||
skip = get_directive_values(GraphQLSkipDirective, node, variable_values)
|
||||
if skip and skip["if"]:
|
||||
return False
|
||||
|
||||
include = get_directive_values(GraphQLIncludeDirective, node, variable_values)
|
||||
if include and not include["if"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def does_fragment_condition_match(
|
||||
schema: GraphQLSchema,
|
||||
fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
|
||||
type_: GraphQLObjectType,
|
||||
) -> bool:
|
||||
"""Determine if a fragment is applicable to the given type."""
|
||||
type_condition_node = fragment.type_condition
|
||||
if not type_condition_node:
|
||||
return True
|
||||
conditional_type = type_from_ast(schema, type_condition_node)
|
||||
if conditional_type is type_:
|
||||
return True
|
||||
if is_abstract_type(conditional_type):
|
||||
return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
|
||||
return False
|
||||
|
||||
|
||||
def get_field_entry_key(node: FieldNode) -> str:
|
||||
"""Implements the logic to compute the key of a given field's entry"""
|
||||
return node.alias.value if node.alias else node.name.value
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,115 @@
|
||||
from asyncio import CancelledError, Event, Task, ensure_future, wait
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from inspect import isasyncgen, isawaitable
|
||||
from types import TracebackType
|
||||
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union
|
||||
|
||||
__all__ = ["MapAsyncIterator"]
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class MapAsyncIterator:
|
||||
"""Map an AsyncIterable over a callback function.
|
||||
|
||||
Given an AsyncIterable and a callback function, return an AsyncIterator which
|
||||
produces values mapped via calling the callback function.
|
||||
|
||||
When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
|
||||
be closed.
|
||||
"""
|
||||
|
||||
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
|
||||
self.iterator = iterable.__aiter__()
|
||||
self.callback = callback
|
||||
self._close_event = Event()
|
||||
|
||||
def __aiter__(self) -> "MapAsyncIterator":
|
||||
"""Get the iterator object."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
"""Get the next value of the iterator."""
|
||||
if self.is_closed:
|
||||
if not isasyncgen(self.iterator):
|
||||
raise StopAsyncIteration
|
||||
value = await self.iterator.__anext__()
|
||||
else:
|
||||
aclose = ensure_future(self._close_event.wait())
|
||||
anext = ensure_future(self.iterator.__anext__())
|
||||
|
||||
try:
|
||||
pending: Set[Task] = (
|
||||
await wait([aclose, anext], return_when=FIRST_COMPLETED)
|
||||
)[1]
|
||||
except CancelledError:
|
||||
# cancel underlying tasks and close
|
||||
aclose.cancel()
|
||||
anext.cancel()
|
||||
await self.aclose()
|
||||
raise # re-raise the cancellation
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if aclose.done():
|
||||
raise StopAsyncIteration
|
||||
|
||||
error = anext.exception()
|
||||
if error:
|
||||
raise error
|
||||
|
||||
value = anext.result()
|
||||
|
||||
result = self.callback(value)
|
||||
|
||||
return await result if isawaitable(result) else result
|
||||
|
||||
async def athrow(
|
||||
self,
|
||||
type_: Union[BaseException, Type[BaseException]],
|
||||
value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
"""Throw an exception into the asynchronous iterator."""
|
||||
if self.is_closed:
|
||||
return
|
||||
if isinstance(type_, BaseException):
|
||||
value = type_
|
||||
type_ = type(value)
|
||||
traceback = value.__traceback__
|
||||
athrow = getattr(self.iterator, "athrow", None)
|
||||
if athrow:
|
||||
await athrow(type_ if value is None else value)
|
||||
else:
|
||||
await self.aclose()
|
||||
if value is None:
|
||||
if traceback is None:
|
||||
raise type_ # pragma: no cover
|
||||
value = type_ if isinstance(value, BaseException) else type_()
|
||||
if traceback is not None:
|
||||
value = value.with_traceback(traceback)
|
||||
raise value
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the iterator."""
|
||||
if not self.is_closed:
|
||||
aclose = getattr(self.iterator, "aclose", None)
|
||||
if aclose:
|
||||
try:
|
||||
await aclose()
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check whether the iterator is closed."""
|
||||
return self._close_event.is_set()
|
||||
|
||||
@is_closed.setter
|
||||
def is_closed(self, value: bool) -> None:
|
||||
"""Mark the iterator as closed."""
|
||||
if value:
|
||||
self._close_event.set()
|
||||
else:
|
||||
self._close_event.clear()
|
||||
@@ -0,0 +1,63 @@
|
||||
from functools import partial, reduce
|
||||
from inspect import isfunction
|
||||
|
||||
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
|
||||
|
||||
__all__ = ["MiddlewareManager"]
|
||||
|
||||
GraphQLFieldResolver = Callable[..., Any]
|
||||
|
||||
|
||||
class MiddlewareManager:
|
||||
"""Manager for the middleware chain.
|
||||
|
||||
This class helps to wrap resolver functions with the provided middleware functions
|
||||
and/or objects. The functions take the next middleware function as first argument.
|
||||
If middleware is provided as an object, it must provide a method ``resolve`` that is
|
||||
used as the middleware function.
|
||||
|
||||
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
|
||||
must be aware of this and check whether values are awaitable before awaiting them.
|
||||
"""
|
||||
|
||||
# allow custom attributes (not used internally)
|
||||
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
|
||||
|
||||
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
|
||||
_middleware_resolvers: Optional[List[Callable]]
|
||||
|
||||
def __init__(self, *middlewares: Any):
|
||||
self.middlewares = middlewares
|
||||
self._middleware_resolvers = (
|
||||
list(get_middleware_resolvers(middlewares)) if middlewares else None
|
||||
)
|
||||
self._cached_resolvers = {}
|
||||
|
||||
def get_field_resolver(
|
||||
self, field_resolver: GraphQLFieldResolver
|
||||
) -> GraphQLFieldResolver:
|
||||
"""Wrap the provided resolver with the middleware.
|
||||
|
||||
Returns a function that chains the middleware functions with the provided
|
||||
resolver function.
|
||||
"""
|
||||
if self._middleware_resolvers is None:
|
||||
return field_resolver
|
||||
if field_resolver not in self._cached_resolvers:
|
||||
self._cached_resolvers[field_resolver] = reduce(
|
||||
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
|
||||
self._middleware_resolvers,
|
||||
field_resolver,
|
||||
)
|
||||
return self._cached_resolvers[field_resolver]
|
||||
|
||||
|
||||
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
|
||||
"""Get a list of resolver functions from a list of classes or functions."""
|
||||
for middleware in middlewares:
|
||||
if isfunction(middleware):
|
||||
yield middleware
|
||||
else: # middleware provided as object with 'resolve' method
|
||||
resolver_func = getattr(middleware, "resolve", None)
|
||||
if resolver_func is not None:
|
||||
yield resolver_func
|
||||
@@ -0,0 +1,212 @@
|
||||
from inspect import isawaitable
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..error import GraphQLError, located_error
|
||||
from ..execution.collect_fields import collect_fields
|
||||
from ..execution.execute import (
|
||||
assert_valid_execution_arguments,
|
||||
execute,
|
||||
get_field_def,
|
||||
ExecutionContext,
|
||||
ExecutionResult,
|
||||
)
|
||||
from ..execution.values import get_argument_values
|
||||
from ..language import DocumentNode
|
||||
from ..pyutils import Path, inspect
|
||||
from ..type import GraphQLFieldResolver, GraphQLSchema
|
||||
from .map_async_iterator import MapAsyncIterator
|
||||
|
||||
__all__ = ["subscribe", "create_source_event_stream"]
|
||||
|
||||
|
||||
async def subscribe(
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
root_value: Any = None,
|
||||
context_value: Any = None,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
field_resolver: Optional[GraphQLFieldResolver] = None,
|
||||
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
|
||||
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
|
||||
"""Create a GraphQL subscription.
|
||||
|
||||
Implements the "Subscribe" algorithm described in the GraphQL spec.
|
||||
|
||||
Returns a coroutine object which yields either an AsyncIterator (if successful) or
|
||||
an ExecutionResult (client error). The coroutine will raise an exception if a server
|
||||
error occurs.
|
||||
|
||||
If the client-provided arguments to this function do not result in a compliant
|
||||
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
|
||||
data will be returned.
|
||||
|
||||
If the source stream could not be created due to faulty subscription resolver logic
|
||||
or underlying systems, the coroutine object will yield a single ExecutionResult
|
||||
containing ``errors`` and no ``data``.
|
||||
|
||||
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
|
||||
a stream of ExecutionResults representing the response stream.
|
||||
"""
|
||||
result_or_stream = await create_source_event_stream(
|
||||
schema,
|
||||
document,
|
||||
root_value,
|
||||
context_value,
|
||||
variable_values,
|
||||
operation_name,
|
||||
subscribe_field_resolver,
|
||||
)
|
||||
if isinstance(result_or_stream, ExecutionResult):
|
||||
return result_or_stream
|
||||
|
||||
async def map_source_to_response(payload: Any) -> ExecutionResult:
|
||||
"""Map source to response.
|
||||
|
||||
For each payload yielded from a subscription, map it over the normal GraphQL
|
||||
:func:`~graphql.execute` function, with ``payload`` as the ``root_value``.
|
||||
This implements the "MapSourceToResponseEvent" algorithm described in the
|
||||
GraphQL specification. The :func:`~graphql.execute` function provides the
|
||||
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
|
||||
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
|
||||
"""
|
||||
result = execute(
|
||||
schema,
|
||||
document,
|
||||
payload,
|
||||
context_value,
|
||||
variable_values,
|
||||
operation_name,
|
||||
field_resolver,
|
||||
)
|
||||
return await result if isawaitable(result) else result
|
||||
|
||||
# Map every source value to a ExecutionResult value as described above.
|
||||
return MapAsyncIterator(result_or_stream, map_source_to_response)
|
||||
|
||||
|
||||
async def create_source_event_stream(
|
||||
schema: GraphQLSchema,
|
||||
document: DocumentNode,
|
||||
root_value: Any = None,
|
||||
context_value: Any = None,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
|
||||
) -> Union[AsyncIterable[Any], ExecutionResult]:
|
||||
"""Create source event stream
|
||||
|
||||
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
|
||||
specification, resolving the subscription source event stream.
|
||||
|
||||
Returns a coroutine that yields an AsyncIterable.
|
||||
|
||||
If the client-provided arguments to this function do not result in a compliant
|
||||
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
|
||||
data will be returned.
|
||||
|
||||
If the source stream could not be created due to faulty subscription resolver logic
|
||||
or underlying systems, the coroutine object will yield a single ExecutionResult
|
||||
containing ``errors`` and no ``data``.
|
||||
|
||||
A source event stream represents a sequence of events, each of which triggers a
|
||||
GraphQL execution for that event.
|
||||
|
||||
This may be useful when hosting the stateful subscription service in a different
|
||||
process or machine than the stateless GraphQL execution engine, or otherwise
|
||||
separating these two steps. For more on this, see the "Supporting Subscriptions
|
||||
at Scale" information in the GraphQL spec.
|
||||
"""
|
||||
# If arguments are missing or incorrectly typed, this is an internal developer
|
||||
# mistake which should throw an early error.
|
||||
assert_valid_execution_arguments(schema, document, variable_values)
|
||||
|
||||
# If a valid context cannot be created due to incorrect arguments,
|
||||
# a "Response" with only errors is returned.
|
||||
context = ExecutionContext.build(
|
||||
schema,
|
||||
document,
|
||||
root_value,
|
||||
context_value,
|
||||
variable_values,
|
||||
operation_name,
|
||||
subscribe_field_resolver=subscribe_field_resolver,
|
||||
)
|
||||
|
||||
# Return early errors if execution context failed.
|
||||
if isinstance(context, list):
|
||||
return ExecutionResult(data=None, errors=context)
|
||||
|
||||
try:
|
||||
event_stream = await execute_subscription(context)
|
||||
|
||||
# Assert field returned an event stream, otherwise yield an error.
|
||||
if not isinstance(event_stream, AsyncIterable):
|
||||
raise TypeError(
|
||||
"Subscription field must return AsyncIterable."
|
||||
f" Received: {inspect(event_stream)}."
|
||||
)
|
||||
return event_stream
|
||||
|
||||
except GraphQLError as error:
|
||||
# Report it as an ExecutionResult, containing only errors and no data.
|
||||
return ExecutionResult(data=None, errors=[error])
|
||||
|
||||
|
||||
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
|
||||
schema = context.schema
|
||||
|
||||
root_type = schema.subscription_type
|
||||
if root_type is None:
|
||||
raise GraphQLError(
|
||||
"Schema is not configured to execute subscription operation.",
|
||||
context.operation,
|
||||
)
|
||||
|
||||
root_fields = collect_fields(
|
||||
schema,
|
||||
context.fragments,
|
||||
context.variable_values,
|
||||
root_type,
|
||||
context.operation.selection_set,
|
||||
)
|
||||
response_name, field_nodes = next(iter(root_fields.items()))
|
||||
field_def = get_field_def(schema, root_type, field_nodes[0])
|
||||
|
||||
if not field_def:
|
||||
field_name = field_nodes[0].name.value
|
||||
raise GraphQLError(
|
||||
f"The subscription field '{field_name}' is not defined.", field_nodes
|
||||
)
|
||||
|
||||
path = Path(None, response_name, root_type.name)
|
||||
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
|
||||
|
||||
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
|
||||
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
|
||||
|
||||
try:
|
||||
# Build a dictionary of arguments from the field.arguments AST, using the
|
||||
# variables scope to fulfill any variable references.
|
||||
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
|
||||
|
||||
# Call the `subscribe()` resolver or the default resolver to produce an
|
||||
# AsyncIterable yielding raw payloads.
|
||||
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
|
||||
|
||||
event_stream = resolve_fn(context.root_value, info, **args)
|
||||
if context.is_awaitable(event_stream):
|
||||
event_stream = await event_stream
|
||||
if isinstance(event_stream, Exception):
|
||||
raise event_stream
|
||||
|
||||
return event_stream
|
||||
except Exception as error:
|
||||
raise located_error(error, field_nodes, path.as_list())
|
||||
@@ -0,0 +1,251 @@
|
||||
from typing import Any, Callable, Collection, Dict, List, Optional, Union, cast
|
||||
|
||||
from ..error import GraphQLError
|
||||
from ..language import (
|
||||
DirectiveNode,
|
||||
EnumValueDefinitionNode,
|
||||
ExecutableDefinitionNode,
|
||||
FieldDefinitionNode,
|
||||
FieldNode,
|
||||
InputValueDefinitionNode,
|
||||
NullValueNode,
|
||||
SchemaDefinitionNode,
|
||||
SelectionNode,
|
||||
TypeDefinitionNode,
|
||||
TypeExtensionNode,
|
||||
VariableDefinitionNode,
|
||||
VariableNode,
|
||||
print_ast,
|
||||
)
|
||||
from ..pyutils import Undefined, inspect, print_path_list
|
||||
from ..type import (
|
||||
GraphQLDirective,
|
||||
GraphQLField,
|
||||
GraphQLInputType,
|
||||
GraphQLSchema,
|
||||
is_input_object_type,
|
||||
is_input_type,
|
||||
is_non_null_type,
|
||||
)
|
||||
from ..utilities.coerce_input_value import coerce_input_value
|
||||
from ..utilities.type_from_ast import type_from_ast
|
||||
from ..utilities.value_from_ast import value_from_ast
|
||||
|
||||
__all__ = ["get_argument_values", "get_directive_values", "get_variable_values"]
|
||||
|
||||
|
||||
CoercedVariableValues = Union[List[GraphQLError], Dict[str, Any]]
|
||||
|
||||
|
||||
def get_variable_values(
|
||||
schema: GraphQLSchema,
|
||||
var_def_nodes: Collection[VariableDefinitionNode],
|
||||
inputs: Dict[str, Any],
|
||||
max_errors: Optional[int] = None,
|
||||
) -> CoercedVariableValues:
|
||||
"""Get coerced variable values based on provided definitions.
|
||||
|
||||
Prepares a dict of variable values of the correct type based on the provided
|
||||
variable definitions and arbitrary input. If the input cannot be parsed to match
|
||||
the variable definitions, a GraphQLError will be raised.
|
||||
"""
|
||||
errors: List[GraphQLError] = []
|
||||
|
||||
def on_error(error: GraphQLError) -> None:
|
||||
if max_errors is not None and len(errors) >= max_errors:
|
||||
raise GraphQLError(
|
||||
"Too many errors processing variables,"
|
||||
" error limit reached. Execution aborted."
|
||||
)
|
||||
errors.append(error)
|
||||
|
||||
try:
|
||||
coerced = coerce_variable_values(schema, var_def_nodes, inputs, on_error)
|
||||
if not errors:
|
||||
return coerced
|
||||
except GraphQLError as e:
|
||||
errors.append(e)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def coerce_variable_values(
|
||||
schema: GraphQLSchema,
|
||||
var_def_nodes: Collection[VariableDefinitionNode],
|
||||
inputs: Dict[str, Any],
|
||||
on_error: Callable[[GraphQLError], None],
|
||||
) -> Dict[str, Any]:
|
||||
coerced_values: Dict[str, Any] = {}
|
||||
for var_def_node in var_def_nodes:
|
||||
var_name = var_def_node.variable.name.value
|
||||
var_type = type_from_ast(schema, var_def_node.type)
|
||||
if not is_input_type(var_type):
|
||||
# Must use input types for variables. This should be caught during
|
||||
# validation, however is checked again here for safety.
|
||||
var_type_str = print_ast(var_def_node.type)
|
||||
on_error(
|
||||
GraphQLError(
|
||||
f"Variable '${var_name}' expected value of type '{var_type_str}'"
|
||||
" which cannot be used as an input type.",
|
||||
var_def_node.type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
var_type = cast(GraphQLInputType, var_type)
|
||||
if var_name not in inputs:
|
||||
if var_def_node.default_value:
|
||||
coerced_values[var_name] = value_from_ast(
|
||||
var_def_node.default_value, var_type
|
||||
)
|
||||
elif is_non_null_type(var_type): # pragma: no cover else
|
||||
var_type_str = inspect(var_type)
|
||||
on_error(
|
||||
GraphQLError(
|
||||
f"Variable '${var_name}' of required type '{var_type_str}'"
|
||||
" was not provided.",
|
||||
var_def_node,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
value = inputs[var_name]
|
||||
if value is None and is_non_null_type(var_type):
|
||||
var_type_str = inspect(var_type)
|
||||
on_error(
|
||||
GraphQLError(
|
||||
f"Variable '${var_name}' of non-null type '{var_type_str}'"
|
||||
" must not be null.",
|
||||
var_def_node,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
def on_input_value_error(
|
||||
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
|
||||
) -> None:
|
||||
invalid_str = inspect(invalid_value)
|
||||
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
|
||||
if path:
|
||||
prefix += f" at '{var_name}{print_path_list(path)}'"
|
||||
on_error(
|
||||
GraphQLError(
|
||||
prefix + "; " + error.message,
|
||||
var_def_node,
|
||||
original_error=error,
|
||||
)
|
||||
)
|
||||
|
||||
coerced_values[var_name] = coerce_input_value(
|
||||
value, var_type, on_input_value_error
|
||||
)
|
||||
|
||||
return coerced_values
|
||||
|
||||
|
||||
def get_argument_values(
|
||||
type_def: Union[GraphQLField, GraphQLDirective],
|
||||
node: Union[FieldNode, DirectiveNode],
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get coerced argument values based on provided definitions and nodes.
|
||||
|
||||
Prepares a dict of argument values given a list of argument definitions and list
|
||||
of argument AST nodes.
|
||||
"""
|
||||
coerced_values: Dict[str, Any] = {}
|
||||
arg_node_map = {arg.name.value: arg for arg in node.arguments or []}
|
||||
|
||||
for name, arg_def in type_def.args.items():
|
||||
arg_type = arg_def.type
|
||||
argument_node = arg_node_map.get(name)
|
||||
|
||||
if argument_node is None:
|
||||
value = arg_def.default_value
|
||||
if value is not Undefined:
|
||||
if is_input_object_type(arg_def.type):
|
||||
# coerce input value so that out_names are used
|
||||
value = coerce_input_value(value, arg_def.type)
|
||||
|
||||
coerced_values[arg_def.out_name or name] = value
|
||||
elif is_non_null_type(arg_type): # pragma: no cover else
|
||||
raise GraphQLError(
|
||||
f"Argument '{name}' of required type '{arg_type}'"
|
||||
" was not provided.",
|
||||
node,
|
||||
)
|
||||
continue # pragma: no cover
|
||||
|
||||
value_node = argument_node.value
|
||||
is_null = isinstance(argument_node.value, NullValueNode)
|
||||
|
||||
if isinstance(value_node, VariableNode):
|
||||
variable_name = value_node.name.value
|
||||
if variable_values is None or variable_name not in variable_values:
|
||||
value = arg_def.default_value
|
||||
if value is not Undefined:
|
||||
if is_input_object_type(arg_def.type):
|
||||
# coerce input value so that out_names are used
|
||||
value = coerce_input_value(value, arg_def.type)
|
||||
coerced_values[arg_def.out_name or name] = value
|
||||
elif is_non_null_type(arg_type): # pragma: no cover else
|
||||
raise GraphQLError(
|
||||
f"Argument '{name}' of required type '{arg_type}'"
|
||||
f" was provided the variable '${variable_name}'"
|
||||
" which was not provided a runtime value.",
|
||||
value_node,
|
||||
)
|
||||
continue # pragma: no cover
|
||||
is_null = variable_values[variable_name] is None
|
||||
|
||||
if is_null and is_non_null_type(arg_type):
|
||||
raise GraphQLError(
|
||||
f"Argument '{name}' of non-null type '{arg_type}' must not be null.",
|
||||
value_node,
|
||||
)
|
||||
|
||||
coerced_value = value_from_ast(value_node, arg_type, variable_values)
|
||||
if coerced_value is Undefined:
|
||||
# Note: `values_of_correct_type` validation should catch this before
|
||||
# execution. This is a runtime check to ensure execution does not
|
||||
# continue with an invalid argument value.
|
||||
raise GraphQLError(
|
||||
f"Argument '{name}' has invalid value {print_ast(value_node)}.",
|
||||
value_node,
|
||||
)
|
||||
coerced_values[arg_def.out_name or name] = coerced_value
|
||||
|
||||
return coerced_values
|
||||
|
||||
|
||||
NodeWithDirective = Union[
|
||||
EnumValueDefinitionNode,
|
||||
ExecutableDefinitionNode,
|
||||
FieldDefinitionNode,
|
||||
InputValueDefinitionNode,
|
||||
SelectionNode,
|
||||
SchemaDefinitionNode,
|
||||
TypeDefinitionNode,
|
||||
TypeExtensionNode,
|
||||
]
|
||||
|
||||
|
||||
def get_directive_values(
|
||||
directive_def: GraphQLDirective,
|
||||
node: NodeWithDirective,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get coerced argument values based on provided nodes.
|
||||
|
||||
Prepares a dict of argument values given a directive definition and an AST node
|
||||
which may contain directives. Optionally also accepts a dict of variable values.
|
||||
|
||||
If the directive does not exist on the node, returns None.
|
||||
"""
|
||||
directives = node.directives
|
||||
if directives:
|
||||
directive_name = directive_def.name
|
||||
for directive in directives:
|
||||
if directive.name.value == directive_name:
|
||||
return get_argument_values(directive_def, directive, variable_values)
|
||||
return None
|
||||
Reference in New Issue
Block a user