2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,38 @@
"""GraphQL Execution
The :mod:`graphql.execution` package is responsible for the execution phase of
fulfilling a GraphQL request.
"""
from .execute import (
execute,
execute_sync,
default_field_resolver,
default_type_resolver,
ExecutionContext,
ExecutionResult,
FormattedExecutionResult,
Middleware,
)
from .map_async_iterator import MapAsyncIterator
from .subscribe import subscribe, create_source_event_stream
from .middleware import MiddlewareManager
from .values import get_argument_values, get_directive_values, get_variable_values
__all__ = [
"create_source_event_stream",
"execute",
"execute_sync",
"default_field_resolver",
"default_type_resolver",
"subscribe",
"ExecutionContext",
"ExecutionResult",
"FormattedExecutionResult",
"MapAsyncIterator",
"Middleware",
"MiddlewareManager",
"get_argument_values",
"get_directive_values",
"get_variable_values",
]
@@ -0,0 +1,174 @@
from typing import Any, Dict, List, Set, Union, cast
from ..language import (
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
InlineFragmentNode,
SelectionSetNode,
)
from ..type import (
GraphQLAbstractType,
GraphQLIncludeDirective,
GraphQLObjectType,
GraphQLSchema,
GraphQLSkipDirective,
is_abstract_type,
)
from ..utilities.type_from_ast import type_from_ast
from .values import get_directive_values
__all__ = ["collect_fields", "collect_sub_fields"]
def collect_fields(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
) -> Dict[str, List[FieldNode]]:
"""Collect fields.
Given a selection_set, collects all the fields and returns them.
collect_fields requires the "runtime type" of an object. For a field that
returns an Interface or Union type, the "runtime type" will be the actual
object type returned by that field.
For internal use only.
"""
fields: Dict[str, List[FieldNode]] = {}
collect_fields_impl(
schema, fragments, variable_values, runtime_type, selection_set, fields, set()
)
return fields
def collect_sub_fields(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
return_type: GraphQLObjectType,
field_nodes: List[FieldNode],
) -> Dict[str, List[FieldNode]]:
"""Collect sub fields.
Given a list of field nodes, collects all the subfields of the passed in fields,
and returns them at the end.
collect_sub_fields requires the "return type" of an object. For a field that
returns an Interface or Union type, the "return type" will be the actual
object type returned by that field.
For internal use only.
"""
sub_field_nodes: Dict[str, List[FieldNode]] = {}
visited_fragment_names: Set[str] = set()
for node in field_nodes:
if node.selection_set:
collect_fields_impl(
schema,
fragments,
variable_values,
return_type,
node.selection_set,
sub_field_nodes,
visited_fragment_names,
)
return sub_field_nodes
def collect_fields_impl(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
fields: Dict[str, List[FieldNode]],
visited_fragment_names: Set[str],
) -> None:
"""Collect fields (internal implementation)."""
for selection in selection_set.selections:
if isinstance(selection, FieldNode):
if not should_include_node(variable_values, selection):
continue
name = get_field_entry_key(selection)
fields.setdefault(name, []).append(selection)
elif isinstance(selection, InlineFragmentNode):
if not should_include_node(
variable_values, selection
) or not does_fragment_condition_match(schema, selection, runtime_type):
continue
collect_fields_impl(
schema,
fragments,
variable_values,
runtime_type,
selection.selection_set,
fields,
visited_fragment_names,
)
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
frag_name = selection.name.value
if frag_name in visited_fragment_names or not should_include_node(
variable_values, selection
):
continue
visited_fragment_names.add(frag_name)
fragment = fragments.get(frag_name)
if not fragment or not does_fragment_condition_match(
schema, fragment, runtime_type
):
continue
collect_fields_impl(
schema,
fragments,
variable_values,
runtime_type,
fragment.selection_set,
fields,
visited_fragment_names,
)
def should_include_node(
variable_values: Dict[str, Any],
node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode],
) -> bool:
"""Check if node should be included
Determines if a field should be included based on the @include and @skip
directives, where @skip has higher precedence than @include.
"""
skip = get_directive_values(GraphQLSkipDirective, node, variable_values)
if skip and skip["if"]:
return False
include = get_directive_values(GraphQLIncludeDirective, node, variable_values)
if include and not include["if"]:
return False
return True
def does_fragment_condition_match(
schema: GraphQLSchema,
fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
type_: GraphQLObjectType,
) -> bool:
"""Determine if a fragment is applicable to the given type."""
type_condition_node = fragment.type_condition
if not type_condition_node:
return True
conditional_type = type_from_ast(schema, type_condition_node)
if conditional_type is type_:
return True
if is_abstract_type(conditional_type):
return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
return False
def get_field_entry_key(node: FieldNode) -> str:
"""Implements the logic to compute the key of a given field's entry"""
return node.alias.value if node.alias else node.name.value
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,115 @@
from asyncio import CancelledError, Event, Task, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from types import TracebackType
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union
__all__ = ["MapAsyncIterator"]
# noinspection PyAttributeOutsideInit
class MapAsyncIterator:
"""Map an AsyncIterable over a callback function.
Given an AsyncIterable and a callback function, return an AsyncIterator which
produces values mapped via calling the callback function.
When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
be closed.
"""
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
self.iterator = iterable.__aiter__()
self.callback = callback
self._close_event = Event()
def __aiter__(self) -> "MapAsyncIterator":
"""Get the iterator object."""
return self
async def __anext__(self) -> Any:
"""Get the next value of the iterator."""
if self.is_closed:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())
try:
pending: Set[Task] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
await self.aclose()
raise # re-raise the cancellation
for task in pending:
task.cancel()
if aclose.done():
raise StopAsyncIteration
error = anext.exception()
if error:
raise error
value = anext.result()
result = self.callback(value)
return await result if isawaitable(result) else result
async def athrow(
self,
type_: Union[BaseException, Type[BaseException]],
value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Throw an exception into the asynchronous iterator."""
if self.is_closed:
return
if isinstance(type_, BaseException):
value = type_
type_ = type(value)
traceback = value.__traceback__
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_ if value is None else value)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_ # pragma: no cover
value = type_ if isinstance(value, BaseException) else type_()
if traceback is not None:
value = value.with_traceback(traceback)
raise value
async def aclose(self) -> None:
"""Close the iterator."""
if not self.is_closed:
aclose = getattr(self.iterator, "aclose", None)
if aclose:
try:
await aclose()
except RuntimeError:
pass
self.is_closed = True
@property
def is_closed(self) -> bool:
"""Check whether the iterator is closed."""
return self._close_event.is_set()
@is_closed.setter
def is_closed(self, value: bool) -> None:
"""Mark the iterator as closed."""
if value:
self._close_event.set()
else:
self._close_event.clear()
@@ -0,0 +1,63 @@
from functools import partial, reduce
from inspect import isfunction
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
__all__ = ["MiddlewareManager"]
GraphQLFieldResolver = Callable[..., Any]
class MiddlewareManager:
"""Manager for the middleware chain.
This class helps to wrap resolver functions with the provided middleware functions
and/or objects. The functions take the next middleware function as first argument.
If middleware is provided as an object, it must provide a method ``resolve`` that is
used as the middleware function.
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
must be aware of this and check whether values are awaitable before awaiting them.
"""
# allow custom attributes (not used internally)
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
_middleware_resolvers: Optional[List[Callable]]
def __init__(self, *middlewares: Any):
self.middlewares = middlewares
self._middleware_resolvers = (
list(get_middleware_resolvers(middlewares)) if middlewares else None
)
self._cached_resolvers = {}
def get_field_resolver(
self, field_resolver: GraphQLFieldResolver
) -> GraphQLFieldResolver:
"""Wrap the provided resolver with the middleware.
Returns a function that chains the middleware functions with the provided
resolver function.
"""
if self._middleware_resolvers is None:
return field_resolver
if field_resolver not in self._cached_resolvers:
self._cached_resolvers[field_resolver] = reduce(
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
self._middleware_resolvers,
field_resolver,
)
return self._cached_resolvers[field_resolver]
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
"""Get a list of resolver functions from a list of classes or functions."""
for middleware in middlewares:
if isfunction(middleware):
yield middleware
else: # middleware provided as object with 'resolve' method
resolver_func = getattr(middleware, "resolve", None)
if resolver_func is not None:
yield resolver_func
@@ -0,0 +1,212 @@
from inspect import isawaitable
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Dict,
Optional,
Union,
)
from ..error import GraphQLError, located_error
from ..execution.collect_fields import collect_fields
from ..execution.execute import (
assert_valid_execution_arguments,
execute,
get_field_def,
ExecutionContext,
ExecutionResult,
)
from ..execution.values import get_argument_values
from ..language import DocumentNode
from ..pyutils import Path, inspect
from ..type import GraphQLFieldResolver, GraphQLSchema
from .map_async_iterator import MapAsyncIterator
__all__ = ["subscribe", "create_source_event_stream"]
async def subscribe(
schema: GraphQLSchema,
document: DocumentNode,
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
field_resolver: Optional[GraphQLFieldResolver] = None,
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
"""Create a GraphQL subscription.
Implements the "Subscribe" algorithm described in the GraphQL spec.
Returns a coroutine object which yields either an AsyncIterator (if successful) or
an ExecutionResult (client error). The coroutine will raise an exception if a server
error occurs.
If the client-provided arguments to this function do not result in a compliant
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
data will be returned.
If the source stream could not be created due to faulty subscription resolver logic
or underlying systems, the coroutine object will yield a single ExecutionResult
containing ``errors`` and no ``data``.
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
a stream of ExecutionResults representing the response stream.
"""
result_or_stream = await create_source_event_stream(
schema,
document,
root_value,
context_value,
variable_values,
operation_name,
subscribe_field_resolver,
)
if isinstance(result_or_stream, ExecutionResult):
return result_or_stream
async def map_source_to_response(payload: Any) -> ExecutionResult:
"""Map source to response.
For each payload yielded from a subscription, map it over the normal GraphQL
:func:`~graphql.execute` function, with ``payload`` as the ``root_value``.
This implements the "MapSourceToResponseEvent" algorithm described in the
GraphQL specification. The :func:`~graphql.execute` function provides the
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
"""
result = execute(
schema,
document,
payload,
context_value,
variable_values,
operation_name,
field_resolver,
)
return await result if isawaitable(result) else result
# Map every source value to a ExecutionResult value as described above.
return MapAsyncIterator(result_or_stream, map_source_to_response)
async def create_source_event_stream(
schema: GraphQLSchema,
document: DocumentNode,
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
) -> Union[AsyncIterable[Any], ExecutionResult]:
"""Create source event stream
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
specification, resolving the subscription source event stream.
Returns a coroutine that yields an AsyncIterable.
If the client-provided arguments to this function do not result in a compliant
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
data will be returned.
If the source stream could not be created due to faulty subscription resolver logic
or underlying systems, the coroutine object will yield a single ExecutionResult
containing ``errors`` and no ``data``.
A source event stream represents a sequence of events, each of which triggers a
GraphQL execution for that event.
This may be useful when hosting the stateful subscription service in a different
process or machine than the stateless GraphQL execution engine, or otherwise
separating these two steps. For more on this, see the "Supporting Subscriptions
at Scale" information in the GraphQL spec.
"""
# If arguments are missing or incorrectly typed, this is an internal developer
# mistake which should throw an early error.
assert_valid_execution_arguments(schema, document, variable_values)
# If a valid context cannot be created due to incorrect arguments,
# a "Response" with only errors is returned.
context = ExecutionContext.build(
schema,
document,
root_value,
context_value,
variable_values,
operation_name,
subscribe_field_resolver=subscribe_field_resolver,
)
# Return early errors if execution context failed.
if isinstance(context, list):
return ExecutionResult(data=None, errors=context)
try:
event_stream = await execute_subscription(context)
# Assert field returned an event stream, otherwise yield an error.
if not isinstance(event_stream, AsyncIterable):
raise TypeError(
"Subscription field must return AsyncIterable."
f" Received: {inspect(event_stream)}."
)
return event_stream
except GraphQLError as error:
# Report it as an ExecutionResult, containing only errors and no data.
return ExecutionResult(data=None, errors=[error])
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
schema = context.schema
root_type = schema.subscription_type
if root_type is None:
raise GraphQLError(
"Schema is not configured to execute subscription operation.",
context.operation,
)
root_fields = collect_fields(
schema,
context.fragments,
context.variable_values,
root_type,
context.operation.selection_set,
)
response_name, field_nodes = next(iter(root_fields.items()))
field_def = get_field_def(schema, root_type, field_nodes[0])
if not field_def:
field_name = field_nodes[0].name.value
raise GraphQLError(
f"The subscription field '{field_name}' is not defined.", field_nodes
)
path = Path(None, response_name, root_type.name)
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
try:
# Build a dictionary of arguments from the field.arguments AST, using the
# variables scope to fulfill any variable references.
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
# Call the `subscribe()` resolver or the default resolver to produce an
# AsyncIterable yielding raw payloads.
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
event_stream = resolve_fn(context.root_value, info, **args)
if context.is_awaitable(event_stream):
event_stream = await event_stream
if isinstance(event_stream, Exception):
raise event_stream
return event_stream
except Exception as error:
raise located_error(error, field_nodes, path.as_list())
@@ -0,0 +1,251 @@
from typing import Any, Callable, Collection, Dict, List, Optional, Union, cast
from ..error import GraphQLError
from ..language import (
DirectiveNode,
EnumValueDefinitionNode,
ExecutableDefinitionNode,
FieldDefinitionNode,
FieldNode,
InputValueDefinitionNode,
NullValueNode,
SchemaDefinitionNode,
SelectionNode,
TypeDefinitionNode,
TypeExtensionNode,
VariableDefinitionNode,
VariableNode,
print_ast,
)
from ..pyutils import Undefined, inspect, print_path_list
from ..type import (
GraphQLDirective,
GraphQLField,
GraphQLInputType,
GraphQLSchema,
is_input_object_type,
is_input_type,
is_non_null_type,
)
from ..utilities.coerce_input_value import coerce_input_value
from ..utilities.type_from_ast import type_from_ast
from ..utilities.value_from_ast import value_from_ast
__all__ = ["get_argument_values", "get_directive_values", "get_variable_values"]
CoercedVariableValues = Union[List[GraphQLError], Dict[str, Any]]
def get_variable_values(
schema: GraphQLSchema,
var_def_nodes: Collection[VariableDefinitionNode],
inputs: Dict[str, Any],
max_errors: Optional[int] = None,
) -> CoercedVariableValues:
"""Get coerced variable values based on provided definitions.
Prepares a dict of variable values of the correct type based on the provided
variable definitions and arbitrary input. If the input cannot be parsed to match
the variable definitions, a GraphQLError will be raised.
"""
errors: List[GraphQLError] = []
def on_error(error: GraphQLError) -> None:
if max_errors is not None and len(errors) >= max_errors:
raise GraphQLError(
"Too many errors processing variables,"
" error limit reached. Execution aborted."
)
errors.append(error)
try:
coerced = coerce_variable_values(schema, var_def_nodes, inputs, on_error)
if not errors:
return coerced
except GraphQLError as e:
errors.append(e)
return errors
def coerce_variable_values(
schema: GraphQLSchema,
var_def_nodes: Collection[VariableDefinitionNode],
inputs: Dict[str, Any],
on_error: Callable[[GraphQLError], None],
) -> Dict[str, Any]:
coerced_values: Dict[str, Any] = {}
for var_def_node in var_def_nodes:
var_name = var_def_node.variable.name.value
var_type = type_from_ast(schema, var_def_node.type)
if not is_input_type(var_type):
# Must use input types for variables. This should be caught during
# validation, however is checked again here for safety.
var_type_str = print_ast(var_def_node.type)
on_error(
GraphQLError(
f"Variable '${var_name}' expected value of type '{var_type_str}'"
" which cannot be used as an input type.",
var_def_node.type,
)
)
continue
var_type = cast(GraphQLInputType, var_type)
if var_name not in inputs:
if var_def_node.default_value:
coerced_values[var_name] = value_from_ast(
var_def_node.default_value, var_type
)
elif is_non_null_type(var_type): # pragma: no cover else
var_type_str = inspect(var_type)
on_error(
GraphQLError(
f"Variable '${var_name}' of required type '{var_type_str}'"
" was not provided.",
var_def_node,
)
)
continue
value = inputs[var_name]
if value is None and is_non_null_type(var_type):
var_type_str = inspect(var_type)
on_error(
GraphQLError(
f"Variable '${var_name}' of non-null type '{var_type_str}'"
" must not be null.",
var_def_node,
)
)
continue
def on_input_value_error(
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
) -> None:
invalid_str = inspect(invalid_value)
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
if path:
prefix += f" at '{var_name}{print_path_list(path)}'"
on_error(
GraphQLError(
prefix + "; " + error.message,
var_def_node,
original_error=error,
)
)
coerced_values[var_name] = coerce_input_value(
value, var_type, on_input_value_error
)
return coerced_values
def get_argument_values(
type_def: Union[GraphQLField, GraphQLDirective],
node: Union[FieldNode, DirectiveNode],
variable_values: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get coerced argument values based on provided definitions and nodes.
Prepares a dict of argument values given a list of argument definitions and list
of argument AST nodes.
"""
coerced_values: Dict[str, Any] = {}
arg_node_map = {arg.name.value: arg for arg in node.arguments or []}
for name, arg_def in type_def.args.items():
arg_type = arg_def.type
argument_node = arg_node_map.get(name)
if argument_node is None:
value = arg_def.default_value
if value is not Undefined:
if is_input_object_type(arg_def.type):
# coerce input value so that out_names are used
value = coerce_input_value(value, arg_def.type)
coerced_values[arg_def.out_name or name] = value
elif is_non_null_type(arg_type): # pragma: no cover else
raise GraphQLError(
f"Argument '{name}' of required type '{arg_type}'"
" was not provided.",
node,
)
continue # pragma: no cover
value_node = argument_node.value
is_null = isinstance(argument_node.value, NullValueNode)
if isinstance(value_node, VariableNode):
variable_name = value_node.name.value
if variable_values is None or variable_name not in variable_values:
value = arg_def.default_value
if value is not Undefined:
if is_input_object_type(arg_def.type):
# coerce input value so that out_names are used
value = coerce_input_value(value, arg_def.type)
coerced_values[arg_def.out_name or name] = value
elif is_non_null_type(arg_type): # pragma: no cover else
raise GraphQLError(
f"Argument '{name}' of required type '{arg_type}'"
f" was provided the variable '${variable_name}'"
" which was not provided a runtime value.",
value_node,
)
continue # pragma: no cover
is_null = variable_values[variable_name] is None
if is_null and is_non_null_type(arg_type):
raise GraphQLError(
f"Argument '{name}' of non-null type '{arg_type}' must not be null.",
value_node,
)
coerced_value = value_from_ast(value_node, arg_type, variable_values)
if coerced_value is Undefined:
# Note: `values_of_correct_type` validation should catch this before
# execution. This is a runtime check to ensure execution does not
# continue with an invalid argument value.
raise GraphQLError(
f"Argument '{name}' has invalid value {print_ast(value_node)}.",
value_node,
)
coerced_values[arg_def.out_name or name] = coerced_value
return coerced_values
NodeWithDirective = Union[
EnumValueDefinitionNode,
ExecutableDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
SelectionNode,
SchemaDefinitionNode,
TypeDefinitionNode,
TypeExtensionNode,
]
def get_directive_values(
directive_def: GraphQLDirective,
node: NodeWithDirective,
variable_values: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Get coerced argument values based on provided nodes.
Prepares a dict of argument values given a directive definition and an AST node
which may contain directives. Optionally also accepts a dict of variable values.
If the directive does not exist on the node, returns None.
"""
directives = node.directives
if directives:
directive_name = directive_def.name
for directive in directives:
if directive.name.value == directive_name:
return get_argument_values(directive_def, directive, variable_values)
return None