2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,780 @@
"""GraphQL-core
The primary :mod:`graphql` package includes everything you need to define a GraphQL
schema and fulfill GraphQL requests.
GraphQL-core provides a reference implementation for the GraphQL specification
but is also a useful utility for operating on GraphQL files and building sophisticated
tools.
This top-level package exports a general purpose function for fulfilling all steps
of the GraphQL specification in a single operation, but also includes utilities
for every part of the GraphQL specification:
- Parsing the GraphQL language.
- Building a GraphQL type schema.
- Validating a GraphQL request against a type schema.
- Executing a GraphQL request against a type schema.
This also includes utility functions for operating on GraphQL types and GraphQL
documents to facilitate building tools.
You may also import from each sub-package directly. For example, the following two
import statements are equivalent::
from graphql import parse
from graphql.language import parse
The sub-packages of GraphQL-core 3 are:
- :mod:`graphql.language`: Parse and operate on the GraphQL language.
- :mod:`graphql.type`: Define GraphQL types and schema.
- :mod:`graphql.validation`: The Validation phase of fulfilling a GraphQL result.
- :mod:`graphql.execution`: The Execution phase of fulfilling a GraphQL request.
- :mod:`graphql.error`: Creating and formatting GraphQL errors.
- :mod:`graphql.utilities`:
Common useful computations upon the GraphQL language and type objects.
"""
# The GraphQL-core 3 and GraphQL.js version info.
from .version import version, version_info, version_js, version_info_js
# Utilities for compatibility with the Python language.
from .pyutils import Undefined, UndefinedType
# Create, format, and print GraphQL errors.
from .error import (
GraphQLError,
GraphQLErrorExtensions,
GraphQLFormattedError,
GraphQLSyntaxError,
located_error,
)
# Parse and operate on GraphQL language source files.
from .language import (
Source,
get_location,
# Print source location
print_location,
print_source_location,
# Lex
Lexer,
TokenKind,
# Parse
parse,
parse_value,
parse_const_value,
parse_type,
# Print
print_ast,
# Visit
visit,
ParallelVisitor,
Visitor,
VisitorAction,
VisitorKeyMap,
BREAK,
SKIP,
REMOVE,
IDLE,
DirectiveLocation,
# Predicates
is_definition_node,
is_executable_definition_node,
is_selection_node,
is_value_node,
is_const_value_node,
is_type_node,
is_type_system_definition_node,
is_type_definition_node,
is_type_system_extension_node,
is_type_extension_node,
# Types
SourceLocation,
Location,
Token,
# AST nodes
Node,
# Each kind of AST node
NameNode,
DocumentNode,
DefinitionNode,
ExecutableDefinitionNode,
OperationDefinitionNode,
OperationType,
VariableDefinitionNode,
VariableNode,
SelectionSetNode,
SelectionNode,
FieldNode,
ArgumentNode,
ConstArgumentNode,
FragmentSpreadNode,
InlineFragmentNode,
FragmentDefinitionNode,
ValueNode,
ConstValueNode,
IntValueNode,
FloatValueNode,
StringValueNode,
BooleanValueNode,
NullValueNode,
EnumValueNode,
ListValueNode,
ConstListValueNode,
ObjectValueNode,
ConstObjectValueNode,
ObjectFieldNode,
ConstObjectFieldNode,
DirectiveNode,
ConstDirectiveNode,
TypeNode,
NamedTypeNode,
ListTypeNode,
NonNullTypeNode,
TypeSystemDefinitionNode,
SchemaDefinitionNode,
OperationTypeDefinitionNode,
TypeDefinitionNode,
ScalarTypeDefinitionNode,
ObjectTypeDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
UnionTypeDefinitionNode,
EnumTypeDefinitionNode,
EnumValueDefinitionNode,
InputObjectTypeDefinitionNode,
DirectiveDefinitionNode,
TypeSystemExtensionNode,
SchemaExtensionNode,
TypeExtensionNode,
ScalarTypeExtensionNode,
ObjectTypeExtensionNode,
InterfaceTypeExtensionNode,
UnionTypeExtensionNode,
EnumTypeExtensionNode,
InputObjectTypeExtensionNode,
)
# Utilities for operating on GraphQL type schema and parsed sources.
from .utilities import (
# Produce the GraphQL query recommended for a full schema introspection.
# Accepts optional IntrospectionOptions.
get_introspection_query,
IntrospectionQuery,
# Get the target Operation from a Document.
get_operation_ast,
# Get the Type for the target Operation AST.
get_operation_root_type,
# Convert a GraphQLSchema to an IntrospectionQuery.
introspection_from_schema,
# Build a GraphQLSchema from an introspection result.
build_client_schema,
# Build a GraphQLSchema from a parsed GraphQL Schema language AST.
build_ast_schema,
# Build a GraphQLSchema from a GraphQL schema language document.
build_schema,
# Extend an existing GraphQLSchema from a parsed GraphQL Schema language AST.
extend_schema,
# Sort a GraphQLSchema.
lexicographic_sort_schema,
# Print a GraphQLSchema to GraphQL Schema language.
print_schema,
# Print a GraphQLType to GraphQL Schema language.
print_type,
# Prints the built-in introspection schema in the Schema Language format.
print_introspection_schema,
# Create a GraphQLType from a GraphQL language AST.
type_from_ast,
# Convert a language AST to a dictionary.
ast_to_dict,
# Create a Python value from a GraphQL language AST with a Type.
value_from_ast,
# Create a Python value from a GraphQL language AST without a Type.
value_from_ast_untyped,
# Create a GraphQL language AST from a Python value.
ast_from_value,
# A helper to use within recursive-descent visitors which need to be aware of the
# GraphQL type system.
TypeInfo,
TypeInfoVisitor,
# Coerce a Python value to a GraphQL type, or produce errors.
coerce_input_value,
# Concatenates multiple ASTs together.
concat_ast,
# Separate an AST into an AST per Operation.
separate_operations,
# Strip characters that are not significant to the validity or execution
# of a GraphQL document.
strip_ignored_characters,
# Comparators for types
is_equal_type,
is_type_sub_type_of,
do_types_overlap,
# Assert a string is a valid GraphQL name.
assert_valid_name,
# Determine if a string is a valid GraphQL name.
is_valid_name_error,
# Compare two GraphQLSchemas and detect breaking changes.
BreakingChange,
BreakingChangeType,
DangerousChange,
DangerousChangeType,
find_breaking_changes,
find_dangerous_changes,
)
# Create and operate on GraphQL type definitions and schema.
from .type import (
# Definitions
GraphQLSchema,
GraphQLDirective,
GraphQLScalarType,
GraphQLObjectType,
GraphQLInterfaceType,
GraphQLUnionType,
GraphQLEnumType,
GraphQLInputObjectType,
GraphQLList,
GraphQLNonNull,
# Standard GraphQL Scalars
specified_scalar_types,
GraphQLInt,
GraphQLFloat,
GraphQLString,
GraphQLBoolean,
GraphQLID,
# Int boundaries constants
GRAPHQL_MAX_INT,
GRAPHQL_MIN_INT,
# Built-in Directives defined by the Spec
specified_directives,
GraphQLIncludeDirective,
GraphQLSkipDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
# "Enum" of Type Kinds
TypeKind,
# Constant Deprecation Reason
DEFAULT_DEPRECATION_REASON,
# GraphQL Types for introspection.
introspection_types,
# Meta-field definitions.
SchemaMetaFieldDef,
TypeMetaFieldDef,
TypeNameMetaFieldDef,
# Predicates
is_schema,
is_directive,
is_type,
is_scalar_type,
is_object_type,
is_interface_type,
is_union_type,
is_enum_type,
is_input_object_type,
is_list_type,
is_non_null_type,
is_input_type,
is_output_type,
is_leaf_type,
is_composite_type,
is_abstract_type,
is_wrapping_type,
is_nullable_type,
is_named_type,
is_required_argument,
is_required_input_field,
is_specified_scalar_type,
is_introspection_type,
is_specified_directive,
# Assertions
assert_schema,
assert_directive,
assert_type,
assert_scalar_type,
assert_object_type,
assert_interface_type,
assert_union_type,
assert_enum_type,
assert_input_object_type,
assert_list_type,
assert_non_null_type,
assert_input_type,
assert_output_type,
assert_leaf_type,
assert_composite_type,
assert_abstract_type,
assert_wrapping_type,
assert_nullable_type,
assert_named_type,
# Un-modifiers
get_nullable_type,
get_named_type,
# Thunk handling
resolve_thunk,
# Validate GraphQL schema.
validate_schema,
assert_valid_schema,
# Uphold the spec rules about naming
assert_name,
assert_enum_value_name,
# Types
GraphQLType,
GraphQLInputType,
GraphQLOutputType,
GraphQLLeafType,
GraphQLCompositeType,
GraphQLAbstractType,
GraphQLWrappingType,
GraphQLNullableType,
GraphQLNamedType,
GraphQLNamedInputType,
GraphQLNamedOutputType,
Thunk,
ThunkCollection,
ThunkMapping,
GraphQLArgument,
GraphQLArgumentMap,
GraphQLEnumValue,
GraphQLEnumValueMap,
GraphQLField,
GraphQLFieldMap,
GraphQLFieldResolver,
GraphQLInputField,
GraphQLInputFieldMap,
GraphQLScalarSerializer,
GraphQLScalarValueParser,
GraphQLScalarLiteralParser,
GraphQLIsTypeOfFn,
GraphQLResolveInfo,
ResponsePath,
GraphQLTypeResolver,
# Keyword args
GraphQLArgumentKwargs,
GraphQLDirectiveKwargs,
GraphQLEnumTypeKwargs,
GraphQLEnumValueKwargs,
GraphQLFieldKwargs,
GraphQLInputFieldKwargs,
GraphQLInputObjectTypeKwargs,
GraphQLInterfaceTypeKwargs,
GraphQLNamedTypeKwargs,
GraphQLObjectTypeKwargs,
GraphQLScalarTypeKwargs,
GraphQLSchemaKwargs,
GraphQLUnionTypeKwargs,
)
# Validate GraphQL queries.
from .validation import (
validate,
ValidationContext,
ValidationRule,
ASTValidationRule,
SDLValidationRule,
# All validation rules in the GraphQL Specification.
specified_rules,
# Individual validation rules.
ExecutableDefinitionsRule,
FieldsOnCorrectTypeRule,
FragmentsOnCompositeTypesRule,
KnownArgumentNamesRule,
KnownDirectivesRule,
KnownFragmentNamesRule,
KnownTypeNamesRule,
LoneAnonymousOperationRule,
NoFragmentCyclesRule,
NoUndefinedVariablesRule,
NoUnusedFragmentsRule,
NoUnusedVariablesRule,
OverlappingFieldsCanBeMergedRule,
PossibleFragmentSpreadsRule,
ProvidedRequiredArgumentsRule,
ScalarLeafsRule,
SingleFieldSubscriptionsRule,
UniqueArgumentNamesRule,
UniqueDirectivesPerLocationRule,
UniqueFragmentNamesRule,
UniqueInputFieldNamesRule,
UniqueOperationNamesRule,
UniqueVariableNamesRule,
ValuesOfCorrectTypeRule,
VariablesAreInputTypesRule,
VariablesInAllowedPositionRule,
# SDL-specific validation rules
LoneSchemaDefinitionRule,
UniqueOperationTypesRule,
UniqueTypeNamesRule,
UniqueEnumValueNamesRule,
UniqueFieldDefinitionNamesRule,
UniqueArgumentDefinitionNamesRule,
UniqueDirectiveNamesRule,
PossibleTypeExtensionsRule,
# Custom validation rules
NoDeprecatedCustomRule,
NoSchemaIntrospectionCustomRule,
)
# Execute GraphQL documents.
from .execution import (
execute,
execute_sync,
default_field_resolver,
default_type_resolver,
get_argument_values,
get_directive_values,
get_variable_values,
# Types
ExecutionContext,
ExecutionResult,
FormattedExecutionResult,
# Subscription
subscribe,
create_source_event_stream,
MapAsyncIterator,
# Middleware
Middleware,
MiddlewareManager,
)
# The primary entry point into fulfilling a GraphQL request.
from .graphql import graphql, graphql_sync
INVALID = Undefined # deprecated alias
# The GraphQL-core version info.
__version__ = version
__version_info__ = version_info
# The GraphQL.js version info.
__version_js__ = version_js
__version_info_js__ = version_info_js
__all__ = [
"version",
"version_info",
"version_js",
"version_info_js",
"graphql",
"graphql_sync",
"GraphQLSchema",
"GraphQLDirective",
"GraphQLScalarType",
"GraphQLObjectType",
"GraphQLInterfaceType",
"GraphQLUnionType",
"GraphQLEnumType",
"GraphQLInputObjectType",
"GraphQLList",
"GraphQLNonNull",
"specified_scalar_types",
"GraphQLInt",
"GraphQLFloat",
"GraphQLString",
"GraphQLBoolean",
"GraphQLID",
"GRAPHQL_MAX_INT",
"GRAPHQL_MIN_INT",
"specified_directives",
"GraphQLIncludeDirective",
"GraphQLSkipDirective",
"GraphQLDeprecatedDirective",
"GraphQLSpecifiedByDirective",
"TypeKind",
"DEFAULT_DEPRECATION_REASON",
"introspection_types",
"SchemaMetaFieldDef",
"TypeMetaFieldDef",
"TypeNameMetaFieldDef",
"is_schema",
"is_directive",
"is_type",
"is_scalar_type",
"is_object_type",
"is_interface_type",
"is_union_type",
"is_enum_type",
"is_input_object_type",
"is_list_type",
"is_non_null_type",
"is_input_type",
"is_output_type",
"is_leaf_type",
"is_composite_type",
"is_abstract_type",
"is_wrapping_type",
"is_nullable_type",
"is_named_type",
"is_required_argument",
"is_required_input_field",
"is_specified_scalar_type",
"is_introspection_type",
"is_specified_directive",
"assert_schema",
"assert_directive",
"assert_type",
"assert_scalar_type",
"assert_object_type",
"assert_interface_type",
"assert_union_type",
"assert_enum_type",
"assert_input_object_type",
"assert_list_type",
"assert_non_null_type",
"assert_input_type",
"assert_output_type",
"assert_leaf_type",
"assert_composite_type",
"assert_abstract_type",
"assert_wrapping_type",
"assert_nullable_type",
"assert_named_type",
"get_nullable_type",
"get_named_type",
"resolve_thunk",
"validate_schema",
"assert_valid_schema",
"assert_name",
"assert_enum_value_name",
"GraphQLType",
"GraphQLInputType",
"GraphQLOutputType",
"GraphQLLeafType",
"GraphQLCompositeType",
"GraphQLAbstractType",
"GraphQLWrappingType",
"GraphQLNullableType",
"GraphQLNamedType",
"GraphQLNamedInputType",
"GraphQLNamedOutputType",
"Thunk",
"ThunkCollection",
"ThunkMapping",
"GraphQLArgument",
"GraphQLArgumentMap",
"GraphQLEnumValue",
"GraphQLEnumValueMap",
"GraphQLField",
"GraphQLFieldMap",
"GraphQLFieldResolver",
"GraphQLInputField",
"GraphQLInputFieldMap",
"GraphQLScalarSerializer",
"GraphQLScalarValueParser",
"GraphQLScalarLiteralParser",
"GraphQLIsTypeOfFn",
"GraphQLResolveInfo",
"ResponsePath",
"GraphQLTypeResolver",
"GraphQLArgumentKwargs",
"GraphQLDirectiveKwargs",
"GraphQLEnumTypeKwargs",
"GraphQLEnumValueKwargs",
"GraphQLFieldKwargs",
"GraphQLInputFieldKwargs",
"GraphQLInputObjectTypeKwargs",
"GraphQLInterfaceTypeKwargs",
"GraphQLNamedTypeKwargs",
"GraphQLObjectTypeKwargs",
"GraphQLScalarTypeKwargs",
"GraphQLSchemaKwargs",
"GraphQLUnionTypeKwargs",
"Source",
"get_location",
"print_location",
"print_source_location",
"Lexer",
"TokenKind",
"parse",
"parse_value",
"parse_const_value",
"parse_type",
"print_ast",
"visit",
"ParallelVisitor",
"TypeInfoVisitor",
"Visitor",
"VisitorAction",
"VisitorKeyMap",
"BREAK",
"SKIP",
"REMOVE",
"IDLE",
"DirectiveLocation",
"is_definition_node",
"is_executable_definition_node",
"is_selection_node",
"is_value_node",
"is_const_value_node",
"is_type_node",
"is_type_system_definition_node",
"is_type_definition_node",
"is_type_system_extension_node",
"is_type_extension_node",
"SourceLocation",
"Location",
"Token",
"Node",
"NameNode",
"DocumentNode",
"DefinitionNode",
"ExecutableDefinitionNode",
"OperationDefinitionNode",
"OperationType",
"VariableDefinitionNode",
"VariableNode",
"SelectionSetNode",
"SelectionNode",
"FieldNode",
"ArgumentNode",
"ConstArgumentNode",
"FragmentSpreadNode",
"InlineFragmentNode",
"FragmentDefinitionNode",
"ValueNode",
"ConstValueNode",
"IntValueNode",
"FloatValueNode",
"StringValueNode",
"BooleanValueNode",
"NullValueNode",
"EnumValueNode",
"ListValueNode",
"ConstListValueNode",
"ObjectValueNode",
"ConstObjectValueNode",
"ObjectFieldNode",
"ConstObjectFieldNode",
"DirectiveNode",
"ConstDirectiveNode",
"TypeNode",
"NamedTypeNode",
"ListTypeNode",
"NonNullTypeNode",
"TypeSystemDefinitionNode",
"SchemaDefinitionNode",
"OperationTypeDefinitionNode",
"TypeDefinitionNode",
"ScalarTypeDefinitionNode",
"ObjectTypeDefinitionNode",
"FieldDefinitionNode",
"InputValueDefinitionNode",
"InterfaceTypeDefinitionNode",
"UnionTypeDefinitionNode",
"EnumTypeDefinitionNode",
"EnumValueDefinitionNode",
"InputObjectTypeDefinitionNode",
"DirectiveDefinitionNode",
"TypeSystemExtensionNode",
"SchemaExtensionNode",
"TypeExtensionNode",
"ScalarTypeExtensionNode",
"ObjectTypeExtensionNode",
"InterfaceTypeExtensionNode",
"UnionTypeExtensionNode",
"EnumTypeExtensionNode",
"InputObjectTypeExtensionNode",
"execute",
"execute_sync",
"default_field_resolver",
"default_type_resolver",
"get_argument_values",
"get_directive_values",
"get_variable_values",
"ExecutionContext",
"ExecutionResult",
"FormattedExecutionResult",
"Middleware",
"MiddlewareManager",
"subscribe",
"create_source_event_stream",
"MapAsyncIterator",
"validate",
"ValidationContext",
"ValidationRule",
"ASTValidationRule",
"SDLValidationRule",
"specified_rules",
"ExecutableDefinitionsRule",
"FieldsOnCorrectTypeRule",
"FragmentsOnCompositeTypesRule",
"KnownArgumentNamesRule",
"KnownDirectivesRule",
"KnownFragmentNamesRule",
"KnownTypeNamesRule",
"LoneAnonymousOperationRule",
"NoFragmentCyclesRule",
"NoUndefinedVariablesRule",
"NoUnusedFragmentsRule",
"NoUnusedVariablesRule",
"OverlappingFieldsCanBeMergedRule",
"PossibleFragmentSpreadsRule",
"ProvidedRequiredArgumentsRule",
"ScalarLeafsRule",
"SingleFieldSubscriptionsRule",
"UniqueArgumentNamesRule",
"UniqueDirectivesPerLocationRule",
"UniqueFragmentNamesRule",
"UniqueInputFieldNamesRule",
"UniqueOperationNamesRule",
"UniqueVariableNamesRule",
"ValuesOfCorrectTypeRule",
"VariablesAreInputTypesRule",
"VariablesInAllowedPositionRule",
"LoneSchemaDefinitionRule",
"UniqueOperationTypesRule",
"UniqueTypeNamesRule",
"UniqueEnumValueNamesRule",
"UniqueFieldDefinitionNamesRule",
"UniqueArgumentDefinitionNamesRule",
"UniqueDirectiveNamesRule",
"PossibleTypeExtensionsRule",
"NoDeprecatedCustomRule",
"NoSchemaIntrospectionCustomRule",
"GraphQLError",
"GraphQLErrorExtensions",
"GraphQLFormattedError",
"GraphQLSyntaxError",
"located_error",
"get_introspection_query",
"IntrospectionQuery",
"get_operation_ast",
"get_operation_root_type",
"introspection_from_schema",
"build_client_schema",
"build_ast_schema",
"build_schema",
"extend_schema",
"lexicographic_sort_schema",
"print_schema",
"print_type",
"print_introspection_schema",
"type_from_ast",
"value_from_ast",
"value_from_ast_untyped",
"ast_from_value",
"ast_to_dict",
"TypeInfo",
"coerce_input_value",
"concat_ast",
"separate_operations",
"strip_ignored_characters",
"is_equal_type",
"is_type_sub_type_of",
"do_types_overlap",
"assert_valid_name",
"is_valid_name_error",
"find_breaking_changes",
"find_dangerous_changes",
"BreakingChange",
"BreakingChangeType",
"DangerousChange",
"DangerousChangeType",
"Undefined",
"UndefinedType",
]
@@ -0,0 +1,19 @@
"""GraphQL Errors
The :mod:`graphql.error` package is responsible for creating and formatting GraphQL
errors.
"""
from .graphql_error import GraphQLError, GraphQLErrorExtensions, GraphQLFormattedError
from .syntax_error import GraphQLSyntaxError
from .located_error import located_error
__all__ = [
"GraphQLError",
"GraphQLErrorExtensions",
"GraphQLFormattedError",
"GraphQLSyntaxError",
"located_error",
]
@@ -0,0 +1,264 @@
from sys import exc_info
from typing import Any, Collection, Dict, List, Optional, Union, TYPE_CHECKING
try:
from typing import TypedDict
except ImportError: # Python < 3.8
from typing_extensions import TypedDict
if TYPE_CHECKING:
from ..language.ast import Node # noqa: F401
from ..language.location import (
SourceLocation,
FormattedSourceLocation,
) # noqa: F401
from ..language.source import Source # noqa: F401
__all__ = ["GraphQLError", "GraphQLErrorExtensions", "GraphQLFormattedError"]
# Custom extensions
GraphQLErrorExtensions = Dict[str, Any]
# Use a unique identifier name for your extension, for example the name of
# your library or project. Do not use a shortened identifier as this increases
# the risk of conflicts. We recommend you add at most one extension key,
# a dictionary which can contain all the values you need.
class GraphQLFormattedError(TypedDict, total=False):
"""Formatted GraphQL error"""
# A short, human-readable summary of the problem that **SHOULD NOT** change
# from occurrence to occurrence of the problem, except for purposes of localization.
message: str
# If an error can be associated to a particular point in the requested
# GraphQL document, it should contain a list of locations.
locations: List["FormattedSourceLocation"]
# If an error can be associated to a particular field in the GraphQL result,
# it _must_ contain an entry with the key `path` that details the path of
# the response field which experienced the error. This allows clients to
# identify whether a null result is intentional or caused by a runtime error.
path: List[Union[str, int]]
# Reserved for implementors to extend the protocol however they see fit,
# and hence there are no additional restrictions on its contents.
extensions: GraphQLErrorExtensions
class GraphQLError(Exception):
"""GraphQL Error
A GraphQLError describes an Error found during the parse, validate, or execute
phases of performing a GraphQL operation. In addition to a message, it also includes
information about the locations in a GraphQL document and/or execution result that
correspond to the Error.
"""
message: str
"""A message describing the Error for debugging purposes"""
locations: Optional[List["SourceLocation"]]
"""Source locations
A list of (line, column) locations within the source GraphQL document which
correspond to this error.
Errors during validation often contain multiple locations, for example to point out
two things with the same name. Errors during execution include a single location,
the field which produced the error.
"""
path: Optional[List[Union[str, int]]]
"""
A list of field names and array indexes describing the JSON-path into the execution
response which corresponds to this error.
Only included for errors during execution.
"""
nodes: Optional[List["Node"]]
"""A list of GraphQL AST Nodes corresponding to this error"""
source: Optional["Source"]
"""The source GraphQL document for the first location of this error
Note that if this Error represents more than one node, the source may not represent
nodes after the first node.
"""
positions: Optional[Collection[int]]
"""Error positions
A list of character offsets within the source GraphQL document which correspond
to this error.
"""
original_error: Optional[Exception]
"""The original error thrown from a field resolver during execution"""
extensions: Optional[GraphQLErrorExtensions]
"""Extension fields to add to the formatted error"""
__slots__ = (
"message",
"nodes",
"source",
"positions",
"locations",
"path",
"original_error",
"extensions",
)
__hash__ = Exception.__hash__
def __init__(
self,
message: str,
nodes: Union[Collection["Node"], "Node", None] = None,
source: Optional["Source"] = None,
positions: Optional[Collection[int]] = None,
path: Optional[Collection[Union[str, int]]] = None,
original_error: Optional[Exception] = None,
extensions: Optional[GraphQLErrorExtensions] = None,
) -> None:
super().__init__(message)
self.message = message
if path and not isinstance(path, list):
path = list(path)
self.path = path or None # type: ignore
self.original_error = original_error
# Compute list of blame nodes.
if nodes and not isinstance(nodes, list):
nodes = [nodes] # type: ignore
self.nodes = nodes or None # type: ignore
node_locations = (
[node.loc for node in nodes if node.loc] if nodes else [] # type: ignore
)
# Compute locations in the source for the given nodes/positions.
self.source = source
if not source and node_locations:
loc = node_locations[0]
if loc.source: # pragma: no cover else
self.source = loc.source
if not positions and node_locations:
positions = [loc.start for loc in node_locations]
self.positions = positions or None
if positions and source:
locations: Optional[List["SourceLocation"]] = [
source.get_location(pos) for pos in positions
]
else:
locations = [loc.source.get_location(loc.start) for loc in node_locations]
self.locations = locations or None
if original_error:
self.__traceback__ = original_error.__traceback__
if original_error.__cause__:
self.__cause__ = original_error.__cause__
elif original_error.__context__:
self.__context__ = original_error.__context__
if extensions is None:
original_extensions = getattr(original_error, "extensions", None)
if isinstance(original_extensions, dict):
extensions = original_extensions
self.extensions = extensions or {}
if not self.__traceback__:
self.__traceback__ = exc_info()[2]
def __str__(self) -> str:
# Lazy import to avoid a cyclic dependency between error and language
from ..language.print_location import print_location, print_source_location
output = [self.message]
if self.nodes:
for node in self.nodes:
if node.loc:
output.append(print_location(node.loc))
elif self.source and self.locations:
source = self.source
for location in self.locations:
output.append(print_source_location(source, location))
return "\n\n".join(output)
def __repr__(self) -> str:
args = [repr(self.message)]
if self.locations:
args.append(f"locations={self.locations!r}")
if self.path:
args.append(f"path={self.path!r}")
if self.extensions:
args.append(f"extensions={self.extensions!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, GraphQLError)
and self.__class__ == other.__class__
and all(
getattr(self, slot) == getattr(other, slot)
for slot in self.__slots__
if slot != "original_error"
)
) or (
isinstance(other, dict)
and "message" in other
and all(
slot in self.__slots__ and getattr(self, slot) == other.get(slot)
for slot in other
if slot != "original_error"
)
)
def __ne__(self, other: Any) -> bool:
return not self == other
@property
def formatted(self) -> GraphQLFormattedError:
"""Get error formatted according to the specification.
Given a GraphQLError, format it according to the rules described by the
"Response Format, Errors" section of the GraphQL Specification.
"""
formatted: GraphQLFormattedError = {
"message": self.message or "An unknown error occurred.",
}
if self.locations is not None:
formatted["locations"] = [location.formatted for location in self.locations]
if self.path is not None:
formatted["path"] = self.path
if self.extensions:
formatted["extensions"] = self.extensions
return formatted
def print_error(error: GraphQLError) -> str:
"""Print a GraphQLError to a string.
Represents useful location information about the error's position in the source.
.. deprecated:: 3.2
Please use ``str(error)`` instead. Will be removed in v3.3.
"""
if not isinstance(error, GraphQLError):
raise TypeError("Expected a GraphQLError.")
return str(error)
def format_error(error: GraphQLError) -> GraphQLFormattedError:
"""Format a GraphQL error.
Given a GraphQLError, format it according to the rules described by the "Response
Format, Errors" section of the GraphQL Specification.
.. deprecated:: 3.2
Please use ``error.formatted`` instead. Will be removed in v3.3.
"""
if not isinstance(error, GraphQLError):
raise TypeError("Expected a GraphQLError.")
return error.formatted
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING, Collection, Optional, Union
from ..pyutils import inspect
from .graphql_error import GraphQLError
if TYPE_CHECKING:
from ..language.ast import Node # noqa: F401
__all__ = ["located_error"]
def located_error(
original_error: Exception,
nodes: Optional[Union["None", Collection["Node"]]] = None,
path: Optional[Collection[Union[str, int]]] = None,
) -> GraphQLError:
"""Located GraphQL Error
Given an arbitrary Exception, presumably thrown while attempting to execute a
GraphQL operation, produce a new GraphQLError aware of the location in the document
responsible for the original Exception.
"""
# Sometimes a non-error is thrown, wrap it as a TypeError to ensure consistency.
if not isinstance(original_error, Exception):
original_error = TypeError(f"Unexpected error value: {inspect(original_error)}")
# Note: this uses a brand-check to support GraphQL errors originating from
# other contexts.
if isinstance(original_error, GraphQLError) and original_error.path is not None:
return original_error
try:
# noinspection PyUnresolvedReferences
message = str(original_error.message) # type: ignore
except AttributeError:
message = str(original_error)
try:
# noinspection PyUnresolvedReferences
source = original_error.source # type: ignore
except AttributeError:
source = None
try:
# noinspection PyUnresolvedReferences
positions = original_error.positions # type: ignore
except AttributeError:
positions = None
try:
# noinspection PyUnresolvedReferences
nodes = original_error.nodes or nodes # type: ignore
except AttributeError:
pass
return GraphQLError(message, nodes, source, positions, path, original_error)
@@ -0,0 +1,18 @@
from typing import TYPE_CHECKING
from .graphql_error import GraphQLError
if TYPE_CHECKING:
from ..language.source import Source # noqa: F401
__all__ = ["GraphQLSyntaxError"]
class GraphQLSyntaxError(GraphQLError):
"""A GraphQLError representing a syntax error."""
def __init__(self, source: "Source", position: int, description: str) -> None:
super().__init__(
f"Syntax Error: {description}", source=source, positions=[position]
)
self.description = description
@@ -0,0 +1,38 @@
"""GraphQL Execution
The :mod:`graphql.execution` package is responsible for the execution phase of
fulfilling a GraphQL request.
"""
from .execute import (
execute,
execute_sync,
default_field_resolver,
default_type_resolver,
ExecutionContext,
ExecutionResult,
FormattedExecutionResult,
Middleware,
)
from .map_async_iterator import MapAsyncIterator
from .subscribe import subscribe, create_source_event_stream
from .middleware import MiddlewareManager
from .values import get_argument_values, get_directive_values, get_variable_values
__all__ = [
"create_source_event_stream",
"execute",
"execute_sync",
"default_field_resolver",
"default_type_resolver",
"subscribe",
"ExecutionContext",
"ExecutionResult",
"FormattedExecutionResult",
"MapAsyncIterator",
"Middleware",
"MiddlewareManager",
"get_argument_values",
"get_directive_values",
"get_variable_values",
]
@@ -0,0 +1,174 @@
from typing import Any, Dict, List, Set, Union, cast
from ..language import (
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
InlineFragmentNode,
SelectionSetNode,
)
from ..type import (
GraphQLAbstractType,
GraphQLIncludeDirective,
GraphQLObjectType,
GraphQLSchema,
GraphQLSkipDirective,
is_abstract_type,
)
from ..utilities.type_from_ast import type_from_ast
from .values import get_directive_values
__all__ = ["collect_fields", "collect_sub_fields"]
def collect_fields(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
) -> Dict[str, List[FieldNode]]:
"""Collect fields.
Given a selection_set, collects all the fields and returns them.
collect_fields requires the "runtime type" of an object. For a field that
returns an Interface or Union type, the "runtime type" will be the actual
object type returned by that field.
For internal use only.
"""
fields: Dict[str, List[FieldNode]] = {}
collect_fields_impl(
schema, fragments, variable_values, runtime_type, selection_set, fields, set()
)
return fields
def collect_sub_fields(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
return_type: GraphQLObjectType,
field_nodes: List[FieldNode],
) -> Dict[str, List[FieldNode]]:
"""Collect sub fields.
Given a list of field nodes, collects all the subfields of the passed in fields,
and returns them at the end.
collect_sub_fields requires the "return type" of an object. For a field that
returns an Interface or Union type, the "return type" will be the actual
object type returned by that field.
For internal use only.
"""
sub_field_nodes: Dict[str, List[FieldNode]] = {}
visited_fragment_names: Set[str] = set()
for node in field_nodes:
if node.selection_set:
collect_fields_impl(
schema,
fragments,
variable_values,
return_type,
node.selection_set,
sub_field_nodes,
visited_fragment_names,
)
return sub_field_nodes
def collect_fields_impl(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
fields: Dict[str, List[FieldNode]],
visited_fragment_names: Set[str],
) -> None:
"""Collect fields (internal implementation)."""
for selection in selection_set.selections:
if isinstance(selection, FieldNode):
if not should_include_node(variable_values, selection):
continue
name = get_field_entry_key(selection)
fields.setdefault(name, []).append(selection)
elif isinstance(selection, InlineFragmentNode):
if not should_include_node(
variable_values, selection
) or not does_fragment_condition_match(schema, selection, runtime_type):
continue
collect_fields_impl(
schema,
fragments,
variable_values,
runtime_type,
selection.selection_set,
fields,
visited_fragment_names,
)
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
frag_name = selection.name.value
if frag_name in visited_fragment_names or not should_include_node(
variable_values, selection
):
continue
visited_fragment_names.add(frag_name)
fragment = fragments.get(frag_name)
if not fragment or not does_fragment_condition_match(
schema, fragment, runtime_type
):
continue
collect_fields_impl(
schema,
fragments,
variable_values,
runtime_type,
fragment.selection_set,
fields,
visited_fragment_names,
)
def should_include_node(
variable_values: Dict[str, Any],
node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode],
) -> bool:
"""Check if node should be included
Determines if a field should be included based on the @include and @skip
directives, where @skip has higher precedence than @include.
"""
skip = get_directive_values(GraphQLSkipDirective, node, variable_values)
if skip and skip["if"]:
return False
include = get_directive_values(GraphQLIncludeDirective, node, variable_values)
if include and not include["if"]:
return False
return True
def does_fragment_condition_match(
schema: GraphQLSchema,
fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
type_: GraphQLObjectType,
) -> bool:
"""Determine if a fragment is applicable to the given type."""
type_condition_node = fragment.type_condition
if not type_condition_node:
return True
conditional_type = type_from_ast(schema, type_condition_node)
if conditional_type is type_:
return True
if is_abstract_type(conditional_type):
return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
return False
def get_field_entry_key(node: FieldNode) -> str:
"""Implements the logic to compute the key of a given field's entry"""
return node.alias.value if node.alias else node.name.value
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,115 @@
from asyncio import CancelledError, Event, Task, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from types import TracebackType
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union
__all__ = ["MapAsyncIterator"]
# noinspection PyAttributeOutsideInit
class MapAsyncIterator:
"""Map an AsyncIterable over a callback function.
Given an AsyncIterable and a callback function, return an AsyncIterator which
produces values mapped via calling the callback function.
When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
be closed.
"""
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
self.iterator = iterable.__aiter__()
self.callback = callback
self._close_event = Event()
def __aiter__(self) -> "MapAsyncIterator":
"""Get the iterator object."""
return self
async def __anext__(self) -> Any:
"""Get the next value of the iterator."""
if self.is_closed:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())
try:
pending: Set[Task] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
await self.aclose()
raise # re-raise the cancellation
for task in pending:
task.cancel()
if aclose.done():
raise StopAsyncIteration
error = anext.exception()
if error:
raise error
value = anext.result()
result = self.callback(value)
return await result if isawaitable(result) else result
async def athrow(
self,
type_: Union[BaseException, Type[BaseException]],
value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Throw an exception into the asynchronous iterator."""
if self.is_closed:
return
if isinstance(type_, BaseException):
value = type_
type_ = type(value)
traceback = value.__traceback__
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_ if value is None else value)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_ # pragma: no cover
value = type_ if isinstance(value, BaseException) else type_()
if traceback is not None:
value = value.with_traceback(traceback)
raise value
async def aclose(self) -> None:
"""Close the iterator."""
if not self.is_closed:
aclose = getattr(self.iterator, "aclose", None)
if aclose:
try:
await aclose()
except RuntimeError:
pass
self.is_closed = True
@property
def is_closed(self) -> bool:
"""Check whether the iterator is closed."""
return self._close_event.is_set()
@is_closed.setter
def is_closed(self, value: bool) -> None:
"""Mark the iterator as closed."""
if value:
self._close_event.set()
else:
self._close_event.clear()
@@ -0,0 +1,63 @@
from functools import partial, reduce
from inspect import isfunction
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
__all__ = ["MiddlewareManager"]
GraphQLFieldResolver = Callable[..., Any]
class MiddlewareManager:
"""Manager for the middleware chain.
This class helps to wrap resolver functions with the provided middleware functions
and/or objects. The functions take the next middleware function as first argument.
If middleware is provided as an object, it must provide a method ``resolve`` that is
used as the middleware function.
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
must be aware of this and check whether values are awaitable before awaiting them.
"""
# allow custom attributes (not used internally)
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
_middleware_resolvers: Optional[List[Callable]]
def __init__(self, *middlewares: Any):
self.middlewares = middlewares
self._middleware_resolvers = (
list(get_middleware_resolvers(middlewares)) if middlewares else None
)
self._cached_resolvers = {}
def get_field_resolver(
self, field_resolver: GraphQLFieldResolver
) -> GraphQLFieldResolver:
"""Wrap the provided resolver with the middleware.
Returns a function that chains the middleware functions with the provided
resolver function.
"""
if self._middleware_resolvers is None:
return field_resolver
if field_resolver not in self._cached_resolvers:
self._cached_resolvers[field_resolver] = reduce(
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
self._middleware_resolvers,
field_resolver,
)
return self._cached_resolvers[field_resolver]
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
"""Get a list of resolver functions from a list of classes or functions."""
for middleware in middlewares:
if isfunction(middleware):
yield middleware
else: # middleware provided as object with 'resolve' method
resolver_func = getattr(middleware, "resolve", None)
if resolver_func is not None:
yield resolver_func
@@ -0,0 +1,212 @@
from inspect import isawaitable
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Dict,
Optional,
Union,
)
from ..error import GraphQLError, located_error
from ..execution.collect_fields import collect_fields
from ..execution.execute import (
assert_valid_execution_arguments,
execute,
get_field_def,
ExecutionContext,
ExecutionResult,
)
from ..execution.values import get_argument_values
from ..language import DocumentNode
from ..pyutils import Path, inspect
from ..type import GraphQLFieldResolver, GraphQLSchema
from .map_async_iterator import MapAsyncIterator
__all__ = ["subscribe", "create_source_event_stream"]
async def subscribe(
schema: GraphQLSchema,
document: DocumentNode,
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
field_resolver: Optional[GraphQLFieldResolver] = None,
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
"""Create a GraphQL subscription.
Implements the "Subscribe" algorithm described in the GraphQL spec.
Returns a coroutine object which yields either an AsyncIterator (if successful) or
an ExecutionResult (client error). The coroutine will raise an exception if a server
error occurs.
If the client-provided arguments to this function do not result in a compliant
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
data will be returned.
If the source stream could not be created due to faulty subscription resolver logic
or underlying systems, the coroutine object will yield a single ExecutionResult
containing ``errors`` and no ``data``.
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
a stream of ExecutionResults representing the response stream.
"""
result_or_stream = await create_source_event_stream(
schema,
document,
root_value,
context_value,
variable_values,
operation_name,
subscribe_field_resolver,
)
if isinstance(result_or_stream, ExecutionResult):
return result_or_stream
async def map_source_to_response(payload: Any) -> ExecutionResult:
"""Map source to response.
For each payload yielded from a subscription, map it over the normal GraphQL
:func:`~graphql.execute` function, with ``payload`` as the ``root_value``.
This implements the "MapSourceToResponseEvent" algorithm described in the
GraphQL specification. The :func:`~graphql.execute` function provides the
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
"""
result = execute(
schema,
document,
payload,
context_value,
variable_values,
operation_name,
field_resolver,
)
return await result if isawaitable(result) else result
# Map every source value to a ExecutionResult value as described above.
return MapAsyncIterator(result_or_stream, map_source_to_response)
async def create_source_event_stream(
schema: GraphQLSchema,
document: DocumentNode,
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
) -> Union[AsyncIterable[Any], ExecutionResult]:
"""Create source event stream
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
specification, resolving the subscription source event stream.
Returns a coroutine that yields an AsyncIterable.
If the client-provided arguments to this function do not result in a compliant
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
data will be returned.
If the source stream could not be created due to faulty subscription resolver logic
or underlying systems, the coroutine object will yield a single ExecutionResult
containing ``errors`` and no ``data``.
A source event stream represents a sequence of events, each of which triggers a
GraphQL execution for that event.
This may be useful when hosting the stateful subscription service in a different
process or machine than the stateless GraphQL execution engine, or otherwise
separating these two steps. For more on this, see the "Supporting Subscriptions
at Scale" information in the GraphQL spec.
"""
# If arguments are missing or incorrectly typed, this is an internal developer
# mistake which should throw an early error.
assert_valid_execution_arguments(schema, document, variable_values)
# If a valid context cannot be created due to incorrect arguments,
# a "Response" with only errors is returned.
context = ExecutionContext.build(
schema,
document,
root_value,
context_value,
variable_values,
operation_name,
subscribe_field_resolver=subscribe_field_resolver,
)
# Return early errors if execution context failed.
if isinstance(context, list):
return ExecutionResult(data=None, errors=context)
try:
event_stream = await execute_subscription(context)
# Assert field returned an event stream, otherwise yield an error.
if not isinstance(event_stream, AsyncIterable):
raise TypeError(
"Subscription field must return AsyncIterable."
f" Received: {inspect(event_stream)}."
)
return event_stream
except GraphQLError as error:
# Report it as an ExecutionResult, containing only errors and no data.
return ExecutionResult(data=None, errors=[error])
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
schema = context.schema
root_type = schema.subscription_type
if root_type is None:
raise GraphQLError(
"Schema is not configured to execute subscription operation.",
context.operation,
)
root_fields = collect_fields(
schema,
context.fragments,
context.variable_values,
root_type,
context.operation.selection_set,
)
response_name, field_nodes = next(iter(root_fields.items()))
field_def = get_field_def(schema, root_type, field_nodes[0])
if not field_def:
field_name = field_nodes[0].name.value
raise GraphQLError(
f"The subscription field '{field_name}' is not defined.", field_nodes
)
path = Path(None, response_name, root_type.name)
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
try:
# Build a dictionary of arguments from the field.arguments AST, using the
# variables scope to fulfill any variable references.
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
# Call the `subscribe()` resolver or the default resolver to produce an
# AsyncIterable yielding raw payloads.
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
event_stream = resolve_fn(context.root_value, info, **args)
if context.is_awaitable(event_stream):
event_stream = await event_stream
if isinstance(event_stream, Exception):
raise event_stream
return event_stream
except Exception as error:
raise located_error(error, field_nodes, path.as_list())
@@ -0,0 +1,251 @@
from typing import Any, Callable, Collection, Dict, List, Optional, Union, cast
from ..error import GraphQLError
from ..language import (
DirectiveNode,
EnumValueDefinitionNode,
ExecutableDefinitionNode,
FieldDefinitionNode,
FieldNode,
InputValueDefinitionNode,
NullValueNode,
SchemaDefinitionNode,
SelectionNode,
TypeDefinitionNode,
TypeExtensionNode,
VariableDefinitionNode,
VariableNode,
print_ast,
)
from ..pyutils import Undefined, inspect, print_path_list
from ..type import (
GraphQLDirective,
GraphQLField,
GraphQLInputType,
GraphQLSchema,
is_input_object_type,
is_input_type,
is_non_null_type,
)
from ..utilities.coerce_input_value import coerce_input_value
from ..utilities.type_from_ast import type_from_ast
from ..utilities.value_from_ast import value_from_ast
__all__ = ["get_argument_values", "get_directive_values", "get_variable_values"]
CoercedVariableValues = Union[List[GraphQLError], Dict[str, Any]]
def get_variable_values(
schema: GraphQLSchema,
var_def_nodes: Collection[VariableDefinitionNode],
inputs: Dict[str, Any],
max_errors: Optional[int] = None,
) -> CoercedVariableValues:
"""Get coerced variable values based on provided definitions.
Prepares a dict of variable values of the correct type based on the provided
variable definitions and arbitrary input. If the input cannot be parsed to match
the variable definitions, a GraphQLError will be raised.
"""
errors: List[GraphQLError] = []
def on_error(error: GraphQLError) -> None:
if max_errors is not None and len(errors) >= max_errors:
raise GraphQLError(
"Too many errors processing variables,"
" error limit reached. Execution aborted."
)
errors.append(error)
try:
coerced = coerce_variable_values(schema, var_def_nodes, inputs, on_error)
if not errors:
return coerced
except GraphQLError as e:
errors.append(e)
return errors
def coerce_variable_values(
schema: GraphQLSchema,
var_def_nodes: Collection[VariableDefinitionNode],
inputs: Dict[str, Any],
on_error: Callable[[GraphQLError], None],
) -> Dict[str, Any]:
coerced_values: Dict[str, Any] = {}
for var_def_node in var_def_nodes:
var_name = var_def_node.variable.name.value
var_type = type_from_ast(schema, var_def_node.type)
if not is_input_type(var_type):
# Must use input types for variables. This should be caught during
# validation, however is checked again here for safety.
var_type_str = print_ast(var_def_node.type)
on_error(
GraphQLError(
f"Variable '${var_name}' expected value of type '{var_type_str}'"
" which cannot be used as an input type.",
var_def_node.type,
)
)
continue
var_type = cast(GraphQLInputType, var_type)
if var_name not in inputs:
if var_def_node.default_value:
coerced_values[var_name] = value_from_ast(
var_def_node.default_value, var_type
)
elif is_non_null_type(var_type): # pragma: no cover else
var_type_str = inspect(var_type)
on_error(
GraphQLError(
f"Variable '${var_name}' of required type '{var_type_str}'"
" was not provided.",
var_def_node,
)
)
continue
value = inputs[var_name]
if value is None and is_non_null_type(var_type):
var_type_str = inspect(var_type)
on_error(
GraphQLError(
f"Variable '${var_name}' of non-null type '{var_type_str}'"
" must not be null.",
var_def_node,
)
)
continue
def on_input_value_error(
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
) -> None:
invalid_str = inspect(invalid_value)
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
if path:
prefix += f" at '{var_name}{print_path_list(path)}'"
on_error(
GraphQLError(
prefix + "; " + error.message,
var_def_node,
original_error=error,
)
)
coerced_values[var_name] = coerce_input_value(
value, var_type, on_input_value_error
)
return coerced_values
def get_argument_values(
type_def: Union[GraphQLField, GraphQLDirective],
node: Union[FieldNode, DirectiveNode],
variable_values: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get coerced argument values based on provided definitions and nodes.
Prepares a dict of argument values given a list of argument definitions and list
of argument AST nodes.
"""
coerced_values: Dict[str, Any] = {}
arg_node_map = {arg.name.value: arg for arg in node.arguments or []}
for name, arg_def in type_def.args.items():
arg_type = arg_def.type
argument_node = arg_node_map.get(name)
if argument_node is None:
value = arg_def.default_value
if value is not Undefined:
if is_input_object_type(arg_def.type):
# coerce input value so that out_names are used
value = coerce_input_value(value, arg_def.type)
coerced_values[arg_def.out_name or name] = value
elif is_non_null_type(arg_type): # pragma: no cover else
raise GraphQLError(
f"Argument '{name}' of required type '{arg_type}'"
" was not provided.",
node,
)
continue # pragma: no cover
value_node = argument_node.value
is_null = isinstance(argument_node.value, NullValueNode)
if isinstance(value_node, VariableNode):
variable_name = value_node.name.value
if variable_values is None or variable_name not in variable_values:
value = arg_def.default_value
if value is not Undefined:
if is_input_object_type(arg_def.type):
# coerce input value so that out_names are used
value = coerce_input_value(value, arg_def.type)
coerced_values[arg_def.out_name or name] = value
elif is_non_null_type(arg_type): # pragma: no cover else
raise GraphQLError(
f"Argument '{name}' of required type '{arg_type}'"
f" was provided the variable '${variable_name}'"
" which was not provided a runtime value.",
value_node,
)
continue # pragma: no cover
is_null = variable_values[variable_name] is None
if is_null and is_non_null_type(arg_type):
raise GraphQLError(
f"Argument '{name}' of non-null type '{arg_type}' must not be null.",
value_node,
)
coerced_value = value_from_ast(value_node, arg_type, variable_values)
if coerced_value is Undefined:
# Note: `values_of_correct_type` validation should catch this before
# execution. This is a runtime check to ensure execution does not
# continue with an invalid argument value.
raise GraphQLError(
f"Argument '{name}' has invalid value {print_ast(value_node)}.",
value_node,
)
coerced_values[arg_def.out_name or name] = coerced_value
return coerced_values
NodeWithDirective = Union[
EnumValueDefinitionNode,
ExecutableDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
SelectionNode,
SchemaDefinitionNode,
TypeDefinitionNode,
TypeExtensionNode,
]
def get_directive_values(
directive_def: GraphQLDirective,
node: NodeWithDirective,
variable_values: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Get coerced argument values based on provided nodes.
Prepares a dict of argument values given a directive definition and an AST node
which may contain directives. Optionally also accepts a dict of variable values.
If the directive does not exist on the node, returns None.
"""
directives = node.directives
if directives:
directive_name = directive_def.name
for directive in directives:
if directive.name.value == directive_name:
return get_argument_values(directive_def, directive, variable_values)
return None
@@ -0,0 +1,198 @@
from asyncio import ensure_future
from inspect import isawaitable
from typing import Any, Callable, Dict, Optional, Type, Union
from .error import GraphQLError
from .execution import ExecutionContext, ExecutionResult, Middleware, execute
from .language import Source, parse
from .pyutils import AwaitableOrValue
from .type import (
GraphQLFieldResolver,
GraphQLSchema,
GraphQLTypeResolver,
validate_schema,
)
__all__ = ["graphql", "graphql_sync"]
async def graphql(
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
field_resolver: Optional[GraphQLFieldResolver] = None,
type_resolver: Optional[GraphQLTypeResolver] = None,
middleware: Optional[Middleware] = None,
execution_context_class: Optional[Type[ExecutionContext]] = None,
is_awaitable: Optional[Callable[[Any], bool]] = None,
) -> ExecutionResult:
"""Execute a GraphQL operation asynchronously.
This is the primary entry point function for fulfilling GraphQL operations by
parsing, validating, and executing a GraphQL document along side a GraphQL schema.
More sophisticated GraphQL servers, such as those which persist queries, may wish
to separate the validation and execution phases to a static time tooling step,
and a server runtime step.
Accepts the following arguments:
:arg schema:
The GraphQL type system to use when validating and executing a query.
:arg source:
A GraphQL language formatted string representing the requested operation.
:arg root_value:
The value provided as the first argument to resolver functions on the top level
type (e.g. the query object type).
:arg context_value:
The context value is provided as an attribute of the second argument
(the resolve info) to resolver functions. It is used to pass shared information
useful at any point during query execution, for example the currently logged in
user and connections to databases or other services.
:arg variable_values:
A mapping of variable name to runtime value to use for all variables defined
in the request string.
:arg operation_name:
The name of the operation to use if request string contains multiple possible
operations. Can be omitted if request string contains only one operation.
:arg field_resolver:
A resolver function to use when one is not provided by the schema.
If not provided, the default field resolver is used (which looks for a value
or method on the source value with the field's name).
:arg type_resolver:
A type resolver function to use when none is provided by the schema.
If not provided, the default type resolver is used (which looks for a
``__typename`` field or alternatively calls the
:meth:`~graphql.type.GraphQLObjectType.is_type_of` method).
:arg middleware:
The middleware to wrap the resolvers with
:arg execution_context_class:
The execution context class to use to build the context
:arg is_awaitable:
The predicate to be used for checking whether values are awaitable
"""
# Always return asynchronously for a consistent API.
result = graphql_impl(
schema,
source,
root_value,
context_value,
variable_values,
operation_name,
field_resolver,
type_resolver,
middleware,
execution_context_class,
is_awaitable,
)
if isawaitable(result):
return await result
return result
def assume_not_awaitable(_value: Any) -> bool:
"""Replacement for isawaitable if everything is assumed to be synchronous."""
return False
def graphql_sync(
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any = None,
context_value: Any = None,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
field_resolver: Optional[GraphQLFieldResolver] = None,
type_resolver: Optional[GraphQLTypeResolver] = None,
middleware: Optional[Middleware] = None,
execution_context_class: Optional[Type[ExecutionContext]] = None,
check_sync: bool = False,
) -> ExecutionResult:
"""Execute a GraphQL operation synchronously.
The graphql_sync function also fulfills GraphQL operations by parsing, validating,
and executing a GraphQL document along side a GraphQL schema. However, it guarantees
to complete synchronously (or throw an error) assuming that all field resolvers
are also synchronous.
Set check_sync to True to still run checks that no awaitable values are returned.
"""
is_awaitable = (
check_sync
if callable(check_sync)
else (None if check_sync else assume_not_awaitable)
)
result = graphql_impl(
schema,
source,
root_value,
context_value,
variable_values,
operation_name,
field_resolver,
type_resolver,
middleware,
execution_context_class,
is_awaitable,
)
# Assert that the execution was synchronous.
if isawaitable(result):
ensure_future(result).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")
return result
def graphql_impl(
schema: GraphQLSchema,
source: Union[str, Source],
root_value: Any,
context_value: Any,
variable_values: Optional[Dict[str, Any]],
operation_name: Optional[str],
field_resolver: Optional[GraphQLFieldResolver],
type_resolver: Optional[GraphQLTypeResolver],
middleware: Optional[Middleware],
execution_context_class: Optional[Type[ExecutionContext]],
is_awaitable: Optional[Callable[[Any], bool]],
) -> AwaitableOrValue[ExecutionResult]:
"""Execute a query, return asynchronously only if necessary."""
# Validate Schema
schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
return ExecutionResult(data=None, errors=schema_validation_errors)
# Parse
try:
document = parse(source)
except GraphQLError as error:
return ExecutionResult(data=None, errors=[error])
# Validate
from .validation import validate
validation_errors = validate(schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
# Execute
return execute(
schema,
document,
root_value,
context_value,
variable_values,
operation_name,
field_resolver,
type_resolver,
None,
middleware,
execution_context_class,
is_awaitable,
)
@@ -0,0 +1,208 @@
"""GraphQL Language
The :mod:`graphql.language` package is responsible for parsing and operating on the
GraphQL language.
"""
from .source import Source
from .location import get_location, SourceLocation, FormattedSourceLocation
from .print_location import print_location, print_source_location
from .token_kind import TokenKind
from .lexer import Lexer
from .parser import parse, parse_type, parse_value, parse_const_value
from .printer import print_ast
from .visitor import (
visit,
Visitor,
ParallelVisitor,
VisitorAction,
VisitorKeyMap,
BREAK,
SKIP,
REMOVE,
IDLE,
)
from .ast import (
Location,
Token,
Node,
# Each kind of AST node
NameNode,
DocumentNode,
DefinitionNode,
ExecutableDefinitionNode,
OperationDefinitionNode,
OperationType,
VariableDefinitionNode,
VariableNode,
SelectionSetNode,
SelectionNode,
FieldNode,
ArgumentNode,
ConstArgumentNode,
FragmentSpreadNode,
InlineFragmentNode,
FragmentDefinitionNode,
ValueNode,
ConstValueNode,
IntValueNode,
FloatValueNode,
StringValueNode,
BooleanValueNode,
NullValueNode,
EnumValueNode,
ListValueNode,
ConstListValueNode,
ObjectValueNode,
ConstObjectValueNode,
ObjectFieldNode,
ConstObjectFieldNode,
DirectiveNode,
ConstDirectiveNode,
TypeNode,
NamedTypeNode,
ListTypeNode,
NonNullTypeNode,
TypeSystemDefinitionNode,
SchemaDefinitionNode,
OperationTypeDefinitionNode,
TypeDefinitionNode,
ScalarTypeDefinitionNode,
ObjectTypeDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
UnionTypeDefinitionNode,
EnumTypeDefinitionNode,
EnumValueDefinitionNode,
InputObjectTypeDefinitionNode,
DirectiveDefinitionNode,
TypeSystemExtensionNode,
SchemaExtensionNode,
TypeExtensionNode,
ScalarTypeExtensionNode,
ObjectTypeExtensionNode,
InterfaceTypeExtensionNode,
UnionTypeExtensionNode,
EnumTypeExtensionNode,
InputObjectTypeExtensionNode,
)
from .predicates import (
is_definition_node,
is_executable_definition_node,
is_selection_node,
is_value_node,
is_const_value_node,
is_type_node,
is_type_system_definition_node,
is_type_definition_node,
is_type_system_extension_node,
is_type_extension_node,
)
from .directive_locations import DirectiveLocation
__all__ = [
"get_location",
"SourceLocation",
"FormattedSourceLocation",
"print_location",
"print_source_location",
"TokenKind",
"Lexer",
"parse",
"parse_value",
"parse_const_value",
"parse_type",
"print_ast",
"Source",
"visit",
"Visitor",
"ParallelVisitor",
"VisitorAction",
"VisitorKeyMap",
"BREAK",
"SKIP",
"REMOVE",
"IDLE",
"Location",
"Token",
"DirectiveLocation",
"Node",
"NameNode",
"DocumentNode",
"DefinitionNode",
"ExecutableDefinitionNode",
"OperationDefinitionNode",
"OperationType",
"VariableDefinitionNode",
"VariableNode",
"SelectionSetNode",
"SelectionNode",
"FieldNode",
"ArgumentNode",
"ConstArgumentNode",
"FragmentSpreadNode",
"InlineFragmentNode",
"FragmentDefinitionNode",
"ValueNode",
"ConstValueNode",
"IntValueNode",
"FloatValueNode",
"StringValueNode",
"BooleanValueNode",
"NullValueNode",
"EnumValueNode",
"ListValueNode",
"ConstListValueNode",
"ObjectValueNode",
"ConstObjectValueNode",
"ObjectFieldNode",
"ConstObjectFieldNode",
"DirectiveNode",
"ConstDirectiveNode",
"TypeNode",
"NamedTypeNode",
"ListTypeNode",
"NonNullTypeNode",
"TypeSystemDefinitionNode",
"SchemaDefinitionNode",
"OperationTypeDefinitionNode",
"TypeDefinitionNode",
"ScalarTypeDefinitionNode",
"ObjectTypeDefinitionNode",
"FieldDefinitionNode",
"InputValueDefinitionNode",
"InterfaceTypeDefinitionNode",
"UnionTypeDefinitionNode",
"EnumTypeDefinitionNode",
"EnumValueDefinitionNode",
"InputObjectTypeDefinitionNode",
"DirectiveDefinitionNode",
"TypeSystemExtensionNode",
"SchemaExtensionNode",
"TypeExtensionNode",
"ScalarTypeExtensionNode",
"ObjectTypeExtensionNode",
"InterfaceTypeExtensionNode",
"UnionTypeExtensionNode",
"EnumTypeExtensionNode",
"InputObjectTypeExtensionNode",
"is_definition_node",
"is_executable_definition_node",
"is_selection_node",
"is_value_node",
"is_const_value_node",
"is_type_node",
"is_type_system_definition_node",
"is_type_definition_node",
"is_type_system_extension_node",
"is_type_extension_node",
]
@@ -0,0 +1,807 @@
from copy import copy, deepcopy
from enum import Enum
from typing import Any, Dict, List, Tuple, Optional, Union
from .source import Source
from .token_kind import TokenKind
from ..pyutils import camel_to_snake
__all__ = [
"Location",
"Token",
"Node",
"NameNode",
"DocumentNode",
"DefinitionNode",
"ExecutableDefinitionNode",
"OperationDefinitionNode",
"VariableDefinitionNode",
"SelectionSetNode",
"SelectionNode",
"FieldNode",
"ArgumentNode",
"ConstArgumentNode",
"FragmentSpreadNode",
"InlineFragmentNode",
"FragmentDefinitionNode",
"ValueNode",
"ConstValueNode",
"VariableNode",
"IntValueNode",
"FloatValueNode",
"StringValueNode",
"BooleanValueNode",
"NullValueNode",
"EnumValueNode",
"ListValueNode",
"ConstListValueNode",
"ObjectValueNode",
"ConstObjectValueNode",
"ObjectFieldNode",
"ConstObjectFieldNode",
"DirectiveNode",
"ConstDirectiveNode",
"TypeNode",
"NamedTypeNode",
"ListTypeNode",
"NonNullTypeNode",
"TypeSystemDefinitionNode",
"SchemaDefinitionNode",
"OperationType",
"OperationTypeDefinitionNode",
"TypeDefinitionNode",
"ScalarTypeDefinitionNode",
"ObjectTypeDefinitionNode",
"FieldDefinitionNode",
"InputValueDefinitionNode",
"InterfaceTypeDefinitionNode",
"UnionTypeDefinitionNode",
"EnumTypeDefinitionNode",
"EnumValueDefinitionNode",
"InputObjectTypeDefinitionNode",
"DirectiveDefinitionNode",
"SchemaExtensionNode",
"TypeExtensionNode",
"TypeSystemExtensionNode",
"ScalarTypeExtensionNode",
"ObjectTypeExtensionNode",
"InterfaceTypeExtensionNode",
"UnionTypeExtensionNode",
"EnumTypeExtensionNode",
"InputObjectTypeExtensionNode",
"QUERY_DOCUMENT_KEYS",
]
class Token:
"""AST Token
Represents a range of characters represented by a lexical token within a Source.
"""
__slots__ = "kind", "start", "end", "line", "column", "prev", "next", "value"
kind: TokenKind # the kind of token
start: int # the character offset at which this Node begins
end: int # the character offset at which this Node ends
line: int # the 1-indexed line number on which this Token appears
column: int # the 1-indexed column number at which this Token begins
# for non-punctuation tokens, represents the interpreted value of the token:
value: Optional[str]
# Tokens exist as nodes in a double-linked-list amongst all tokens including
# ignored tokens. <SOF> is always the first node and <EOF> the last.
prev: Optional["Token"]
next: Optional["Token"]
def __init__(
self,
kind: TokenKind,
start: int,
end: int,
line: int,
column: int,
value: Optional[str] = None,
) -> None:
self.kind = kind
self.start, self.end = start, end
self.line, self.column = line, column
self.value = value
self.prev = self.next = None
def __str__(self) -> str:
return self.desc
def __repr__(self) -> str:
"""Print a simplified form when appearing in repr() or inspect()."""
return f"<Token {self.desc} {self.line}:{self.column}>"
def __inspect__(self) -> str:
return repr(self)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Token):
return (
self.kind == other.kind
and self.start == other.start
and self.end == other.end
and self.line == other.line
and self.column == other.column
and self.value == other.value
)
elif isinstance(other, str):
return other == self.desc
return False
def __hash__(self) -> int:
return hash(
(self.kind, self.start, self.end, self.line, self.column, self.value)
)
def __copy__(self) -> "Token":
"""Create a shallow copy of the token"""
token = self.__class__(
self.kind,
self.start,
self.end,
self.line,
self.column,
self.value,
)
token.prev = self.prev
return token
def __deepcopy__(self, memo: Dict) -> "Token":
"""Allow only shallow copies to avoid recursion."""
return copy(self)
def __getstate__(self) -> Dict[str, Any]:
"""Remove the links when pickling.
Keeping the links would make pickling a schema too expensive.
"""
return {
key: getattr(self, key)
for key in self.__slots__
if key not in {"prev", "next"}
}
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Reset the links when un-pickling."""
for key, value in state.items():
setattr(self, key, value)
self.prev = self.next = None
@property
def desc(self) -> str:
"""A helper property to describe a token as a string for debugging"""
kind, value = self.kind.value, self.value
return f"{kind} {value!r}" if value else kind
class Location:
"""AST Location
Contains a range of UTF-8 character offsets and token references that identify the
region of the source from which the AST derived.
"""
__slots__ = (
"start",
"end",
"start_token",
"end_token",
"source",
)
start: int # character offset at which this Node begins
end: int # character offset at which this Node ends
start_token: Token # Token at which this Node begins
end_token: Token # Token at which this Node ends.
source: Source # Source document the AST represents
def __init__(self, start_token: Token, end_token: Token, source: Source) -> None:
self.start = start_token.start
self.end = end_token.end
self.start_token = start_token
self.end_token = end_token
self.source = source
def __str__(self) -> str:
return f"{self.start}:{self.end}"
def __repr__(self) -> str:
"""Print a simplified form when appearing in repr() or inspect()."""
return f"<Location {self.start}:{self.end}>"
def __inspect__(self) -> str:
return repr(self)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Location):
return self.start == other.start and self.end == other.end
elif isinstance(other, (list, tuple)) and len(other) == 2:
return self.start == other[0] and self.end == other[1]
return False
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self) -> int:
return hash((self.start, self.end))
class OperationType(Enum):
QUERY = "query"
MUTATION = "mutation"
SUBSCRIPTION = "subscription"
# Default map from node kinds to their node attributes (internal)
QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
"name": (),
"document": ("definitions",),
"operation_definition": (
"name",
"variable_definitions",
"directives",
"selection_set",
),
"variable_definition": ("variable", "type", "default_value", "directives"),
"variable": ("name",),
"selection_set": ("selections",),
"field": ("alias", "name", "arguments", "directives", "selection_set"),
"argument": ("name", "value"),
"fragment_spread": ("name", "directives"),
"inline_fragment": ("type_condition", "directives", "selection_set"),
"fragment_definition": (
# Note: fragment variable definitions are deprecated and will be removed in v3.3
"name",
"variable_definitions",
"type_condition",
"directives",
"selection_set",
),
"list_value": ("values",),
"object_value": ("fields",),
"object_field": ("name", "value"),
"directive": ("name", "arguments"),
"named_type": ("name",),
"list_type": ("type",),
"non_null_type": ("type",),
"schema_definition": ("description", "directives", "operation_types"),
"operation_type_definition": ("type",),
"scalar_type_definition": ("description", "name", "directives"),
"object_type_definition": (
"description",
"name",
"interfaces",
"directives",
"fields",
),
"field_definition": ("description", "name", "arguments", "type", "directives"),
"input_value_definition": (
"description",
"name",
"type",
"default_value",
"directives",
),
"interface_type_definition": (
"description",
"name",
"interfaces",
"directives",
"fields",
),
"union_type_definition": ("description", "name", "directives", "types"),
"enum_type_definition": ("description", "name", "directives", "values"),
"enum_value_definition": ("description", "name", "directives"),
"input_object_type_definition": ("description", "name", "directives", "fields"),
"directive_definition": ("description", "name", "arguments", "locations"),
"schema_extension": ("directives", "operation_types"),
"scalar_type_extension": ("name", "directives"),
"object_type_extension": ("name", "interfaces", "directives", "fields"),
"interface_type_extension": ("name", "interfaces", "directives", "fields"),
"union_type_extension": ("name", "directives", "types"),
"enum_type_extension": ("name", "directives", "values"),
"input_object_type_extension": ("name", "directives", "fields"),
}
# Base AST Node
class Node:
"""AST nodes"""
# allow custom attributes and weak references (not used internally)
__slots__ = "__dict__", "__weakref__", "loc", "_hash"
loc: Optional[Location]
kind: str = "ast" # the kind of the node as a snake_case string
keys: Tuple[str, ...] = ("loc",) # the names of the attributes of this node
def __init__(self, **kwargs: Any) -> None:
"""Initialize the node with the given keyword arguments."""
for key in self.keys:
value = kwargs.get(key)
if isinstance(value, list):
value = tuple(value)
setattr(self, key, value)
def __repr__(self) -> str:
"""Get a simple representation of the node."""
name, loc = self.__class__.__name__, getattr(self, "loc", None)
return f"{name} at {loc}" if loc else name
def __eq__(self, other: Any) -> bool:
"""Test whether two nodes are equal (recursively)."""
return (
isinstance(other, Node)
and self.__class__ == other.__class__
and all(getattr(self, key) == getattr(other, key) for key in self.keys)
)
def __hash__(self) -> int:
"""Get a cached hash value for the node."""
# Caching the hash values improves the performance of AST validators
hashed = getattr(self, "_hash", None)
if hashed is None:
self._hash = id(self) # avoid recursion
hashed = hash(tuple(getattr(self, key) for key in self.keys))
self._hash = hashed
return hashed
def __setattr__(self, key: str, value: Any) -> None:
# reset cashed hash value if attributes are changed
if hasattr(self, "_hash") and key in self.keys:
del self._hash
super().__setattr__(key, value)
def __copy__(self) -> "Node":
"""Create a shallow copy of the node."""
return self.__class__(**{key: getattr(self, key) for key in self.keys})
def __deepcopy__(self, memo: Dict) -> "Node":
"""Create a deep copy of the node"""
# noinspection PyArgumentList
return self.__class__(
**{key: deepcopy(getattr(self, key), memo) for key in self.keys}
)
def __init_subclass__(cls) -> None:
super().__init_subclass__()
name = cls.__name__
try:
name = name.removeprefix("Const").removesuffix("Node")
except AttributeError: # pragma: no cover (Python < 3.9)
if name.startswith("Const"):
name = name[5:]
if name.endswith("Node"):
name = name[:-4]
cls.kind = camel_to_snake(name)
keys: List[str] = []
for base in cls.__bases__:
# noinspection PyUnresolvedReferences
keys.extend(base.keys) # type: ignore
keys.extend(cls.__slots__)
cls.keys = tuple(keys)
def to_dict(self, locations: bool = False) -> Dict:
from ..utilities import ast_to_dict
return ast_to_dict(self, locations)
# Name
class NameNode(Node):
__slots__ = ("value",)
value: str
# Document
class DocumentNode(Node):
__slots__ = ("definitions",)
definitions: Tuple["DefinitionNode", ...]
class DefinitionNode(Node):
__slots__ = ()
class ExecutableDefinitionNode(DefinitionNode):
__slots__ = "name", "directives", "variable_definitions", "selection_set"
name: Optional[NameNode]
directives: Tuple["DirectiveNode", ...]
variable_definitions: Tuple["VariableDefinitionNode", ...]
selection_set: "SelectionSetNode"
class OperationDefinitionNode(ExecutableDefinitionNode):
__slots__ = ("operation",)
operation: OperationType
class VariableDefinitionNode(Node):
__slots__ = "variable", "type", "default_value", "directives"
variable: "VariableNode"
type: "TypeNode"
default_value: Optional["ConstValueNode"]
directives: Tuple["ConstDirectiveNode", ...]
class SelectionSetNode(Node):
__slots__ = ("selections",)
selections: Tuple["SelectionNode", ...]
class SelectionNode(Node):
__slots__ = ("directives",)
directives: Tuple["DirectiveNode", ...]
class FieldNode(SelectionNode):
__slots__ = "alias", "name", "arguments", "selection_set"
alias: Optional[NameNode]
name: NameNode
arguments: Tuple["ArgumentNode", ...]
selection_set: Optional[SelectionSetNode]
class ArgumentNode(Node):
__slots__ = "name", "value"
name: NameNode
value: "ValueNode"
class ConstArgumentNode(ArgumentNode):
value: "ConstValueNode"
# Fragments
class FragmentSpreadNode(SelectionNode):
__slots__ = ("name",)
name: NameNode
class InlineFragmentNode(SelectionNode):
__slots__ = "type_condition", "selection_set"
type_condition: "NamedTypeNode"
selection_set: SelectionSetNode
class FragmentDefinitionNode(ExecutableDefinitionNode):
__slots__ = ("type_condition",)
name: NameNode
type_condition: "NamedTypeNode"
# Values
class ValueNode(Node):
__slots__ = ()
class VariableNode(ValueNode):
__slots__ = ("name",)
name: NameNode
class IntValueNode(ValueNode):
__slots__ = ("value",)
value: str
class FloatValueNode(ValueNode):
__slots__ = ("value",)
value: str
class StringValueNode(ValueNode):
__slots__ = "value", "block"
value: str
block: Optional[bool]
class BooleanValueNode(ValueNode):
__slots__ = ("value",)
value: bool
class NullValueNode(ValueNode):
__slots__ = ()
class EnumValueNode(ValueNode):
__slots__ = ("value",)
value: str
class ListValueNode(ValueNode):
__slots__ = ("values",)
values: Tuple[ValueNode, ...]
class ConstListValueNode(ListValueNode):
values: Tuple["ConstValueNode", ...]
class ObjectValueNode(ValueNode):
__slots__ = ("fields",)
fields: Tuple["ObjectFieldNode", ...]
class ConstObjectValueNode(ObjectValueNode):
fields: Tuple["ConstObjectFieldNode", ...]
class ObjectFieldNode(Node):
__slots__ = "name", "value"
name: NameNode
value: ValueNode
class ConstObjectFieldNode(ObjectFieldNode):
value: "ConstValueNode"
ConstValueNode = Union[
IntValueNode,
FloatValueNode,
StringValueNode,
BooleanValueNode,
NullValueNode,
EnumValueNode,
ConstListValueNode,
ConstObjectValueNode,
]
# Directives
class DirectiveNode(Node):
__slots__ = "name", "arguments"
name: NameNode
arguments: Tuple[ArgumentNode, ...]
class ConstDirectiveNode(DirectiveNode):
arguments: Tuple[ConstArgumentNode, ...]
# Type Reference
class TypeNode(Node):
__slots__ = ()
class NamedTypeNode(TypeNode):
__slots__ = ("name",)
name: NameNode
class ListTypeNode(TypeNode):
__slots__ = ("type",)
type: TypeNode
class NonNullTypeNode(TypeNode):
__slots__ = ("type",)
type: Union[NamedTypeNode, ListTypeNode]
# Type System Definition
class TypeSystemDefinitionNode(DefinitionNode):
__slots__ = ()
class SchemaDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "directives", "operation_types"
description: Optional[StringValueNode]
directives: Tuple[ConstDirectiveNode, ...]
operation_types: Tuple["OperationTypeDefinitionNode", ...]
class OperationTypeDefinitionNode(Node):
__slots__ = "operation", "type"
operation: OperationType
type: NamedTypeNode
# Type Definition
class TypeDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "name", "directives"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[DirectiveNode, ...]
class ScalarTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ()
directives: Tuple[ConstDirectiveNode, ...]
class ObjectTypeDefinitionNode(TypeDefinitionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
directives: Tuple[ConstDirectiveNode, ...]
fields: Tuple["FieldDefinitionNode", ...]
class FieldDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives", "arguments", "type"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
arguments: Tuple["InputValueDefinitionNode", ...]
type: TypeNode
class InputValueDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives", "type", "default_value"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
type: TypeNode
default_value: Optional[ConstValueNode]
class InterfaceTypeDefinitionNode(TypeDefinitionNode):
__slots__ = "fields", "interfaces"
fields: Tuple["FieldDefinitionNode", ...]
directives: Tuple[ConstDirectiveNode, ...]
interfaces: Tuple[NamedTypeNode, ...]
class UnionTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("types",)
directives: Tuple[ConstDirectiveNode, ...]
types: Tuple[NamedTypeNode, ...]
class EnumTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("values",)
directives: Tuple[ConstDirectiveNode, ...]
values: Tuple["EnumValueDefinitionNode", ...]
class EnumValueDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
class InputObjectTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("fields",)
directives: Tuple[ConstDirectiveNode, ...]
fields: Tuple[InputValueDefinitionNode, ...]
# Directive Definitions
class DirectiveDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "name", "arguments", "repeatable", "locations"
description: Optional[StringValueNode]
name: NameNode
arguments: Tuple[InputValueDefinitionNode, ...]
repeatable: bool
locations: Tuple[NameNode, ...]
# Type System Extensions
class SchemaExtensionNode(Node):
__slots__ = "directives", "operation_types"
directives: Tuple[ConstDirectiveNode, ...]
operation_types: Tuple[OperationTypeDefinitionNode, ...]
# Type Extensions
class TypeExtensionNode(TypeSystemDefinitionNode):
__slots__ = "name", "directives"
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
TypeSystemExtensionNode = Union[SchemaExtensionNode, TypeExtensionNode]
class ScalarTypeExtensionNode(TypeExtensionNode):
__slots__ = ()
class ObjectTypeExtensionNode(TypeExtensionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
fields: Tuple[FieldDefinitionNode, ...]
class InterfaceTypeExtensionNode(TypeExtensionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
fields: Tuple[FieldDefinitionNode, ...]
class UnionTypeExtensionNode(TypeExtensionNode):
__slots__ = ("types",)
types: Tuple[NamedTypeNode, ...]
class EnumTypeExtensionNode(TypeExtensionNode):
__slots__ = ("values",)
values: Tuple[EnumValueDefinitionNode, ...]
class InputObjectTypeExtensionNode(TypeExtensionNode):
__slots__ = ("fields",)
fields: Tuple[InputValueDefinitionNode, ...]
@@ -0,0 +1,155 @@
from typing import Collection, List
from sys import maxsize
__all__ = [
"dedent_block_string_lines",
"is_printable_as_block_string",
"print_block_string",
]
def dedent_block_string_lines(lines: Collection[str]) -> List[str]:
"""Produce the value of a block string from its parsed raw value.
This function works similar to CoffeeScript's block string,
Python's docstring trim or Ruby's strip_heredoc.
It implements the GraphQL spec's BlockStringValue() static algorithm.
Note that this is very similar to Python's inspect.cleandoc() function.
The difference is that the latter also expands tabs to spaces and
removes whitespace at the beginning of the first line. Python also has
textwrap.dedent() which uses a completely different algorithm.
For internal use only.
"""
common_indent = maxsize
first_non_empty_line = None
last_non_empty_line = -1
for i, line in enumerate(lines):
indent = leading_white_space(line)
if indent == len(line):
continue # skip empty lines
if first_non_empty_line is None:
first_non_empty_line = i
last_non_empty_line = i
if i and indent < common_indent:
common_indent = indent
if first_non_empty_line is None:
first_non_empty_line = 0
return [ # Remove common indentation from all lines but first.
line[common_indent:] if i else line for i, line in enumerate(lines)
][ # Remove leading and trailing blank lines.
first_non_empty_line : last_non_empty_line + 1
]
def leading_white_space(s: str) -> int:
i = 0
for c in s:
if c not in " \t":
return i
i += 1
return i
def is_printable_as_block_string(value: str) -> bool:
"""Check whether the given string is printable as a block string.
For internal use only.
"""
if not isinstance(value, str):
value = str(value) # resolve lazy string proxy object
if not value:
return True # emtpy string is printable
is_empty_line = True
has_indent = False
has_common_indent = True
seen_non_empty_line = False
for c in value:
if c == "\n":
if is_empty_line and not seen_non_empty_line:
return False # has leading new line
seen_non_empty_line = True
is_empty_line = True
has_indent = False
elif c in " \t":
has_indent = has_indent or is_empty_line
elif c <= "\x0f":
return False
else:
has_common_indent = has_common_indent and has_indent
is_empty_line = False
if is_empty_line:
return False # has trailing empty lines
if has_common_indent and seen_non_empty_line:
return False # has internal indent
return True
def print_block_string(value: str, minimize: bool = False) -> str:
"""Print a block string in the indented block form.
Prints a block string in the indented block form by adding a leading and
trailing blank line. However, if a block string starts with whitespace and
is a single-line, adding a leading blank line would strip that whitespace.
For internal use only.
"""
if not isinstance(value, str):
value = str(value) # resolve lazy string proxy object
escaped_value = value.replace('"""', '\\"""')
# Expand a block string's raw value into independent lines.
lines = escaped_value.splitlines() or [""]
num_lines = len(lines)
is_single_line = num_lines == 1
# If common indentation is found,
# we can fix some of those cases by adding a leading new line.
force_leading_new_line = num_lines > 1 and all(
not line or line[0] in " \t" for line in lines[1:]
)
# Trailing triple quotes just looks confusing but doesn't force trailing new line.
has_trailing_triple_quotes = escaped_value.endswith('\\"""')
# Trailing quote (single or double) or slash forces trailing new line
has_trailing_quote = value.endswith('"') and not has_trailing_triple_quotes
has_trailing_slash = value.endswith("\\")
force_trailing_new_line = has_trailing_quote or has_trailing_slash
print_as_multiple_lines = not minimize and (
# add leading and trailing new lines only if it improves readability
not is_single_line
or len(value) > 70
or force_trailing_new_line
or force_leading_new_line
or has_trailing_triple_quotes
)
# Format a multi-line block quote to account for leading space.
skip_leading_new_line = is_single_line and value and value[0] in " \t"
before = (
"\n"
if print_as_multiple_lines
and not skip_leading_new_line
or force_leading_new_line
else ""
)
after = "\n" if print_as_multiple_lines or force_trailing_new_line else ""
return f'"""{before}{escaped_value}{after}"""'
@@ -0,0 +1,68 @@
__all__ = ["is_digit", "is_letter", "is_name_start", "is_name_continue"]
try:
"string".isascii()
except AttributeError: # Python < 3.7
def is_digit(char: str) -> bool:
"""Check whether char is a digit
For internal use by the lexer only.
"""
return "0" <= char <= "9"
def is_letter(char: str) -> bool:
"""Check whether char is a plain ASCII letter
For internal use by the lexer only.
"""
return "a" <= char <= "z" or "A" <= char <= "Z"
def is_name_start(char: str) -> bool:
"""Check whether char is allowed at the beginning of a GraphQL name
For internal use by the lexer only.
"""
return "a" <= char <= "z" or "A" <= char <= "Z" or char == "_"
def is_name_continue(char: str) -> bool:
"""Check whether char is allowed in the continuation of a GraphQL name
For internal use by the lexer only.
"""
return (
"a" <= char <= "z"
or "A" <= char <= "Z"
or "0" <= char <= "9"
or char == "_"
)
else:
def is_digit(char: str) -> bool:
"""Check whether char is a digit
For internal use by the lexer only.
"""
return char.isascii() and char.isdigit()
def is_letter(char: str) -> bool:
"""Check whether char is a plain ASCII letter
For internal use by the lexer only.
"""
return char.isascii() and char.isalpha()
def is_name_start(char: str) -> bool:
"""Check whether char is allowed at the beginning of a GraphQL name
For internal use by the lexer only.
"""
return char.isascii() and (char.isalpha() or char == "_")
def is_name_continue(char: str) -> bool:
"""Check whether char is allowed in the continuation of a GraphQL name
For internal use by the lexer only.
"""
return char.isascii() and (char.isalnum() or char == "_")
@@ -0,0 +1,30 @@
from enum import Enum
__all__ = ["DirectiveLocation"]
class DirectiveLocation(Enum):
"""The enum type representing the directive location values."""
# Request Definitions
QUERY = "query"
MUTATION = "mutation"
SUBSCRIPTION = "subscription"
FIELD = "field"
FRAGMENT_DEFINITION = "fragment definition"
FRAGMENT_SPREAD = "fragment spread"
VARIABLE_DEFINITION = "variable definition"
INLINE_FRAGMENT = "inline fragment"
# Type System Definitions
SCHEMA = "schema"
SCALAR = "scalar"
OBJECT = "object"
FIELD_DEFINITION = "field definition"
ARGUMENT_DEFINITION = "argument definition"
INTERFACE = "interface"
UNION = "union"
ENUM = "enum"
ENUM_VALUE = "enum value"
INPUT_OBJECT = "input object"
INPUT_FIELD_DEFINITION = "input field definition"
@@ -0,0 +1,574 @@
from typing import List, NamedTuple, Optional
from ..error import GraphQLSyntaxError
from .ast import Token
from .block_string import dedent_block_string_lines
from .character_classes import is_digit, is_name_start, is_name_continue
from .source import Source
from .token_kind import TokenKind
__all__ = ["Lexer", "is_punctuator_token_kind"]
class EscapeSequence(NamedTuple):
"""The string value and lexed size of an escape sequence."""
value: str
size: int
class Lexer:
"""GraphQL Lexer
A Lexer is a stateful stream generator in that every time it is advanced, it returns
the next token in the Source. Assuming the source lexes, the final Token emitted by
the lexer will be of kind EOF, after which the lexer will repeatedly return the same
EOF token whenever called.
"""
def __init__(self, source: Source):
"""Given a Source object, initialize a Lexer for that source."""
self.source = source
self.token = self.last_token = Token(TokenKind.SOF, 0, 0, 0, 0)
self.line, self.line_start = 1, 0
def advance(self) -> Token:
"""Advance the token stream to the next non-ignored token."""
self.last_token = self.token
token = self.token = self.lookahead()
return token
def lookahead(self) -> Token:
"""Look ahead and return the next non-ignored token, but do not change state."""
token = self.token
if token.kind != TokenKind.EOF:
while True:
if token.next:
token = token.next
else:
# Read the next token and form a link in the token linked-list.
next_token = self.read_next_token(token.end)
token.next = next_token
next_token.prev = token
token = next_token
if token.kind != TokenKind.COMMENT:
break
return token
def print_code_point_at(self, location: int) -> str:
"""Print the code point at the given location.
Prints the code point (or end of file reference) at a given location in a
source for use in error messages.
Printable ASCII is printed quoted, while other points are printed in Unicode
code point form (ie. U+1234).
"""
body = self.source.body
if location >= len(body):
return TokenKind.EOF.value
char = body[location]
# Printable ASCII
if "\x20" <= char <= "\x7E":
return "'\"'" if char == '"' else f"'{char}'"
# Unicode code point
point = ord(
body[location : location + 2]
.encode("utf-16", "surrogatepass")
.decode("utf-16")
if is_supplementary_code_point(body, location)
else char
)
return f"U+{point:04X}"
def create_token(
self, kind: TokenKind, start: int, end: int, value: Optional[str] = None
) -> Token:
"""Create a token with line and column location information."""
line = self.line
col = 1 + start - self.line_start
return Token(kind, start, end, line, col, value)
def read_next_token(self, start: int) -> Token:
"""Get the next token from the source starting at the given position.
This skips over whitespace until it finds the next lexable token, then lexes
punctuators immediately or calls the appropriate helper function for more
complicated tokens.
"""
body = self.source.body
body_length = len(body)
position = start
while position < body_length:
char = body[position] # SourceCharacter
if char in " \t,\ufeff":
position += 1
continue
elif char == "\n":
position += 1
self.line += 1
self.line_start = position
continue
elif char == "\r":
if body[position + 1 : position + 2] == "\n":
position += 2
else:
position += 1
self.line += 1
self.line_start = position
continue
if char == "#":
return self.read_comment(position)
if char == '"':
if body[position + 1 : position + 3] == '""':
return self.read_block_string(position)
return self.read_string(position)
kind = _KIND_FOR_PUNCT.get(char)
if kind:
return self.create_token(kind, position, position + 1)
if is_digit(char) or char == "-":
return self.read_number(position, char)
if is_name_start(char):
return self.read_name(position)
if char == ".":
if body[position + 1 : position + 3] == "..":
return self.create_token(TokenKind.SPREAD, position, position + 3)
message = (
"Unexpected single quote character ('),"
' did you mean to use a double quote (")?'
if char == "'"
else (
f"Unexpected character: {self.print_code_point_at(position)}."
if is_unicode_scalar_value(char)
or is_supplementary_code_point(body, position)
else f"Invalid character: {self.print_code_point_at(position)}."
)
)
raise GraphQLSyntaxError(self.source, position, message)
return self.create_token(TokenKind.EOF, body_length, body_length)
def read_comment(self, start: int) -> Token:
"""Read a comment token from the source file."""
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length:
char = body[position]
if char in "\r\n":
break
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
break # pragma: no cover
return self.create_token(
TokenKind.COMMENT,
start,
position,
body[start + 1 : position],
)
def read_number(self, start: int, first_char: str) -> Token:
"""Reads a number token from the source file.
This can be either a FloatValue or an IntValue,
depending on whether a FractionalPart or ExponentPart is encountered.
"""
body = self.source.body
position = start
char = first_char
is_float = False
if char == "-":
position += 1
char = body[position : position + 1]
if char == "0":
position += 1
char = body[position : position + 1]
if is_digit(char):
raise GraphQLSyntaxError(
self.source,
position,
"Invalid number, unexpected digit after 0:"
f" {self.print_code_point_at(position)}.",
)
else:
position = self.read_digits(position, char)
char = body[position : position + 1]
if char == ".":
is_float = True
position += 1
char = body[position : position + 1]
position = self.read_digits(position, char)
char = body[position : position + 1]
if char and char in "Ee":
is_float = True
position += 1
char = body[position : position + 1]
if char and char in "+-":
position += 1
char = body[position : position + 1]
position = self.read_digits(position, char)
char = body[position : position + 1]
# Numbers cannot be followed by . or NameStart
if char and (char == "." or is_name_start(char)):
raise GraphQLSyntaxError(
self.source,
position,
"Invalid number, expected digit but got:"
f" {self.print_code_point_at(position)}.",
)
return self.create_token(
TokenKind.FLOAT if is_float else TokenKind.INT,
start,
position,
body[start:position],
)
def read_digits(self, start: int, first_char: str) -> int:
"""Return the new position in the source after reading one or more digits."""
if not is_digit(first_char):
raise GraphQLSyntaxError(
self.source,
start,
"Invalid number, expected digit but got:"
f" {self.print_code_point_at(start)}.",
)
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length and is_digit(body[position]):
position += 1
return position
def read_string(self, start: int) -> Token:
"""Read a single-quote string token from the source file."""
body = self.source.body
body_length = len(body)
position = start + 1
chunk_start = position
value: List[str] = []
append = value.append
while position < body_length:
char = body[position]
if char == '"':
append(body[chunk_start:position])
return self.create_token(
TokenKind.STRING,
start,
position + 1,
"".join(value),
)
if char == "\\":
append(body[chunk_start:position])
escape = (
(
self.read_escaped_unicode_variable_width(position)
if body[position + 2 : position + 3] == "{"
else self.read_escaped_unicode_fixed_width(position)
)
if body[position + 1 : position + 2] == "u"
else self.read_escaped_character(position)
)
append(escape.value)
position += escape.size
chunk_start = position
continue
if char in "\r\n":
break
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
raise GraphQLSyntaxError(
self.source,
position,
"Invalid character within String:"
f" {self.print_code_point_at(position)}.",
)
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
def read_escaped_unicode_variable_width(self, position: int) -> EscapeSequence:
body = self.source.body
point = 0
size = 3
max_size = min(12, len(body) - position)
# Cannot be larger than 12 chars (\u{00000000}).
while size < max_size:
char = body[position + size]
size += 1
if char == "}":
# Must be at least 5 chars (\u{0}) and encode a Unicode scalar value.
if size < 5 or not (
0 <= point <= 0xD7FF or 0xE000 <= point <= 0x10FFFF
):
break
return EscapeSequence(chr(point), size)
# Append this hex digit to the code point.
point = (point << 4) | read_hex_digit(char)
if point < 0:
break
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid Unicode escape sequence: '{body[position: position + size]}'.",
)
def read_escaped_unicode_fixed_width(self, position: int) -> EscapeSequence:
body = self.source.body
code = read_16_bit_hex_code(body, position + 2)
if 0 <= code <= 0xD7FF or 0xE000 <= code <= 0x10FFFF:
return EscapeSequence(chr(code), 6)
# GraphQL allows JSON-style surrogate pair escape sequences, but only when
# a valid pair is formed.
if 0xD800 <= code <= 0xDBFF:
if body[position + 6 : position + 8] == "\\u":
trailing_code = read_16_bit_hex_code(body, position + 8)
if 0xDC00 <= trailing_code <= 0xDFFF:
return EscapeSequence(
(chr(code) + chr(trailing_code))
.encode("utf-16", "surrogatepass")
.decode("utf-16"),
12,
)
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid Unicode escape sequence: '{body[position: position + 6]}'.",
)
def read_escaped_character(self, position: int) -> EscapeSequence:
body = self.source.body
value = _ESCAPED_CHARS.get(body[position + 1])
if value:
return EscapeSequence(value, 2)
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid character escape sequence: '{body[position: position + 2]}'.",
)
def read_block_string(self, start: int) -> Token:
"""Read a block string token from the source file."""
body = self.source.body
body_length = len(body)
line_start = self.line_start
position = start + 3
chunk_start = position
current_line = ""
block_lines = []
while position < body_length:
char = body[position]
if char == '"' and body[position + 1 : position + 3] == '""':
current_line += body[chunk_start:position]
block_lines.append(current_line)
token = self.create_token(
TokenKind.BLOCK_STRING,
start,
position + 3,
# return a string of the lines joined with new lines
"\n".join(dedent_block_string_lines(block_lines)),
)
self.line += len(block_lines) - 1
self.line_start = line_start
return token
if char == "\\" and body[position + 1 : position + 4] == '"""':
current_line += body[chunk_start:position]
chunk_start = position + 1 # skip only slash
position += 4
continue
if char in "\r\n":
current_line += body[chunk_start:position]
block_lines.append(current_line)
if char == "\r" and body[position + 1 : position + 2] == "\n":
position += 2
else:
position += 1
current_line = ""
chunk_start = line_start = position
continue
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
raise GraphQLSyntaxError(
self.source,
position,
"Invalid character within String:"
f" {self.print_code_point_at(position)}.",
)
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
def read_name(self, start: int) -> Token:
"""Read an alphanumeric + underscore name from the source."""
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length:
char = body[position]
if not is_name_continue(char):
break
position += 1
return self.create_token(TokenKind.NAME, start, position, body[start:position])
_punctuator_token_kinds = frozenset(
[
TokenKind.BANG,
TokenKind.DOLLAR,
TokenKind.AMP,
TokenKind.PAREN_L,
TokenKind.PAREN_R,
TokenKind.SPREAD,
TokenKind.COLON,
TokenKind.EQUALS,
TokenKind.AT,
TokenKind.BRACKET_L,
TokenKind.BRACKET_R,
TokenKind.BRACE_L,
TokenKind.PIPE,
TokenKind.BRACE_R,
]
)
def is_punctuator_token_kind(kind: TokenKind) -> bool:
"""Check whether the given token kind corresponds to a punctuator.
For internal use only.
"""
return kind in _punctuator_token_kinds
_KIND_FOR_PUNCT = {
"!": TokenKind.BANG,
"$": TokenKind.DOLLAR,
"&": TokenKind.AMP,
"(": TokenKind.PAREN_L,
")": TokenKind.PAREN_R,
":": TokenKind.COLON,
"=": TokenKind.EQUALS,
"@": TokenKind.AT,
"[": TokenKind.BRACKET_L,
"]": TokenKind.BRACKET_R,
"{": TokenKind.BRACE_L,
"}": TokenKind.BRACE_R,
"|": TokenKind.PIPE,
}
_ESCAPED_CHARS = {
'"': '"',
"/": "/",
"\\": "\\",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
}
def read_16_bit_hex_code(body: str, position: int) -> int:
"""Read a 16bit hexadecimal string and return its positive integer value (0-65535).
Reads four hexadecimal characters and returns the positive integer that 16bit
hexadecimal string represents. For example, "000f" will return 15, and "dead"
will return 57005.
Returns a negative number if any char was not a valid hexadecimal digit.
"""
# read_hex_digit() returns -1 on error. ORing a negative value with any other
# value always produces a negative value.
return (
read_hex_digit(body[position]) << 12
| read_hex_digit(body[position + 1]) << 8
| read_hex_digit(body[position + 2]) << 4
| read_hex_digit(body[position + 3])
)
def read_hex_digit(char: str) -> int:
"""Read a hexadecimal character and returns its positive integer value (0-15).
'0' becomes 0, '9' becomes 9
'A' becomes 10, 'F' becomes 15
'a' becomes 10, 'f' becomes 15
Returns -1 if the provided character code was not a valid hexadecimal digit.
"""
if "0" <= char <= "9":
return ord(char) - 48
elif "A" <= char <= "F":
return ord(char) - 55
elif "a" <= char <= "f":
return ord(char) - 87
return -1
def is_unicode_scalar_value(char: str) -> bool:
"""Check whether this is a Unicode scalar value.
A Unicode scalar value is any Unicode code point except surrogate code
points. In other words, the inclusive ranges of values 0x0000 to 0xD7FF and
0xE000 to 0x10FFFF.
"""
return "\x00" <= char <= "\ud7ff" or "\ue000" <= char <= "\U0010ffff"
def is_supplementary_code_point(body: str, location: int) -> bool:
"""
Check whether the current location is a supplementary code point.
The GraphQL specification defines source text as a sequence of unicode scalar
values (which Unicode defines to exclude surrogate code points).
"""
try:
return (
"\ud800" <= body[location] <= "\udbff"
and "\udc00" <= body[location + 1] <= "\udfff"
)
except IndexError:
return False
@@ -0,0 +1,46 @@
from typing import Any, NamedTuple, TYPE_CHECKING
try:
from typing import TypedDict
except ImportError: # Python < 3.8
from typing_extensions import TypedDict
if TYPE_CHECKING:
from .source import Source # noqa: F401
__all__ = ["get_location", "SourceLocation", "FormattedSourceLocation"]
class FormattedSourceLocation(TypedDict):
"""Formatted source location"""
line: int
column: int
class SourceLocation(NamedTuple):
"""Represents a location in a Source."""
line: int
column: int
@property
def formatted(self) -> FormattedSourceLocation:
return dict(line=self.line, column=self.column)
def __eq__(self, other: Any) -> bool:
if isinstance(other, dict):
return self.formatted == other
return tuple(self) == other
def __ne__(self, other: Any) -> bool:
return not self == other
def get_location(source: "Source", position: int) -> SourceLocation:
"""Get the line and column for a character position in the source.
Takes a Source and a UTF-8 character offset, and returns the corresponding line and
column as a SourceLocation.
"""
return source.get_location(position)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,86 @@
from .ast import (
Node,
DefinitionNode,
ExecutableDefinitionNode,
ListValueNode,
ObjectValueNode,
SchemaExtensionNode,
SelectionNode,
TypeDefinitionNode,
TypeExtensionNode,
TypeNode,
TypeSystemDefinitionNode,
ValueNode,
VariableNode,
)
__all__ = [
"is_definition_node",
"is_executable_definition_node",
"is_selection_node",
"is_value_node",
"is_const_value_node",
"is_type_node",
"is_type_system_definition_node",
"is_type_definition_node",
"is_type_system_extension_node",
"is_type_extension_node",
]
def is_definition_node(node: Node) -> bool:
"""Check whether the given node represents a definition."""
return isinstance(node, DefinitionNode)
def is_executable_definition_node(node: Node) -> bool:
"""Check whether the given node represents an executable definition."""
return isinstance(node, ExecutableDefinitionNode)
def is_selection_node(node: Node) -> bool:
"""Check whether the given node represents a selection."""
return isinstance(node, SelectionNode)
def is_value_node(node: Node) -> bool:
"""Check whether the given node represents a value."""
return isinstance(node, ValueNode)
def is_const_value_node(node: Node) -> bool:
"""Check whether the given node represents a constant value."""
return is_value_node(node) and (
any(is_const_value_node(value) for value in node.values)
if isinstance(node, ListValueNode)
else (
any(is_const_value_node(field.value) for field in node.fields)
if isinstance(node, ObjectValueNode)
else not isinstance(node, VariableNode)
)
)
def is_type_node(node: Node) -> bool:
"""Check whether the given node represents a type."""
return isinstance(node, TypeNode)
def is_type_system_definition_node(node: Node) -> bool:
"""Check whether the given node represents a type system definition."""
return isinstance(node, TypeSystemDefinitionNode)
def is_type_definition_node(node: Node) -> bool:
"""Check whether the given node represents a type definition."""
return isinstance(node, TypeDefinitionNode)
def is_type_system_extension_node(node: Node) -> bool:
"""Check whether the given node represents a type system extension."""
return isinstance(node, (SchemaExtensionNode, TypeExtensionNode))
def is_type_extension_node(node: Node) -> bool:
"""Check whether the given node represents a type extension."""
return isinstance(node, TypeExtensionNode)
@@ -0,0 +1,79 @@
import re
from typing import Optional, Tuple, cast
from .ast import Location
from .location import SourceLocation, get_location
from .source import Source
__all__ = ["print_location", "print_source_location"]
def print_location(location: Location) -> str:
"""Render a helpful description of the location in the GraphQL Source document."""
return print_source_location(
location.source, get_location(location.source, location.start)
)
_re_newline = re.compile(r"\r\n|[\n\r]")
def print_source_location(source: Source, source_location: SourceLocation) -> str:
"""Render a helpful description of the location in the GraphQL Source document."""
first_line_column_offset = source.location_offset.column - 1
body = "".rjust(first_line_column_offset) + source.body
line_index = source_location.line - 1
line_offset = source.location_offset.line - 1
line_num = source_location.line + line_offset
column_offset = first_line_column_offset if source_location.line == 1 else 0
column_num = source_location.column + column_offset
location_str = f"{source.name}:{line_num}:{column_num}\n"
lines = _re_newline.split(body) # works a bit different from splitlines()
location_line = lines[line_index]
# Special case for minified documents
if len(location_line) > 120:
sub_line_index, sub_line_column_num = divmod(column_num, 80)
sub_lines = [
location_line[i : i + 80] for i in range(0, len(location_line), 80)
]
return location_str + print_prefixed_lines(
(f"{line_num} |", sub_lines[0]),
*[("|", sub_line) for sub_line in sub_lines[1 : sub_line_index + 1]],
("|", "^".rjust(sub_line_column_num)),
(
"|",
(
sub_lines[sub_line_index + 1]
if sub_line_index < len(sub_lines) - 1
else None
),
),
)
return location_str + print_prefixed_lines(
(f"{line_num - 1} |", lines[line_index - 1] if line_index > 0 else None),
(f"{line_num} |", location_line),
("|", "^".rjust(column_num)),
(
f"{line_num + 1} |",
lines[line_index + 1] if line_index < len(lines) - 1 else None,
),
)
def print_prefixed_lines(*lines: Tuple[str, Optional[str]]) -> str:
"""Print lines specified like this: ("prefix", "string")"""
existing_lines = [
cast(Tuple[str, str], line) for line in lines if line[1] is not None
]
pad_len = max(len(line[0]) for line in existing_lines)
return "\n".join(
prefix.rjust(pad_len) + (" " + line if line else "")
for prefix, line in existing_lines
)
@@ -0,0 +1,83 @@
__all__ = ["print_string"]
def print_string(s: str) -> str:
"""Print a string as a GraphQL StringValue literal.
Replaces control characters and excluded characters (" U+0022 and \\ U+005C)
with escape sequences.
"""
if not isinstance(s, str):
s = str(s)
return f'"{s.translate(escape_sequences)}"'
escape_sequences = {
0x00: "\\u0000",
0x01: "\\u0001",
0x02: "\\u0002",
0x03: "\\u0003",
0x04: "\\u0004",
0x05: "\\u0005",
0x06: "\\u0006",
0x07: "\\u0007",
0x08: "\\b",
0x09: "\\t",
0x0A: "\\n",
0x0B: "\\u000B",
0x0C: "\\f",
0x0D: "\\r",
0x0E: "\\u000E",
0x0F: "\\u000F",
0x10: "\\u0010",
0x11: "\\u0011",
0x12: "\\u0012",
0x13: "\\u0013",
0x14: "\\u0014",
0x15: "\\u0015",
0x16: "\\u0016",
0x17: "\\u0017",
0x18: "\\u0018",
0x19: "\\u0019",
0x1A: "\\u001A",
0x1B: "\\u001B",
0x1C: "\\u001C",
0x1D: "\\u001D",
0x1E: "\\u001E",
0x1F: "\\u001F",
0x22: '\\"',
0x5C: "\\\\",
0x7F: "\\u007F",
0x80: "\\u0080",
0x81: "\\u0081",
0x82: "\\u0082",
0x83: "\\u0083",
0x84: "\\u0084",
0x85: "\\u0085",
0x86: "\\u0086",
0x87: "\\u0087",
0x88: "\\u0088",
0x89: "\\u0089",
0x8A: "\\u008A",
0x8B: "\\u008B",
0x8C: "\\u008C",
0x8D: "\\u008D",
0x8E: "\\u008E",
0x8F: "\\u008F",
0x90: "\\u0090",
0x91: "\\u0091",
0x92: "\\u0092",
0x93: "\\u0093",
0x94: "\\u0094",
0x95: "\\u0095",
0x96: "\\u0096",
0x97: "\\u0097",
0x98: "\\u0098",
0x99: "\\u0099",
0x9A: "\\u009A",
0x9B: "\\u009B",
0x9C: "\\u009C",
0x9D: "\\u009D",
0x9E: "\\u009E",
0x9F: "\\u009F",
}
@@ -0,0 +1,428 @@
from typing import Any, Collection, Optional
from ..language.ast import Node, OperationType
from .block_string import print_block_string
from .print_string import print_string
from .visitor import visit, Visitor
__all__ = ["print_ast"]
MAX_LINE_LENGTH = 80
Strings = Collection[str]
class PrintedNode:
"""A union type for all nodes that have been processed by the printer."""
alias: str
arguments: Strings
block: bool
default_value: str
definitions: Strings
description: str
directives: str
fields: Strings
interfaces: Strings
locations: Strings
name: str
operation: OperationType
operation_types: Strings
repeatable: bool
selection_set: str
selections: Strings
type: str
type_condition: str
types: Strings
value: str
values: Strings
variable: str
variable_definitions: Strings
def print_ast(ast: Node) -> str:
"""Convert an AST into a string.
The conversion is done using a set of reasonable formatting rules.
"""
return visit(ast, PrintAstVisitor())
class PrintAstVisitor(Visitor):
@staticmethod
def leave_name(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_variable(node: PrintedNode, *_args: Any) -> str:
return f"${node.name}"
# Document
@staticmethod
def leave_document(node: PrintedNode, *_args: Any) -> str:
return join(node.definitions, "\n\n")
@staticmethod
def leave_operation_definition(node: PrintedNode, *_args: Any) -> str:
var_defs = wrap("(", join(node.variable_definitions, ", "), ")")
prefix = join(
(
node.operation.value,
join((node.name, var_defs)),
join(node.directives, " "),
),
" ",
)
# Anonymous queries with no directives or variable definitions can use the
# query short form.
return ("" if prefix == "query" else prefix + " ") + node.selection_set
@staticmethod
def leave_variable_definition(node: PrintedNode, *_args: Any) -> str:
return (
f"{node.variable}: {node.type}"
f"{wrap(' = ', node.default_value)}"
f"{wrap(' ', join(node.directives, ' '))}"
)
@staticmethod
def leave_selection_set(node: PrintedNode, *_args: Any) -> str:
return block(node.selections)
@staticmethod
def leave_field(node: PrintedNode, *_args: Any) -> str:
prefix = wrap("", node.alias, ": ") + node.name
args_line = prefix + wrap("(", join(node.arguments, ", "), ")")
if len(args_line) > MAX_LINE_LENGTH:
args_line = prefix + wrap("(\n", indent(join(node.arguments, "\n")), "\n)")
return join((args_line, join(node.directives, " "), node.selection_set), " ")
@staticmethod
def leave_argument(node: PrintedNode, *_args: Any) -> str:
return f"{node.name}: {node.value}"
# Fragments
@staticmethod
def leave_fragment_spread(node: PrintedNode, *_args: Any) -> str:
return f"...{node.name}{wrap(' ', join(node.directives, ' '))}"
@staticmethod
def leave_inline_fragment(node: PrintedNode, *_args: Any) -> str:
return join(
(
"...",
wrap("on ", node.type_condition),
join(node.directives, " "),
node.selection_set,
),
" ",
)
@staticmethod
def leave_fragment_definition(node: PrintedNode, *_args: Any) -> str:
# Note: fragment variable definitions are deprecated and will be removed in v3.3
return (
f"fragment {node.name}"
f"{wrap('(', join(node.variable_definitions, ', '), ')')}"
f" on {node.type_condition}"
f" {wrap('', join(node.directives, ' '), ' ')}"
f"{node.selection_set}"
)
# Value
@staticmethod
def leave_int_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_float_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_string_value(node: PrintedNode, *_args: Any) -> str:
if node.block:
return print_block_string(node.value)
return print_string(node.value)
@staticmethod
def leave_boolean_value(node: PrintedNode, *_args: Any) -> str:
return "true" if node.value else "false"
@staticmethod
def leave_null_value(_node: PrintedNode, *_args: Any) -> str:
return "null"
@staticmethod
def leave_enum_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_list_value(node: PrintedNode, *_args: Any) -> str:
return f"[{join(node.values, ', ')}]"
@staticmethod
def leave_object_value(node: PrintedNode, *_args: Any) -> str:
return f"{{{join(node.fields, ', ')}}}"
@staticmethod
def leave_object_field(node: PrintedNode, *_args: Any) -> str:
return f"{node.name}: {node.value}"
# Directive
@staticmethod
def leave_directive(node: PrintedNode, *_args: Any) -> str:
return f"@{node.name}{wrap('(', join(node.arguments, ', '), ')')}"
# Type
@staticmethod
def leave_named_type(node: PrintedNode, *_args: Any) -> str:
return node.name
@staticmethod
def leave_list_type(node: PrintedNode, *_args: Any) -> str:
return f"[{node.type}]"
@staticmethod
def leave_non_null_type(node: PrintedNode, *_args: Any) -> str:
return f"{node.type}!"
# Type System Definitions
@staticmethod
def leave_schema_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"schema",
join(node.directives, " "),
block(node.operation_types),
),
" ",
)
@staticmethod
def leave_operation_type_definition(node: PrintedNode, *_args: Any) -> str:
return f"{node.operation.value}: {node.type}"
@staticmethod
def leave_scalar_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"scalar",
node.name,
join(node.directives, " "),
),
" ",
)
@staticmethod
def leave_object_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"type",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_field_definition(node: PrintedNode, *_args: Any) -> str:
args = node.arguments
args = (
wrap("(\n", indent(join(args, "\n")), "\n)")
if has_multiline_items(args)
else wrap("(", join(args, ", "), ")")
)
directives = wrap(" ", join(node.directives, " "))
return (
wrap("", node.description, "\n")
+ f"{node.name}{args}: {node.type}{directives}"
)
@staticmethod
def leave_input_value_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
f"{node.name}: {node.type}",
wrap("= ", node.default_value),
join(node.directives, " "),
),
" ",
)
@staticmethod
def leave_interface_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"interface",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_union_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"union",
node.name,
join(node.directives, " "),
wrap("= ", join(node.types, " | ")),
),
" ",
)
@staticmethod
def leave_enum_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
("enum", node.name, join(node.directives, " "), block(node.values)), " "
)
@staticmethod
def leave_enum_value_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(node.name, join(node.directives, " ")), " "
)
@staticmethod
def leave_input_object_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
("input", node.name, join(node.directives, " "), block(node.fields)), " "
)
@staticmethod
def leave_directive_definition(node: PrintedNode, *_args: Any) -> str:
args = node.arguments
args = (
wrap("(\n", indent(join(args, "\n")), "\n)")
if has_multiline_items(args)
else wrap("(", join(args, ", "), ")")
)
repeatable = " repeatable" if node.repeatable else ""
locations = join(node.locations, " | ")
return (
wrap("", node.description, "\n")
+ f"directive @{node.name}{args}{repeatable} on {locations}"
)
@staticmethod
def leave_schema_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend schema", join(node.directives, " "), block(node.operation_types)),
" ",
)
@staticmethod
def leave_scalar_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(("extend scalar", node.name, join(node.directives, " ")), " ")
@staticmethod
def leave_object_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend type",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_interface_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend interface",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_union_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend union",
node.name,
join(node.directives, " "),
wrap("= ", join(node.types, " | ")),
),
" ",
)
@staticmethod
def leave_enum_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend enum", node.name, join(node.directives, " "), block(node.values)),
" ",
)
@staticmethod
def leave_input_object_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend input", node.name, join(node.directives, " "), block(node.fields)),
" ",
)
def join(strings: Optional[Strings], separator: str = "") -> str:
"""Join strings in a given collection.
Return an empty string if it is None or empty, otherwise join all items together
separated by separator if provided.
"""
return separator.join(s for s in strings if s) if strings else ""
def block(strings: Optional[Strings]) -> str:
"""Return strings inside a block.
Given a collection of strings, return a string with each item on its own line,
wrapped in an indented "{ }" block.
"""
return wrap("{\n", indent(join(strings, "\n")), "\n}")
def wrap(start: str, string: Optional[str], end: str = "") -> str:
"""Wrap string inside other strings at start and end.
If the string is not None or empty, then wrap with start and end, otherwise return
an empty string.
"""
return f"{start}{string}{end}" if string else ""
def indent(string: str) -> str:
"""Indent string with two spaces.
If the string is not None or empty, add two spaces at the beginning of every line
inside the string.
"""
return wrap(" ", string.replace("\n", "\n "))
def is_multiline(string: str) -> bool:
"""Check whether a string consists of multiple lines."""
return "\n" in string
def has_multiline_items(strings: Optional[Strings]) -> bool:
"""Check whether one of the items in the list has multiple lines."""
return any(is_multiline(item) for item in strings) if strings else False
@@ -0,0 +1,70 @@
from typing import Any
from .location import SourceLocation
__all__ = ["Source", "is_source"]
class Source:
"""A representation of source input to GraphQL."""
# allow custom attributes and weak references (not used internally)
__slots__ = "__weakref__", "__dict__", "body", "name", "location_offset"
def __init__(
self,
body: str,
name: str = "GraphQL request",
location_offset: SourceLocation = SourceLocation(1, 1),
) -> None:
"""Initialize source input.
The ``name`` and ``location_offset`` parameters are optional, but they are
useful for clients who store GraphQL documents in source files. For example,
if the GraphQL input starts at line 40 in a file named ``Foo.graphql``, it might
be useful for ``name`` to be ``"Foo.graphql"`` and location to be ``(40, 0)``.
The ``line`` and ``column`` attributes in ``location_offset`` are 1-indexed.
"""
self.body = body
self.name = name
if not isinstance(location_offset, SourceLocation):
location_offset = SourceLocation._make(location_offset)
if location_offset.line <= 0:
raise ValueError(
"line in location_offset is 1-indexed and must be positive."
)
if location_offset.column <= 0:
raise ValueError(
"column in location_offset is 1-indexed and must be positive."
)
self.location_offset = location_offset
def get_location(self, position: int) -> SourceLocation:
lines = self.body[:position].splitlines()
if lines:
line = len(lines)
column = len(lines[-1]) + 1
else:
line = 1
column = 1
return SourceLocation(line, column)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name!r}>"
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Source) and other.body == self.body) or (
isinstance(other, str) and other == self.body
)
def __ne__(self, other: Any) -> bool:
return not self == other
def is_source(source: Any) -> bool:
"""Test if the given value is a Source object.
For internal use only.
"""
return isinstance(source, Source)
@@ -0,0 +1,30 @@
from enum import Enum
__all__ = ["TokenKind"]
class TokenKind(Enum):
"""The different kinds of tokens that the lexer emits"""
SOF = "<SOF>"
EOF = "<EOF>"
BANG = "!"
DOLLAR = "$"
AMP = "&"
PAREN_L = "("
PAREN_R = ")"
SPREAD = "..."
COLON = ":"
EQUALS = "="
AT = "@"
BRACKET_L = "["
BRACKET_R = "]"
BRACE_L = "{"
PIPE = "|"
BRACE_R = "}"
NAME = "Name"
INT = "Int"
FLOAT = "Float"
STRING = "String"
BLOCK_STRING = "BlockString"
COMMENT = "Comment"
@@ -0,0 +1,375 @@
from copy import copy
from enum import Enum
from typing import (
Any,
Callable,
Collection,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
from ..pyutils import inspect, snake_to_camel
from . import ast
from .ast import QUERY_DOCUMENT_KEYS, Node
__all__ = [
"Visitor",
"ParallelVisitor",
"VisitorAction",
"visit",
"BREAK",
"SKIP",
"REMOVE",
"IDLE",
]
class VisitorActionEnum(Enum):
"""Special return values for the visitor methods.
You can also use the values of this enum directly.
"""
BREAK = True
SKIP = False
REMOVE = Ellipsis
VisitorAction = Optional[VisitorActionEnum]
# Note that in GraphQL.js these are defined differently:
# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
BREAK = VisitorActionEnum.BREAK
SKIP = VisitorActionEnum.SKIP
REMOVE = VisitorActionEnum.REMOVE
IDLE = None
VisitorKeyMap = Dict[str, Tuple[str, ...]]
class EnterLeaveVisitor(NamedTuple):
"""Visitor with functions for entering and leaving."""
enter: Optional[Callable[..., Optional[VisitorAction]]]
leave: Optional[Callable[..., Optional[VisitorAction]]]
class Visitor:
"""Visitor that walks through an AST.
Visitors can define two generic methods "enter" and "leave". The former will be
called when a node is entered in the traversal, the latter is called after visiting
the node and its child nodes. These methods have the following signature::
def enter(self, node, key, parent, path, ancestors):
# The return value has the following meaning:
# IDLE (None): no action
# SKIP: skip visiting this node
# BREAK: stop visiting altogether
# REMOVE: delete this node
# any other value: replace this node with the returned value
return
def leave(self, node, key, parent, path, ancestors):
# The return value has the following meaning:
# IDLE (None) or SKIP: no action
# BREAK: stop visiting altogether
# REMOVE: delete this node
# any other value: replace this node with the returned value
return
The parameters have the following meaning:
:arg node: The current node being visiting.
:arg key: The index or key to this node from the parent node or Array.
:arg parent: the parent immediately above this node, which may be an Array.
:arg path: The key path to get to this node from the root node.
:arg ancestors: All nodes and Arrays visited before reaching parent
of this node. These correspond to array indices in ``path``.
Note: ancestors includes arrays which contain the parent of visited node.
You can also define node kind specific methods by suffixing them with an underscore
followed by the kind of the node to be visited. For instance, to visit ``field``
nodes, you would defined the methods ``enter_field()`` and/or ``leave_field()``,
with the same signature as above. If no kind specific method has been defined
for a given node, the generic method is called.
"""
# Provide special return values as attributes
BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE
enter_leave_map: Dict[str, EnterLeaveVisitor]
def __init_subclass__(cls) -> None:
"""Verify that all defined handlers are valid."""
super().__init_subclass__()
for attr, val in cls.__dict__.items():
if attr.startswith("_"):
continue
attr_kind = attr.split("_", 1)
if len(attr_kind) < 2:
kind: Optional[str] = None
else:
attr, kind = attr_kind
if attr in ("enter", "leave") and kind:
name = snake_to_camel(kind) + "Node"
node_cls = getattr(ast, name, None)
if (
not node_cls
or not isinstance(node_cls, type)
or not issubclass(node_cls, Node)
):
raise TypeError(f"Invalid AST node kind: {kind}.")
def __init__(self) -> None:
self.enter_leave_map = {}
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
try:
return self.enter_leave_map[kind]
except KeyError:
enter_fn = getattr(self, f"enter_{kind}", None)
if not enter_fn:
enter_fn = getattr(self, "enter", None)
leave_fn = getattr(self, f"leave_{kind}", None)
if not leave_fn:
leave_fn = getattr(self, "leave", None)
enter_leave = EnterLeaveVisitor(enter_fn, leave_fn)
self.enter_leave_map[kind] = enter_leave
return enter_leave
def get_visit_fn(
self, kind: str, is_leaving: bool = False
) -> Optional[Callable[..., Optional[VisitorAction]]]:
"""Get the visit function for the given node kind and direction.
.. deprecated:: 3.2
Please use ``get_enter_leave_for_kind`` instead. Will be removed in v3.3.
"""
enter_leave = self.get_enter_leave_for_kind(kind)
return enter_leave.leave if is_leaving else enter_leave.enter
class Stack(NamedTuple):
"""A stack for the visit function."""
in_array: bool
idx: int
keys: Tuple[Node, ...]
edits: List[Tuple[Union[int, str], Node]]
prev: Any # 'Stack' (python/mypy/issues/731)
def visit(
root: Node, visitor: Visitor, visitor_keys: Optional[VisitorKeyMap] = None
) -> Any:
"""Visit each node in an AST.
:func:`~.visit` will walk through an AST using a depth-first traversal, calling the
visitor's enter methods at each node in the traversal, and calling the leave methods
after visiting that node and all of its child nodes.
By returning different values from the enter and leave methods, the behavior of the
visitor can be altered, including skipping over a sub-tree of the AST (by returning
False), editing the AST by returning a value or None to remove the value, or to stop
the whole traversal by returning :data:`~.BREAK`.
When using :func:`~.visit` to edit an AST, the original AST will not be modified,
and a new version of the AST with the changes applied will be returned from the
visit function.
To customize the node attributes to be used for traversal, you can provide a
dictionary visitor_keys mapping node kinds to node attributes.
"""
if not isinstance(root, Node):
raise TypeError(f"Not an AST Node: {inspect(root)}.")
if not isinstance(visitor, Visitor):
raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.")
if visitor_keys is None:
visitor_keys = QUERY_DOCUMENT_KEYS
stack: Any = None
in_array = False
keys: Tuple[Node, ...] = (root,)
idx = -1
edits: List[Any] = []
node: Any = root
key: Any = None
parent: Any = None
path: List[Any] = []
path_append = path.append
path_pop = path.pop
ancestors: List[Any] = []
ancestors_append = ancestors.append
ancestors_pop = ancestors.pop
while True:
idx += 1
is_leaving = idx == len(keys)
is_edited = is_leaving and edits
if is_leaving:
key = path[-1] if ancestors else None
node = parent
parent = ancestors_pop() if ancestors else None
if is_edited:
if in_array:
node = list(node)
edit_offset = 0
for edit_key, edit_value in edits:
array_key = edit_key - edit_offset
if edit_value is REMOVE or edit_value is Ellipsis:
node.pop(array_key)
edit_offset += 1
else:
node[array_key] = edit_value
node = tuple(node)
else:
node = copy(node)
for edit_key, edit_value in edits:
setattr(node, edit_key, edit_value)
idx = stack.idx
keys = stack.keys
edits = stack.edits
in_array = stack.in_array
stack = stack.prev
elif parent:
if in_array:
key = idx
node = parent[key]
else:
key = keys[idx]
node = getattr(parent, key, None)
if node is None:
continue
path_append(key)
if isinstance(node, tuple):
result = None
else:
if not isinstance(node, Node):
raise TypeError(f"Invalid AST Node: {inspect(node)}.")
enter_leave = visitor.get_enter_leave_for_kind(node.kind)
visit_fn = enter_leave.leave if is_leaving else enter_leave.enter
if visit_fn:
result = visit_fn(node, key, parent, path, ancestors)
if result is BREAK or result is True:
break
if result is SKIP or result is False:
if not is_leaving:
path_pop()
continue
elif result is not None:
edits.append((key, result))
if not is_leaving:
if isinstance(result, Node):
node = result
else:
path_pop()
continue
else:
result = None
if result is None and is_edited:
edits.append((key, node))
if is_leaving:
if path:
path_pop()
else:
stack = Stack(in_array, idx, keys, edits, stack)
in_array = isinstance(node, tuple)
keys = node if in_array else visitor_keys.get(node.kind, ()) # type: ignore
idx = -1
edits = []
if parent:
ancestors_append(parent)
parent = node
if not stack:
break
if edits:
return edits[-1][1]
return root
class ParallelVisitor(Visitor):
"""A Visitor which delegates to many visitors to run in parallel.
Each visitor will be visited for each node before moving on.
If a prior visitor edits a node, no following visitors will see that node.
"""
def __init__(self, visitors: Collection[Visitor]):
"""Create a new visitor from the given list of parallel visitors."""
super().__init__()
self.visitors = visitors
self.skipping: List[Any] = [None] * len(visitors)
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
try:
return self.enter_leave_map[kind]
except KeyError:
has_visitor = False
enter_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
leave_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
for visitor in self.visitors:
enter, leave = visitor.get_enter_leave_for_kind(kind)
if not has_visitor and (enter or leave):
has_visitor = True
enter_list.append(enter)
leave_list.append(leave)
if has_visitor:
def enter(node: Node, *args: Any) -> Optional[VisitorAction]:
skipping = self.skipping
for i, fn in enumerate(enter_list):
if not skipping[i]:
if fn:
result = fn(node, *args)
if result is SKIP or result is False:
skipping[i] = node
elif result is BREAK or result is True:
skipping[i] = BREAK
elif result is not None:
return result
return None
def leave(node: Node, *args: Any) -> Optional[VisitorAction]:
skipping = self.skipping
for i, fn in enumerate(leave_list):
if not skipping[i]:
if fn:
result = fn(node, *args)
if result is BREAK or result is True:
skipping[i] = BREAK
elif (
result is not None
and result is not SKIP
and result is not False
):
return result
elif skipping[i] is node:
skipping[i] = None
return None
else:
enter = leave = None
enter_leave = EnterLeaveVisitor(enter, leave)
self.enter_leave_map[kind] = enter_leave
return enter_leave
@@ -0,0 +1 @@
# Marker file for PEP 561. The graphql package uses inline types.
@@ -0,0 +1,65 @@
"""Python Utils
This package contains dependency-free Python utility functions used throughout the
codebase.
Each utility should belong in its own file and be the default export.
These functions are not part of the module interface and are subject to change.
"""
from .convert_case import camel_to_snake, snake_to_camel
from .cached_property import cached_property
from .description import (
Description,
is_description,
register_description,
unregister_description,
)
from .did_you_mean import did_you_mean
from .group_by import group_by
from .identity_func import identity_func
from .inspect import inspect
from .is_awaitable import is_awaitable
from .is_iterable import is_collection, is_iterable
from .natural_compare import natural_comparison_key
from .awaitable_or_value import AwaitableOrValue
from .suggestion_list import suggestion_list
from .frozen_error import FrozenError
from .frozen_list import FrozenList
from .frozen_dict import FrozenDict
from .merge_kwargs import merge_kwargs
from .path import Path
from .print_path_list import print_path_list
from .simple_pub_sub import SimplePubSub, SimplePubSubIterator
from .undefined import Undefined, UndefinedType
__all__ = [
"camel_to_snake",
"snake_to_camel",
"cached_property",
"did_you_mean",
"Description",
"group_by",
"is_description",
"register_description",
"unregister_description",
"identity_func",
"inspect",
"is_awaitable",
"is_collection",
"is_iterable",
"merge_kwargs",
"natural_comparison_key",
"AwaitableOrValue",
"suggestion_list",
"FrozenError",
"FrozenList",
"FrozenDict",
"Path",
"print_path_list",
"SimplePubSub",
"SimplePubSubIterator",
"Undefined",
"UndefinedType",
]
@@ -0,0 +1,8 @@
from typing import Awaitable, TypeVar, Union
__all__ = ["AwaitableOrValue"]
T = TypeVar("T")
AwaitableOrValue = Union[Awaitable[T], T]
@@ -0,0 +1,35 @@
from typing import Any, Callable, TYPE_CHECKING
if TYPE_CHECKING:
standard_cached_property = None
else:
try:
from functools import cached_property as standard_cached_property
except ImportError: # Python < 3.8
standard_cached_property = None
if standard_cached_property:
cached_property = standard_cached_property
else:
# Code taken from https://github.com/bottlepy/bottle
class CachedProperty:
"""A cached property.
A property that is only computed once per instance and then replaces itself with
an ordinary attribute. Deleting the attribute resets the property.
"""
def __init__(self, func: Callable) -> None:
self.__doc__ = getattr(func, "__doc__")
self.func = func
def __get__(self, obj: object, cls: type) -> Any:
if obj is None:
return self
value = obj.__dict__[self.func.__name__] = self.func(obj)
return value
cached_property = CachedProperty
__all__ = ["cached_property"]
@@ -0,0 +1,25 @@
# uses code from https://github.com/daveoncode/python-string-utils
import re
__all__ = ["camel_to_snake", "snake_to_camel"]
_re_camel_to_snake = re.compile(r"([a-z]|[A-Z0-9]+)(?=[A-Z])")
_re_snake_to_camel = re.compile(r"(_)([a-z\d])")
def camel_to_snake(s: str) -> str:
"""Convert from CamelCase to snake_case"""
return _re_camel_to_snake.sub(r"\1_", s).lower()
def snake_to_camel(s: str, upper: bool = True) -> str:
"""Convert from snake_case to CamelCase
If upper is set, then convert to upper CamelCase, otherwise the first character
keeps its case.
"""
s = _re_snake_to_camel.sub(lambda m: m.group(2).upper(), s)
if upper:
s = s[:1].upper() + s[1:]
return s
@@ -0,0 +1,59 @@
from typing import Any, Tuple, Union
__all__ = [
"Description",
"is_description",
"register_description",
"unregister_description",
]
class Description:
"""Type checker for human readable descriptions.
By default, only ordinary strings are accepted as descriptions,
but you can register() other classes that will also be allowed,
e.g. to support lazy string objects that are evaluated only at runtime.
If you register(object), any object will be allowed as description.
"""
bases: Union[type, Tuple[type, ...]] = str
@classmethod
def isinstance(cls, obj: Any) -> bool:
return isinstance(obj, cls.bases)
@classmethod
def register(cls, base: type) -> None:
"""Register a class that shall be accepted as a description."""
if not isinstance(base, type):
raise TypeError("Only types can be registered.")
if base is object:
cls.bases = object
elif cls.bases is object:
cls.bases = base
elif not isinstance(cls.bases, tuple):
if base is not cls.bases:
cls.bases = (cls.bases, base)
elif base not in cls.bases:
cls.bases += (base,)
@classmethod
def unregister(cls, base: type) -> None:
"""Unregister a class that shall no more be accepted as a description."""
if not isinstance(base, type):
raise TypeError("Only types can be unregistered.")
if isinstance(cls.bases, tuple):
if base in cls.bases: # pragma: no branch
cls.bases = tuple(b for b in cls.bases if b is not base)
if not cls.bases:
cls.bases = object
elif len(cls.bases) == 1:
cls.bases = cls.bases[0]
elif cls.bases is base:
cls.bases = object
is_description = Description.isinstance
register_description = Description.register
unregister_description = Description.unregister
@@ -0,0 +1,28 @@
from typing import Optional, Sequence
__all__ = ["did_you_mean"]
MAX_LENGTH = 5
def did_you_mean(suggestions: Sequence[str], sub_message: Optional[str] = None) -> str:
"""Given [ A, B, C ] return ' Did you mean A, B, or C?'"""
if not suggestions or not MAX_LENGTH:
return ""
parts = [" Did you mean "]
if sub_message:
parts.extend([sub_message, " "])
suggestions = suggestions[:MAX_LENGTH]
n = len(suggestions)
if n == 1:
parts.append(f"'{suggestions[0]}'?")
elif n == 2:
parts.append(f"'{suggestions[0]}' or '{suggestions[1]}'?")
else:
parts.extend(
[
", ".join(f"'{s}'" for s in suggestions[:-1]),
f", or '{suggestions[-1]}'?",
]
)
return "".join(parts)
@@ -0,0 +1,52 @@
from copy import deepcopy
from typing import Dict, TypeVar
from .frozen_error import FrozenError
__all__ = ["FrozenDict"]
KT = TypeVar("KT")
VT = TypeVar("VT")
class FrozenDict(Dict[KT, VT]):
"""Dictionary that can only be read, but not changed.
.. deprecated:: 3.2
Use dicts and the Mapping type instead. Will be removed in v3.3.
"""
def __delitem__(self, key):
raise FrozenError
def __setitem__(self, key, value):
raise FrozenError
def __iadd__(self, value):
raise FrozenError
def __hash__(self) -> int: # type: ignore
return hash(tuple(self.items()))
def __copy__(self) -> "FrozenDict":
return FrozenDict(self)
copy = __copy__
def __deepcopy__(self, memo: Dict) -> "FrozenDict":
return FrozenDict({k: deepcopy(v, memo) for k, v in self.items()})
def clear(self):
raise FrozenError
def pop(self, key, default=None):
raise FrozenError
def popitem(self):
raise FrozenError
def setdefault(self, key, default=None):
raise FrozenError
def update(self, other=None): # type: ignore
raise FrozenError
@@ -0,0 +1,5 @@
__all__ = ["FrozenError"]
class FrozenError(TypeError):
"""Error when trying to change a frozen (read only) collection."""
@@ -0,0 +1,70 @@
from copy import deepcopy
from typing import Dict, List, TypeVar
from .frozen_error import FrozenError
__all__ = ["FrozenList"]
T = TypeVar("T")
class FrozenList(List[T]):
"""List that can only be read, but not changed.
.. deprecated:: 3.2
Use tuples or lists and the Collection type instead. Will be removed in v3.3.
"""
def __delitem__(self, key):
raise FrozenError
def __setitem__(self, key, value):
raise FrozenError
def __add__(self, value):
if isinstance(value, tuple):
value = list(value)
return list.__add__(self, value)
def __iadd__(self, value):
raise FrozenError
def __mul__(self, value):
return list.__mul__(self, value)
def __imul__(self, value):
raise FrozenError
def __hash__(self) -> int: # type: ignore
return hash(tuple(self))
def __copy__(self) -> "FrozenList":
return FrozenList(self)
def __deepcopy__(self, memo: Dict) -> "FrozenList":
return FrozenList(deepcopy(value, memo) for value in self)
def append(self, x):
raise FrozenError
def extend(self, iterable):
raise FrozenError
def insert(self, i, x):
raise FrozenError
def remove(self, x):
raise FrozenError
def pop(self, i=None):
raise FrozenError
def clear(self):
raise FrozenError
def sort(self, *, key=None, reverse=False):
raise FrozenError
def reverse(self):
raise FrozenError
@@ -0,0 +1,16 @@
from collections import defaultdict
from typing import Callable, Collection, Dict, List, TypeVar
__all__ = ["group_by"]
K = TypeVar("K")
T = TypeVar("T")
def group_by(items: Collection[T], key_fn: Callable[[T], K]) -> Dict[K, List[T]]:
"""Group an unsorted collection of items by a key derived via a function."""
result: Dict[K, List[T]] = defaultdict(list)
for item in items:
key = key_fn(item)
result[key].append(item)
return result
@@ -0,0 +1,13 @@
from typing import cast, Any, TypeVar
from .undefined import Undefined
__all__ = ["identity_func"]
T = TypeVar("T")
def identity_func(x: T = cast(Any, Undefined), *_args: Any) -> T:
"""Return the first received argument."""
return x
@@ -0,0 +1,182 @@
from inspect import (
isclass,
ismethod,
isfunction,
isgeneratorfunction,
isgenerator,
iscoroutinefunction,
iscoroutine,
isasyncgenfunction,
isasyncgen,
)
from typing import Any, List
from .undefined import Undefined
__all__ = ["inspect"]
max_recursive_depth = 2
max_str_size = 240
max_list_size = 10
def inspect(value: Any) -> str:
"""Inspect value and a return string representation for error messages.
Used to print values in error messages. We do not use repr() in order to not
leak too much of the inner Python representation of unknown objects, and we
do not use json.dumps() because not all objects can be serialized as JSON and
we want to output strings with single quotes like Python repr() does it.
We also restrict the size of the representation by truncating strings and
collections and allowing only a maximum recursion depth.
"""
return inspect_recursive(value, [])
def inspect_recursive(value: Any, seen_values: List) -> str:
if value is None or value is Undefined or isinstance(value, (bool, float, complex)):
return repr(value)
if isinstance(value, (int, str, bytes, bytearray)):
return trunc_str(repr(value))
if len(seen_values) < max_recursive_depth and value not in seen_values:
# check if we have a custom inspect method
inspect_method = getattr(value, "__inspect__", None)
if inspect_method is not None and callable(inspect_method):
s = inspect_method()
if isinstance(s, str):
return trunc_str(s)
seen_values = [*seen_values, value]
return inspect_recursive(s, seen_values)
# recursively inspect collections
if isinstance(value, (list, tuple, dict, set, frozenset)):
if not value:
return repr(value)
seen_values = [*seen_values, value]
if isinstance(value, list):
items = value
elif isinstance(value, dict):
items = list(value.items())
else:
items = list(value)
items = trunc_list(items)
if isinstance(value, dict):
s = ", ".join(
(
"..."
if v is ELLIPSIS
else inspect_recursive(v[0], seen_values)
+ ": "
+ inspect_recursive(v[1], seen_values)
)
for v in items
)
else:
s = ", ".join(
"..." if v is ELLIPSIS else inspect_recursive(v, seen_values)
for v in items
)
if isinstance(value, tuple):
if len(items) == 1:
return f"({s},)"
return f"({s})"
if isinstance(value, (dict, set)):
return "{" + s + "}"
if isinstance(value, frozenset):
return f"frozenset({{{s}}})"
return f"[{s}]"
else:
# handle collections that are nested too deep
if isinstance(value, (list, tuple, dict, set, frozenset)):
if not value:
return repr(value)
if isinstance(value, list):
return "[...]"
if isinstance(value, tuple):
return "(...)"
if isinstance(value, dict):
return "{...}"
if isinstance(value, set):
return "set(...)"
return "frozenset(...)"
if isinstance(value, Exception):
type_ = "exception"
value = type(value)
elif isclass(value):
type_ = "exception class" if issubclass(value, Exception) else "class"
elif ismethod(value):
type_ = "method"
elif iscoroutinefunction(value):
type_ = "coroutine function"
elif isasyncgenfunction(value):
type_ = "async generator function"
elif isgeneratorfunction(value):
type_ = "generator function"
elif isfunction(value):
type_ = "function"
elif iscoroutine(value):
type_ = "coroutine"
elif isasyncgen(value):
type_ = "async generator"
elif isgenerator(value):
type_ = "generator"
else:
# stringify (only) the well-known GraphQL types
from ..type import (
GraphQLDirective,
GraphQLNamedType,
GraphQLScalarType,
GraphQLWrappingType,
)
if isinstance(
value,
(
GraphQLDirective,
GraphQLNamedType,
GraphQLScalarType,
GraphQLWrappingType,
),
):
return str(value)
try:
name = type(value).__name__
if not name or "<" in name or ">" in name:
raise AttributeError
except AttributeError:
return "<object>"
else:
return f"<{name} instance>"
try:
name = value.__name__
if not name or "<" in name or ">" in name:
raise AttributeError
except AttributeError:
return f"<{type_}>"
else:
return f"<{type_} {name}>"
def trunc_str(s: str) -> str:
"""Truncate strings to maximum length."""
if len(s) > max_str_size:
i = max(0, (max_str_size - 3) // 2)
j = max(0, max_str_size - 3 - i)
s = s[:i] + "..." + s[-j:]
return s
def trunc_list(s: List) -> List:
"""Truncate lists to maximum length."""
if len(s) > max_list_size:
i = max_list_size // 2
j = i - 1
s = s[:i] + [ELLIPSIS] + s[-j:]
return s
class InspectEllipsisType:
"""Singleton class for indicating ellipses in iterables."""
ELLIPSIS = InspectEllipsisType()
@@ -0,0 +1,24 @@
import inspect
from typing import Any
from types import CoroutineType, GeneratorType
__all__ = ["is_awaitable"]
CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE
def is_awaitable(value: Any) -> bool:
"""Return true if object can be passed to an ``await`` expression.
Instead of testing if the object is an instance of abc.Awaitable, it checks
the existence of an `__await__` attribute. This is much faster.
"""
return (
# check for coroutine objects
isinstance(value, CoroutineType)
# check for old-style generator based coroutine objects
or isinstance(value, GeneratorType)
and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE)
# check for other awaitables (e.g. futures)
or hasattr(value, "__await__")
)
@@ -0,0 +1,24 @@
from collections.abc import Collection, Iterable, Mapping, ValuesView
from typing import Any
__all__ = ["is_collection", "is_iterable"]
collection_types: Any = Collection
if not isinstance({}.values(), Collection): # Python < 3.7.2
collection_types = (Collection, ValuesView)
iterable_types: Any = Iterable
not_iterable_types: Any = (bytes, bytearray, memoryview, str, Mapping)
def is_collection(value: Any) -> bool:
"""Check if value is a collection, but not a string or a mapping."""
return isinstance(value, collection_types) and not isinstance(
value, not_iterable_types
)
def is_iterable(value: Any) -> bool:
"""Check if value is an iterable, but not a string or a mapping."""
return isinstance(value, iterable_types) and not isinstance(
value, not_iterable_types
)
@@ -0,0 +1,8 @@
from typing import cast, Any, Dict, TypeVar
T = TypeVar("T")
def merge_kwargs(base_dict: T, **kwargs: Any) -> T:
"""Return arbitrary typed dictionary with some keyword args merged in."""
return cast(T, {**cast(Dict, base_dict), **kwargs})
@@ -0,0 +1,19 @@
import re
from typing import Tuple
from itertools import cycle
__all__ = ["natural_comparison_key"]
_re_digits = re.compile(r"(\d+)")
def natural_comparison_key(key: str) -> Tuple:
"""Comparison key function for sorting strings by natural sort order.
See: https://en.wikipedia.org/wiki/Natural_sort_order
"""
return tuple(
(int(part), part) if is_digit else part
for part, is_digit in zip(_re_digits.split(key), cycle((False, True)))
)
@@ -0,0 +1,28 @@
from typing import Any, List, NamedTuple, Optional, Union
__all__ = ["Path"]
class Path(NamedTuple):
"""A generic path of string or integer indices"""
prev: Any # Optional['Path'] (python/mypy/issues/731)
"""path with the previous indices"""
key: Union[str, int]
"""current index in the path (string or integer)"""
typename: Optional[str]
"""name of the parent type to avoid path ambiguity"""
def add_key(self, key: Union[str, int], typename: Optional[str] = None) -> "Path":
"""Return a new Path containing the given key."""
return Path(self, key, typename)
def as_list(self) -> List[Union[str, int]]:
"""Return a list of the path keys."""
flattened: List[Union[str, int]] = []
append = flattened.append
curr: Path = self
while curr:
append(curr.key)
curr = curr.prev
return flattened[::-1]
@@ -0,0 +1,6 @@
from typing import Collection, Union
def print_path_list(path: Collection[Union[str, int]]) -> str:
"""Build a string describing the path."""
return "".join(f"[{key}]" if isinstance(key, int) else f".{key}" for key in path)
@@ -0,0 +1,81 @@
from asyncio import Future, Queue, ensure_future, sleep
from inspect import isawaitable
from typing import Any, AsyncIterator, Callable, Optional, Set
try:
from asyncio import get_running_loop
except ImportError:
from asyncio import get_event_loop as get_running_loop # Python < 3.7
__all__ = ["SimplePubSub", "SimplePubSubIterator"]
class SimplePubSub:
"""A very simple publish-subscript system.
Creates an AsyncIterator from an EventEmitter.
Useful for mocking a PubSub system for tests.
"""
subscribers: Set[Callable]
def __init__(self) -> None:
self.subscribers = set()
def emit(self, event: Any) -> bool:
"""Emit an event."""
for subscriber in self.subscribers:
result = subscriber(event)
if isawaitable(result):
ensure_future(result)
return bool(self.subscribers)
def get_subscriber(
self, transform: Optional[Callable] = None
) -> "SimplePubSubIterator":
return SimplePubSubIterator(self, transform)
class SimplePubSubIterator(AsyncIterator):
def __init__(self, pubsub: SimplePubSub, transform: Optional[Callable]) -> None:
self.pubsub = pubsub
self.transform = transform
self.pull_queue: Queue[Future] = Queue()
self.push_queue: Queue[Any] = Queue()
self.listening = True
pubsub.subscribers.add(self.push_value)
def __aiter__(self) -> "SimplePubSubIterator":
return self
async def __anext__(self) -> Any:
if not self.listening:
raise StopAsyncIteration
await sleep(0)
if not self.push_queue.empty():
return await self.push_queue.get()
future = get_running_loop().create_future()
await self.pull_queue.put(future)
return future
async def aclose(self) -> None:
if self.listening:
await self.empty_queue()
async def empty_queue(self) -> None:
self.listening = False
self.pubsub.subscribers.remove(self.push_value)
while not self.pull_queue.empty():
future = await self.pull_queue.get()
future.cancel()
while not self.push_queue.empty():
await self.push_queue.get()
async def push_value(self, event: Any) -> None:
value = event if self.transform is None else self.transform(event)
if self.pull_queue.empty():
await self.push_queue.put(value)
else:
(await self.pull_queue.get()).set_result(value)
@@ -0,0 +1,109 @@
from typing import Collection, Optional, List
from .natural_compare import natural_comparison_key
__all__ = ["suggestion_list"]
def suggestion_list(input_: str, options: Collection[str]) -> List[str]:
"""Get list with suggestions for a given input.
Given an invalid input string and list of valid options, returns a filtered list
of valid options sorted based on their similarity with the input.
"""
options_by_distance = {}
lexical_distance = LexicalDistance(input_)
threshold = int(len(input_) * 0.4) + 1
for option in options:
distance = lexical_distance.measure(option, threshold)
if distance is not None:
options_by_distance[option] = distance
# noinspection PyShadowingNames
return sorted(
options_by_distance,
key=lambda option: (
options_by_distance.get(option, 0),
natural_comparison_key(option),
),
)
class LexicalDistance:
"""Computes the lexical distance between strings A and B.
The "distance" between two strings is given by counting the minimum number of edits
needed to transform string A into string B. An edit can be an insertion, deletion,
or substitution of a single character, or a swap of two adjacent characters.
This distance can be useful for detecting typos in input or sorting.
"""
_input: str
_input_lower_case: str
_input_list: List[int]
_rows: List[List[int]]
def __init__(self, input_: str):
self._input = input_
self._input_lower_case = input_.lower()
row_size = len(input_) + 1
self._input_list = list(map(ord, self._input_lower_case))
self._rows = [[0] * row_size, [0] * row_size, [0] * row_size]
def measure(self, option: str, threshold: int) -> Optional[int]:
if self._input == option:
return 0
option_lower_case = option.lower()
# Any case change counts as a single edit
if self._input_lower_case == option_lower_case:
return 1
a, b = list(map(ord, option_lower_case)), self._input_list
a_len, b_len = len(a), len(b)
if a_len < b_len:
a, b = b, a
a_len, b_len = b_len, a_len
if a_len - b_len > threshold:
return None
rows = self._rows
for j in range(b_len + 1):
rows[0][j] = j
for i in range(1, a_len + 1):
up_row = rows[(i - 1) % 3]
current_row = rows[i % 3]
smallest_cell = current_row[0] = i
for j in range(1, b_len + 1):
cost = 0 if a[i - 1] == b[j - 1] else 1
current_cell = min(
up_row[j] + 1, # delete
current_row[j - 1] + 1, # insert
up_row[j - 1] + cost, # substitute
)
if i > 1 and j > 1 and a[i - 1] == b[j - 2] and a[i - 2] == b[j - 1]:
# transposition
double_diagonal_cell = rows[(i - 2) % 3][j - 2]
current_cell = min(current_cell, double_diagonal_cell + 1)
if current_cell < smallest_cell:
smallest_cell = current_cell
current_row[j] = current_cell
# Early exit, since distance can't go smaller than smallest element
# of the previous row.
if smallest_cell > threshold:
return None
distance = rows[a_len % 3][b_len]
return distance if distance <= threshold else None
@@ -0,0 +1,34 @@
from typing import Any
__all__ = ["Undefined", "UndefinedType"]
class UndefinedType(ValueError):
"""Auxiliary class for creating the Undefined singleton."""
def __repr__(self) -> str:
return "Undefined"
__str__ = __repr__
def __hash__(self) -> int:
return hash(UndefinedType)
def __bool__(self) -> bool:
return False
def __eq__(self, other: Any) -> bool:
return other is Undefined
def __ne__(self, other: Any) -> bool:
return not self == other
# Used to indicate undefined or invalid values (like "undefined" in JavaScript):
Undefined = UndefinedType()
Undefined.__doc__ = """Symbol for undefined values
This singleton object is used to describe undefined or invalid values.
It can be used in places where you would use ``undefined`` in GraphQL.js.
"""
@@ -0,0 +1,16 @@
"""GraphQL Subscription
The :mod:`graphql.subscription` package is responsible for subscribing to updates
on specific data.
.. deprecated:: 3.2
This package has been deprecated with its exported functions integrated into the
:mod:`graphql.execution` package, to better conform with the terminology of the
GraphQL specification. For backwards compatibility, the :mod:`graphql.subscription`
package currently re-exports the moved functions from the :mod:`graphql.execution`
package. In v3.3, the :mod:`graphql.subscription` package will be dropped entirely.
"""
from ..execution import subscribe, create_source_event_stream, MapAsyncIterator
__all__ = ["subscribe", "create_source_event_stream", "MapAsyncIterator"]
@@ -0,0 +1,298 @@
"""GraphQL Type System
The :mod:`graphql.type` package is responsible for defining GraphQL types and schema.
"""
from ..pyutils import Path as ResponsePath
from .schema import (
# Predicate
is_schema,
# Assertion
assert_schema,
# GraphQL Schema definition
GraphQLSchema,
# Keyword Args
GraphQLSchemaKwargs,
)
# Uphold the spec rules about naming.
from .assert_name import assert_name, assert_enum_value_name
from .definition import (
# Predicates
is_type,
is_scalar_type,
is_object_type,
is_interface_type,
is_union_type,
is_enum_type,
is_input_object_type,
is_list_type,
is_non_null_type,
is_input_type,
is_output_type,
is_leaf_type,
is_composite_type,
is_abstract_type,
is_wrapping_type,
is_nullable_type,
is_named_type,
is_required_argument,
is_required_input_field,
# Assertions
assert_type,
assert_scalar_type,
assert_object_type,
assert_interface_type,
assert_union_type,
assert_enum_type,
assert_input_object_type,
assert_list_type,
assert_non_null_type,
assert_input_type,
assert_output_type,
assert_leaf_type,
assert_composite_type,
assert_abstract_type,
assert_wrapping_type,
assert_nullable_type,
assert_named_type,
# Un-modifiers
get_nullable_type,
get_named_type,
# Thunk handling
resolve_thunk,
# Definitions
GraphQLScalarType,
GraphQLObjectType,
GraphQLInterfaceType,
GraphQLUnionType,
GraphQLEnumType,
GraphQLInputObjectType,
# Type Wrappers
GraphQLList,
GraphQLNonNull,
# Types
GraphQLType,
GraphQLInputType,
GraphQLOutputType,
GraphQLLeafType,
GraphQLCompositeType,
GraphQLAbstractType,
GraphQLWrappingType,
GraphQLNullableType,
GraphQLNamedType,
GraphQLNamedInputType,
GraphQLNamedOutputType,
Thunk,
ThunkCollection,
ThunkMapping,
GraphQLArgument,
GraphQLArgumentMap,
GraphQLEnumValue,
GraphQLEnumValueMap,
GraphQLField,
GraphQLFieldMap,
GraphQLInputField,
GraphQLInputFieldMap,
GraphQLScalarSerializer,
GraphQLScalarValueParser,
GraphQLScalarLiteralParser,
# Keyword Args
GraphQLArgumentKwargs,
GraphQLEnumTypeKwargs,
GraphQLEnumValueKwargs,
GraphQLFieldKwargs,
GraphQLInputFieldKwargs,
GraphQLInputObjectTypeKwargs,
GraphQLInterfaceTypeKwargs,
GraphQLNamedTypeKwargs,
GraphQLObjectTypeKwargs,
GraphQLScalarTypeKwargs,
GraphQLUnionTypeKwargs,
# Resolvers
GraphQLFieldResolver,
GraphQLTypeResolver,
GraphQLIsTypeOfFn,
GraphQLResolveInfo,
)
from .directives import (
# Predicate
is_directive,
# Assertion
assert_directive,
# Directives Definition
GraphQLDirective,
# Built-in Directives defined by the Spec
is_specified_directive,
specified_directives,
GraphQLIncludeDirective,
GraphQLSkipDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
# Keyword Args
GraphQLDirectiveKwargs,
# Constant Deprecation Reason
DEFAULT_DEPRECATION_REASON,
)
# Common built-in scalar instances.
from .scalars import (
# Predicate
is_specified_scalar_type,
# Standard GraphQL Scalars
specified_scalar_types,
GraphQLInt,
GraphQLFloat,
GraphQLString,
GraphQLBoolean,
GraphQLID,
# Int boundaries constants
GRAPHQL_MAX_INT,
GRAPHQL_MIN_INT,
)
from .introspection import (
# Predicate
is_introspection_type,
# GraphQL Types for introspection.
introspection_types,
# "Enum" of Type Kinds
TypeKind,
# Meta-field definitions.
SchemaMetaFieldDef,
TypeMetaFieldDef,
TypeNameMetaFieldDef,
)
# Validate GraphQL schema.
from .validate import validate_schema, assert_valid_schema
__all__ = [
"is_schema",
"assert_schema",
"assert_name",
"assert_enum_value_name",
"GraphQLSchema",
"GraphQLSchemaKwargs",
"is_type",
"is_scalar_type",
"is_object_type",
"is_interface_type",
"is_union_type",
"is_enum_type",
"is_input_object_type",
"is_list_type",
"is_non_null_type",
"is_input_type",
"is_output_type",
"is_leaf_type",
"is_composite_type",
"is_abstract_type",
"is_wrapping_type",
"is_nullable_type",
"is_named_type",
"is_required_argument",
"is_required_input_field",
"assert_type",
"assert_scalar_type",
"assert_object_type",
"assert_interface_type",
"assert_union_type",
"assert_enum_type",
"assert_input_object_type",
"assert_list_type",
"assert_non_null_type",
"assert_input_type",
"assert_output_type",
"assert_leaf_type",
"assert_composite_type",
"assert_abstract_type",
"assert_wrapping_type",
"assert_nullable_type",
"assert_named_type",
"get_nullable_type",
"get_named_type",
"resolve_thunk",
"GraphQLScalarType",
"GraphQLObjectType",
"GraphQLInterfaceType",
"GraphQLUnionType",
"GraphQLEnumType",
"GraphQLInputObjectType",
"GraphQLInputType",
"GraphQLArgument",
"GraphQLList",
"GraphQLNonNull",
"GraphQLType",
"GraphQLInputType",
"GraphQLOutputType",
"GraphQLLeafType",
"GraphQLCompositeType",
"GraphQLAbstractType",
"GraphQLWrappingType",
"GraphQLNullableType",
"GraphQLNamedType",
"GraphQLNamedInputType",
"GraphQLNamedOutputType",
"Thunk",
"ThunkCollection",
"ThunkMapping",
"GraphQLArgument",
"GraphQLArgumentMap",
"GraphQLEnumValue",
"GraphQLEnumValueMap",
"GraphQLField",
"GraphQLFieldMap",
"GraphQLInputField",
"GraphQLInputFieldMap",
"GraphQLScalarSerializer",
"GraphQLScalarValueParser",
"GraphQLScalarLiteralParser",
"GraphQLArgumentKwargs",
"GraphQLEnumTypeKwargs",
"GraphQLEnumValueKwargs",
"GraphQLFieldKwargs",
"GraphQLInputFieldKwargs",
"GraphQLInputObjectTypeKwargs",
"GraphQLInterfaceTypeKwargs",
"GraphQLNamedTypeKwargs",
"GraphQLObjectTypeKwargs",
"GraphQLScalarTypeKwargs",
"GraphQLUnionTypeKwargs",
"GraphQLFieldResolver",
"GraphQLTypeResolver",
"GraphQLIsTypeOfFn",
"GraphQLResolveInfo",
"ResponsePath",
"is_directive",
"assert_directive",
"is_specified_directive",
"specified_directives",
"GraphQLDirective",
"GraphQLIncludeDirective",
"GraphQLSkipDirective",
"GraphQLDeprecatedDirective",
"GraphQLSpecifiedByDirective",
"GraphQLDirectiveKwargs",
"DEFAULT_DEPRECATION_REASON",
"is_specified_scalar_type",
"specified_scalar_types",
"GraphQLInt",
"GraphQLFloat",
"GraphQLString",
"GraphQLBoolean",
"GraphQLID",
"GRAPHQL_MAX_INT",
"GRAPHQL_MIN_INT",
"is_introspection_type",
"introspection_types",
"TypeKind",
"SchemaMetaFieldDef",
"TypeMetaFieldDef",
"TypeNameMetaFieldDef",
"validate_schema",
"assert_valid_schema",
]
@@ -0,0 +1,29 @@
from ..error import GraphQLError
from ..language.character_classes import is_name_start, is_name_continue
__all__ = ["assert_name", "assert_enum_value_name"]
def assert_name(name: str) -> str:
"""Uphold the spec rules about naming."""
if name is None:
raise TypeError("Must provide name.")
if not isinstance(name, str):
raise TypeError("Expected name to be a string.")
if not name:
raise GraphQLError("Expected name to be a non-empty string.")
if not all(is_name_continue(char) for char in name[1:]):
raise GraphQLError(
f"Names must only contain [_a-zA-Z0-9] but {name!r} does not."
)
if not is_name_start(name[0]):
raise GraphQLError(f"Names must start with [_a-zA-Z] but {name!r} does not.")
return name
def assert_enum_value_name(name: str) -> str:
"""Uphold the spec rules about naming enum values."""
assert_name(name)
if name in {"true", "false", "null"}:
raise GraphQLError(f"Enum values cannot be named: {name}.")
return name
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,255 @@
from typing import Any, Collection, Dict, Optional, Tuple, cast
from ..language import DirectiveLocation, ast
from ..pyutils import inspect, is_description
from .assert_name import assert_name
from .definition import GraphQLArgument, GraphQLInputType, GraphQLNonNull, is_input_type
from .scalars import GraphQLBoolean, GraphQLString
try:
from typing import TypedDict
except ImportError: # Python < 3.8
from typing_extensions import TypedDict
__all__ = [
"is_directive",
"assert_directive",
"is_specified_directive",
"specified_directives",
"GraphQLDirective",
"GraphQLDirectiveKwargs",
"GraphQLIncludeDirective",
"GraphQLSkipDirective",
"GraphQLDeprecatedDirective",
"GraphQLSpecifiedByDirective",
"DirectiveLocation",
"DEFAULT_DEPRECATION_REASON",
]
class GraphQLDirectiveKwargs(TypedDict, total=False):
name: str
locations: Tuple[DirectiveLocation, ...]
args: Dict[str, GraphQLArgument]
is_repeatable: bool
description: Optional[str]
extensions: Dict[str, Any]
ast_node: Optional[ast.DirectiveDefinitionNode]
class GraphQLDirective:
"""GraphQL Directive
Directives are used by the GraphQL runtime as a way of modifying execution behavior.
Type system creators will usually not create these directly.
"""
name: str
locations: Tuple[DirectiveLocation, ...]
is_repeatable: bool
args: Dict[str, GraphQLArgument]
description: Optional[str]
extensions: Dict[str, Any]
ast_node: Optional[ast.DirectiveDefinitionNode]
def __init__(
self,
name: str,
locations: Collection[DirectiveLocation],
args: Optional[Dict[str, GraphQLArgument]] = None,
is_repeatable: bool = False,
description: Optional[str] = None,
extensions: Optional[Dict[str, Any]] = None,
ast_node: Optional[ast.DirectiveDefinitionNode] = None,
) -> None:
assert_name(name)
try:
locations = tuple(
(
value
if isinstance(value, DirectiveLocation)
else DirectiveLocation[cast(str, value)]
)
for value in locations
)
except (KeyError, TypeError):
raise TypeError(
f"{name} locations must be specified"
" as a collection of DirectiveLocation enum values."
)
if args is None:
args = {}
elif not isinstance(args, dict) or not all(
isinstance(key, str) for key in args
):
raise TypeError(f"{name} args must be a dict with argument names as keys.")
elif not all(
isinstance(value, GraphQLArgument) or is_input_type(value)
for value in args.values()
):
raise TypeError(
f"{name} args must be GraphQLArgument or input type objects."
)
else:
args = {
assert_name(name): (
value
if isinstance(value, GraphQLArgument)
else GraphQLArgument(cast(GraphQLInputType, value))
)
for name, value in args.items()
}
if not isinstance(is_repeatable, bool):
raise TypeError(f"{name} is_repeatable flag must be True or False.")
if ast_node and not isinstance(ast_node, ast.DirectiveDefinitionNode):
raise TypeError(f"{name} AST node must be a DirectiveDefinitionNode.")
if description is not None and not is_description(description):
raise TypeError(f"{name} description must be a string.")
if extensions is None:
extensions = {}
elif not isinstance(extensions, dict) or not all(
isinstance(key, str) for key in extensions
):
raise TypeError(f"{name} extensions must be a dictionary with string keys.")
self.name = name
self.locations = locations
self.args = args
self.is_repeatable = is_repeatable
self.description = description
self.extensions = extensions
self.ast_node = ast_node
def __str__(self) -> str:
return f"@{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}({self})>"
def __eq__(self, other: Any) -> bool:
return self is other or (
isinstance(other, GraphQLDirective)
and self.name == other.name
and self.locations == other.locations
and self.args == other.args
and self.is_repeatable == other.is_repeatable
and self.description == other.description
and self.extensions == other.extensions
)
def to_kwargs(self) -> GraphQLDirectiveKwargs:
return GraphQLDirectiveKwargs(
name=self.name,
locations=self.locations,
args=self.args,
is_repeatable=self.is_repeatable,
description=self.description,
extensions=self.extensions,
ast_node=self.ast_node,
)
def __copy__(self) -> "GraphQLDirective": # pragma: no cover
return self.__class__(**self.to_kwargs())
def is_directive(directive: Any) -> bool:
"""Test if the given value is a GraphQL directive."""
return isinstance(directive, GraphQLDirective)
def assert_directive(directive: Any) -> GraphQLDirective:
if not is_directive(directive):
raise TypeError(f"Expected {inspect(directive)} to be a GraphQL directive.")
return cast(GraphQLDirective, directive)
# Used to conditionally include fields or fragments.
GraphQLIncludeDirective = GraphQLDirective(
name="include",
locations=[
DirectiveLocation.FIELD,
DirectiveLocation.FRAGMENT_SPREAD,
DirectiveLocation.INLINE_FRAGMENT,
],
args={
"if": GraphQLArgument(
GraphQLNonNull(GraphQLBoolean), description="Included when true."
)
},
description="Directs the executor to include this field or fragment"
" only when the `if` argument is true.",
)
# Used to conditionally skip (exclude) fields or fragments:
GraphQLSkipDirective = GraphQLDirective(
name="skip",
locations=[
DirectiveLocation.FIELD,
DirectiveLocation.FRAGMENT_SPREAD,
DirectiveLocation.INLINE_FRAGMENT,
],
args={
"if": GraphQLArgument(
GraphQLNonNull(GraphQLBoolean), description="Skipped when true."
)
},
description="Directs the executor to skip this field or fragment"
" when the `if` argument is true.",
)
# Constant string used for default reason for a deprecation:
DEFAULT_DEPRECATION_REASON = "No longer supported"
# Used to declare element of a GraphQL schema as deprecated:
GraphQLDeprecatedDirective = GraphQLDirective(
name="deprecated",
locations=[
DirectiveLocation.FIELD_DEFINITION,
DirectiveLocation.ARGUMENT_DEFINITION,
DirectiveLocation.INPUT_FIELD_DEFINITION,
DirectiveLocation.ENUM_VALUE,
],
args={
"reason": GraphQLArgument(
GraphQLString,
description="Explains why this element was deprecated,"
" usually also including a suggestion for how to access"
" supported similar data."
" Formatted using the Markdown syntax, as specified by"
" [CommonMark](https://commonmark.org/).",
default_value=DEFAULT_DEPRECATION_REASON,
)
},
description="Marks an element of a GraphQL schema as no longer supported.",
)
# Used to provide a URL for specifying the behavior of custom scalar definitions:
GraphQLSpecifiedByDirective = GraphQLDirective(
name="specifiedBy",
locations=[DirectiveLocation.SCALAR],
args={
"url": GraphQLArgument(
GraphQLNonNull(GraphQLString),
description="The URL that specifies the behavior of this scalar.",
)
},
description="Exposes a URL that specifies the behavior of this scalar.",
)
specified_directives: Tuple[GraphQLDirective, ...] = (
GraphQLIncludeDirective,
GraphQLSkipDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
)
"""A tuple with all directives from the GraphQL specification"""
def is_specified_directive(directive: GraphQLDirective) -> bool:
"""Check whether the given directive is one of the specified directives."""
return any(
specified_directive.name == directive.name
for specified_directive in specified_directives
)
@@ -0,0 +1,618 @@
from enum import Enum
from typing import Mapping
from .definition import (
GraphQLArgument,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLField,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLObjectType,
is_abstract_type,
is_enum_type,
is_input_object_type,
is_interface_type,
is_list_type,
is_non_null_type,
is_object_type,
is_scalar_type,
is_union_type,
)
from ..language import DirectiveLocation, print_ast
from ..pyutils import inspect
from .scalars import GraphQLBoolean, GraphQLString
__all__ = [
"SchemaMetaFieldDef",
"TypeKind",
"TypeMetaFieldDef",
"TypeNameMetaFieldDef",
"introspection_types",
"is_introspection_type",
]
__Schema: GraphQLObjectType = GraphQLObjectType(
name="__Schema",
description="A GraphQL Schema defines the capabilities of a GraphQL"
" server. It exposes all available types and directives"
" on the server, as well as the entry points for query,"
" mutation, and subscription operations.",
fields=lambda: {
"description": GraphQLField(
GraphQLString, resolve=lambda schema, _info: schema.description
),
"types": GraphQLField(
GraphQLNonNull(GraphQLList(GraphQLNonNull(__Type))),
resolve=lambda schema, _info: schema.type_map.values(),
description="A list of all types supported by this server.",
),
"queryType": GraphQLField(
GraphQLNonNull(__Type),
resolve=lambda schema, _info: schema.query_type,
description="The type that query operations will be rooted at.",
),
"mutationType": GraphQLField(
__Type,
resolve=lambda schema, _info: schema.mutation_type,
description="If this server supports mutation, the type that"
" mutation operations will be rooted at.",
),
"subscriptionType": GraphQLField(
__Type,
resolve=lambda schema, _info: schema.subscription_type,
description="If this server support subscription, the type that"
" subscription operations will be rooted at.",
),
"directives": GraphQLField(
GraphQLNonNull(GraphQLList(GraphQLNonNull(__Directive))),
resolve=lambda schema, _info: schema.directives,
description="A list of all directives supported by this server.",
),
},
)
__Directive: GraphQLObjectType = GraphQLObjectType(
name="__Directive",
description="A Directive provides a way to describe alternate runtime"
" execution and type validation behavior in a GraphQL"
" document.\n\nIn some cases, you need to provide options"
" to alter GraphQL's execution behavior in ways field"
" arguments will not suffice, such as conditionally including"
" or skipping a field. Directives provide this by describing"
" additional information to the executor.",
fields=lambda: {
# Note: The fields onOperation, onFragment and onField are deprecated
"name": GraphQLField(
GraphQLNonNull(GraphQLString),
resolve=DirectiveResolvers.name,
),
"description": GraphQLField(
GraphQLString,
resolve=DirectiveResolvers.description,
),
"isRepeatable": GraphQLField(
GraphQLNonNull(GraphQLBoolean),
resolve=DirectiveResolvers.is_repeatable,
),
"locations": GraphQLField(
GraphQLNonNull(GraphQLList(GraphQLNonNull(__DirectiveLocation))),
resolve=DirectiveResolvers.locations,
),
"args": GraphQLField(
GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))),
args={
"includeDeprecated": GraphQLArgument(
GraphQLBoolean, default_value=False
)
},
resolve=DirectiveResolvers.args,
),
},
)
class DirectiveResolvers:
@staticmethod
def name(directive, _info):
return directive.name
@staticmethod
def description(directive, _info):
return directive.description
@staticmethod
def is_repeatable(directive, _info):
return directive.is_repeatable
@staticmethod
def locations(directive, _info):
return directive.locations
# noinspection PyPep8Naming
@staticmethod
def args(directive, _info, includeDeprecated=False):
items = directive.args.items()
return (
list(items)
if includeDeprecated
else [item for item in items if item[1].deprecation_reason is None]
)
__DirectiveLocation: GraphQLEnumType = GraphQLEnumType(
name="__DirectiveLocation",
description="A Directive can be adjacent to many parts of the GraphQL"
" language, a __DirectiveLocation describes one such possible"
" adjacencies.",
values={
"QUERY": GraphQLEnumValue(
DirectiveLocation.QUERY,
description="Location adjacent to a query operation.",
),
"MUTATION": GraphQLEnumValue(
DirectiveLocation.MUTATION,
description="Location adjacent to a mutation operation.",
),
"SUBSCRIPTION": GraphQLEnumValue(
DirectiveLocation.SUBSCRIPTION,
description="Location adjacent to a subscription operation.",
),
"FIELD": GraphQLEnumValue(
DirectiveLocation.FIELD, description="Location adjacent to a field."
),
"FRAGMENT_DEFINITION": GraphQLEnumValue(
DirectiveLocation.FRAGMENT_DEFINITION,
description="Location adjacent to a fragment definition.",
),
"FRAGMENT_SPREAD": GraphQLEnumValue(
DirectiveLocation.FRAGMENT_SPREAD,
description="Location adjacent to a fragment spread.",
),
"INLINE_FRAGMENT": GraphQLEnumValue(
DirectiveLocation.INLINE_FRAGMENT,
description="Location adjacent to an inline fragment.",
),
"VARIABLE_DEFINITION": GraphQLEnumValue(
DirectiveLocation.VARIABLE_DEFINITION,
description="Location adjacent to a variable definition.",
),
"SCHEMA": GraphQLEnumValue(
DirectiveLocation.SCHEMA,
description="Location adjacent to a schema definition.",
),
"SCALAR": GraphQLEnumValue(
DirectiveLocation.SCALAR,
description="Location adjacent to a scalar definition.",
),
"OBJECT": GraphQLEnumValue(
DirectiveLocation.OBJECT,
description="Location adjacent to an object type definition.",
),
"FIELD_DEFINITION": GraphQLEnumValue(
DirectiveLocation.FIELD_DEFINITION,
description="Location adjacent to a field definition.",
),
"ARGUMENT_DEFINITION": GraphQLEnumValue(
DirectiveLocation.ARGUMENT_DEFINITION,
description="Location adjacent to an argument definition.",
),
"INTERFACE": GraphQLEnumValue(
DirectiveLocation.INTERFACE,
description="Location adjacent to an interface definition.",
),
"UNION": GraphQLEnumValue(
DirectiveLocation.UNION,
description="Location adjacent to a union definition.",
),
"ENUM": GraphQLEnumValue(
DirectiveLocation.ENUM,
description="Location adjacent to an enum definition.",
),
"ENUM_VALUE": GraphQLEnumValue(
DirectiveLocation.ENUM_VALUE,
description="Location adjacent to an enum value definition.",
),
"INPUT_OBJECT": GraphQLEnumValue(
DirectiveLocation.INPUT_OBJECT,
description="Location adjacent to an input object type definition.",
),
"INPUT_FIELD_DEFINITION": GraphQLEnumValue(
DirectiveLocation.INPUT_FIELD_DEFINITION,
description="Location adjacent to an input object field definition.",
),
},
)
__Type: GraphQLObjectType = GraphQLObjectType(
name="__Type",
description="The fundamental unit of any GraphQL Schema is the type."
" There are many kinds of types in GraphQL as represented"
" by the `__TypeKind` enum.\n\nDepending on the kind of a"
" type, certain fields describe information about that type."
" Scalar types provide no information beyond a name, description"
" and optional `specifiedByURL`, while Enum types provide their values."
" Object and Interface types provide the fields they describe."
" Abstract types, Union and Interface, provide the Object"
" types possible at runtime. List and NonNull types compose"
" other types.",
fields=lambda: {
"kind": GraphQLField(GraphQLNonNull(__TypeKind), resolve=TypeResolvers.kind),
"name": GraphQLField(GraphQLString, resolve=TypeResolvers.name),
"description": GraphQLField(GraphQLString, resolve=TypeResolvers.description),
"specifiedByURL": GraphQLField(
GraphQLString, resolve=TypeResolvers.specified_by_url
),
"fields": GraphQLField(
GraphQLList(GraphQLNonNull(__Field)),
args={
"includeDeprecated": GraphQLArgument(
GraphQLBoolean, default_value=False
)
},
resolve=TypeResolvers.fields,
),
"interfaces": GraphQLField(
GraphQLList(GraphQLNonNull(__Type)), resolve=TypeResolvers.interfaces
),
"possibleTypes": GraphQLField(
GraphQLList(GraphQLNonNull(__Type)),
resolve=TypeResolvers.possible_types,
),
"enumValues": GraphQLField(
GraphQLList(GraphQLNonNull(__EnumValue)),
args={
"includeDeprecated": GraphQLArgument(
GraphQLBoolean, default_value=False
)
},
resolve=TypeResolvers.enum_values,
),
"inputFields": GraphQLField(
GraphQLList(GraphQLNonNull(__InputValue)),
args={
"includeDeprecated": GraphQLArgument(
GraphQLBoolean, default_value=False
)
},
resolve=TypeResolvers.input_fields,
),
"ofType": GraphQLField(__Type, resolve=TypeResolvers.of_type),
},
)
class TypeResolvers:
@staticmethod
def kind(type_, _info):
if is_scalar_type(type_):
return TypeKind.SCALAR
if is_object_type(type_):
return TypeKind.OBJECT
if is_interface_type(type_):
return TypeKind.INTERFACE
if is_union_type(type_):
return TypeKind.UNION
if is_enum_type(type_):
return TypeKind.ENUM
if is_input_object_type(type_):
return TypeKind.INPUT_OBJECT
if is_list_type(type_):
return TypeKind.LIST
if is_non_null_type(type_):
return TypeKind.NON_NULL
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.") # pragma: no cover
@staticmethod
def name(type_, _info):
return getattr(type_, "name", None)
@staticmethod
def description(type_, _info):
return getattr(type_, "description", None)
@staticmethod
def specified_by_url(type_, _info):
return getattr(type_, "specified_by_url", None)
# noinspection PyPep8Naming
@staticmethod
def fields(type_, _info, includeDeprecated=False):
if is_object_type(type_) or is_interface_type(type_):
items = type_.fields.items()
return (
list(items)
if includeDeprecated
else [item for item in items if item[1].deprecation_reason is None]
)
@staticmethod
def interfaces(type_, _info):
if is_object_type(type_) or is_interface_type(type_):
return type_.interfaces
@staticmethod
def possible_types(type_, info):
if is_abstract_type(type_):
return info.schema.get_possible_types(type_)
# noinspection PyPep8Naming
@staticmethod
def enum_values(type_, _info, includeDeprecated=False):
if is_enum_type(type_):
items = type_.values.items()
return (
items
if includeDeprecated
else [item for item in items if item[1].deprecation_reason is None]
)
# noinspection PyPep8Naming
@staticmethod
def input_fields(type_, _info, includeDeprecated=False):
if is_input_object_type(type_):
items = type_.fields.items()
return (
items
if includeDeprecated
else [item for item in items if item[1].deprecation_reason is None]
)
@staticmethod
def of_type(type_, _info):
return getattr(type_, "of_type", None)
__Field: GraphQLObjectType = GraphQLObjectType(
name="__Field",
description="Object and Interface types are described by a list of Fields,"
" each of which has a name, potentially a list of arguments,"
" and a return type.",
fields=lambda: {
"name": GraphQLField(
GraphQLNonNull(GraphQLString), resolve=FieldResolvers.name
),
"description": GraphQLField(GraphQLString, resolve=FieldResolvers.description),
"args": GraphQLField(
GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))),
args={
"includeDeprecated": GraphQLArgument(
GraphQLBoolean, default_value=False
)
},
resolve=FieldResolvers.args,
),
"type": GraphQLField(GraphQLNonNull(__Type), resolve=FieldResolvers.type),
"isDeprecated": GraphQLField(
GraphQLNonNull(GraphQLBoolean),
resolve=FieldResolvers.is_deprecated,
),
"deprecationReason": GraphQLField(
GraphQLString, resolve=FieldResolvers.deprecation_reason
),
},
)
class FieldResolvers:
@staticmethod
def name(item, _info):
return item[0]
@staticmethod
def description(item, _info):
return item[1].description
# noinspection PyPep8Naming
@staticmethod
def args(item, _info, includeDeprecated=False):
items = item[1].args.items()
return (
items
if includeDeprecated
else [item for item in items if item[1].deprecation_reason is None]
)
@staticmethod
def type(item, _info):
return item[1].type
@staticmethod
def is_deprecated(item, _info):
return item[1].deprecation_reason is not None
@staticmethod
def deprecation_reason(item, _info):
return item[1].deprecation_reason
__InputValue: GraphQLObjectType = GraphQLObjectType(
name="__InputValue",
description="Arguments provided to Fields or Directives and the input"
" fields of an InputObject are represented as Input Values"
" which describe their type and optionally a default value.",
fields=lambda: {
"name": GraphQLField(
GraphQLNonNull(GraphQLString), resolve=InputValueFieldResolvers.name
),
"description": GraphQLField(
GraphQLString, resolve=InputValueFieldResolvers.description
),
"type": GraphQLField(
GraphQLNonNull(__Type), resolve=InputValueFieldResolvers.type
),
"defaultValue": GraphQLField(
GraphQLString,
description="A GraphQL-formatted string representing"
" the default value for this input value.",
resolve=InputValueFieldResolvers.default_value,
),
"isDeprecated": GraphQLField(
GraphQLNonNull(GraphQLBoolean),
resolve=InputValueFieldResolvers.is_deprecated,
),
"deprecationReason": GraphQLField(
GraphQLString, resolve=InputValueFieldResolvers.deprecation_reason
),
},
)
class InputValueFieldResolvers:
@staticmethod
def name(item, _info):
return item[0]
@staticmethod
def description(item, _info):
return item[1].description
@staticmethod
def type(item, _info):
return item[1].type
@staticmethod
def default_value(item, _info):
# Since ast_from_value needs graphql.type, it can only be imported later
from ..utilities import ast_from_value
value_ast = ast_from_value(item[1].default_value, item[1].type)
return print_ast(value_ast) if value_ast else None
@staticmethod
def is_deprecated(item, _info):
return item[1].deprecation_reason is not None
@staticmethod
def deprecation_reason(item, _info):
return item[1].deprecation_reason
__EnumValue: GraphQLObjectType = GraphQLObjectType(
name="__EnumValue",
description="One possible value for a given Enum. Enum values are unique"
" values, not a placeholder for a string or numeric value."
" However an Enum value is returned in a JSON response as a"
" string.",
fields=lambda: {
"name": GraphQLField(
GraphQLNonNull(GraphQLString), resolve=lambda item, _info: item[0]
),
"description": GraphQLField(
GraphQLString, resolve=lambda item, _info: item[1].description
),
"isDeprecated": GraphQLField(
GraphQLNonNull(GraphQLBoolean),
resolve=lambda item, _info: item[1].deprecation_reason is not None,
),
"deprecationReason": GraphQLField(
GraphQLString, resolve=lambda item, _info: item[1].deprecation_reason
),
},
)
class TypeKind(Enum):
SCALAR = "scalar"
OBJECT = "object"
INTERFACE = "interface"
UNION = "union"
ENUM = "enum"
INPUT_OBJECT = "input object"
LIST = "list"
NON_NULL = "non-null"
__TypeKind: GraphQLEnumType = GraphQLEnumType(
name="__TypeKind",
description="An enum describing what kind of type a given `__Type` is.",
values={
"SCALAR": GraphQLEnumValue(
TypeKind.SCALAR, description="Indicates this type is a scalar."
),
"OBJECT": GraphQLEnumValue(
TypeKind.OBJECT,
description="Indicates this type is an object."
" `fields` and `interfaces` are valid fields.",
),
"INTERFACE": GraphQLEnumValue(
TypeKind.INTERFACE,
description="Indicates this type is an interface."
" `fields`, `interfaces`, and `possibleTypes` are valid fields.",
),
"UNION": GraphQLEnumValue(
TypeKind.UNION,
description="Indicates this type is a union."
" `possibleTypes` is a valid field.",
),
"ENUM": GraphQLEnumValue(
TypeKind.ENUM,
description="Indicates this type is an enum."
" `enumValues` is a valid field.",
),
"INPUT_OBJECT": GraphQLEnumValue(
TypeKind.INPUT_OBJECT,
description="Indicates this type is an input object."
" `inputFields` is a valid field.",
),
"LIST": GraphQLEnumValue(
TypeKind.LIST,
description="Indicates this type is a list. `ofType` is a valid field.",
),
"NON_NULL": GraphQLEnumValue(
TypeKind.NON_NULL,
description="Indicates this type is a non-null."
" `ofType` is a valid field.",
),
},
)
SchemaMetaFieldDef = GraphQLField(
GraphQLNonNull(__Schema), # name = '__schema'
description="Access the current type schema of this server.",
args={},
resolve=lambda _source, info: info.schema,
)
TypeMetaFieldDef = GraphQLField(
__Type, # name = '__type'
description="Request the type information of a single type.",
args={"name": GraphQLArgument(GraphQLNonNull(GraphQLString))},
resolve=lambda _source, info, **args: info.schema.get_type(args["name"]),
)
TypeNameMetaFieldDef = GraphQLField(
GraphQLNonNull(GraphQLString), # name='__typename'
description="The name of the current Object type at runtime.",
args={},
resolve=lambda _source, info, **_args: info.parent_type.name,
)
# Since double underscore names are subject to name mangling in Python,
# the introspection classes are best imported via this dictionary:
introspection_types: Mapping[str, GraphQLNamedType] = { # treat as read-only
"__Schema": __Schema,
"__Directive": __Directive,
"__DirectiveLocation": __DirectiveLocation,
"__Type": __Type,
"__Field": __Field,
"__InputValue": __InputValue,
"__EnumValue": __EnumValue,
"__TypeKind": __TypeKind,
}
"""A mapping containing all introspection types with their names as keys"""
def is_introspection_type(type_: GraphQLNamedType) -> bool:
"""Check whether the given named GraphQL type is an introspection type."""
return type_.name in introspection_types
@@ -0,0 +1,321 @@
from math import isfinite
from typing import Any, Mapping
from ..error import GraphQLError
from ..pyutils import inspect
from ..language.ast import (
BooleanValueNode,
FloatValueNode,
IntValueNode,
StringValueNode,
ValueNode,
)
from ..language.printer import print_ast
from .definition import GraphQLNamedType, GraphQLScalarType
__all__ = [
"is_specified_scalar_type",
"specified_scalar_types",
"GraphQLInt",
"GraphQLFloat",
"GraphQLString",
"GraphQLBoolean",
"GraphQLID",
"GRAPHQL_MAX_INT",
"GRAPHQL_MIN_INT",
]
# As per the GraphQL Spec, Integers are only treated as valid
# when they can be represented as a 32-bit signed integer,
# providing the broadest support across platforms.
# n.b. JavaScript's numbers are safe between -(2^53 - 1) and 2^53 - 1
# because they are internally represented as IEEE 754 doubles,
# while Python's integers may be arbitrarily large.
GRAPHQL_MAX_INT = 2_147_483_647
"""Maximum possible Int value as per GraphQL Spec (32-bit signed integer)"""
GRAPHQL_MIN_INT = -2_147_483_648
"""Minimum possible Int value as per GraphQL Spec (32-bit signed integer)"""
def serialize_int(output_value: Any) -> int:
if isinstance(output_value, bool):
return 1 if output_value else 0
try:
if isinstance(output_value, int):
num = output_value
elif isinstance(output_value, float):
num = int(output_value)
if num != output_value:
raise ValueError
elif not output_value and isinstance(output_value, str):
output_value = ""
raise ValueError
else:
num = int(output_value) # raises ValueError if not an integer
except (OverflowError, ValueError, TypeError):
raise GraphQLError(
"Int cannot represent non-integer value: " + inspect(output_value)
)
if not GRAPHQL_MIN_INT <= num <= GRAPHQL_MAX_INT:
raise GraphQLError(
"Int cannot represent non 32-bit signed integer value: "
+ inspect(output_value)
)
return num
def coerce_int(input_value: Any) -> int:
if not (
isinstance(input_value, int) and not isinstance(input_value, bool)
) and not (
isinstance(input_value, float)
and isfinite(input_value)
and int(input_value) == input_value
):
raise GraphQLError(
"Int cannot represent non-integer value: " + inspect(input_value)
)
if not GRAPHQL_MIN_INT <= input_value <= GRAPHQL_MAX_INT:
raise GraphQLError(
"Int cannot represent non 32-bit signed integer value: "
+ inspect(input_value)
)
return int(input_value)
def parse_int_literal(value_node: ValueNode, _variables: Any = None) -> int:
"""Parse an integer value node in the AST."""
if not isinstance(value_node, IntValueNode):
raise GraphQLError(
"Int cannot represent non-integer value: " + print_ast(value_node),
value_node,
)
num = int(value_node.value)
if not GRAPHQL_MIN_INT <= num <= GRAPHQL_MAX_INT:
raise GraphQLError(
"Int cannot represent non 32-bit signed integer value: "
+ print_ast(value_node),
value_node,
)
return num
GraphQLInt = GraphQLScalarType(
name="Int",
description="The `Int` scalar type represents"
" non-fractional signed whole numeric values."
" Int can represent values between -(2^31) and 2^31 - 1.",
serialize=serialize_int,
parse_value=coerce_int,
parse_literal=parse_int_literal,
)
def serialize_float(output_value: Any) -> float:
if isinstance(output_value, bool):
return 1 if output_value else 0
try:
if not output_value and isinstance(output_value, str):
output_value = ""
raise ValueError
num = output_value if isinstance(output_value, float) else float(output_value)
if not isfinite(num):
raise ValueError
except (ValueError, TypeError):
raise GraphQLError(
"Float cannot represent non numeric value: " + inspect(output_value)
)
return num
def coerce_float(input_value: Any) -> float:
if not (
isinstance(input_value, int) and not isinstance(input_value, bool)
) and not (isinstance(input_value, float) and isfinite(input_value)):
raise GraphQLError(
"Float cannot represent non numeric value: " + inspect(input_value)
)
return float(input_value)
def parse_float_literal(value_node: ValueNode, _variables: Any = None) -> float:
"""Parse a float value node in the AST."""
if not isinstance(value_node, (FloatValueNode, IntValueNode)):
raise GraphQLError(
"Float cannot represent non numeric value: " + print_ast(value_node),
value_node,
)
return float(value_node.value)
GraphQLFloat = GraphQLScalarType(
name="Float",
description="The `Float` scalar type represents"
" signed double-precision fractional values"
" as specified by [IEEE 754]"
"(https://en.wikipedia.org/wiki/IEEE_floating_point).",
serialize=serialize_float,
parse_value=coerce_float,
parse_literal=parse_float_literal,
)
def serialize_string(output_value: Any) -> str:
if isinstance(output_value, str):
return output_value
if isinstance(output_value, bool):
return "true" if output_value else "false"
if isinstance(output_value, int) or (
isinstance(output_value, float) and isfinite(output_value)
):
return str(output_value)
# do not serialize builtin types as strings, but allow serialization of custom
# types via their `__str__` method
if type(output_value).__module__ == "builtins":
raise GraphQLError("String cannot represent value: " + inspect(output_value))
return str(output_value)
def coerce_string(input_value: Any) -> str:
if not isinstance(input_value, str):
raise GraphQLError(
"String cannot represent a non string value: " + inspect(input_value)
)
return input_value
def parse_string_literal(value_node: ValueNode, _variables: Any = None) -> str:
"""Parse a string value node in the AST."""
if not isinstance(value_node, StringValueNode):
raise GraphQLError(
"String cannot represent a non string value: " + print_ast(value_node),
value_node,
)
return value_node.value
GraphQLString = GraphQLScalarType(
name="String",
description="The `String` scalar type represents textual data,"
" represented as UTF-8 character sequences."
" The String type is most often used by GraphQL"
" to represent free-form human-readable text.",
serialize=serialize_string,
parse_value=coerce_string,
parse_literal=parse_string_literal,
)
def serialize_boolean(output_value: Any) -> bool:
if isinstance(output_value, bool):
return output_value
if isinstance(output_value, int) or (
isinstance(output_value, float) and isfinite(output_value)
):
return bool(output_value)
raise GraphQLError(
"Boolean cannot represent a non boolean value: " + inspect(output_value)
)
def coerce_boolean(input_value: Any) -> bool:
if not isinstance(input_value, bool):
raise GraphQLError(
"Boolean cannot represent a non boolean value: " + inspect(input_value)
)
return input_value
def parse_boolean_literal(value_node: ValueNode, _variables: Any = None) -> bool:
"""Parse a boolean value node in the AST."""
if not isinstance(value_node, BooleanValueNode):
raise GraphQLError(
"Boolean cannot represent a non boolean value: " + print_ast(value_node),
value_node,
)
return value_node.value
GraphQLBoolean = GraphQLScalarType(
name="Boolean",
description="The `Boolean` scalar type represents `true` or `false`.",
serialize=serialize_boolean,
parse_value=coerce_boolean,
parse_literal=parse_boolean_literal,
)
def serialize_id(output_value: Any) -> str:
if isinstance(output_value, str):
return output_value
if isinstance(output_value, int) and not isinstance(output_value, bool):
return str(output_value)
if (
isinstance(output_value, float)
and isfinite(output_value)
and int(output_value) == output_value
):
return str(int(output_value))
# do not serialize builtin types as IDs, but allow serialization of custom types
# via their `__str__` method
if type(output_value).__module__ == "builtins":
raise GraphQLError("ID cannot represent value: " + inspect(output_value))
return str(output_value)
def coerce_id(input_value: Any) -> str:
if isinstance(input_value, str):
return input_value
if isinstance(input_value, int) and not isinstance(input_value, bool):
return str(input_value)
if (
isinstance(input_value, float)
and isfinite(input_value)
and int(input_value) == input_value
):
return str(int(input_value))
raise GraphQLError("ID cannot represent value: " + inspect(input_value))
def parse_id_literal(value_node: ValueNode, _variables: Any = None) -> str:
"""Parse an ID value node in the AST."""
if not isinstance(value_node, (StringValueNode, IntValueNode)):
raise GraphQLError(
"ID cannot represent a non-string and non-integer value: "
+ print_ast(value_node),
value_node,
)
return value_node.value
GraphQLID = GraphQLScalarType(
name="ID",
description="The `ID` scalar type represents a unique identifier,"
" often used to refetch an object or as key for a cache."
" The ID type appears in a JSON response as a String; however,"
" it is not intended to be human-readable. When expected as an"
' input type, any string (such as `"4"`) or integer (such as'
" `4`) input value will be accepted as an ID.",
serialize=serialize_id,
parse_value=coerce_id,
parse_literal=parse_id_literal,
)
specified_scalar_types: Mapping[str, GraphQLScalarType] = {
type_.name: type_
for type_ in (
GraphQLString,
GraphQLInt,
GraphQLFloat,
GraphQLBoolean,
GraphQLID,
)
}
def is_specified_scalar_type(type_: GraphQLNamedType) -> bool:
"""Check whether the given named GraphQL type is a specified scalar type."""
return type_.name in specified_scalar_types
@@ -0,0 +1,503 @@
from copy import copy, deepcopy
from typing import (
Any,
Collection,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
cast,
)
from ..error import GraphQLError
from ..language import OperationType, ast
from ..pyutils import inspect, is_collection, is_description
from .definition import (
GraphQLAbstractType,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLNamedType,
GraphQLObjectType,
GraphQLType,
GraphQLUnionType,
GraphQLWrappingType,
get_named_type,
is_input_object_type,
is_interface_type,
is_object_type,
is_union_type,
is_wrapping_type,
)
from .directives import GraphQLDirective, is_directive, specified_directives
from .introspection import introspection_types
try:
from typing import TypedDict
except ImportError: # Python < 3.8
from typing_extensions import TypedDict
__all__ = ["GraphQLSchema", "GraphQLSchemaKwargs", "is_schema", "assert_schema"]
TypeMap = Dict[str, GraphQLNamedType]
class InterfaceImplementations(NamedTuple):
objects: List[GraphQLObjectType]
interfaces: List[GraphQLInterfaceType]
class GraphQLSchemaKwargs(TypedDict, total=False):
query: Optional[GraphQLObjectType]
mutation: Optional[GraphQLObjectType]
subscription: Optional[GraphQLObjectType]
types: Optional[Tuple[GraphQLNamedType, ...]]
directives: Tuple[GraphQLDirective, ...]
description: Optional[str]
extensions: Dict[str, Any]
ast_node: Optional[ast.SchemaDefinitionNode]
extension_ast_nodes: Tuple[ast.SchemaExtensionNode, ...]
assume_valid: bool
class GraphQLSchema:
"""Schema Definition
A Schema is created by supplying the root types of each type of operation, query
and mutation (optional). A schema definition is then supplied to the validator
and executor.
Schemas should be considered immutable once they are created. If you want to modify
a schema, modify the result of the ``to_kwargs()`` method and recreate the schema.
Example::
MyAppSchema = GraphQLSchema(
query=MyAppQueryRootType,
mutation=MyAppMutationRootType)
Note: When the schema is constructed, by default only the types that are
reachable by traversing the root types are included, other types must be
explicitly referenced.
Example::
character_interface = GraphQLInterfaceType('Character', ...)
human_type = GraphQLObjectType(
'Human', interfaces=[character_interface], ...)
droid_type = GraphQLObjectType(
'Droid', interfaces: [character_interface], ...)
schema = GraphQLSchema(
query=GraphQLObjectType('Query',
fields={'hero': GraphQLField(character_interface, ....)}),
...
# Since this schema references only the `Character` interface it's
# necessary to explicitly list the types that implement it if
# you want them to be included in the final schema.
types=[human_type, droid_type])
Note: If a list of ``directives`` is provided to GraphQLSchema, that will be the
exact list of directives represented and allowed. If ``directives`` is not provided,
then a default set of the specified directives (e.g. @include and @skip) will be
used. If you wish to provide *additional* directives to these specified directives,
you must explicitly declare them. Example::
MyAppSchema = GraphQLSchema(
...
directives=specified_directives + [my_custom_directive])
"""
query_type: Optional[GraphQLObjectType]
mutation_type: Optional[GraphQLObjectType]
subscription_type: Optional[GraphQLObjectType]
type_map: TypeMap
directives: Tuple[GraphQLDirective, ...]
description: Optional[str]
extensions: Dict[str, Any]
ast_node: Optional[ast.SchemaDefinitionNode]
extension_ast_nodes: Tuple[ast.SchemaExtensionNode, ...]
_implementations_map: Dict[str, InterfaceImplementations]
_sub_type_map: Dict[str, Set[str]]
_validation_errors: Optional[List[GraphQLError]]
def __init__(
self,
query: Optional[GraphQLObjectType] = None,
mutation: Optional[GraphQLObjectType] = None,
subscription: Optional[GraphQLObjectType] = None,
types: Optional[Collection[GraphQLNamedType]] = None,
directives: Optional[Collection[GraphQLDirective]] = None,
description: Optional[str] = None,
extensions: Optional[Dict[str, Any]] = None,
ast_node: Optional[ast.SchemaDefinitionNode] = None,
extension_ast_nodes: Optional[Collection[ast.SchemaExtensionNode]] = None,
assume_valid: bool = False,
) -> None:
"""Initialize GraphQL schema.
If this schema was built from a source known to be valid, then it may be marked
with ``assume_valid`` to avoid an additional type system validation.
"""
self._validation_errors = [] if assume_valid else None
# Check for common mistakes during construction to produce clear and early
# error messages, but we leave the specific tests for the validation.
if query and not isinstance(query, GraphQLType):
raise TypeError("Expected query to be a GraphQL type.")
if mutation and not isinstance(mutation, GraphQLType):
raise TypeError("Expected mutation to be a GraphQL type.")
if subscription and not isinstance(subscription, GraphQLType):
raise TypeError("Expected subscription to be a GraphQL type.")
if types is None:
types = []
else:
if not is_collection(types) or not all(
isinstance(type_, GraphQLType) for type_ in types
):
raise TypeError(
"Schema types must be specified as a collection of GraphQL types."
)
if directives is not None:
# noinspection PyUnresolvedReferences
if not is_collection(directives):
raise TypeError("Schema directives must be a collection.")
if not isinstance(directives, tuple):
directives = tuple(directives)
if description is not None and not is_description(description):
raise TypeError("Schema description must be a string.")
if extensions is None:
extensions = {}
elif not isinstance(extensions, dict) or not all(
isinstance(key, str) for key in extensions
):
raise TypeError("Schema extensions must be a dictionary with string keys.")
if ast_node and not isinstance(ast_node, ast.SchemaDefinitionNode):
raise TypeError("Schema AST node must be a SchemaDefinitionNode.")
if extension_ast_nodes:
if not is_collection(extension_ast_nodes) or not all(
isinstance(node, ast.SchemaExtensionNode)
for node in extension_ast_nodes
):
raise TypeError(
"Schema extension AST nodes must be specified"
" as a collection of SchemaExtensionNode instances."
)
if not isinstance(extension_ast_nodes, tuple):
extension_ast_nodes = tuple(extension_ast_nodes)
else:
extension_ast_nodes = ()
self.description = description
self.extensions = extensions
self.ast_node = ast_node
self.extension_ast_nodes = extension_ast_nodes
self.query_type = query
self.mutation_type = mutation
self.subscription_type = subscription
# Provide specified directives (e.g. @include and @skip) by default
self.directives = specified_directives if directives is None else directives
# To preserve order of user-provided types, we add first to add them to
# the set of "collected" types, so `collect_referenced_types` ignore them.
if types:
all_referenced_types = TypeSet.with_initial_types(types)
collect_referenced_types = all_referenced_types.collect_referenced_types
for type_ in types:
# When we are ready to process this type, we remove it from "collected"
# types and then add it together with all dependent types in the correct
# position.
del all_referenced_types[type_]
collect_referenced_types(type_)
else:
all_referenced_types = TypeSet()
collect_referenced_types = all_referenced_types.collect_referenced_types
if query:
collect_referenced_types(query)
if mutation:
collect_referenced_types(mutation)
if subscription:
collect_referenced_types(subscription)
for directive in self.directives:
# Directives are not validated until validate_schema() is called.
if is_directive(directive):
for arg in directive.args.values():
collect_referenced_types(arg.type)
collect_referenced_types(introspection_types["__Schema"])
# Storing the resulting map for reference by the schema.
type_map: TypeMap = {}
self.type_map = type_map
self._sub_type_map = {}
# Keep track of all implementations by interface name.
implementations_map: Dict[str, InterfaceImplementations] = {}
self._implementations_map = implementations_map
for named_type in all_referenced_types:
if not named_type:
continue
type_name = getattr(named_type, "name", None)
if not type_name:
raise TypeError(
"One of the provided types for building the Schema"
" is missing a name.",
)
if type_name in type_map:
raise TypeError(
"Schema must contain uniquely named types"
f" but contains multiple types named '{type_name}'."
)
type_map[type_name] = named_type
if is_interface_type(named_type):
named_type = cast(GraphQLInterfaceType, named_type)
# Store implementations by interface.
for iface in named_type.interfaces:
if is_interface_type(iface):
iface = cast(GraphQLInterfaceType, iface)
if iface.name in implementations_map:
implementations = implementations_map[iface.name]
else:
implementations = implementations_map[iface.name] = (
InterfaceImplementations(objects=[], interfaces=[])
)
implementations.interfaces.append(named_type)
elif is_object_type(named_type):
named_type = cast(GraphQLObjectType, named_type)
# Store implementations by objects.
for iface in named_type.interfaces:
if is_interface_type(iface):
iface = cast(GraphQLInterfaceType, iface)
if iface.name in implementations_map:
implementations = implementations_map[iface.name]
else:
implementations = implementations_map[iface.name] = (
InterfaceImplementations(objects=[], interfaces=[])
)
implementations.objects.append(named_type)
def to_kwargs(self) -> GraphQLSchemaKwargs:
return GraphQLSchemaKwargs(
query=self.query_type,
mutation=self.mutation_type,
subscription=self.subscription_type,
types=tuple(self.type_map.values()) or None,
directives=self.directives,
description=self.description,
extensions=self.extensions,
ast_node=self.ast_node,
extension_ast_nodes=self.extension_ast_nodes,
assume_valid=self._validation_errors is not None,
)
def __copy__(self) -> "GraphQLSchema": # pragma: no cover
return self.__class__(**self.to_kwargs())
def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema":
from ..type import (
is_introspection_type,
is_specified_directive,
is_specified_scalar_type,
)
type_map: TypeMap = {
name: copy(type_)
for name, type_ in self.type_map.items()
if not is_introspection_type(type_) and not is_specified_scalar_type(type_)
}
types = type_map.values()
for type_ in types:
remap_named_type(type_, type_map)
directives = [
directive if is_specified_directive(directive) else copy(directive)
for directive in self.directives
]
for directive in directives:
remap_directive(directive, type_map)
return self.__class__(
self.query_type and cast(GraphQLObjectType, type_map[self.query_type.name]),
self.mutation_type
and cast(GraphQLObjectType, type_map[self.mutation_type.name]),
self.subscription_type
and cast(GraphQLObjectType, type_map[self.subscription_type.name]),
types,
directives,
self.description,
extensions=deepcopy(self.extensions),
ast_node=deepcopy(self.ast_node),
extension_ast_nodes=deepcopy(self.extension_ast_nodes),
assume_valid=True,
)
def get_root_type(self, operation: OperationType) -> Optional[GraphQLObjectType]:
return getattr(self, f"{operation.value}_type")
def get_type(self, name: str) -> Optional[GraphQLNamedType]:
return self.type_map.get(name)
def get_possible_types(
self, abstract_type: GraphQLAbstractType
) -> List[GraphQLObjectType]:
"""Get list of all possible concrete types for given abstract type."""
return (
cast(GraphQLUnionType, abstract_type).types
if is_union_type(abstract_type)
else self.get_implementations(
cast(GraphQLInterfaceType, abstract_type)
).objects
)
def get_implementations(
self, interface_type: GraphQLInterfaceType
) -> InterfaceImplementations:
return self._implementations_map.get(
interface_type.name, InterfaceImplementations(objects=[], interfaces=[])
)
def is_sub_type(
self,
abstract_type: GraphQLAbstractType,
maybe_sub_type: GraphQLNamedType,
) -> bool:
"""Check whether a type is a subtype of a given abstract type."""
types = self._sub_type_map.get(abstract_type.name)
if types is None:
types = set()
add = types.add
if is_union_type(abstract_type):
for type_ in cast(GraphQLUnionType, abstract_type).types:
add(type_.name)
else:
implementations = self.get_implementations(
cast(GraphQLInterfaceType, abstract_type)
)
for type_ in implementations.objects:
add(type_.name)
for type_ in implementations.interfaces:
add(type_.name)
self._sub_type_map[abstract_type.name] = types
return maybe_sub_type.name in types
def get_directive(self, name: str) -> Optional[GraphQLDirective]:
for directive in self.directives:
if directive.name == name:
return directive
return None
@property
def validation_errors(self) -> Optional[List[GraphQLError]]:
return self._validation_errors
class TypeSet(Dict[GraphQLNamedType, None]):
"""An ordered set of types that can be collected starting from initial types."""
@classmethod
def with_initial_types(cls, types: Collection[GraphQLType]) -> "TypeSet":
return cast(TypeSet, super().fromkeys(types))
def collect_referenced_types(self, type_: GraphQLType) -> None:
"""Recursive function supplementing the type starting from an initial type."""
named_type = get_named_type(type_)
if named_type in self:
return
self[named_type] = None
collect_referenced_types = self.collect_referenced_types
if is_union_type(named_type):
named_type = cast(GraphQLUnionType, named_type)
for member_type in named_type.types:
collect_referenced_types(member_type)
elif is_object_type(named_type) or is_interface_type(named_type):
named_type = cast(
Union[GraphQLObjectType, GraphQLInterfaceType], named_type
)
for interface_type in named_type.interfaces:
collect_referenced_types(interface_type)
for field in named_type.fields.values():
collect_referenced_types(field.type)
for arg in field.args.values():
collect_referenced_types(arg.type)
elif is_input_object_type(named_type):
named_type = cast(GraphQLInputObjectType, named_type)
for field in named_type.fields.values():
collect_referenced_types(field.type)
def is_schema(schema: Any) -> bool:
"""Test if the given value is a GraphQL schema."""
return isinstance(schema, GraphQLSchema)
def assert_schema(schema: Any) -> GraphQLSchema:
if not is_schema(schema):
raise TypeError(f"Expected {inspect(schema)} to be a GraphQL schema.")
return cast(GraphQLSchema, schema)
def remapped_type(type_: GraphQLType, type_map: TypeMap) -> GraphQLType:
"""Get a copy of the given type that uses this type map."""
if is_wrapping_type(type_):
type_ = cast(GraphQLWrappingType, type_)
return type_.__class__(remapped_type(type_.of_type, type_map))
type_ = cast(GraphQLNamedType, type_)
return type_map.get(type_.name, type_)
def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
"""Change all references in the given named type to use this type map."""
if is_object_type(type_) or is_interface_type(type_):
type_ = cast(Union[GraphQLObjectType, GraphQLInterfaceType], type_)
type_.interfaces = [
type_map.get(interface_type.name, interface_type)
for interface_type in type_.interfaces
]
fields = type_.fields
for field_name, field in fields.items():
field = copy(field)
field.type = remapped_type(field.type, type_map)
args = field.args
for arg_name, arg in args.items():
arg = copy(arg)
arg.type = remapped_type(arg.type, type_map)
args[arg_name] = arg
fields[field_name] = field
elif is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
type_.types = [
type_map.get(member_type.name, member_type) for member_type in type_.types
]
elif is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
fields = type_.fields
for field_name, field in fields.items():
field = copy(field)
field.type = remapped_type(field.type, type_map)
fields[field_name] = field
def remap_directive(directive: GraphQLDirective, type_map: TypeMap) -> None:
"""Change all references in the given directive to use this type map."""
args = directive.args
for arg_name, arg in args.items():
arg = copy(arg) # noqa: PLW2901
arg.type = cast(GraphQLInputType, remapped_type(arg.type, type_map))
args[arg_name] = arg
@@ -0,0 +1,611 @@
from operator import attrgetter, itemgetter
from typing import (
Any,
Collection,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from ..error import GraphQLError
from ..pyutils import inspect
from ..language import (
DirectiveNode,
InputValueDefinitionNode,
NamedTypeNode,
Node,
OperationType,
SchemaDefinitionNode,
SchemaExtensionNode,
)
from .definition import (
GraphQLEnumType,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLUnionType,
is_enum_type,
is_input_object_type,
is_input_type,
is_interface_type,
is_named_type,
is_non_null_type,
is_object_type,
is_output_type,
is_union_type,
is_required_argument,
is_required_input_field,
)
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
from .directives import is_directive, GraphQLDeprecatedDirective
from .introspection import is_introspection_type
from .schema import GraphQLSchema, assert_schema
__all__ = ["validate_schema", "assert_valid_schema"]
def validate_schema(schema: GraphQLSchema) -> List[GraphQLError]:
"""Validate a GraphQL schema.
Implements the "Type Validation" sub-sections of the specification's "Type System"
section.
Validation runs synchronously, returning a list of encountered errors, or an empty
list if no errors were encountered and the Schema is valid.
"""
# First check to ensure the provided value is in fact a GraphQLSchema.
assert_schema(schema)
# If this Schema has already been validated, return the previous results.
# noinspection PyProtectedMember
errors = schema._validation_errors
if errors is None:
# Validate the schema, producing a list of errors.
context = SchemaValidationContext(schema)
context.validate_root_types()
context.validate_directives()
context.validate_types()
# Persist the results of validation before returning to ensure validation does
# not run multiple times for this schema.
errors = context.errors
schema._validation_errors = errors
return errors
def assert_valid_schema(schema: GraphQLSchema) -> None:
"""Utility function which asserts a schema is valid.
Throws a TypeError if the schema is invalid.
"""
errors = validate_schema(schema)
if errors:
raise TypeError("\n\n".join(error.message for error in errors))
class SchemaValidationContext:
"""Utility class providing a context for schema validation."""
errors: List[GraphQLError]
schema: GraphQLSchema
def __init__(self, schema: GraphQLSchema):
self.errors = []
self.schema = schema
def report_error(
self,
message: str,
nodes: Union[Optional[Node], Collection[Optional[Node]]] = None,
) -> None:
if nodes and not isinstance(nodes, Node):
nodes = [node for node in nodes if node]
nodes = cast(Optional[Collection[Node]], nodes)
self.errors.append(GraphQLError(message, nodes))
def validate_root_types(self) -> None:
schema = self.schema
query_type = schema.query_type
if not query_type:
self.report_error("Query root type must be provided.", schema.ast_node)
elif not is_object_type(query_type):
self.report_error(
f"Query root type must be Object type, it cannot be {query_type}.",
get_operation_type_node(schema, OperationType.QUERY)
or query_type.ast_node,
)
mutation_type = schema.mutation_type
if mutation_type and not is_object_type(mutation_type):
self.report_error(
"Mutation root type must be Object type if provided,"
f" it cannot be {mutation_type}.",
get_operation_type_node(schema, OperationType.MUTATION)
or mutation_type.ast_node,
)
subscription_type = schema.subscription_type
if subscription_type and not is_object_type(subscription_type):
self.report_error(
"Subscription root type must be Object type if provided,"
f" it cannot be {subscription_type}.",
get_operation_type_node(schema, OperationType.SUBSCRIPTION)
or subscription_type.ast_node,
)
def validate_directives(self) -> None:
directives = self.schema.directives
for directive in directives:
# Ensure all directives are in fact GraphQL directives.
if not is_directive(directive):
self.report_error(
f"Expected directive but got: {inspect(directive)}.",
getattr(directive, "ast_node", None),
)
continue
# Ensure they are named correctly.
self.validate_name(directive)
# Ensure the arguments are valid.
for arg_name, arg in directive.args.items():
# Ensure they are named correctly.
self.validate_name(arg, arg_name)
# Ensure the type is an input type.
if not is_input_type(arg.type):
self.report_error(
f"The type of @{directive.name}({arg_name}:)"
f" must be Input Type but got: {inspect(arg.type)}.",
arg.ast_node,
)
if is_required_argument(arg) and arg.deprecation_reason is not None:
self.report_error(
f"Required argument @{directive.name}({arg_name}:)"
" cannot be deprecated.",
[
get_deprecated_directive_node(arg.ast_node),
arg.ast_node and arg.ast_node.type,
],
)
def validate_name(self, node: Any, name: Optional[str] = None) -> None:
# Ensure names are valid, however introspection types opt out.
try:
if not name:
name = node.name
name = cast(str, name)
ast_node = node.ast_node
except AttributeError: # pragma: no cover
pass
else:
if name.startswith("__"):
self.report_error(
f"Name {name!r} must not begin with '__',"
" which is reserved by GraphQL introspection.",
ast_node,
)
def validate_types(self) -> None:
validate_input_object_circular_refs = InputObjectCircularRefsValidator(self)
for type_ in self.schema.type_map.values():
# Ensure all provided types are in fact GraphQL type.
if not is_named_type(type_):
self.report_error(
f"Expected GraphQL named type but got: {inspect(type_)}.",
type_.ast_node if is_named_type(type_) else None,
)
continue
# Ensure it is named correctly (excluding introspection types).
if not is_introspection_type(type_):
self.validate_name(type_)
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
# Ensure fields are valid
self.validate_fields(type_)
# Ensure objects implement the interfaces they claim to.
self.validate_interfaces(type_)
elif is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
# Ensure fields are valid.
self.validate_fields(type_)
# Ensure interfaces implement the interfaces they claim to.
self.validate_interfaces(type_)
elif is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
# Ensure Unions include valid member types.
self.validate_union_members(type_)
elif is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
# Ensure Enums have valid values.
self.validate_enum_values(type_)
elif is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
# Ensure Input Object fields are valid.
self.validate_input_fields(type_)
# Ensure Input Objects do not contain non-nullable circular references
validate_input_object_circular_refs(type_)
def validate_fields(
self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
) -> None:
fields = type_.fields
# Objects and Interfaces both must define one or more fields.
if not fields:
self.report_error(
f"Type {type_.name} must define one or more fields.",
[type_.ast_node, *type_.extension_ast_nodes],
)
for field_name, field in fields.items():
# Ensure they are named correctly.
self.validate_name(field, field_name)
# Ensure the type is an output type
if not is_output_type(field.type):
self.report_error(
f"The type of {type_.name}.{field_name}"
f" must be Output Type but got: {inspect(field.type)}.",
field.ast_node and field.ast_node.type,
)
# Ensure the arguments are valid.
for arg_name, arg in field.args.items():
# Ensure they are named correctly.
self.validate_name(arg, arg_name)
# Ensure the type is an input type.
if not is_input_type(arg.type):
self.report_error(
f"The type of {type_.name}.{field_name}({arg_name}:)"
f" must be Input Type but got: {inspect(arg.type)}.",
arg.ast_node and arg.ast_node.type,
)
if is_required_argument(arg) and arg.deprecation_reason is not None:
self.report_error(
f"Required argument {type_.name}.{field_name}({arg_name}:)"
" cannot be deprecated.",
[
get_deprecated_directive_node(arg.ast_node),
arg.ast_node and arg.ast_node.type,
],
)
def validate_interfaces(
self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
) -> None:
iface_type_names: Set[str] = set()
for iface in type_.interfaces:
if not is_interface_type(iface):
self.report_error(
f"Type {type_.name} must only implement Interface"
f" types, it cannot implement {inspect(iface)}.",
get_all_implements_interface_nodes(type_, iface),
)
continue
if type_ is iface:
self.report_error(
f"Type {type_.name} cannot implement itself"
" because it would create a circular reference.",
get_all_implements_interface_nodes(type_, iface),
)
if iface.name in iface_type_names:
self.report_error(
f"Type {type_.name} can only implement {iface.name} once.",
get_all_implements_interface_nodes(type_, iface),
)
continue
iface_type_names.add(iface.name)
self.validate_type_implements_ancestors(type_, iface)
self.validate_type_implements_interface(type_, iface)
def validate_type_implements_interface(
self,
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
iface: GraphQLInterfaceType,
) -> None:
type_fields, iface_fields = type_.fields, iface.fields
# Assert each interface field is implemented.
for field_name, iface_field in iface_fields.items():
type_field = type_fields.get(field_name)
# Assert interface field exists on object.
if not type_field:
self.report_error(
f"Interface field {iface.name}.{field_name}"
f" expected but {type_.name} does not provide it.",
[
iface_field.ast_node,
type_.ast_node,
*type_.extension_ast_nodes,
],
)
continue
# Assert interface field type is satisfied by type field type, by being
# a valid subtype (covariant).
if not is_type_sub_type_of(self.schema, type_field.type, iface_field.type):
self.report_error(
f"Interface field {iface.name}.{field_name}"
f" expects type {iface_field.type}"
f" but {type_.name}.{field_name}"
f" is type {type_field.type}.",
[
iface_field.ast_node and iface_field.ast_node.type,
type_field.ast_node and type_field.ast_node.type,
],
)
# Assert each interface field arg is implemented.
for arg_name, iface_arg in iface_field.args.items():
type_arg = type_field.args.get(arg_name)
# Assert interface field arg exists on object field.
if not type_arg:
self.report_error(
"Interface field argument"
f" {iface.name}.{field_name}({arg_name}:)"
f" expected but {type_.name}.{field_name}"
" does not provide it.",
[iface_arg.ast_node, type_field.ast_node],
)
continue
# Assert interface field arg type matches object field arg type
# (invariant).
if not is_equal_type(iface_arg.type, type_arg.type):
self.report_error(
"Interface field argument"
f" {iface.name}.{field_name}({arg_name}:)"
f" expects type {iface_arg.type}"
f" but {type_.name}.{field_name}({arg_name}:)"
f" is type {type_arg.type}.",
[
iface_arg.ast_node and iface_arg.ast_node.type,
type_arg.ast_node and type_arg.ast_node.type,
],
)
# Assert additional arguments must not be required.
for arg_name, type_arg in type_field.args.items():
iface_arg = iface_field.args.get(arg_name)
if not iface_arg and is_required_argument(type_arg):
self.report_error(
f"Object field {type_.name}.{field_name} includes"
f" required argument {arg_name} that is missing from"
f" the Interface field {iface.name}.{field_name}.",
[type_arg.ast_node, iface_field.ast_node],
)
def validate_type_implements_ancestors(
self,
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
iface: GraphQLInterfaceType,
) -> None:
type_interfaces, iface_interfaces = type_.interfaces, iface.interfaces
for transitive in iface_interfaces:
if transitive not in type_interfaces:
self.report_error(
(
f"Type {type_.name} cannot implement {iface.name}"
" because it would create a circular reference."
if transitive is type_
else f"Type {type_.name} must implement {transitive.name}"
f" because it is implemented by {iface.name}."
),
get_all_implements_interface_nodes(iface, transitive)
+ get_all_implements_interface_nodes(type_, iface),
)
def validate_union_members(self, union: GraphQLUnionType) -> None:
member_types = union.types
if not member_types:
self.report_error(
f"Union type {union.name} must define one or more member types.",
[union.ast_node, *union.extension_ast_nodes],
)
included_type_names: Set[str] = set()
for member_type in member_types:
if is_object_type(member_type):
if member_type.name in included_type_names:
self.report_error(
f"Union type {union.name} can only include type"
f" {member_type.name} once.",
get_union_member_type_nodes(union, member_type.name),
)
else:
included_type_names.add(member_type.name)
else:
self.report_error(
f"Union type {union.name} can only include Object types,"
f" it cannot include {inspect(member_type)}.",
get_union_member_type_nodes(union, str(member_type)),
)
def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
enum_values = enum_type.values
if not enum_values:
self.report_error(
f"Enum type {enum_type.name} must define one or more values.",
[enum_type.ast_node, *enum_type.extension_ast_nodes],
)
for value_name, enum_value in enum_values.items():
# Ensure valid name.
self.validate_name(enum_value, value_name)
def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
fields = input_obj.fields
if not fields:
self.report_error(
f"Input Object type {input_obj.name}"
" must define one or more fields.",
[input_obj.ast_node, *input_obj.extension_ast_nodes],
)
# Ensure the arguments are valid
for field_name, field in fields.items():
# Ensure they are named correctly.
self.validate_name(field, field_name)
# Ensure the type is an input type.
if not is_input_type(field.type):
self.report_error(
f"The type of {input_obj.name}.{field_name}"
f" must be Input Type but got: {inspect(field.type)}.",
field.ast_node.type if field.ast_node else None,
)
if is_required_input_field(field) and field.deprecation_reason is not None:
self.report_error(
f"Required input field {input_obj.name}.{field_name}"
" cannot be deprecated.",
[
get_deprecated_directive_node(field.ast_node),
field.ast_node and field.ast_node.type,
],
)
def get_operation_type_node(
schema: GraphQLSchema, operation: OperationType
) -> Optional[Node]:
ast_node: Optional[Union[SchemaDefinitionNode, SchemaExtensionNode]]
for ast_node in [schema.ast_node, *(schema.extension_ast_nodes or ())]:
if ast_node:
operation_types = ast_node.operation_types
if operation_types: # pragma: no cover else
for operation_type in operation_types:
if operation_type.operation == operation:
return operation_type.type
return None
class InputObjectCircularRefsValidator:
"""Modified copy of algorithm from validation.rules.NoFragmentCycles"""
def __init__(self, context: SchemaValidationContext):
self.context = context
# Tracks already visited types to maintain O(N) and to ensure that cycles
# are not redundantly reported.
self.visited_types: Set[str] = set()
# Array of input fields used to produce meaningful errors
self.field_path: List[Tuple[str, GraphQLInputField]] = []
# Position in the type path
self.field_path_index_by_type_name: Dict[str, int] = {}
def __call__(self, input_obj: GraphQLInputObjectType) -> None:
"""Detect cycles recursively."""
# This does a straight-forward DFS to find cycles.
# It does not terminate when a cycle was found but continues to explore
# the graph to find all possible cycles.
name = input_obj.name
if name in self.visited_types:
return
self.visited_types.add(name)
self.field_path_index_by_type_name[name] = len(self.field_path)
for field_name, field in input_obj.fields.items():
if is_non_null_type(field.type) and is_input_object_type(
field.type.of_type
):
field_type = cast(GraphQLInputObjectType, field.type.of_type)
cycle_index = self.field_path_index_by_type_name.get(field_type.name)
self.field_path.append((field_name, field))
if cycle_index is None:
self(field_type)
else:
cycle_path = self.field_path[cycle_index:]
field_names = map(itemgetter(0), cycle_path)
self.context.report_error(
f"Cannot reference Input Object '{field_type.name}'"
" within itself through a series of non-null fields:"
f" '{'.'.join(field_names)}'.",
cast(
Collection[Node],
map(attrgetter("ast_node"), map(itemgetter(1), cycle_path)),
),
)
self.field_path.pop()
del self.field_path_index_by_type_name[name]
def get_all_implements_interface_nodes(
type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
) -> List[NamedTypeNode]:
ast_node = type_.ast_node
nodes = type_.extension_ast_nodes
if ast_node is not None:
nodes = [ast_node, *nodes] # type: ignore
implements_nodes: List[NamedTypeNode] = []
for node in nodes:
iface_nodes = node.interfaces
if iface_nodes: # pragma: no cover else
implements_nodes.extend(
iface_node
for iface_node in iface_nodes
if iface_node.name.value == iface.name
)
return implements_nodes
def get_union_member_type_nodes(
union: GraphQLUnionType, type_name: str
) -> List[NamedTypeNode]:
ast_node = union.ast_node
nodes = union.extension_ast_nodes
if ast_node is not None:
nodes = [ast_node, *nodes] # type: ignore
member_type_nodes: List[NamedTypeNode] = []
for node in nodes:
type_nodes = node.types
if type_nodes: # pragma: no cover else
member_type_nodes.extend(
type_node
for type_node in type_nodes
if type_node.name.value == type_name
)
return member_type_nodes
def get_deprecated_directive_node(
definition_node: Optional[Union[InputValueDefinitionNode]],
) -> Optional[DirectiveNode]:
directives = definition_node and definition_node.directives
if directives:
for directive in directives:
if (
directive.name.value == GraphQLDeprecatedDirective.name
): # pragma: no cover else
return directive
return None # pragma: no cover
@@ -0,0 +1,124 @@
"""GraphQL Utilities
The :mod:`graphql.utilities` package contains common useful computations to use with
the GraphQL language and type objects.
"""
# Produce the GraphQL query recommended for a full schema introspection.
from .get_introspection_query import get_introspection_query, IntrospectionQuery
# Get the target Operation from a Document.
from .get_operation_ast import get_operation_ast
# Get the Type for the target Operation AST.
from .get_operation_root_type import get_operation_root_type
# Convert a GraphQLSchema to an IntrospectionQuery.
from .introspection_from_schema import introspection_from_schema
# Build a GraphQLSchema from an introspection result.
from .build_client_schema import build_client_schema
# Build a GraphQLSchema from GraphQL Schema language.
from .build_ast_schema import build_ast_schema, build_schema
# Extend an existing GraphQLSchema from a parsed GraphQL Schema language AST.
from .extend_schema import extend_schema
# Sort a GraphQLSchema.
from .lexicographic_sort_schema import lexicographic_sort_schema
# Print a GraphQLSchema to GraphQL Schema language.
from .print_schema import (
print_introspection_schema,
print_schema,
print_type,
print_value, # deprecated
)
# Create a GraphQLType from a GraphQL language AST.
from .type_from_ast import type_from_ast
# Convert a language AST to a dictionary.
from .ast_to_dict import ast_to_dict
# Create a Python value from a GraphQL language AST with a type.
from .value_from_ast import value_from_ast
# Create a Python value from a GraphQL language AST without a type.
from .value_from_ast_untyped import value_from_ast_untyped
# Create a GraphQL language AST from a Python value.
from .ast_from_value import ast_from_value
# A helper to use within recursive-descent visitors which need to be aware of
# the GraphQL type system
from .type_info import TypeInfo, TypeInfoVisitor
# Coerce a Python value to a GraphQL type, or produce errors.
from .coerce_input_value import coerce_input_value
# Concatenate multiple ASTs together.
from .concat_ast import concat_ast
# Separate an AST into an AST per Operation.
from .separate_operations import separate_operations
# Strip characters that are not significant to the validity or execution
# of a GraphQL document.
from .strip_ignored_characters import strip_ignored_characters
# Comparators for types
from .type_comparators import is_equal_type, is_type_sub_type_of, do_types_overlap
# Assert that a string is a valid GraphQL name.
from .assert_valid_name import assert_valid_name, is_valid_name_error
# Compare two GraphQLSchemas and detect breaking changes.
from .find_breaking_changes import (
BreakingChange,
BreakingChangeType,
DangerousChange,
DangerousChangeType,
find_breaking_changes,
find_dangerous_changes,
)
__all__ = [
"BreakingChange",
"BreakingChangeType",
"DangerousChange",
"DangerousChangeType",
"IntrospectionQuery",
"TypeInfo",
"TypeInfoVisitor",
"assert_valid_name",
"ast_from_value",
"ast_to_dict",
"build_ast_schema",
"build_client_schema",
"build_schema",
"coerce_input_value",
"concat_ast",
"do_types_overlap",
"extend_schema",
"find_breaking_changes",
"find_dangerous_changes",
"get_introspection_query",
"get_operation_ast",
"get_operation_root_type",
"is_equal_type",
"is_type_sub_type_of",
"is_valid_name_error",
"introspection_from_schema",
"lexicographic_sort_schema",
"print_introspection_schema",
"print_schema",
"print_type",
"print_value",
"separate_operations",
"strip_ignored_characters",
"type_from_ast",
"value_from_ast",
"value_from_ast_untyped",
]
@@ -0,0 +1,38 @@
from typing import Optional
from ..type.assert_name import assert_name
from ..error import GraphQLError
__all__ = ["assert_valid_name", "is_valid_name_error"]
def assert_valid_name(name: str) -> str:
"""Uphold the spec rules about naming.
.. deprecated:: 3.2
Please use ``assert_name`` instead. Will be removed in v3.3.
"""
error = is_valid_name_error(name)
if error:
raise error
return name
def is_valid_name_error(name: str) -> Optional[GraphQLError]:
"""Return an Error if a name is invalid.
.. deprecated:: 3.2
Please use ``assert_name`` instead. Will be removed in v3.3.
"""
if not isinstance(name, str):
raise TypeError("Expected name to be a string.")
if name.startswith("__"):
return GraphQLError(
f"Name {name!r} must not begin with '__',"
" which is reserved by GraphQL introspection."
)
try:
assert_name(name)
except GraphQLError as error:
return error
return None
@@ -0,0 +1,139 @@
import re
from math import isfinite
from typing import Any, Mapping, Optional, cast
from ..language import (
BooleanValueNode,
EnumValueNode,
FloatValueNode,
IntValueNode,
ListValueNode,
NameNode,
NullValueNode,
ObjectFieldNode,
ObjectValueNode,
StringValueNode,
ValueNode,
)
from ..pyutils import inspect, is_iterable, Undefined
from ..type import (
GraphQLID,
GraphQLInputType,
GraphQLInputObjectType,
GraphQLList,
GraphQLNonNull,
is_enum_type,
is_input_object_type,
is_leaf_type,
is_list_type,
is_non_null_type,
)
__all__ = ["ast_from_value"]
_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$")
def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]:
"""Produce a GraphQL Value AST given a Python object.
This function will match Python/JSON values to GraphQL AST schema format by using
the suggested GraphQLInputType. For example::
ast_from_value('value', GraphQLString)
A GraphQL type must be provided, which will be used to interpret different Python
values.
================ =======================
JSON Value GraphQL Value
================ =======================
Object Input Object
Array List
Boolean Boolean
String String / Enum Value
Number Int / Float
Mixed Enum Value
null NullValue
================ =======================
"""
if is_non_null_type(type_):
type_ = cast(GraphQLNonNull, type_)
ast_value = ast_from_value(value, type_.of_type)
if isinstance(ast_value, NullValueNode):
return None
return ast_value
# only explicit None, not Undefined or NaN
if value is None:
return NullValueNode()
# undefined
if value is Undefined:
return None
# Convert Python list to GraphQL list. If the GraphQLType is a list, but the value
# is not a list, convert the value using the list's item type.
if is_list_type(type_):
type_ = cast(GraphQLList, type_)
item_type = type_.of_type
if is_iterable(value):
maybe_value_nodes = (ast_from_value(item, item_type) for item in value)
value_nodes = tuple(node for node in maybe_value_nodes if node)
return ListValueNode(values=value_nodes)
return ast_from_value(value, item_type)
# Populate the fields of the input object by creating ASTs from each value in the
# Python dict according to the fields in the input type.
if is_input_object_type(type_):
if value is None or not isinstance(value, Mapping):
return None
type_ = cast(GraphQLInputObjectType, type_)
field_items = (
(field_name, ast_from_value(value[field_name], field.type))
for field_name, field in type_.fields.items()
if field_name in value
)
field_nodes = tuple(
ObjectFieldNode(name=NameNode(value=field_name), value=field_value)
for field_name, field_value in field_items
if field_value
)
return ObjectValueNode(fields=field_nodes)
if is_leaf_type(type_):
# Since value is an internally represented value, it must be serialized to an
# externally represented value before converting into an AST.
serialized = type_.serialize(value) # type: ignore
if serialized is None or serialized is Undefined:
return None
# Others serialize based on their corresponding Python scalar types.
if isinstance(serialized, bool):
return BooleanValueNode(value=serialized)
# Python ints and floats correspond nicely to Int and Float values.
if isinstance(serialized, int):
return IntValueNode(value=str(serialized))
if isinstance(serialized, float) and isfinite(serialized):
value = str(serialized)
if value.endswith(".0"):
value = value[:-2]
return FloatValueNode(value=value)
if isinstance(serialized, str):
# Enum types use Enum literals.
if is_enum_type(type_):
return EnumValueNode(value=serialized)
# ID types can use Int literals.
if type_ is GraphQLID and _re_integer_string.match(serialized):
return IntValueNode(value=serialized)
return StringValueNode(value=serialized)
raise TypeError(f"Cannot convert value to AST: {inspect(serialized)}.")
# Not reachable. All possible input types have been considered.
raise TypeError(f"Unexpected input type: {inspect(type_)}.")
@@ -0,0 +1,59 @@
from typing import Any, Collection, Dict, List, Optional, overload
from ..language import Node, OperationType
from ..pyutils import is_iterable
__all__ = ["ast_to_dict"]
@overload
def ast_to_dict(
node: Node, locations: bool = False, cache: Optional[Dict[Node, Any]] = None
) -> Dict: ...
@overload
def ast_to_dict(
node: Collection[Node],
locations: bool = False,
cache: Optional[Dict[Node, Any]] = None,
) -> List[Node]: ...
@overload
def ast_to_dict(
node: OperationType,
locations: bool = False,
cache: Optional[Dict[Node, Any]] = None,
) -> str: ...
def ast_to_dict(
node: Any, locations: bool = False, cache: Optional[Dict[Node, Any]] = None
) -> Any:
"""Convert a language AST to a nested Python dictionary.
Set `locations` to True in order to get the locations as well.
"""
if isinstance(node, Node):
if cache is None:
cache = {}
elif node in cache:
return cache[node]
cache[node] = res = {}
res.update(
{
key: ast_to_dict(getattr(node, key), locations, cache)
for key in ("kind",) + node.keys[1:]
}
)
if locations:
loc = node.loc
if loc:
res["loc"] = dict(start=loc.start, end=loc.end)
return res
if is_iterable(node):
return [ast_to_dict(sub_node, locations, cache) for sub_node in node]
if isinstance(node, OperationType):
return node.value
return node
@@ -0,0 +1,103 @@
from typing import cast, Union
from ..language import DocumentNode, Source, parse
from ..type import (
GraphQLObjectType,
GraphQLSchema,
GraphQLSchemaKwargs,
specified_directives,
)
from .extend_schema import extend_schema_impl
__all__ = [
"build_ast_schema",
"build_schema",
]
def build_ast_schema(
document_ast: DocumentNode,
assume_valid: bool = False,
assume_valid_sdl: bool = False,
) -> GraphQLSchema:
"""Build a GraphQL Schema from a given AST.
This takes the ast of a schema document produced by the parse function in
src/language/parser.py.
If no schema definition is provided, then it will look for types named Query,
Mutation and Subscription.
Given that AST it constructs a GraphQLSchema. The resulting schema has no
resolve methods, so execution will use default resolvers.
When building a schema from a GraphQL service's introspection result, it might
be safe to assume the schema is valid. Set ``assume_valid`` to ``True`` to assume
the produced schema is valid. Set ``assume_valid_sdl`` to ``True`` to assume it is
already a valid SDL document.
"""
if not isinstance(document_ast, DocumentNode):
raise TypeError("Must provide valid Document AST.")
if not (assume_valid or assume_valid_sdl):
from ..validation.validate import assert_valid_sdl
assert_valid_sdl(document_ast)
empty_schema_kwargs = GraphQLSchemaKwargs(
query=None,
mutation=None,
subscription=None,
description=None,
types=(),
directives=(),
extensions={},
ast_node=None,
extension_ast_nodes=(),
assume_valid=False,
)
schema_kwargs = extend_schema_impl(empty_schema_kwargs, document_ast, assume_valid)
if not schema_kwargs["ast_node"]:
for type_ in schema_kwargs["types"] or ():
# Note: While this could make early assertions to get the correctly
# typed values below, that would throw immediately while type system
# validation with validate_schema() will produce more actionable results.
type_name = type_.name
if type_name == "Query":
schema_kwargs["query"] = cast(GraphQLObjectType, type_)
elif type_name == "Mutation":
schema_kwargs["mutation"] = cast(GraphQLObjectType, type_)
elif type_name == "Subscription":
schema_kwargs["subscription"] = cast(GraphQLObjectType, type_)
# If specified directives were not explicitly declared, add them.
directives = schema_kwargs["directives"]
directive_names = set(directive.name for directive in directives)
missing_directives = []
for directive in specified_directives:
if directive.name not in directive_names:
missing_directives.append(directive)
if missing_directives:
schema_kwargs["directives"] = directives + tuple(missing_directives)
return GraphQLSchema(**schema_kwargs)
def build_schema(
source: Union[str, Source],
assume_valid: bool = False,
assume_valid_sdl: bool = False,
no_location: bool = False,
allow_legacy_fragment_variables: bool = False,
) -> GraphQLSchema:
"""Build a GraphQLSchema directly from a source document."""
return build_ast_schema(
parse(
source,
no_location=no_location,
allow_legacy_fragment_variables=allow_legacy_fragment_variables,
),
assume_valid=assume_valid,
assume_valid_sdl=assume_valid_sdl,
)
@@ -0,0 +1,418 @@
from itertools import chain
from typing import cast, Callable, Collection, Dict, List, Union
from ..language import DirectiveLocation, parse_value
from ..pyutils import inspect, Undefined
from ..type import (
GraphQLArgument,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLField,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLObjectType,
GraphQLOutputType,
GraphQLScalarType,
GraphQLSchema,
GraphQLType,
GraphQLUnionType,
TypeKind,
assert_interface_type,
assert_nullable_type,
assert_object_type,
introspection_types,
is_input_type,
is_output_type,
specified_scalar_types,
)
from .get_introspection_query import (
IntrospectionDirective,
IntrospectionEnumType,
IntrospectionField,
IntrospectionInterfaceType,
IntrospectionInputObjectType,
IntrospectionInputValue,
IntrospectionObjectType,
IntrospectionQuery,
IntrospectionScalarType,
IntrospectionType,
IntrospectionTypeRef,
IntrospectionUnionType,
)
from .value_from_ast import value_from_ast
__all__ = ["build_client_schema"]
def build_client_schema(
introspection: IntrospectionQuery, assume_valid: bool = False
) -> GraphQLSchema:
"""Build a GraphQLSchema for use by client tools.
Given the result of a client running the introspection query, creates and returns
a GraphQLSchema instance which can be then used with all GraphQL-core 3 tools,
but cannot be used to execute a query, as introspection does not represent the
"resolver", "parse" or "serialize" functions or any other server-internal
mechanisms.
This function expects a complete introspection result. Don't forget to check the
"errors" field of a server response before calling this function.
"""
if not isinstance(introspection, dict) or not isinstance(
introspection.get("__schema"), dict
):
raise TypeError(
"Invalid or incomplete introspection result. Ensure that you"
" are passing the 'data' attribute of an introspection response"
f" and no 'errors' were returned alongside: {inspect(introspection)}."
)
# Get the schema from the introspection result.
schema_introspection = introspection["__schema"]
# Given a type reference in introspection, return the GraphQLType instance,
# preferring cached instances before building new instances.
def get_type(type_ref: IntrospectionTypeRef) -> GraphQLType:
kind = type_ref.get("kind")
if kind == TypeKind.LIST.name:
item_ref = type_ref.get("ofType")
if not item_ref:
raise TypeError("Decorated type deeper than introspection query.")
item_ref = cast(IntrospectionTypeRef, item_ref)
return GraphQLList(get_type(item_ref))
if kind == TypeKind.NON_NULL.name:
nullable_ref = type_ref.get("ofType")
if not nullable_ref:
raise TypeError("Decorated type deeper than introspection query.")
nullable_ref = cast(IntrospectionTypeRef, nullable_ref)
nullable_type = get_type(nullable_ref)
return GraphQLNonNull(assert_nullable_type(nullable_type))
type_ref = cast(IntrospectionType, type_ref)
return get_named_type(type_ref)
def get_named_type(type_ref: IntrospectionType) -> GraphQLNamedType:
type_name = type_ref.get("name")
if not type_name:
raise TypeError(f"Unknown type reference: {inspect(type_ref)}.")
type_ = type_map.get(type_name)
if not type_:
raise TypeError(
f"Invalid or incomplete schema, unknown type: {type_name}."
" Ensure that a full introspection query is used in order"
" to build a client schema."
)
return type_
def get_object_type(type_ref: IntrospectionObjectType) -> GraphQLObjectType:
return assert_object_type(get_type(type_ref))
def get_interface_type(
type_ref: IntrospectionInterfaceType,
) -> GraphQLInterfaceType:
return assert_interface_type(get_type(type_ref))
# Given a type's introspection result, construct the correct GraphQLType instance.
def build_type(type_: IntrospectionType) -> GraphQLNamedType:
if type_ and "name" in type_ and "kind" in type_:
builder = type_builders.get(type_["kind"])
if builder: # pragma: no cover else
return builder(type_)
raise TypeError(
"Invalid or incomplete introspection result."
" Ensure that a full introspection query is used in order"
f" to build a client schema: {inspect(type_)}."
)
def build_scalar_def(
scalar_introspection: IntrospectionScalarType,
) -> GraphQLScalarType:
return GraphQLScalarType(
name=scalar_introspection["name"],
description=scalar_introspection.get("description"),
specified_by_url=scalar_introspection.get("specifiedByURL"),
)
def build_implementations_list(
implementing_introspection: Union[
IntrospectionObjectType, IntrospectionInterfaceType
],
) -> List[GraphQLInterfaceType]:
maybe_interfaces = implementing_introspection.get("interfaces")
if maybe_interfaces is None:
# Temporary workaround until GraphQL ecosystem will fully support
# 'interfaces' on interface types
if implementing_introspection["kind"] == TypeKind.INTERFACE.name:
return []
raise TypeError(
"Introspection result missing interfaces:"
f" {inspect(implementing_introspection)}."
)
interfaces = cast(Collection[IntrospectionInterfaceType], maybe_interfaces)
return [get_interface_type(interface) for interface in interfaces]
def build_object_def(
object_introspection: IntrospectionObjectType,
) -> GraphQLObjectType:
return GraphQLObjectType(
name=object_introspection["name"],
description=object_introspection.get("description"),
interfaces=lambda: build_implementations_list(object_introspection),
fields=lambda: build_field_def_map(object_introspection),
)
def build_interface_def(
interface_introspection: IntrospectionInterfaceType,
) -> GraphQLInterfaceType:
return GraphQLInterfaceType(
name=interface_introspection["name"],
description=interface_introspection.get("description"),
interfaces=lambda: build_implementations_list(interface_introspection),
fields=lambda: build_field_def_map(interface_introspection),
)
def build_union_def(
union_introspection: IntrospectionUnionType,
) -> GraphQLUnionType:
maybe_possible_types = union_introspection.get("possibleTypes")
if maybe_possible_types is None:
raise TypeError(
"Introspection result missing possibleTypes:"
f" {inspect(union_introspection)}."
)
possible_types = cast(Collection[IntrospectionObjectType], maybe_possible_types)
return GraphQLUnionType(
name=union_introspection["name"],
description=union_introspection.get("description"),
types=lambda: [get_object_type(type_) for type_ in possible_types],
)
def build_enum_def(enum_introspection: IntrospectionEnumType) -> GraphQLEnumType:
if enum_introspection.get("enumValues") is None:
raise TypeError(
"Introspection result missing enumValues:"
f" {inspect(enum_introspection)}."
)
return GraphQLEnumType(
name=enum_introspection["name"],
description=enum_introspection.get("description"),
values={
value_introspect["name"]: GraphQLEnumValue(
value=value_introspect["name"],
description=value_introspect.get("description"),
deprecation_reason=value_introspect.get("deprecationReason"),
)
for value_introspect in enum_introspection["enumValues"]
},
)
def build_input_object_def(
input_object_introspection: IntrospectionInputObjectType,
) -> GraphQLInputObjectType:
if input_object_introspection.get("inputFields") is None:
raise TypeError(
"Introspection result missing inputFields:"
f" {inspect(input_object_introspection)}."
)
return GraphQLInputObjectType(
name=input_object_introspection["name"],
description=input_object_introspection.get("description"),
fields=lambda: build_input_value_def_map(
input_object_introspection["inputFields"]
),
)
type_builders: Dict[str, Callable[[IntrospectionType], GraphQLNamedType]] = {
TypeKind.SCALAR.name: build_scalar_def, # type: ignore
TypeKind.OBJECT.name: build_object_def, # type: ignore
TypeKind.INTERFACE.name: build_interface_def, # type: ignore
TypeKind.UNION.name: build_union_def, # type: ignore
TypeKind.ENUM.name: build_enum_def, # type: ignore
TypeKind.INPUT_OBJECT.name: build_input_object_def, # type: ignore
}
def build_field_def_map(
type_introspection: Union[IntrospectionObjectType, IntrospectionInterfaceType],
) -> Dict[str, GraphQLField]:
if type_introspection.get("fields") is None:
raise TypeError(
f"Introspection result missing fields: {type_introspection}."
)
return {
field_introspection["name"]: build_field(field_introspection)
for field_introspection in type_introspection["fields"]
}
def build_field(field_introspection: IntrospectionField) -> GraphQLField:
type_introspection = cast(IntrospectionType, field_introspection["type"])
type_ = get_type(type_introspection)
if not is_output_type(type_):
raise TypeError(
"Introspection must provide output type for fields,"
f" but received: {inspect(type_)}."
)
type_ = cast(GraphQLOutputType, type_)
args_introspection = field_introspection.get("args")
if args_introspection is None:
raise TypeError(
"Introspection result missing field args:"
f" {inspect(field_introspection)}."
)
return GraphQLField(
type_,
args=build_argument_def_map(args_introspection),
description=field_introspection.get("description"),
deprecation_reason=field_introspection.get("deprecationReason"),
)
def build_argument_def_map(
argument_value_introspections: Collection[IntrospectionInputValue],
) -> Dict[str, GraphQLArgument]:
return {
argument_introspection["name"]: build_argument(argument_introspection)
for argument_introspection in argument_value_introspections
}
def build_argument(
argument_introspection: IntrospectionInputValue,
) -> GraphQLArgument:
type_introspection = cast(IntrospectionType, argument_introspection["type"])
type_ = get_type(type_introspection)
if not is_input_type(type_):
raise TypeError(
"Introspection must provide input type for arguments,"
f" but received: {inspect(type_)}."
)
type_ = cast(GraphQLInputType, type_)
default_value_introspection = argument_introspection.get("defaultValue")
default_value = (
Undefined
if default_value_introspection is None
else value_from_ast(parse_value(default_value_introspection), type_)
)
return GraphQLArgument(
type_,
default_value=default_value,
description=argument_introspection.get("description"),
deprecation_reason=argument_introspection.get("deprecationReason"),
)
def build_input_value_def_map(
input_value_introspections: Collection[IntrospectionInputValue],
) -> Dict[str, GraphQLInputField]:
return {
input_value_introspection["name"]: build_input_value(
input_value_introspection
)
for input_value_introspection in input_value_introspections
}
def build_input_value(
input_value_introspection: IntrospectionInputValue,
) -> GraphQLInputField:
type_introspection = cast(IntrospectionType, input_value_introspection["type"])
type_ = get_type(type_introspection)
if not is_input_type(type_):
raise TypeError(
"Introspection must provide input type for input fields,"
f" but received: {inspect(type_)}."
)
type_ = cast(GraphQLInputType, type_)
default_value_introspection = input_value_introspection.get("defaultValue")
default_value = (
Undefined
if default_value_introspection is None
else value_from_ast(parse_value(default_value_introspection), type_)
)
return GraphQLInputField(
type_,
default_value=default_value,
description=input_value_introspection.get("description"),
deprecation_reason=input_value_introspection.get("deprecationReason"),
)
def build_directive(
directive_introspection: IntrospectionDirective,
) -> GraphQLDirective:
if directive_introspection.get("args") is None:
raise TypeError(
"Introspection result missing directive args:"
f" {inspect(directive_introspection)}."
)
if directive_introspection.get("locations") is None:
raise TypeError(
"Introspection result missing directive locations:"
f" {inspect(directive_introspection)}."
)
return GraphQLDirective(
name=directive_introspection["name"],
description=directive_introspection.get("description"),
is_repeatable=directive_introspection.get("isRepeatable", False),
locations=list(
cast(
Collection[DirectiveLocation],
directive_introspection.get("locations"),
)
),
args=build_argument_def_map(directive_introspection["args"]),
)
# Iterate through all types, getting the type definition for each.
type_map: Dict[str, GraphQLNamedType] = {
type_introspection["name"]: build_type(type_introspection)
for type_introspection in schema_introspection["types"]
}
# Include standard types only if they are used.
for std_type_name, std_type in chain(
specified_scalar_types.items(), introspection_types.items()
):
if std_type_name in type_map:
type_map[std_type_name] = std_type
# Get the root Query, Mutation, and Subscription types.
query_type_ref = schema_introspection.get("queryType")
query_type = None if query_type_ref is None else get_object_type(query_type_ref)
mutation_type_ref = schema_introspection.get("mutationType")
mutation_type = (
None if mutation_type_ref is None else get_object_type(mutation_type_ref)
)
subscription_type_ref = schema_introspection.get("subscriptionType")
subscription_type = (
None
if subscription_type_ref is None
else get_object_type(subscription_type_ref)
)
# Get the directives supported by Introspection, assuming empty-set if directives
# were not queried for.
directive_introspections = schema_introspection.get("directives")
directives = (
[
build_directive(directive_introspection)
for directive_introspection in directive_introspections
]
if directive_introspections
else []
)
# Then produce and return a Schema with these types.
return GraphQLSchema(
query=query_type,
mutation=mutation_type,
subscription=subscription_type,
types=list(type_map.values()),
directives=directives,
description=schema_introspection.get("description"),
assume_valid=assume_valid,
)
@@ -0,0 +1,159 @@
from typing import Any, Callable, Dict, List, Optional, Union, cast
from ..error import GraphQLError
from ..pyutils import (
Path,
did_you_mean,
inspect,
is_iterable,
print_path_list,
suggestion_list,
Undefined,
)
from ..type import (
GraphQLInputObjectType,
GraphQLInputType,
GraphQLList,
GraphQLScalarType,
is_leaf_type,
is_input_object_type,
is_list_type,
is_non_null_type,
GraphQLNonNull,
)
__all__ = ["coerce_input_value"]
OnErrorCB = Callable[[List[Union[str, int]], Any, GraphQLError], None]
def default_on_error(
path: List[Union[str, int]], invalid_value: Any, error: GraphQLError
) -> None:
error_prefix = "Invalid value " + inspect(invalid_value)
if path:
error_prefix += f" at 'value{print_path_list(path)}'"
error.message = error_prefix + ": " + error.message
raise error
def coerce_input_value(
input_value: Any,
type_: GraphQLInputType,
on_error: OnErrorCB = default_on_error,
path: Optional[Path] = None,
) -> Any:
"""Coerce a Python value given a GraphQL Input Type."""
if is_non_null_type(type_):
if input_value is not None and input_value is not Undefined:
type_ = cast(GraphQLNonNull, type_)
return coerce_input_value(input_value, type_.of_type, on_error, path)
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(
f"Expected non-nullable type '{inspect(type_)}' not to be None."
),
)
return Undefined
if input_value is None or input_value is Undefined:
# Explicitly return the value null.
return None
if is_list_type(type_):
type_ = cast(GraphQLList, type_)
item_type = type_.of_type
if is_iterable(input_value):
coerced_list: List[Any] = []
append_item = coerced_list.append
for index, item_value in enumerate(input_value):
append_item(
coerce_input_value(
item_value, item_type, on_error, Path(path, index, None)
)
)
return coerced_list
# Lists accept a non-list value as a list of one.
return [coerce_input_value(input_value, item_type, on_error, path)]
if is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
if not isinstance(input_value, dict):
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(f"Expected type '{type_.name}' to be a mapping."),
)
return Undefined
coerced_dict: Dict[str, Any] = {}
fields = type_.fields
for field_name, field in fields.items():
field_value = input_value.get(field_name, Undefined)
if field_value is Undefined:
if field.default_value is not Undefined:
# Use out name as name if it exists (extension of GraphQL.js).
coerced_dict[field.out_name or field_name] = field.default_value
elif is_non_null_type(field.type): # pragma: no cover else
type_str = inspect(field.type)
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(
f"Field '{field_name}' of required type '{type_str}'"
" was not provided."
),
)
continue
coerced_dict[field.out_name or field_name] = coerce_input_value(
field_value, field.type, on_error, Path(path, field_name, type_.name)
)
# Ensure every provided field is defined.
for field_name in input_value:
if field_name not in fields:
suggestions = suggestion_list(field_name, fields)
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(
f"Field '{field_name}' is not defined by type '{type_.name}'."
+ did_you_mean(suggestions)
),
)
return type_.out_type(coerced_dict)
if is_leaf_type(type_):
# Scalars determine if a value is valid via `parse_value()`, which can throw to
# indicate failure. If it throws, maintain a reference to the original error.
type_ = cast(GraphQLScalarType, type_)
try:
parse_result = type_.parse_value(input_value)
except GraphQLError as error:
on_error(path.as_list() if path else [], input_value, error)
return Undefined
except Exception as error:
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(
f"Expected type '{type_.name}'. {error}", original_error=error
),
)
return Undefined
if parse_result is Undefined:
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(f"Expected type '{type_.name}'."),
)
return parse_result
# Not reachable. All possible input types have been considered.
raise TypeError(f"Unexpected input type: {inspect(type_)}.")
@@ -0,0 +1,18 @@
from itertools import chain
from typing import Collection
from ..language.ast import DocumentNode
__all__ = ["concat_ast"]
def concat_ast(asts: Collection[DocumentNode]) -> DocumentNode:
"""Concat ASTs.
Provided a collection of ASTs, presumably each from different files, concatenate
the ASTs together into batched AST, useful for validating many GraphQL source files
which together represent one conceptual application.
"""
return DocumentNode(
definitions=list(chain.from_iterable(document.definitions for document in asts))
)
@@ -0,0 +1,700 @@
from collections import defaultdict
from typing import (
Any,
Callable,
Collection,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Union,
cast,
)
from ..language import (
DirectiveDefinitionNode,
DirectiveLocation,
DocumentNode,
EnumTypeDefinitionNode,
EnumTypeExtensionNode,
EnumValueDefinitionNode,
FieldDefinitionNode,
InputObjectTypeDefinitionNode,
InputObjectTypeExtensionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ListTypeNode,
NamedTypeNode,
NonNullTypeNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
OperationType,
ScalarTypeDefinitionNode,
ScalarTypeExtensionNode,
SchemaExtensionNode,
SchemaDefinitionNode,
TypeDefinitionNode,
TypeExtensionNode,
TypeNode,
UnionTypeDefinitionNode,
UnionTypeExtensionNode,
)
from ..pyutils import inspect, merge_kwargs
from ..type import (
GraphQLArgument,
GraphQLArgumentMap,
GraphQLDeprecatedDirective,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLEnumValueMap,
GraphQLField,
GraphQLFieldMap,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInputFieldMap,
GraphQLInterfaceType,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLNullableType,
GraphQLObjectType,
GraphQLOutputType,
GraphQLScalarType,
GraphQLSchema,
GraphQLSchemaKwargs,
GraphQLSpecifiedByDirective,
GraphQLType,
GraphQLUnionType,
assert_schema,
is_enum_type,
is_input_object_type,
is_interface_type,
is_list_type,
is_non_null_type,
is_object_type,
is_scalar_type,
is_union_type,
is_introspection_type,
is_specified_scalar_type,
introspection_types,
specified_scalar_types,
)
from .value_from_ast import value_from_ast
__all__ = [
"extend_schema",
"extend_schema_impl",
]
def extend_schema(
schema: GraphQLSchema,
document_ast: DocumentNode,
assume_valid: bool = False,
assume_valid_sdl: bool = False,
) -> GraphQLSchema:
"""Extend the schema with extensions from a given document.
Produces a new schema given an existing schema and a document which may contain
GraphQL type extensions and definitions. The original schema will remain unaltered.
Because a schema represents a graph of references, a schema cannot be extended
without effectively making an entire copy. We do not know until it's too late if
subgraphs remain unchanged.
This algorithm copies the provided schema, applying extensions while producing the
copy. The original schema remains unaltered.
When extending a schema with a known valid extension, it might be safe to assume the
schema is valid. Set ``assume_valid`` to ``True`` to assume the produced schema is
valid. Set ``assume_valid_sdl`` to ``True`` to assume it is already a valid SDL
document.
"""
assert_schema(schema)
if not isinstance(document_ast, DocumentNode):
raise TypeError("Must provide valid Document AST.")
if not (assume_valid or assume_valid_sdl):
from ..validation.validate import assert_valid_sdl_extension
assert_valid_sdl_extension(document_ast, schema)
schema_kwargs = schema.to_kwargs()
extended_kwargs = extend_schema_impl(schema_kwargs, document_ast, assume_valid)
return (
schema if schema_kwargs is extended_kwargs else GraphQLSchema(**extended_kwargs)
)
def extend_schema_impl(
schema_kwargs: GraphQLSchemaKwargs,
document_ast: DocumentNode,
assume_valid: bool = False,
) -> GraphQLSchemaKwargs:
"""Extend the given schema arguments with extensions from a given document.
For internal use only.
"""
# Note: schema_kwargs should become a TypedDict once we require Python 3.8
# Collect the type definitions and extensions found in the document.
type_defs: List[TypeDefinitionNode] = []
type_extensions_map: DefaultDict[str, Any] = defaultdict(list)
# New directives and types are separate because a directives and types can have the
# same name. For example, a type named "skip".
directive_defs: List[DirectiveDefinitionNode] = []
schema_def: Optional[SchemaDefinitionNode] = None
# Schema extensions are collected which may add additional operation types.
schema_extensions: List[SchemaExtensionNode] = []
for def_ in document_ast.definitions:
if isinstance(def_, SchemaDefinitionNode):
schema_def = def_
elif isinstance(def_, SchemaExtensionNode):
schema_extensions.append(def_)
elif isinstance(def_, TypeDefinitionNode):
type_defs.append(def_)
elif isinstance(def_, TypeExtensionNode):
extended_type_name = def_.name.value
type_extensions_map[extended_type_name].append(def_)
elif isinstance(def_, DirectiveDefinitionNode):
directive_defs.append(def_)
# If this document contains no new types, extensions, or directives then return the
# same unmodified GraphQLSchema instance.
if (
not type_extensions_map
and not type_defs
and not directive_defs
and not schema_extensions
and not schema_def
):
return schema_kwargs
# Below are functions used for producing this schema that have closed over this
# scope and have access to the schema, cache, and newly defined types.
# noinspection PyTypeChecker,PyUnresolvedReferences
def replace_type(type_: GraphQLType) -> GraphQLType:
if is_list_type(type_):
return GraphQLList(replace_type(type_.of_type)) # type: ignore
if is_non_null_type(type_):
return GraphQLNonNull(replace_type(type_.of_type)) # type: ignore
return replace_named_type(type_) # type: ignore
def replace_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
# Note: While this could make early assertions to get the correctly
# typed values below, that would throw immediately while type system
# validation with validate_schema() will produce more actionable results.
return type_map[type_.name]
# noinspection PyShadowingNames
def replace_directive(directive: GraphQLDirective) -> GraphQLDirective:
kwargs = directive.to_kwargs()
return GraphQLDirective(
**merge_kwargs(
kwargs,
args={name: extend_arg(arg) for name, arg in kwargs["args"].items()},
)
)
def extend_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
if is_introspection_type(type_) or is_specified_scalar_type(type_):
# Builtin types are not extended.
return type_
if is_scalar_type(type_):
type_ = cast(GraphQLScalarType, type_)
return extend_scalar_type(type_)
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
return extend_object_type(type_)
if is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
return extend_interface_type(type_)
if is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
return extend_union_type(type_)
if is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
return extend_enum_type(type_)
if is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
return extend_input_object_type(type_)
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.") # pragma: no cover
# noinspection PyShadowingNames
def extend_input_object_type(
type_: GraphQLInputObjectType,
) -> GraphQLInputObjectType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLInputObjectType(
**merge_kwargs(
kwargs,
fields=lambda: {
**{
name: GraphQLInputField(
**merge_kwargs(
field.to_kwargs(),
type_=replace_type(field.type),
)
)
for name, field in kwargs["fields"].items()
},
**build_input_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_enum_type(type_: GraphQLEnumType) -> GraphQLEnumType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLEnumType(
**merge_kwargs(
kwargs,
values={**kwargs["values"], **build_enum_value_map(extensions)},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
specified_by_url = kwargs["specified_by_url"]
for extension_node in extensions:
specified_by_url = get_specified_by_url(extension_node) or specified_by_url
return GraphQLScalarType(
**merge_kwargs(
kwargs,
specified_by_url=specified_by_url,
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_object_type(type_: GraphQLObjectType) -> GraphQLObjectType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLObjectType(
**merge_kwargs(
kwargs,
interfaces=lambda: [
cast(GraphQLInterfaceType, replace_named_type(interface))
for interface in kwargs["interfaces"]
]
+ build_interfaces(extensions),
fields=lambda: {
**{
name: extend_field(field)
for name, field in kwargs["fields"].items()
},
**build_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_interface_type(type_: GraphQLInterfaceType) -> GraphQLInterfaceType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLInterfaceType(
**merge_kwargs(
kwargs,
interfaces=lambda: [
cast(GraphQLInterfaceType, replace_named_type(interface))
for interface in kwargs["interfaces"]
]
+ build_interfaces(extensions),
fields=lambda: {
**{
name: extend_field(field)
for name, field in kwargs["fields"].items()
},
**build_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_union_type(type_: GraphQLUnionType) -> GraphQLUnionType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLUnionType(
**merge_kwargs(
kwargs,
types=lambda: [
cast(GraphQLObjectType, replace_named_type(member_type))
for member_type in kwargs["types"]
]
+ build_union_types(extensions),
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_field(field: GraphQLField) -> GraphQLField:
return GraphQLField(
**merge_kwargs(
field.to_kwargs(),
type_=replace_type(field.type),
args={name: extend_arg(arg) for name, arg in field.args.items()},
)
)
def extend_arg(arg: GraphQLArgument) -> GraphQLArgument:
return GraphQLArgument(
**merge_kwargs(
arg.to_kwargs(),
type_=replace_type(arg.type),
)
)
# noinspection PyShadowingNames
def get_operation_types(
nodes: Collection[Union[SchemaDefinitionNode, SchemaExtensionNode]]
) -> Dict[OperationType, GraphQLNamedType]:
# Note: While this could make early assertions to get the correctly
# typed values below, that would throw immediately while type system
# validation with validate_schema() will produce more actionable results.
return {
operation_type.operation: get_named_type(operation_type.type)
for node in nodes
for operation_type in node.operation_types or []
}
# noinspection PyShadowingNames
def get_named_type(node: NamedTypeNode) -> GraphQLNamedType:
name = node.name.value
type_ = std_type_map.get(name) or type_map.get(name)
if not type_:
raise TypeError(f"Unknown type: '{name}'.")
return type_
def get_wrapped_type(node: TypeNode) -> GraphQLType:
if isinstance(node, ListTypeNode):
return GraphQLList(get_wrapped_type(node.type))
if isinstance(node, NonNullTypeNode):
return GraphQLNonNull(
cast(GraphQLNullableType, get_wrapped_type(node.type))
)
return get_named_type(cast(NamedTypeNode, node))
def build_directive(node: DirectiveDefinitionNode) -> GraphQLDirective:
locations = [DirectiveLocation[node.value] for node in node.locations]
return GraphQLDirective(
name=node.name.value,
description=node.description.value if node.description else None,
locations=locations,
is_repeatable=node.repeatable,
args=build_argument_map(node.arguments),
ast_node=node,
)
def build_field_map(
nodes: Collection[
Union[
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
]
],
) -> GraphQLFieldMap:
field_map: GraphQLFieldMap = {}
for node in nodes:
for field in node.fields or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
field_map[field.name.value] = GraphQLField(
type_=cast(GraphQLOutputType, get_wrapped_type(field.type)),
description=field.description.value if field.description else None,
args=build_argument_map(field.arguments),
deprecation_reason=get_deprecation_reason(field),
ast_node=field,
)
return field_map
def build_argument_map(
args: Optional[Collection[InputValueDefinitionNode]],
) -> GraphQLArgumentMap:
arg_map: GraphQLArgumentMap = {}
for arg in args or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
type_ = cast(GraphQLInputType, get_wrapped_type(arg.type))
arg_map[arg.name.value] = GraphQLArgument(
type_=type_,
description=arg.description.value if arg.description else None,
default_value=value_from_ast(arg.default_value, type_),
deprecation_reason=get_deprecation_reason(arg),
ast_node=arg,
)
return arg_map
def build_input_field_map(
nodes: Collection[
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
],
) -> GraphQLInputFieldMap:
input_field_map: GraphQLInputFieldMap = {}
for node in nodes:
for field in node.fields or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
type_ = cast(GraphQLInputType, get_wrapped_type(field.type))
input_field_map[field.name.value] = GraphQLInputField(
type_=type_,
description=field.description.value if field.description else None,
default_value=value_from_ast(field.default_value, type_),
deprecation_reason=get_deprecation_reason(field),
ast_node=field,
)
return input_field_map
def build_enum_value_map(
nodes: Collection[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]]
) -> GraphQLEnumValueMap:
enum_value_map: GraphQLEnumValueMap = {}
for node in nodes:
for value in node.values or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
value_name = value.name.value
enum_value_map[value_name] = GraphQLEnumValue(
value=value_name,
description=value.description.value if value.description else None,
deprecation_reason=get_deprecation_reason(value),
ast_node=value,
)
return enum_value_map
def build_interfaces(
nodes: Collection[
Union[
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
]
],
) -> List[GraphQLInterfaceType]:
interfaces: List[GraphQLInterfaceType] = []
for node in nodes:
for type_ in node.interfaces or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
interfaces.append(cast(GraphQLInterfaceType, get_named_type(type_)))
return interfaces
def build_union_types(
nodes: Collection[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]],
) -> List[GraphQLObjectType]:
types: List[GraphQLObjectType] = []
for node in nodes:
for type_ in node.types or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
types.append(cast(GraphQLObjectType, get_named_type(type_)))
return types
def build_object_type(ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLObjectType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
interfaces=lambda: build_interfaces(all_nodes),
fields=lambda: build_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_interface_type(
ast_node: InterfaceTypeDefinitionNode,
) -> GraphQLInterfaceType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[
Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode]
] = [ast_node, *extension_nodes]
return GraphQLInterfaceType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
interfaces=lambda: build_interfaces(all_nodes),
fields=lambda: build_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_enum_type(ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLEnumType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
values=build_enum_value_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_union_type(ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLUnionType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
types=lambda: build_union_types(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType:
extension_nodes = type_extensions_map[ast_node.name.value]
return GraphQLScalarType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
specified_by_url=get_specified_by_url(ast_node),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_input_object_type(
ast_node: InputObjectTypeDefinitionNode,
) -> GraphQLInputObjectType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
] = [ast_node, *extension_nodes]
return GraphQLInputObjectType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
fields=lambda: build_input_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
build_type_for_kind = cast(
Dict[str, Callable[[TypeDefinitionNode], GraphQLNamedType]],
{
"object_type_definition": build_object_type,
"interface_type_definition": build_interface_type,
"enum_type_definition": build_enum_type,
"union_type_definition": build_union_type,
"scalar_type_definition": build_scalar_type,
"input_object_type_definition": build_input_object_type,
},
)
def build_type(ast_node: TypeDefinitionNode) -> GraphQLNamedType:
try:
# object_type_definition_node is built with _build_object_type etc.
build_function = build_type_for_kind[ast_node.kind]
except KeyError: # pragma: no cover
# Not reachable. All possible type definition nodes have been considered.
raise TypeError( # pragma: no cover
f"Unexpected type definition node: {inspect(ast_node)}."
)
else:
return build_function(ast_node)
type_map: Dict[str, GraphQLNamedType] = {}
for existing_type in schema_kwargs["types"] or ():
type_map[existing_type.name] = extend_named_type(existing_type)
for type_node in type_defs:
name = type_node.name.value
type_map[name] = std_type_map.get(name) or build_type(type_node)
# Get the extended root operation types.
operation_types: Dict[OperationType, GraphQLNamedType] = {}
for operation_type in OperationType:
original_type = schema_kwargs[operation_type.value]
if original_type:
operation_types[operation_type] = replace_named_type(original_type)
# Then, incorporate schema definition and all schema extensions.
if schema_def:
operation_types.update(get_operation_types([schema_def]))
if schema_extensions:
operation_types.update(get_operation_types(schema_extensions))
# Then produce and return the kwargs for a Schema with these types.
get_operation = operation_types.get
return GraphQLSchemaKwargs(
query=get_operation(OperationType.QUERY), # type: ignore
mutation=get_operation(OperationType.MUTATION), # type: ignore
subscription=get_operation(OperationType.SUBSCRIPTION), # type: ignore
types=tuple(type_map.values()),
directives=tuple(
replace_directive(directive) for directive in schema_kwargs["directives"]
)
+ tuple(build_directive(directive) for directive in directive_defs),
description=(
schema_def.description.value
if schema_def and schema_def.description
else None
),
extensions={},
ast_node=schema_def or schema_kwargs["ast_node"],
extension_ast_nodes=schema_kwargs["extension_ast_nodes"]
+ tuple(schema_extensions),
assume_valid=assume_valid,
)
std_type_map: Mapping[str, Union[GraphQLNamedType, GraphQLObjectType]] = {
**specified_scalar_types,
**introspection_types,
}
def get_deprecation_reason(
node: Union[EnumValueDefinitionNode, FieldDefinitionNode, InputValueDefinitionNode]
) -> Optional[str]:
"""Given a field or enum value node, get deprecation reason as string."""
from ..execution import get_directive_values
deprecated = get_directive_values(GraphQLDeprecatedDirective, node)
return deprecated["reason"] if deprecated else None
def get_specified_by_url(
node: Union[ScalarTypeDefinitionNode, ScalarTypeExtensionNode]
) -> Optional[str]:
"""Given a scalar node, return the string value for the specifiedByURL."""
from ..execution import get_directive_values
specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
return specified_by_url["url"] if specified_by_url else None
@@ -0,0 +1,620 @@
from enum import Enum
from typing import Any, Collection, Dict, List, NamedTuple, Union, cast
from ..language import print_ast
from ..pyutils import inspect, Undefined
from ..type import (
GraphQLEnumType,
GraphQLField,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLSchema,
GraphQLType,
GraphQLUnionType,
is_enum_type,
is_input_object_type,
is_interface_type,
is_list_type,
is_named_type,
is_non_null_type,
is_object_type,
is_required_argument,
is_required_input_field,
is_scalar_type,
is_specified_scalar_type,
is_union_type,
)
from ..utilities.sort_value_node import sort_value_node
from .ast_from_value import ast_from_value
__all__ = [
"BreakingChange",
"BreakingChangeType",
"DangerousChange",
"DangerousChangeType",
"find_breaking_changes",
"find_dangerous_changes",
]
class BreakingChangeType(Enum):
TYPE_REMOVED = 10
TYPE_CHANGED_KIND = 11
TYPE_REMOVED_FROM_UNION = 20
VALUE_REMOVED_FROM_ENUM = 21
REQUIRED_INPUT_FIELD_ADDED = 22
IMPLEMENTED_INTERFACE_REMOVED = 23
FIELD_REMOVED = 30
FIELD_CHANGED_KIND = 31
REQUIRED_ARG_ADDED = 40
ARG_REMOVED = 41
ARG_CHANGED_KIND = 42
DIRECTIVE_REMOVED = 50
DIRECTIVE_ARG_REMOVED = 51
REQUIRED_DIRECTIVE_ARG_ADDED = 52
DIRECTIVE_REPEATABLE_REMOVED = 53
DIRECTIVE_LOCATION_REMOVED = 54
class DangerousChangeType(Enum):
VALUE_ADDED_TO_ENUM = 60
TYPE_ADDED_TO_UNION = 61
OPTIONAL_INPUT_FIELD_ADDED = 62
OPTIONAL_ARG_ADDED = 63
IMPLEMENTED_INTERFACE_ADDED = 64
ARG_DEFAULT_VALUE_CHANGE = 65
class BreakingChange(NamedTuple):
type: BreakingChangeType
description: str
class DangerousChange(NamedTuple):
type: DangerousChangeType
description: str
Change = Union[BreakingChange, DangerousChange]
def find_breaking_changes(
old_schema: GraphQLSchema, new_schema: GraphQLSchema
) -> List[BreakingChange]:
"""Find breaking changes.
Given two schemas, returns a list containing descriptions of all the types of
breaking changes covered by the other functions down below.
"""
return [
change
for change in find_schema_changes(old_schema, new_schema)
if isinstance(change.type, BreakingChangeType)
]
def find_dangerous_changes(
old_schema: GraphQLSchema, new_schema: GraphQLSchema
) -> List[DangerousChange]:
"""Find dangerous changes.
Given two schemas, returns a list containing descriptions of all the types of
potentially dangerous changes covered by the other functions down below.
"""
return [
change
for change in find_schema_changes(old_schema, new_schema)
if isinstance(change.type, DangerousChangeType)
]
def find_schema_changes(
old_schema: GraphQLSchema, new_schema: GraphQLSchema
) -> List[Change]:
return find_type_changes(old_schema, new_schema) + find_directive_changes(
old_schema, new_schema
)
def find_directive_changes(
old_schema: GraphQLSchema, new_schema: GraphQLSchema
) -> List[Change]:
schema_changes: List[Change] = []
directives_diff = list_diff(old_schema.directives, new_schema.directives)
for directive in directives_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.DIRECTIVE_REMOVED, f"{directive.name} was removed."
)
)
for old_directive, new_directive in directives_diff.persisted:
args_diff = dict_diff(old_directive.args, new_directive.args)
for arg_name, new_arg in args_diff.added.items():
if is_required_argument(new_arg):
schema_changes.append(
BreakingChange(
BreakingChangeType.REQUIRED_DIRECTIVE_ARG_ADDED,
f"A required arg {arg_name} on directive"
f" {old_directive.name} was added.",
)
)
for arg_name in args_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.DIRECTIVE_ARG_REMOVED,
f"{arg_name} was removed from {new_directive.name}.",
)
)
if old_directive.is_repeatable and not new_directive.is_repeatable:
schema_changes.append(
BreakingChange(
BreakingChangeType.DIRECTIVE_REPEATABLE_REMOVED,
f"Repeatable flag was removed from {old_directive.name}.",
)
)
for location in old_directive.locations:
if location not in new_directive.locations:
schema_changes.append(
BreakingChange(
BreakingChangeType.DIRECTIVE_LOCATION_REMOVED,
f"{location.name} was removed from {new_directive.name}.",
)
)
return schema_changes
def find_type_changes(
old_schema: GraphQLSchema, new_schema: GraphQLSchema
) -> List[Change]:
schema_changes: List[Change] = []
types_diff = dict_diff(old_schema.type_map, new_schema.type_map)
for type_name, old_type in types_diff.removed.items():
schema_changes.append(
BreakingChange(
BreakingChangeType.TYPE_REMOVED,
(
f"Standard scalar {type_name} was removed"
" because it is not referenced anymore."
if is_specified_scalar_type(old_type)
else f"{type_name} was removed."
),
)
)
for type_name, (old_type, new_type) in types_diff.persisted.items():
if is_enum_type(old_type) and is_enum_type(new_type):
schema_changes.extend(find_enum_type_changes(old_type, new_type))
elif is_union_type(old_type) and is_union_type(new_type):
schema_changes.extend(find_union_type_changes(old_type, new_type))
elif is_input_object_type(old_type) and is_input_object_type(new_type):
schema_changes.extend(find_input_object_type_changes(old_type, new_type))
elif is_object_type(old_type) and is_object_type(new_type):
schema_changes.extend(find_field_changes(old_type, new_type))
schema_changes.extend(
find_implemented_interfaces_changes(old_type, new_type)
)
elif is_interface_type(old_type) and is_interface_type(new_type):
schema_changes.extend(find_field_changes(old_type, new_type))
schema_changes.extend(
find_implemented_interfaces_changes(old_type, new_type)
)
elif old_type.__class__ is not new_type.__class__:
schema_changes.append(
BreakingChange(
BreakingChangeType.TYPE_CHANGED_KIND,
f"{type_name} changed from {type_kind_name(old_type)}"
f" to {type_kind_name(new_type)}.",
)
)
return schema_changes
def find_input_object_type_changes(
old_type: Union[GraphQLObjectType, GraphQLInterfaceType],
new_type: Union[GraphQLObjectType, GraphQLInterfaceType],
) -> List[Change]:
schema_changes: List[Change] = []
fields_diff = dict_diff(old_type.fields, new_type.fields)
for field_name, new_field in fields_diff.added.items():
if is_required_input_field(new_field):
schema_changes.append(
BreakingChange(
BreakingChangeType.REQUIRED_INPUT_FIELD_ADDED,
f"A required field {field_name} on"
f" input type {old_type.name} was added.",
)
)
else:
schema_changes.append(
DangerousChange(
DangerousChangeType.OPTIONAL_INPUT_FIELD_ADDED,
f"An optional field {field_name} on"
f" input type {old_type.name} was added.",
)
)
for field_name in fields_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.FIELD_REMOVED,
f"{old_type.name}.{field_name} was removed.",
)
)
for field_name, (old_field, new_field) in fields_diff.persisted.items():
is_safe = is_change_safe_for_input_object_field_or_field_arg(
old_field.type, new_field.type
)
if not is_safe:
schema_changes.append(
BreakingChange(
BreakingChangeType.FIELD_CHANGED_KIND,
f"{old_type.name}.{field_name} changed type"
f" from {old_field.type} to {new_field.type}.",
)
)
return schema_changes
def find_union_type_changes(
old_type: GraphQLUnionType, new_type: GraphQLUnionType
) -> List[Change]:
schema_changes: List[Change] = []
possible_types_diff = list_diff(old_type.types, new_type.types)
for possible_type in possible_types_diff.added:
schema_changes.append(
DangerousChange(
DangerousChangeType.TYPE_ADDED_TO_UNION,
f"{possible_type.name} was added" f" to union type {old_type.name}.",
)
)
for possible_type in possible_types_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.TYPE_REMOVED_FROM_UNION,
f"{possible_type.name} was removed from union type {old_type.name}.",
)
)
return schema_changes
def find_enum_type_changes(
old_type: GraphQLEnumType, new_type: GraphQLEnumType
) -> List[Change]:
schema_changes: List[Change] = []
values_diff = dict_diff(old_type.values, new_type.values)
for value_name in values_diff.added:
schema_changes.append(
DangerousChange(
DangerousChangeType.VALUE_ADDED_TO_ENUM,
f"{value_name} was added to enum type {old_type.name}.",
)
)
for value_name in values_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.VALUE_REMOVED_FROM_ENUM,
f"{value_name} was removed from enum type {old_type.name}.",
)
)
return schema_changes
def find_implemented_interfaces_changes(
old_type: Union[GraphQLObjectType, GraphQLInterfaceType],
new_type: Union[GraphQLObjectType, GraphQLInterfaceType],
) -> List[Change]:
schema_changes: List[Change] = []
interfaces_diff = list_diff(old_type.interfaces, new_type.interfaces)
for interface in interfaces_diff.added:
schema_changes.append(
DangerousChange(
DangerousChangeType.IMPLEMENTED_INTERFACE_ADDED,
f"{interface.name} added to interfaces implemented by {old_type.name}.",
)
)
for interface in interfaces_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.IMPLEMENTED_INTERFACE_REMOVED,
f"{old_type.name} no longer implements interface {interface.name}.",
)
)
return schema_changes
def find_field_changes(
old_type: Union[GraphQLObjectType, GraphQLInterfaceType],
new_type: Union[GraphQLObjectType, GraphQLInterfaceType],
) -> List[Change]:
schema_changes: List[Change] = []
fields_diff = dict_diff(old_type.fields, new_type.fields)
for field_name in fields_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.FIELD_REMOVED,
f"{old_type.name}.{field_name} was removed.",
)
)
for field_name, (old_field, new_field) in fields_diff.persisted.items():
schema_changes.extend(
find_arg_changes(old_type, field_name, old_field, new_field)
)
is_safe = is_change_safe_for_object_or_interface_field(
old_field.type, new_field.type
)
if not is_safe:
schema_changes.append(
BreakingChange(
BreakingChangeType.FIELD_CHANGED_KIND,
f"{old_type.name}.{field_name} changed type"
f" from {old_field.type} to {new_field.type}.",
)
)
return schema_changes
def find_arg_changes(
old_type: Union[GraphQLObjectType, GraphQLInterfaceType],
field_name: str,
old_field: GraphQLField,
new_field: GraphQLField,
) -> List[Change]:
schema_changes: List[Change] = []
args_diff = dict_diff(old_field.args, new_field.args)
for arg_name in args_diff.removed:
schema_changes.append(
BreakingChange(
BreakingChangeType.ARG_REMOVED,
f"{old_type.name}.{field_name} arg" f" {arg_name} was removed.",
)
)
for arg_name, (old_arg, new_arg) in args_diff.persisted.items():
is_safe = is_change_safe_for_input_object_field_or_field_arg(
old_arg.type, new_arg.type
)
if not is_safe:
schema_changes.append(
BreakingChange(
BreakingChangeType.ARG_CHANGED_KIND,
f"{old_type.name}.{field_name} arg"
f" {arg_name} has changed type from"
f" {old_arg.type} to {new_arg.type}.",
)
)
elif old_arg.default_value is not Undefined:
if new_arg.default_value is Undefined:
schema_changes.append(
DangerousChange(
DangerousChangeType.ARG_DEFAULT_VALUE_CHANGE,
f"{old_type.name}.{field_name} arg"
f" {arg_name} defaultValue was removed.",
)
)
else:
# Since we are looking only for client's observable changes we should
# compare default values in the same representation as they are
# represented inside introspection.
old_value_str = stringify_value(old_arg.default_value, old_arg.type)
new_value_str = stringify_value(new_arg.default_value, new_arg.type)
if old_value_str != new_value_str:
schema_changes.append(
DangerousChange(
DangerousChangeType.ARG_DEFAULT_VALUE_CHANGE,
f"{old_type.name}.{field_name} arg"
f" {arg_name} has changed defaultValue"
f" from {old_value_str} to {new_value_str}.",
)
)
for arg_name, new_arg in args_diff.added.items():
if is_required_argument(new_arg):
schema_changes.append(
BreakingChange(
BreakingChangeType.REQUIRED_ARG_ADDED,
f"A required arg {arg_name} on"
f" {old_type.name}.{field_name} was added.",
)
)
else:
schema_changes.append(
DangerousChange(
DangerousChangeType.OPTIONAL_ARG_ADDED,
f"An optional arg {arg_name} on"
f" {old_type.name}.{field_name} was added.",
)
)
return schema_changes
def is_change_safe_for_object_or_interface_field(
old_type: GraphQLType, new_type: GraphQLType
) -> bool:
if is_list_type(old_type):
return (
# if they're both lists, make sure underlying types are compatible
is_list_type(new_type)
and is_change_safe_for_object_or_interface_field(
cast(GraphQLList, old_type).of_type, cast(GraphQLList, new_type).of_type
)
) or (
# moving from nullable to non-null of same underlying type is safe
is_non_null_type(new_type)
and is_change_safe_for_object_or_interface_field(
old_type, cast(GraphQLNonNull, new_type).of_type
)
)
if is_non_null_type(old_type):
# if they're both non-null, make sure underlying types are compatible
return is_non_null_type(
new_type
) and is_change_safe_for_object_or_interface_field(
cast(GraphQLNonNull, old_type).of_type,
cast(GraphQLNonNull, new_type).of_type,
)
return (
# if they're both named types, see if their names are equivalent
is_named_type(new_type)
and cast(GraphQLNamedType, old_type).name
== cast(GraphQLNamedType, new_type).name
) or (
# moving from nullable to non-null of same underlying type is safe
is_non_null_type(new_type)
and is_change_safe_for_object_or_interface_field(
old_type, cast(GraphQLNonNull, new_type).of_type
)
)
def is_change_safe_for_input_object_field_or_field_arg(
old_type: GraphQLType, new_type: GraphQLType
) -> bool:
if is_list_type(old_type):
return is_list_type(
# if they're both lists, make sure underlying types are compatible
new_type
) and is_change_safe_for_input_object_field_or_field_arg(
cast(GraphQLList, old_type).of_type, cast(GraphQLList, new_type).of_type
)
if is_non_null_type(old_type):
return (
# if they're both non-null, make sure the underlying types are compatible
is_non_null_type(new_type)
and is_change_safe_for_input_object_field_or_field_arg(
cast(GraphQLNonNull, old_type).of_type,
cast(GraphQLNonNull, new_type).of_type,
)
) or (
# moving from non-null to nullable of same underlying type is safe
not is_non_null_type(new_type)
and is_change_safe_for_input_object_field_or_field_arg(
cast(GraphQLNonNull, old_type).of_type, new_type
)
)
return (
# if they're both named types, see if their names are equivalent
is_named_type(new_type)
and cast(GraphQLNamedType, old_type).name
== cast(GraphQLNamedType, new_type).name
)
def type_kind_name(type_: GraphQLNamedType) -> str:
if is_scalar_type(type_):
return "a Scalar type"
if is_object_type(type_):
return "an Object type"
if is_interface_type(type_):
return "an Interface type"
if is_union_type(type_):
return "a Union type"
if is_enum_type(type_):
return "an Enum type"
if is_input_object_type(type_):
return "an Input type"
# Not reachable. All possible output types have been considered.
raise TypeError(f"Unexpected type {inspect(type)}")
def stringify_value(value: Any, type_: GraphQLInputType) -> str:
ast = ast_from_value(value, type_)
if ast is None: # pragma: no cover
raise TypeError(f"Invalid value: {inspect(value)}")
return print_ast(sort_value_node(ast))
class ListDiff(NamedTuple):
"""Tuple with added, removed and persisted list items."""
added: List
removed: List
persisted: List
def list_diff(old_list: Collection, new_list: Collection) -> ListDiff:
"""Get differences between two lists of named items."""
added = []
persisted = []
removed = []
old_set = {item.name for item in old_list}
new_map = {item.name: item for item in new_list}
for old_item in old_list:
new_item = new_map.get(old_item.name)
if new_item:
persisted.append([old_item, new_item])
else:
removed.append(old_item)
for new_item in new_list:
if new_item.name not in old_set:
added.append(new_item)
return ListDiff(added, removed, persisted)
class DictDiff(NamedTuple):
"""Tuple with added, removed and persisted dict entries."""
added: Dict
removed: Dict
persisted: Dict
def dict_diff(old_dict: Dict, new_dict: Dict) -> DictDiff:
"""Get differences between two dicts."""
added = {}
removed = {}
persisted = {}
for old_name, old_item in old_dict.items():
new_item = new_dict.get(old_name)
if new_item:
persisted[old_name] = [old_item, new_item]
else:
removed[old_name] = old_item
for new_name, new_item in new_dict.items():
if new_name not in old_dict:
added[new_name] = new_item
return DictDiff(added, removed, persisted)
@@ -0,0 +1,298 @@
from textwrap import dedent
from typing import Any, Dict, List, Optional, Union
from ..language import DirectiveLocation
try:
from typing import Literal, TypedDict
except ImportError: # Python < 3.8
from typing_extensions import Literal, TypedDict # type: ignore
__all__ = [
"get_introspection_query",
"IntrospectionDirective",
"IntrospectionEnumType",
"IntrospectionField",
"IntrospectionInputObjectType",
"IntrospectionInputValue",
"IntrospectionInterfaceType",
"IntrospectionListType",
"IntrospectionNonNullType",
"IntrospectionObjectType",
"IntrospectionQuery",
"IntrospectionScalarType",
"IntrospectionSchema",
"IntrospectionType",
"IntrospectionTypeRef",
"IntrospectionUnionType",
]
def get_introspection_query(
descriptions: bool = True,
specified_by_url: bool = False,
directive_is_repeatable: bool = False,
schema_description: bool = False,
input_value_deprecation: bool = False,
) -> str:
"""Get a query for introspection.
Optionally, you can exclude descriptions, include specification URLs,
include repeatability of directives, and specify whether to include
the schema description as well.
"""
maybe_description = "description" if descriptions else ""
maybe_specified_by_url = "specifiedByURL" if specified_by_url else ""
maybe_directive_is_repeatable = "isRepeatable" if directive_is_repeatable else ""
maybe_schema_description = maybe_description if schema_description else ""
def input_deprecation(string: str) -> Optional[str]:
return string if input_value_deprecation else ""
return dedent(
f"""
query IntrospectionQuery {{
__schema {{
{maybe_schema_description}
queryType {{ name }}
mutationType {{ name }}
subscriptionType {{ name }}
types {{
...FullType
}}
directives {{
name
{maybe_description}
{maybe_directive_is_repeatable}
locations
args{input_deprecation("(includeDeprecated: true)")} {{
...InputValue
}}
}}
}}
}}
fragment FullType on __Type {{
kind
name
{maybe_description}
{maybe_specified_by_url}
fields(includeDeprecated: true) {{
name
{maybe_description}
args{input_deprecation("(includeDeprecated: true)")} {{
...InputValue
}}
type {{
...TypeRef
}}
isDeprecated
deprecationReason
}}
inputFields{input_deprecation("(includeDeprecated: true)")} {{
...InputValue
}}
interfaces {{
...TypeRef
}}
enumValues(includeDeprecated: true) {{
name
{maybe_description}
isDeprecated
deprecationReason
}}
possibleTypes {{
...TypeRef
}}
}}
fragment InputValue on __InputValue {{
name
{maybe_description}
type {{ ...TypeRef }}
defaultValue
{input_deprecation("isDeprecated")}
{input_deprecation("deprecationReason")}
}}
fragment TypeRef on __Type {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
ofType {{
kind
name
}}
}}
}}
}}
}}
}}
}}
}}
}}
}}
"""
)
# Unfortunately, the following type definitions are a bit simplistic
# because of current restrictions in the typing system (mypy):
# - no recursion, see https://github.com/python/mypy/issues/731
# - no generic typed dicts, see https://github.com/python/mypy/issues/3863
# simplified IntrospectionNamedType to avoids cycles
SimpleIntrospectionType = Dict[str, Any]
class MaybeWithDescription(TypedDict, total=False):
description: Optional[str]
class WithName(MaybeWithDescription):
name: str
class MaybeWithSpecifiedByUrl(TypedDict, total=False):
specifiedByURL: Optional[str]
class WithDeprecated(TypedDict):
isDeprecated: bool
deprecationReason: Optional[str]
class MaybeWithDeprecated(TypedDict, total=False):
isDeprecated: bool
deprecationReason: Optional[str]
class IntrospectionInputValue(WithName, MaybeWithDeprecated):
type: SimpleIntrospectionType # should be IntrospectionInputType
defaultValue: Optional[str]
class IntrospectionField(WithName, WithDeprecated):
args: List[IntrospectionInputValue]
type: SimpleIntrospectionType # should be IntrospectionOutputType
class IntrospectionEnumValue(WithName, WithDeprecated):
pass
class MaybeWithIsRepeatable(TypedDict, total=False):
isRepeatable: bool
class IntrospectionDirective(WithName, MaybeWithIsRepeatable):
locations: List[DirectiveLocation]
args: List[IntrospectionInputValue]
class IntrospectionScalarType(WithName, MaybeWithSpecifiedByUrl):
kind: Literal["scalar"]
class IntrospectionInterfaceType(WithName):
kind: Literal["interface"]
fields: List[IntrospectionField]
interfaces: List[SimpleIntrospectionType] # should be InterfaceType
possibleTypes: List[SimpleIntrospectionType] # should be NamedType
class IntrospectionObjectType(WithName):
kind: Literal["object"]
fields: List[IntrospectionField]
interfaces: List[SimpleIntrospectionType] # should be InterfaceType
class IntrospectionUnionType(WithName):
kind: Literal["union"]
possibleTypes: List[SimpleIntrospectionType] # should be NamedType
class IntrospectionEnumType(WithName):
kind: Literal["enum"]
enumValues: List[IntrospectionEnumValue]
class IntrospectionInputObjectType(WithName):
kind: Literal["input_object"]
inputFields: List[IntrospectionInputValue]
IntrospectionType = Union[
IntrospectionScalarType,
IntrospectionObjectType,
IntrospectionInterfaceType,
IntrospectionUnionType,
IntrospectionEnumType,
IntrospectionInputObjectType,
]
IntrospectionOutputType = Union[
IntrospectionScalarType,
IntrospectionObjectType,
IntrospectionInterfaceType,
IntrospectionUnionType,
IntrospectionEnumType,
]
IntrospectionInputType = Union[
IntrospectionScalarType, IntrospectionEnumType, IntrospectionInputObjectType
]
class IntrospectionListType(TypedDict):
kind: Literal["list"]
ofType: SimpleIntrospectionType # should be IntrospectionType
class IntrospectionNonNullType(TypedDict):
kind: Literal["non_null"]
ofType: SimpleIntrospectionType # should be IntrospectionType
IntrospectionTypeRef = Union[
IntrospectionType, IntrospectionListType, IntrospectionNonNullType
]
class IntrospectionSchema(MaybeWithDescription):
queryType: IntrospectionObjectType
mutationType: Optional[IntrospectionObjectType]
subscriptionType: Optional[IntrospectionObjectType]
types: List[IntrospectionType]
directives: List[IntrospectionDirective]
class IntrospectionQuery(TypedDict):
"""The root typed dictionary for schema introspections."""
__schema: IntrospectionSchema
@@ -0,0 +1,29 @@
from typing import Optional
from ..language import DocumentNode, OperationDefinitionNode
__all__ = ["get_operation_ast"]
def get_operation_ast(
document_ast: DocumentNode, operation_name: Optional[str] = None
) -> Optional[OperationDefinitionNode]:
"""Get operation AST node.
Returns an operation AST given a document AST and optionally an operation
name. If a name is not provided, an operation is only returned if only one
is provided in the document.
"""
operation = None
for definition in document_ast.definitions:
if isinstance(definition, OperationDefinitionNode):
if operation_name is None:
# If no operation name was provided, only return an Operation if there
# is one defined in the document.
# Upon encountering the second, return None.
if operation:
return None
operation = definition
elif definition.name and definition.name.value == operation_name:
return definition
return operation
@@ -0,0 +1,46 @@
from typing import Union
from ..error import GraphQLError
from ..language import (
OperationType,
OperationDefinitionNode,
OperationTypeDefinitionNode,
)
from ..type import GraphQLObjectType, GraphQLSchema
__all__ = ["get_operation_root_type"]
def get_operation_root_type(
schema: GraphQLSchema,
operation: Union[OperationDefinitionNode, OperationTypeDefinitionNode],
) -> GraphQLObjectType:
"""Extract the root type of the operation from the schema.
.. deprecated:: 3.2
Please use `GraphQLSchema.getRootType` instead. Will be removed in v3.3.
"""
operation_type = operation.operation
if operation_type == OperationType.QUERY:
query_type = schema.query_type
if not query_type:
raise GraphQLError(
"Schema does not define the required query root type.", operation
)
return query_type
if operation_type == OperationType.MUTATION:
mutation_type = schema.mutation_type
if not mutation_type:
raise GraphQLError("Schema is not configured for mutations.", operation)
return mutation_type
if operation_type == OperationType.SUBSCRIPTION:
subscription_type = schema.subscription_type
if not subscription_type:
raise GraphQLError("Schema is not configured for subscriptions.", operation)
return subscription_type
raise GraphQLError(
"Can only have query, mutation and subscription operations.", operation
)
@@ -0,0 +1,46 @@
from typing import cast
from ..error import GraphQLError
from ..language import parse
from ..type import GraphQLSchema
from .get_introspection_query import get_introspection_query, IntrospectionQuery
__all__ = ["introspection_from_schema"]
def introspection_from_schema(
schema: GraphQLSchema,
descriptions: bool = True,
specified_by_url: bool = True,
directive_is_repeatable: bool = True,
schema_description: bool = True,
input_value_deprecation: bool = True,
) -> IntrospectionQuery:
"""Build an IntrospectionQuery from a GraphQLSchema
IntrospectionQuery is useful for utilities that care about type and field
relationships, but do not need to traverse through those relationships.
This is the inverse of build_client_schema. The primary use case is outside of the
server context, for instance when doing schema comparisons.
"""
document = parse(
get_introspection_query(
descriptions,
specified_by_url,
directive_is_repeatable,
schema_description,
input_value_deprecation,
)
)
from ..execution.execute import execute_sync, ExecutionResult
result = execute_sync(schema, document)
if not isinstance(result, ExecutionResult): # pragma: no cover
raise RuntimeError("Introspection cannot be executed")
if result.errors: # pragma: no cover
raise result.errors[0]
if not result.data: # pragma: no cover
raise GraphQLError("Introspection did not return a result")
return cast(IntrospectionQuery, result.data)
@@ -0,0 +1,189 @@
from typing import Collection, Dict, Optional, Tuple, Union, cast
from ..language import DirectiveLocation
from ..pyutils import inspect, merge_kwargs, natural_comparison_key
from ..type import (
GraphQLArgument,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLField,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLUnionType,
is_enum_type,
is_input_object_type,
is_interface_type,
is_introspection_type,
is_list_type,
is_non_null_type,
is_object_type,
is_scalar_type,
is_union_type,
)
__all__ = ["lexicographic_sort_schema"]
def lexicographic_sort_schema(schema: GraphQLSchema) -> GraphQLSchema:
"""Sort GraphQLSchema.
This function returns a sorted copy of the given GraphQLSchema.
"""
def replace_type(
type_: Union[GraphQLList, GraphQLNonNull, GraphQLNamedType]
) -> Union[GraphQLList, GraphQLNonNull, GraphQLNamedType]:
if is_list_type(type_):
return GraphQLList(replace_type(cast(GraphQLList, type_).of_type))
if is_non_null_type(type_):
return GraphQLNonNull(replace_type(cast(GraphQLNonNull, type_).of_type))
return replace_named_type(cast(GraphQLNamedType, type_))
def replace_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
return type_map[type_.name]
def replace_maybe_type(
maybe_type: Optional[GraphQLNamedType],
) -> Optional[GraphQLNamedType]:
return maybe_type and replace_named_type(maybe_type)
def sort_directive(directive: GraphQLDirective) -> GraphQLDirective:
return GraphQLDirective(
**merge_kwargs(
directive.to_kwargs(),
locations=sorted(directive.locations, key=sort_by_name_key),
args=sort_args(directive.args),
)
)
def sort_args(args_map: Dict[str, GraphQLArgument]) -> Dict[str, GraphQLArgument]:
args = {}
for name, arg in sorted(args_map.items()):
args[name] = GraphQLArgument(
**merge_kwargs(
arg.to_kwargs(),
type_=replace_type(cast(GraphQLNamedType, arg.type)),
)
)
return args
def sort_fields(fields_map: Dict[str, GraphQLField]) -> Dict[str, GraphQLField]:
fields = {}
for name, field in sorted(fields_map.items()):
fields[name] = GraphQLField(
**merge_kwargs(
field.to_kwargs(),
type_=replace_type(cast(GraphQLNamedType, field.type)),
args=sort_args(field.args),
)
)
return fields
def sort_input_fields(
fields_map: Dict[str, GraphQLInputField]
) -> Dict[str, GraphQLInputField]:
return {
name: GraphQLInputField(
cast(
GraphQLInputType, replace_type(cast(GraphQLNamedType, field.type))
),
description=field.description,
default_value=field.default_value,
ast_node=field.ast_node,
)
for name, field in sorted(fields_map.items())
}
def sort_types(array: Collection[GraphQLNamedType]) -> Tuple[GraphQLNamedType, ...]:
return tuple(
replace_named_type(type_) for type_ in sorted(array, key=sort_by_name_key)
)
def sort_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
if is_scalar_type(type_) or is_introspection_type(type_):
return type_
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
return GraphQLObjectType(
**merge_kwargs(
type_.to_kwargs(),
interfaces=lambda: sort_types(type_.interfaces),
fields=lambda: sort_fields(type_.fields),
)
)
if is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
return GraphQLInterfaceType(
**merge_kwargs(
type_.to_kwargs(),
interfaces=lambda: sort_types(type_.interfaces),
fields=lambda: sort_fields(type_.fields),
)
)
if is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
return GraphQLUnionType(
**merge_kwargs(type_.to_kwargs(), types=lambda: sort_types(type_.types))
)
if is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
return GraphQLEnumType(
**merge_kwargs(
type_.to_kwargs(),
values={
name: GraphQLEnumValue(
val.value,
description=val.description,
deprecation_reason=val.deprecation_reason,
ast_node=val.ast_node,
)
for name, val in sorted(type_.values.items())
},
)
)
if is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
return GraphQLInputObjectType(
**merge_kwargs(
type_.to_kwargs(),
fields=lambda: sort_input_fields(type_.fields),
)
)
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.")
type_map: Dict[str, GraphQLNamedType] = {
type_.name: sort_named_type(type_)
for type_ in sorted(schema.type_map.values(), key=sort_by_name_key)
}
return GraphQLSchema(
types=type_map.values(),
directives=[
sort_directive(directive)
for directive in sorted(schema.directives, key=sort_by_name_key)
],
query=cast(Optional[GraphQLObjectType], replace_maybe_type(schema.query_type)),
mutation=cast(
Optional[GraphQLObjectType], replace_maybe_type(schema.mutation_type)
),
subscription=cast(
Optional[GraphQLObjectType], replace_maybe_type(schema.subscription_type)
),
ast_node=schema.ast_node,
)
def sort_by_name_key(
type_: Union[GraphQLNamedType, GraphQLDirective, DirectiveLocation]
) -> Tuple:
return natural_comparison_key(type_.name)
@@ -0,0 +1,298 @@
from typing import Any, Callable, Dict, List, Optional, Union, cast
from ..language import print_ast, StringValueNode
from ..language.block_string import is_printable_as_block_string
from ..pyutils import inspect
from ..type import (
DEFAULT_DEPRECATION_REASON,
GraphQLArgument,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLNamedType,
GraphQLObjectType,
GraphQLScalarType,
GraphQLSchema,
GraphQLUnionType,
is_enum_type,
is_input_object_type,
is_interface_type,
is_introspection_type,
is_object_type,
is_scalar_type,
is_specified_directive,
is_specified_scalar_type,
is_union_type,
)
from .ast_from_value import ast_from_value
__all__ = ["print_schema", "print_introspection_schema", "print_type", "print_value"]
def print_schema(schema: GraphQLSchema) -> str:
return print_filtered_schema(
schema, lambda n: not is_specified_directive(n), is_defined_type
)
def print_introspection_schema(schema: GraphQLSchema) -> str:
return print_filtered_schema(schema, is_specified_directive, is_introspection_type)
def is_defined_type(type_: GraphQLNamedType) -> bool:
return not is_specified_scalar_type(type_) and not is_introspection_type(type_)
def print_filtered_schema(
schema: GraphQLSchema,
directive_filter: Callable[[GraphQLDirective], bool],
type_filter: Callable[[GraphQLNamedType], bool],
) -> str:
directives = filter(directive_filter, schema.directives)
types = filter(type_filter, schema.type_map.values())
return "\n\n".join(
(
*filter(None, (print_schema_definition(schema),)),
*map(print_directive, directives),
*map(print_type, types),
)
)
def print_schema_definition(schema: GraphQLSchema) -> Optional[str]:
if schema.description is None and is_schema_of_common_names(schema):
return None
operation_types = []
query_type = schema.query_type
if query_type:
operation_types.append(f" query: {query_type.name}")
mutation_type = schema.mutation_type
if mutation_type:
operation_types.append(f" mutation: {mutation_type.name}")
subscription_type = schema.subscription_type
if subscription_type:
operation_types.append(f" subscription: {subscription_type.name}")
return print_description(schema) + "schema {\n" + "\n".join(operation_types) + "\n}"
def is_schema_of_common_names(schema: GraphQLSchema) -> bool:
"""Check whether this schema uses the common naming convention.
GraphQL schema define root types for each type of operation. These types are the
same as any other type and can be named in any manner, however there is a common
naming convention:
schema {
query: Query
mutation: Mutation
subscription: Subscription
}
When using this naming convention, the schema description can be omitted.
"""
query_type = schema.query_type
if query_type and query_type.name != "Query":
return False
mutation_type = schema.mutation_type
if mutation_type and mutation_type.name != "Mutation":
return False
subscription_type = schema.subscription_type
return not subscription_type or subscription_type.name == "Subscription"
def print_type(type_: GraphQLNamedType) -> str:
if is_scalar_type(type_):
type_ = cast(GraphQLScalarType, type_)
return print_scalar(type_)
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
return print_object(type_)
if is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
return print_interface(type_)
if is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
return print_union(type_)
if is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
return print_enum(type_)
if is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
return print_input_object(type_)
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.")
def print_scalar(type_: GraphQLScalarType) -> str:
return (
print_description(type_)
+ f"scalar {type_.name}"
+ print_specified_by_url(type_)
)
def print_implemented_interfaces(
type_: Union[GraphQLObjectType, GraphQLInterfaceType]
) -> str:
interfaces = type_.interfaces
return " implements " + " & ".join(i.name for i in interfaces) if interfaces else ""
def print_object(type_: GraphQLObjectType) -> str:
return (
print_description(type_)
+ f"type {type_.name}"
+ print_implemented_interfaces(type_)
+ print_fields(type_)
)
def print_interface(type_: GraphQLInterfaceType) -> str:
return (
print_description(type_)
+ f"interface {type_.name}"
+ print_implemented_interfaces(type_)
+ print_fields(type_)
)
def print_union(type_: GraphQLUnionType) -> str:
types = type_.types
possible_types = " = " + " | ".join(t.name for t in types) if types else ""
return print_description(type_) + f"union {type_.name}" + possible_types
def print_enum(type_: GraphQLEnumType) -> str:
values = [
print_description(value, " ", not i)
+ f" {name}"
+ print_deprecated(value.deprecation_reason)
for i, (name, value) in enumerate(type_.values.items())
]
return print_description(type_) + f"enum {type_.name}" + print_block(values)
def print_input_object(type_: GraphQLInputObjectType) -> str:
fields = [
print_description(field, " ", not i) + " " + print_input_value(name, field)
for i, (name, field) in enumerate(type_.fields.items())
]
return print_description(type_) + f"input {type_.name}" + print_block(fields)
def print_fields(type_: Union[GraphQLObjectType, GraphQLInterfaceType]) -> str:
fields = [
print_description(field, " ", not i)
+ f" {name}"
+ print_args(field.args, " ")
+ f": {field.type}"
+ print_deprecated(field.deprecation_reason)
for i, (name, field) in enumerate(type_.fields.items())
]
return print_block(fields)
def print_block(items: List[str]) -> str:
return " {\n" + "\n".join(items) + "\n}" if items else ""
def print_args(args: Dict[str, GraphQLArgument], indentation: str = "") -> str:
if not args:
return ""
# If every arg does not have a description, print them on one line.
if not any(arg.description for arg in args.values()):
return (
"("
+ ", ".join(print_input_value(name, arg) for name, arg in args.items())
+ ")"
)
return (
"(\n"
+ "\n".join(
print_description(arg, f" {indentation}", not i)
+ f" {indentation}"
+ print_input_value(name, arg)
for i, (name, arg) in enumerate(args.items())
)
+ f"\n{indentation})"
)
def print_input_value(name: str, arg: GraphQLArgument) -> str:
default_ast = ast_from_value(arg.default_value, arg.type)
arg_decl = f"{name}: {arg.type}"
if default_ast:
arg_decl += f" = {print_ast(default_ast)}"
return arg_decl + print_deprecated(arg.deprecation_reason)
def print_directive(directive: GraphQLDirective) -> str:
return (
print_description(directive)
+ f"directive @{directive.name}"
+ print_args(directive.args)
+ (" repeatable" if directive.is_repeatable else "")
+ " on "
+ " | ".join(location.name for location in directive.locations)
)
def print_deprecated(reason: Optional[str]) -> str:
if reason is None:
return ""
if reason != DEFAULT_DEPRECATION_REASON:
ast_value = print_ast(StringValueNode(value=reason))
return f" @deprecated(reason: {ast_value})"
return " @deprecated"
def print_specified_by_url(scalar: GraphQLScalarType) -> str:
if scalar.specified_by_url is None:
return ""
ast_value = print_ast(StringValueNode(value=scalar.specified_by_url))
return f" @specifiedBy(url: {ast_value})"
def print_description(
def_: Union[
GraphQLArgument,
GraphQLDirective,
GraphQLEnumValue,
GraphQLNamedType,
GraphQLSchema,
],
indentation: str = "",
first_in_block: bool = True,
) -> str:
description = def_.description
if description is None:
return ""
block_string = print_ast(
StringValueNode(
value=description, block=is_printable_as_block_string(description)
)
)
prefix = "\n" + indentation if indentation and not first_in_block else indentation
return prefix + block_string.replace("\n", "\n" + indentation) + "\n"
def print_value(value: Any, type_: GraphQLInputType) -> str:
"""@deprecated: Convenience function for printing a Python value"""
return print_ast(ast_from_value(value, type_)) # type: ignore
@@ -0,0 +1,101 @@
from typing import Any, Dict, List, Set
from ..language import (
DocumentNode,
FragmentDefinitionNode,
FragmentSpreadNode,
OperationDefinitionNode,
SelectionSetNode,
Visitor,
visit,
)
__all__ = ["separate_operations"]
DepGraph = Dict[str, List[str]]
def separate_operations(document_ast: DocumentNode) -> Dict[str, DocumentNode]:
"""Separate operations in a given AST document.
This function accepts a single AST document which may contain many operations and
fragments and returns a collection of AST documents each of which contains a single
operation as well the fragment definitions it refers to.
"""
operations: List[OperationDefinitionNode] = []
dep_graph: DepGraph = {}
# Populate metadata and build a dependency graph.
for definition_node in document_ast.definitions:
if isinstance(definition_node, OperationDefinitionNode):
operations.append(definition_node)
elif isinstance(
definition_node, FragmentDefinitionNode
): # pragma: no cover else
dep_graph[definition_node.name.value] = collect_dependencies(
definition_node.selection_set
)
# For each operation, produce a new synthesized AST which includes only what is
# necessary for completing that operation.
separated_document_asts: Dict[str, DocumentNode] = {}
for operation in operations:
dependencies: Set[str] = set()
for fragment_name in collect_dependencies(operation.selection_set):
collect_transitive_dependencies(dependencies, dep_graph, fragment_name)
# Provides the empty string for anonymous operations.
operation_name = operation.name.value if operation.name else ""
# The list of definition nodes to be included for this operation, sorted
# to retain the same order as the original document.
separated_document_asts[operation_name] = DocumentNode(
definitions=[
node
for node in document_ast.definitions
if node is operation
or (
isinstance(node, FragmentDefinitionNode)
and node.name.value in dependencies
)
]
)
return separated_document_asts
def collect_transitive_dependencies(
collected: Set[str], dep_graph: DepGraph, from_name: str
) -> None:
"""Collect transitive dependencies.
From a dependency graph, collects a list of transitive dependencies by recursing
through a dependency graph.
"""
if from_name not in collected:
collected.add(from_name)
immediate_deps = dep_graph.get(from_name)
if immediate_deps is not None:
for to_name in immediate_deps:
collect_transitive_dependencies(collected, dep_graph, to_name)
class DependencyCollector(Visitor):
dependencies: List[str]
def __init__(self) -> None:
super().__init__()
self.dependencies = []
self.add_dependency = self.dependencies.append
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
self.add_dependency(node.name.value)
def collect_dependencies(selection_set: SelectionSetNode) -> List[str]:
collector = DependencyCollector()
visit(selection_set, collector)
return collector.dependencies
@@ -0,0 +1,38 @@
from copy import copy
from typing import Tuple
from ..language import ListValueNode, ObjectFieldNode, ObjectValueNode, ValueNode
from ..pyutils import natural_comparison_key
__all__ = ["sort_value_node"]
def sort_value_node(value_node: ValueNode) -> ValueNode:
"""Sort ValueNode.
This function returns a sorted copy of the given ValueNode
For internal use only.
"""
if isinstance(value_node, ObjectValueNode):
value_node = copy(value_node)
value_node.fields = sort_fields(value_node.fields)
elif isinstance(value_node, ListValueNode):
value_node = copy(value_node)
value_node.values = tuple(sort_value_node(value) for value in value_node.values)
return value_node
def sort_field(field: ObjectFieldNode) -> ObjectFieldNode:
field = copy(field)
field.value = sort_value_node(field.value)
return field
def sort_fields(fields: Tuple[ObjectFieldNode, ...]) -> Tuple[ObjectFieldNode, ...]:
return tuple(
sorted(
(sort_field(field) for field in fields),
key=lambda field: natural_comparison_key(field.name.value),
)
)
@@ -0,0 +1,96 @@
from typing import Union, cast
from ..language import Lexer, TokenKind
from ..language.source import Source, is_source
from ..language.block_string import print_block_string
from ..language.lexer import is_punctuator_token_kind
__all__ = ["strip_ignored_characters"]
def strip_ignored_characters(source: Union[str, Source]) -> str:
"""Strip characters that are ignored anyway.
Strips characters that are not significant to the validity or execution
of a GraphQL document:
- UnicodeBOM
- WhiteSpace
- LineTerminator
- Comment
- Comma
- BlockString indentation
Note: It is required to have a delimiter character between neighboring
non-punctuator tokes and this function always uses single space as delimiter.
It is guaranteed that both input and output documents if parsed would result
in the exact same AST except for nodes location.
Warning: It is guaranteed that this function will always produce stable results.
However, it's not guaranteed that it will stay the same between different
releases due to bugfixes or changes in the GraphQL specification.
""" '''
Query example::
query SomeQuery($foo: String!, $bar: String) {
someField(foo: $foo, bar: $bar) {
a
b {
c
d
}
}
}
Becomes::
query SomeQuery($foo:String!$bar:String){someField(foo:$foo bar:$bar){a b{c d}}}
SDL example::
"""
Type description
"""
type Foo {
"""
Field description
"""
bar: String
}
Becomes::
"""Type description""" type Foo{"""Field description""" bar:String}
'''
source = cast(Source, source) if is_source(source) else Source(cast(str, source))
body = source.body
lexer = Lexer(source)
stripped_body = ""
was_last_added_token_non_punctuator = False
while lexer.advance().kind != TokenKind.EOF:
current_token = lexer.token
token_kind = current_token.kind
# Every two non-punctuator tokens should have space between them.
# Also prevent case of non-punctuator token following by spread resulting
# in invalid token (e.g.`1...` is invalid Float token).
is_non_punctuator = not is_punctuator_token_kind(current_token.kind)
if was_last_added_token_non_punctuator and (
is_non_punctuator or current_token.kind == TokenKind.SPREAD
):
stripped_body += " "
token_body = body[current_token.start : current_token.end]
if token_kind == TokenKind.BLOCK_STRING:
stripped_body += print_block_string(
current_token.value or "", minimize=True
)
else:
stripped_body += token_body
was_last_added_token_non_punctuator = is_non_punctuator
return stripped_body
@@ -0,0 +1,131 @@
from typing import cast
from ..type import (
GraphQLAbstractType,
GraphQLCompositeType,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLType,
is_abstract_type,
is_interface_type,
is_list_type,
is_non_null_type,
is_object_type,
)
__all__ = ["is_equal_type", "is_type_sub_type_of", "do_types_overlap"]
def is_equal_type(type_a: GraphQLType, type_b: GraphQLType) -> bool:
"""Check whether two types are equal.
Provided two types, return true if the types are equal (invariant)."""
# Equivalent types are equal.
if type_a is type_b:
return True
# If either type is non-null, the other must also be non-null.
if is_non_null_type(type_a) and is_non_null_type(type_b):
# noinspection PyUnresolvedReferences
return is_equal_type(type_a.of_type, type_b.of_type) # type:ignore
# If either type is a list, the other must also be a list.
if is_list_type(type_a) and is_list_type(type_b):
# noinspection PyUnresolvedReferences
return is_equal_type(type_a.of_type, type_b.of_type) # type:ignore
# Otherwise the types are not equal.
return False
def is_type_sub_type_of(
schema: GraphQLSchema, maybe_subtype: GraphQLType, super_type: GraphQLType
) -> bool:
"""Check whether a type is subtype of another type in a given schema.
Provided a type and a super type, return true if the first type is either equal or
a subset of the second super type (covariant).
"""
# Equivalent type is a valid subtype
if maybe_subtype is super_type:
return True
# If super_type is non-null, maybe_subtype must also be non-null.
if is_non_null_type(super_type):
if is_non_null_type(maybe_subtype):
return is_type_sub_type_of(
schema,
cast(GraphQLNonNull, maybe_subtype).of_type,
cast(GraphQLNonNull, super_type).of_type,
)
return False
elif is_non_null_type(maybe_subtype):
# If super_type is nullable, maybe_subtype may be non-null or nullable.
return is_type_sub_type_of(
schema, cast(GraphQLNonNull, maybe_subtype).of_type, super_type
)
# If super_type type is a list, maybeSubType type must also be a list.
if is_list_type(super_type):
if is_list_type(maybe_subtype):
return is_type_sub_type_of(
schema,
cast(GraphQLList, maybe_subtype).of_type,
cast(GraphQLList, super_type).of_type,
)
return False
elif is_list_type(maybe_subtype):
# If super_type is not a list, maybe_subtype must also be not a list.
return False
# If super_type type is abstract, check if it is super type of maybe_subtype.
# Otherwise, the child type is not a valid subtype of the parent type.
return (
is_abstract_type(super_type)
and (is_interface_type(maybe_subtype) or is_object_type(maybe_subtype))
and schema.is_sub_type(
cast(GraphQLAbstractType, super_type),
cast(GraphQLObjectType, maybe_subtype),
)
)
def do_types_overlap(
schema: GraphQLSchema, type_a: GraphQLCompositeType, type_b: GraphQLCompositeType
) -> bool:
"""Check whether two types overlap in a given schema.
Provided two composite types, determine if they "overlap". Two composite types
overlap when the Sets of possible concrete types for each intersect.
This is often used to determine if a fragment of a given type could possibly be
visited in a context of another type.
This function is commutative.
"""
# Equivalent types overlap
if type_a is type_b:
return True
if is_abstract_type(type_a):
type_a = cast(GraphQLAbstractType, type_a)
if is_abstract_type(type_b):
# If both types are abstract, then determine if there is any intersection
# between possible concrete types of each.
type_b = cast(GraphQLAbstractType, type_b)
return any(
schema.is_sub_type(type_b, type_)
for type_ in schema.get_possible_types(type_a)
)
# Determine if latter type is a possible concrete type of the former.
return schema.is_sub_type(type_a, type_b)
if is_abstract_type(type_b):
# Determine if former type is a possible concrete type of the latter.
type_b = cast(GraphQLAbstractType, type_b)
return schema.is_sub_type(type_b, type_a)
# Otherwise the types do not overlap.
return False
@@ -0,0 +1,65 @@
from typing import Optional, cast, overload
from ..language import ListTypeNode, NamedTypeNode, NonNullTypeNode, TypeNode
from ..pyutils import inspect
from ..type import (
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLNullableType,
GraphQLSchema,
GraphQLType,
)
__all__ = ["type_from_ast"]
@overload
def type_from_ast(
schema: GraphQLSchema, type_node: NamedTypeNode
) -> Optional[GraphQLNamedType]: ...
@overload
def type_from_ast(
schema: GraphQLSchema, type_node: ListTypeNode
) -> Optional[GraphQLList]: ...
@overload
def type_from_ast(
schema: GraphQLSchema, type_node: NonNullTypeNode
) -> Optional[GraphQLNonNull]: ...
@overload
def type_from_ast(
schema: GraphQLSchema, type_node: TypeNode
) -> Optional[GraphQLType]: ...
def type_from_ast(
schema: GraphQLSchema,
type_node: TypeNode,
) -> Optional[GraphQLType]:
"""Get the GraphQL type definition from an AST node.
Given a Schema and an AST node describing a type, return a GraphQLType definition
which applies to that type. For example, if provided the parsed AST node for
``[User]``, a GraphQLList instance will be returned, containing the type called
"User" found in the schema. If a type called "User" is not found in the schema,
then None will be returned.
"""
inner_type: Optional[GraphQLType]
if isinstance(type_node, ListTypeNode):
inner_type = type_from_ast(schema, type_node.type)
return GraphQLList(inner_type) if inner_type else None
if isinstance(type_node, NonNullTypeNode):
inner_type = type_from_ast(schema, type_node.type)
inner_type = cast(GraphQLNullableType, inner_type)
return GraphQLNonNull(inner_type) if inner_type else None
if isinstance(type_node, NamedTypeNode):
return schema.get_type(type_node.name.value)
# Not reachable. All possible type nodes have been considered.
raise TypeError(f"Unexpected type node: {inspect(type_node)}.")
@@ -0,0 +1,321 @@
from typing import Any, Callable, List, Optional, Union, cast
from ..language import (
ArgumentNode,
DirectiveNode,
EnumValueNode,
FieldNode,
InlineFragmentNode,
ListValueNode,
Node,
ObjectFieldNode,
OperationDefinitionNode,
SelectionSetNode,
VariableDefinitionNode,
Visitor,
)
from ..pyutils import Undefined
from ..type import (
GraphQLArgument,
GraphQLCompositeType,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLList,
GraphQLObjectType,
GraphQLOutputType,
GraphQLSchema,
GraphQLType,
is_composite_type,
is_input_type,
is_output_type,
get_named_type,
SchemaMetaFieldDef,
TypeMetaFieldDef,
TypeNameMetaFieldDef,
is_object_type,
is_interface_type,
get_nullable_type,
is_list_type,
is_input_object_type,
is_enum_type,
)
from .type_from_ast import type_from_ast
__all__ = ["TypeInfo", "TypeInfoVisitor"]
GetFieldDefFn = Callable[
[GraphQLSchema, GraphQLType, FieldNode], Optional[GraphQLField]
]
class TypeInfo:
"""Utility class for keeping track of type definitions.
TypeInfo is a utility class which, given a GraphQL schema, can keep track of the
current field and type definitions at any point in a GraphQL document AST during
a recursive descent by calling :meth:`enter(node) <.TypeInfo.enter>` and
:meth:`leave(node) <.TypeInfo.leave>`.
"""
def __init__(
self,
schema: GraphQLSchema,
initial_type: Optional[GraphQLType] = None,
get_field_def_fn: Optional[GetFieldDefFn] = None,
) -> None:
"""Initialize the TypeInfo for the given GraphQL schema.
Initial type may be provided in rare cases to facilitate traversals beginning
somewhere other than documents.
The optional last parameter is deprecated and will be removed in v3.3.
"""
self._schema = schema
self._type_stack: List[Optional[GraphQLOutputType]] = []
self._parent_type_stack: List[Optional[GraphQLCompositeType]] = []
self._input_type_stack: List[Optional[GraphQLInputType]] = []
self._field_def_stack: List[Optional[GraphQLField]] = []
self._default_value_stack: List[Any] = []
self._directive: Optional[GraphQLDirective] = None
self._argument: Optional[GraphQLArgument] = None
self._enum_value: Optional[GraphQLEnumValue] = None
self._get_field_def: GetFieldDefFn = get_field_def_fn or get_field_def
if initial_type:
if is_input_type(initial_type):
self._input_type_stack.append(cast(GraphQLInputType, initial_type))
if is_composite_type(initial_type):
self._parent_type_stack.append(cast(GraphQLCompositeType, initial_type))
if is_output_type(initial_type):
self._type_stack.append(cast(GraphQLOutputType, initial_type))
def get_type(self) -> Optional[GraphQLOutputType]:
if self._type_stack:
return self._type_stack[-1]
return None
def get_parent_type(self) -> Optional[GraphQLCompositeType]:
if self._parent_type_stack:
return self._parent_type_stack[-1]
return None
def get_input_type(self) -> Optional[GraphQLInputType]:
if self._input_type_stack:
return self._input_type_stack[-1]
return None
def get_parent_input_type(self) -> Optional[GraphQLInputType]:
if len(self._input_type_stack) > 1:
return self._input_type_stack[-2]
return None
def get_field_def(self) -> Optional[GraphQLField]:
if self._field_def_stack:
return self._field_def_stack[-1]
return None
def get_default_value(self) -> Any:
if self._default_value_stack:
return self._default_value_stack[-1]
return None
def get_directive(self) -> Optional[GraphQLDirective]:
return self._directive
def get_argument(self) -> Optional[GraphQLArgument]:
return self._argument
def get_enum_value(self) -> Optional[GraphQLEnumValue]:
return self._enum_value
def enter(self, node: Node) -> None:
method = getattr(self, "enter_" + node.kind, None)
if method:
method(node)
def leave(self, node: Node) -> None:
method = getattr(self, "leave_" + node.kind, None)
if method:
method()
# noinspection PyUnusedLocal
def enter_selection_set(self, node: SelectionSetNode) -> None:
named_type = get_named_type(self.get_type())
self._parent_type_stack.append(
cast(GraphQLCompositeType, named_type)
if is_composite_type(named_type)
else None
)
def enter_field(self, node: FieldNode) -> None:
parent_type = self.get_parent_type()
if parent_type:
field_def = self._get_field_def(self._schema, parent_type, node)
field_type = field_def.type if field_def else None
else:
field_def = field_type = None
self._field_def_stack.append(field_def)
self._type_stack.append(field_type if is_output_type(field_type) else None)
def enter_directive(self, node: DirectiveNode) -> None:
self._directive = self._schema.get_directive(node.name.value)
def enter_operation_definition(self, node: OperationDefinitionNode) -> None:
root_type = self._schema.get_root_type(node.operation)
self._type_stack.append(root_type if is_object_type(root_type) else None)
def enter_inline_fragment(self, node: InlineFragmentNode) -> None:
type_condition_ast = node.type_condition
output_type = (
type_from_ast(self._schema, type_condition_ast)
if type_condition_ast
else get_named_type(self.get_type())
)
self._type_stack.append(
cast(GraphQLOutputType, output_type)
if is_output_type(output_type)
else None
)
enter_fragment_definition = enter_inline_fragment
def enter_variable_definition(self, node: VariableDefinitionNode) -> None:
input_type = type_from_ast(self._schema, node.type)
self._input_type_stack.append(
cast(GraphQLInputType, input_type) if is_input_type(input_type) else None
)
def enter_argument(self, node: ArgumentNode) -> None:
field_or_directive = self.get_directive() or self.get_field_def()
if field_or_directive:
arg_def = field_or_directive.args.get(node.name.value)
arg_type = arg_def.type if arg_def else None
else:
arg_def = arg_type = None
self._argument = arg_def
self._default_value_stack.append(
arg_def.default_value if arg_def else Undefined
)
self._input_type_stack.append(arg_type if is_input_type(arg_type) else None)
# noinspection PyUnusedLocal
def enter_list_value(self, node: ListValueNode) -> None:
list_type = get_nullable_type(self.get_input_type()) # type: ignore
item_type = (
cast(GraphQLList, list_type).of_type
if is_list_type(list_type)
else list_type
)
# List positions never have a default value.
self._default_value_stack.append(Undefined)
self._input_type_stack.append(item_type if is_input_type(item_type) else None)
def enter_object_field(self, node: ObjectFieldNode) -> None:
object_type = get_named_type(self.get_input_type())
if is_input_object_type(object_type):
input_field = cast(GraphQLInputObjectType, object_type).fields.get(
node.name.value
)
input_field_type = input_field.type if input_field else None
else:
input_field = input_field_type = None
self._default_value_stack.append(
input_field.default_value if input_field else Undefined
)
self._input_type_stack.append(
input_field_type if is_input_type(input_field_type) else None
)
def enter_enum_value(self, node: EnumValueNode) -> None:
enum_type = get_named_type(self.get_input_type())
if is_enum_type(enum_type):
enum_value = cast(GraphQLEnumType, enum_type).values.get(node.value)
else:
enum_value = None
self._enum_value = enum_value
def leave_selection_set(self) -> None:
del self._parent_type_stack[-1:]
def leave_field(self) -> None:
del self._field_def_stack[-1:]
del self._type_stack[-1:]
def leave_directive(self) -> None:
self._directive = None
def leave_operation_definition(self) -> None:
del self._type_stack[-1:]
leave_inline_fragment = leave_operation_definition
leave_fragment_definition = leave_operation_definition
def leave_variable_definition(self) -> None:
del self._input_type_stack[-1:]
def leave_argument(self) -> None:
self._argument = None
del self._default_value_stack[-1:]
del self._input_type_stack[-1:]
def leave_list_value(self) -> None:
del self._default_value_stack[-1:]
del self._input_type_stack[-1:]
leave_object_field = leave_list_value
def leave_enum_value(self) -> None:
self._enum_value = None
def get_field_def(
schema: GraphQLSchema, parent_type: GraphQLType, field_node: FieldNode
) -> Optional[GraphQLField]:
"""Get field definition.
Not exactly the same as the executor's definition of
:func:`graphql.execution.get_field_def`, in this statically evaluated environment
we do not always have an Object type, and need to handle Interface and Union types.
"""
name = field_node.name.value
if name == "__schema" and schema.query_type is parent_type:
return SchemaMetaFieldDef
if name == "__type" and schema.query_type is parent_type:
return TypeMetaFieldDef
if name == "__typename" and is_composite_type(parent_type):
return TypeNameMetaFieldDef
if is_object_type(parent_type) or is_interface_type(parent_type):
parent_type = cast(Union[GraphQLObjectType, GraphQLInterfaceType], parent_type)
return parent_type.fields.get(name)
return None
class TypeInfoVisitor(Visitor):
"""A visitor which maintains a provided TypeInfo."""
def __init__(self, type_info: "TypeInfo", visitor: Visitor):
super().__init__()
self.type_info = type_info
self.visitor = visitor
def enter(self, node: Node, *args: Any) -> Any:
self.type_info.enter(node)
fn = self.visitor.get_enter_leave_for_kind(node.kind).enter
if fn:
result = fn(node, *args)
if result is not None:
self.type_info.leave(node)
if isinstance(result, Node):
self.type_info.enter(result)
return result
def leave(self, node: Node, *args: Any) -> Any:
fn = self.visitor.get_enter_leave_for_kind(node.kind).leave
result = fn(node, *args) if fn else None
self.type_info.leave(node)
return result
@@ -0,0 +1,149 @@
from typing import Any, Dict, List, Optional, cast
from ..language import (
ListValueNode,
NullValueNode,
ObjectValueNode,
ValueNode,
VariableNode,
)
from ..pyutils import inspect, Undefined
from ..type import (
GraphQLInputObjectType,
GraphQLInputType,
GraphQLList,
GraphQLNonNull,
GraphQLScalarType,
is_input_object_type,
is_leaf_type,
is_list_type,
is_non_null_type,
)
__all__ = ["value_from_ast"]
def value_from_ast(
value_node: Optional[ValueNode],
type_: GraphQLInputType,
variables: Optional[Dict[str, Any]] = None,
) -> Any:
"""Produce a Python value given a GraphQL Value AST.
A GraphQL type must be provided, which will be used to interpret different GraphQL
Value literals.
Returns ``Undefined`` when the value could not be validly coerced according
to the provided type.
=================== ============== ================
GraphQL Value JSON Value Python Value
=================== ============== ================
Input Object Object dict
List Array list
Boolean Boolean bool
String String str
Int / Float Number int / float
Enum Value Mixed Any
NullValue null None
=================== ============== ================
"""
if not value_node:
# When there is no node, then there is also no value.
# Importantly, this is different from returning the value null.
return Undefined
if isinstance(value_node, VariableNode):
variable_name = value_node.name.value
if not variables:
return Undefined
variable_value = variables.get(variable_name, Undefined)
if variable_value is None and is_non_null_type(type_):
return Undefined
# Note: This does no further checking that this variable is correct.
# This assumes that this query has been validated and the variable usage here
# is of the correct type.
return variable_value
if is_non_null_type(type_):
if isinstance(value_node, NullValueNode):
return Undefined
type_ = cast(GraphQLNonNull, type_)
return value_from_ast(value_node, type_.of_type, variables)
if isinstance(value_node, NullValueNode):
return None # This is explicitly returning the value None.
if is_list_type(type_):
type_ = cast(GraphQLList, type_)
item_type = type_.of_type
if isinstance(value_node, ListValueNode):
coerced_values: List[Any] = []
append_value = coerced_values.append
for item_node in value_node.values:
if is_missing_variable(item_node, variables):
# If an array contains a missing variable, it is either coerced to
# None or if the item type is non-null, it is considered invalid.
if is_non_null_type(item_type):
return Undefined
append_value(None)
else:
item_value = value_from_ast(item_node, item_type, variables)
if item_value is Undefined:
return Undefined
append_value(item_value)
return coerced_values
coerced_value = value_from_ast(value_node, item_type, variables)
if coerced_value is Undefined:
return Undefined
return [coerced_value]
if is_input_object_type(type_):
if not isinstance(value_node, ObjectValueNode):
return Undefined
type_ = cast(GraphQLInputObjectType, type_)
coerced_obj: Dict[str, Any] = {}
fields = type_.fields
field_nodes = {field.name.value: field for field in value_node.fields}
for field_name, field in fields.items():
field_node = field_nodes.get(field_name)
if not field_node or is_missing_variable(field_node.value, variables):
if field.default_value is not Undefined:
# Use out name as name if it exists (extension of GraphQL.js).
coerced_obj[field.out_name or field_name] = field.default_value
elif is_non_null_type(field.type): # pragma: no cover else
return Undefined
continue
field_value = value_from_ast(field_node.value, field.type, variables)
if field_value is Undefined:
return Undefined
coerced_obj[field.out_name or field_name] = field_value
return type_.out_type(coerced_obj)
if is_leaf_type(type_):
# Scalars fulfill parsing a literal value via `parse_literal()`. Invalid values
# represent a failure to parse correctly, in which case Undefined is returned.
type_ = cast(GraphQLScalarType, type_)
# noinspection PyBroadException
try:
if variables:
result = type_.parse_literal(value_node, variables)
else:
result = type_.parse_literal(value_node)
except Exception:
return Undefined
return result
# Not reachable. All possible input types have been considered.
raise TypeError(f"Unexpected input type: {inspect(type_)}.")
def is_missing_variable(
value_node: ValueNode, variables: Optional[Dict[str, Any]] = None
) -> bool:
"""Check if ``value_node`` is a variable not defined in the ``variables`` dict."""
return isinstance(value_node, VariableNode) and (
not variables or variables.get(value_node.name.value, Undefined) is Undefined
)
@@ -0,0 +1,110 @@
from math import nan
from typing import Any, Callable, Dict, Optional, Union
from ..language import (
ValueNode,
BooleanValueNode,
EnumValueNode,
FloatValueNode,
IntValueNode,
ListValueNode,
NullValueNode,
ObjectValueNode,
StringValueNode,
VariableNode,
)
from ..pyutils import inspect, Undefined
__all__ = ["value_from_ast_untyped"]
def value_from_ast_untyped(
value_node: ValueNode, variables: Optional[Dict[str, Any]] = None
) -> Any:
"""Produce a Python value given a GraphQL Value AST.
Unlike :func:`~graphql.utilities.value_from_ast`, no type is provided.
The resulting Python value will reflect the provided GraphQL value AST.
=================== ============== ================
GraphQL Value JSON Value Python Value
=================== ============== ================
Input Object Object dict
List Array list
Boolean Boolean bool
String / Enum String str
Int / Float Number int / float
Null null None
=================== ============== ================
"""
func = _value_from_kind_functions.get(value_node.kind)
if func:
return func(value_node, variables)
# Not reachable. All possible value nodes have been considered.
raise TypeError( # pragma: no cover
f"Unexpected value node: {inspect(value_node)}."
)
def value_from_null(_value_node: NullValueNode, _variables: Any) -> Any:
return None
def value_from_int(value_node: IntValueNode, _variables: Any) -> Any:
try:
return int(value_node.value)
except ValueError:
return nan
def value_from_float(value_node: FloatValueNode, _variables: Any) -> Any:
try:
return float(value_node.value)
except ValueError:
return nan
def value_from_string(
value_node: Union[BooleanValueNode, EnumValueNode, StringValueNode], _variables: Any
) -> Any:
return value_node.value
def value_from_list(
value_node: ListValueNode, variables: Optional[Dict[str, Any]]
) -> Any:
return [value_from_ast_untyped(node, variables) for node in value_node.values]
def value_from_object(
value_node: ObjectValueNode, variables: Optional[Dict[str, Any]]
) -> Any:
return {
field.name.value: value_from_ast_untyped(field.value, variables)
for field in value_node.fields
}
def value_from_variable(
value_node: VariableNode, variables: Optional[Dict[str, Any]]
) -> Any:
variable_name = value_node.name.value
if not variables:
return Undefined
return variables.get(variable_name, Undefined)
_value_from_kind_functions: Dict[str, Callable] = {
"null_value": value_from_null,
"int_value": value_from_int,
"float_value": value_from_float,
"string_value": value_from_string,
"enum_value": value_from_string,
"boolean_value": value_from_string,
"list_value": value_from_list,
"object_value": value_from_object,
"variable": value_from_variable,
}
@@ -0,0 +1,157 @@
"""GraphQL Validation
The :mod:`graphql.validation` package fulfills the Validation phase of fulfilling a
GraphQL result.
"""
from .validate import validate
from .validation_context import (
ASTValidationContext,
SDLValidationContext,
ValidationContext,
)
from .rules import ValidationRule, ASTValidationRule, SDLValidationRule
# All validation rules in the GraphQL Specification.
from .specified_rules import specified_rules
# Spec Section: "Executable Definitions"
from .rules.executable_definitions import ExecutableDefinitionsRule
# Spec Section: "Field Selections on Objects, Interfaces, and Unions Types"
from .rules.fields_on_correct_type import FieldsOnCorrectTypeRule
# Spec Section: "Fragments on Composite Types"
from .rules.fragments_on_composite_types import FragmentsOnCompositeTypesRule
# Spec Section: "Argument Names"
from .rules.known_argument_names import KnownArgumentNamesRule
# Spec Section: "Directives Are Defined"
from .rules.known_directives import KnownDirectivesRule
# Spec Section: "Fragment spread target defined"
from .rules.known_fragment_names import KnownFragmentNamesRule
# Spec Section: "Fragment Spread Type Existence"
from .rules.known_type_names import KnownTypeNamesRule
# Spec Section: "Lone Anonymous Operation"
from .rules.lone_anonymous_operation import LoneAnonymousOperationRule
# Spec Section: "Fragments must not form cycles"
from .rules.no_fragment_cycles import NoFragmentCyclesRule
# Spec Section: "All Variable Used Defined"
from .rules.no_undefined_variables import NoUndefinedVariablesRule
# Spec Section: "Fragments must be used"
from .rules.no_unused_fragments import NoUnusedFragmentsRule
# Spec Section: "All Variables Used"
from .rules.no_unused_variables import NoUnusedVariablesRule
# Spec Section: "Field Selection Merging"
from .rules.overlapping_fields_can_be_merged import OverlappingFieldsCanBeMergedRule
# Spec Section: "Fragment spread is possible"
from .rules.possible_fragment_spreads import PossibleFragmentSpreadsRule
# Spec Section: "Argument Optionality"
from .rules.provided_required_arguments import ProvidedRequiredArgumentsRule
# Spec Section: "Leaf Field Selections"
from .rules.scalar_leafs import ScalarLeafsRule
# Spec Section: "Subscriptions with Single Root Field"
from .rules.single_field_subscriptions import SingleFieldSubscriptionsRule
# Spec Section: "Argument Uniqueness"
from .rules.unique_argument_names import UniqueArgumentNamesRule
# Spec Section: "Directives Are Unique Per Location"
from .rules.unique_directives_per_location import UniqueDirectivesPerLocationRule
# Spec Section: "Fragment Name Uniqueness"
from .rules.unique_fragment_names import UniqueFragmentNamesRule
# Spec Section: "Input Object Field Uniqueness"
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
# Spec Section: "Operation Name Uniqueness"
from .rules.unique_operation_names import UniqueOperationNamesRule
# Spec Section: "Variable Uniqueness"
from .rules.unique_variable_names import UniqueVariableNamesRule
# Spec Section: "Value Type Correctness"
from .rules.values_of_correct_type import ValuesOfCorrectTypeRule
# Spec Section: "Variables are Input Types"
from .rules.variables_are_input_types import VariablesAreInputTypesRule
# Spec Section: "All Variable Usages Are Allowed"
from .rules.variables_in_allowed_position import VariablesInAllowedPositionRule
# SDL-specific validation rules
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
from .rules.unique_operation_types import UniqueOperationTypesRule
from .rules.unique_type_names import UniqueTypeNamesRule
from .rules.unique_enum_value_names import UniqueEnumValueNamesRule
from .rules.unique_field_definition_names import UniqueFieldDefinitionNamesRule
from .rules.unique_argument_definition_names import UniqueArgumentDefinitionNamesRule
from .rules.unique_directive_names import UniqueDirectiveNamesRule
from .rules.possible_type_extensions import PossibleTypeExtensionsRule
# Optional rules not defined by the GraphQL Specification
from .rules.custom.no_deprecated import NoDeprecatedCustomRule
from .rules.custom.no_schema_introspection import NoSchemaIntrospectionCustomRule
__all__ = [
"validate",
"ASTValidationContext",
"ASTValidationRule",
"SDLValidationContext",
"SDLValidationRule",
"ValidationContext",
"ValidationRule",
"specified_rules",
"ExecutableDefinitionsRule",
"FieldsOnCorrectTypeRule",
"FragmentsOnCompositeTypesRule",
"KnownArgumentNamesRule",
"KnownDirectivesRule",
"KnownFragmentNamesRule",
"KnownTypeNamesRule",
"LoneAnonymousOperationRule",
"NoFragmentCyclesRule",
"NoUndefinedVariablesRule",
"NoUnusedFragmentsRule",
"NoUnusedVariablesRule",
"OverlappingFieldsCanBeMergedRule",
"PossibleFragmentSpreadsRule",
"ProvidedRequiredArgumentsRule",
"ScalarLeafsRule",
"SingleFieldSubscriptionsRule",
"UniqueArgumentNamesRule",
"UniqueDirectivesPerLocationRule",
"UniqueFragmentNamesRule",
"UniqueInputFieldNamesRule",
"UniqueOperationNamesRule",
"UniqueVariableNamesRule",
"ValuesOfCorrectTypeRule",
"VariablesAreInputTypesRule",
"VariablesInAllowedPositionRule",
"LoneSchemaDefinitionRule",
"UniqueOperationTypesRule",
"UniqueTypeNamesRule",
"UniqueEnumValueNamesRule",
"UniqueFieldDefinitionNamesRule",
"UniqueArgumentDefinitionNamesRule",
"UniqueDirectiveNamesRule",
"PossibleTypeExtensionsRule",
"NoDeprecatedCustomRule",
"NoSchemaIntrospectionCustomRule",
]
@@ -0,0 +1,42 @@
"""graphql.validation.rules package"""
from ...error import GraphQLError
from ...language.visitor import Visitor
from ..validation_context import (
ASTValidationContext,
SDLValidationContext,
ValidationContext,
)
__all__ = ["ASTValidationRule", "SDLValidationRule", "ValidationRule"]
class ASTValidationRule(Visitor):
"""Visitor for validation of an AST."""
context: ASTValidationContext
def __init__(self, context: ASTValidationContext):
super().__init__()
self.context = context
def report_error(self, error: GraphQLError) -> None:
self.context.report_error(error)
class SDLValidationRule(ASTValidationRule):
"""Visitor for validation of an SDL AST."""
context: SDLValidationContext
def __init__(self, context: SDLValidationContext) -> None:
super().__init__(context)
class ValidationRule(ASTValidationRule):
"""Visitor for validation using a GraphQL schema."""
context: ValidationContext
def __init__(self, context: ValidationContext) -> None:
super().__init__(context)
@@ -0,0 +1 @@
"""graphql.validation.rules.custom package"""
@@ -0,0 +1,101 @@
from typing import Any, cast
from ....error import GraphQLError
from ....language import ArgumentNode, EnumValueNode, FieldNode, ObjectFieldNode
from ....type import GraphQLInputObjectType, get_named_type, is_input_object_type
from .. import ValidationRule
__all__ = ["NoDeprecatedCustomRule"]
class NoDeprecatedCustomRule(ValidationRule):
"""No deprecated
A GraphQL document is only valid if all selected fields and all used enum values
have not been deprecated.
Note: This rule is optional and is not part of the Validation section of the GraphQL
Specification. The main purpose of this rule is detection of deprecated usages and
not necessarily to forbid their use when querying a service.
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
context = self.context
field_def = context.get_field_def()
if field_def:
deprecation_reason = field_def.deprecation_reason
if deprecation_reason is not None:
parent_type = context.get_parent_type()
parent_name = parent_type.name # type: ignore
self.report_error(
GraphQLError(
f"The field {parent_name}.{node.name.value}"
f" is deprecated. {deprecation_reason}",
node,
)
)
def enter_argument(self, node: ArgumentNode, *_args: Any) -> None:
context = self.context
arg_def = context.get_argument()
if arg_def:
deprecation_reason = arg_def.deprecation_reason
if deprecation_reason is not None:
directive_def = context.get_directive()
arg_name = node.name.value
if directive_def is None:
parent_type = context.get_parent_type()
parent_name = parent_type.name # type: ignore
field_def = context.get_field_def()
field_name = field_def.ast_node.name.value # type: ignore
self.report_error(
GraphQLError(
f"Field '{parent_name}.{field_name}' argument"
f" '{arg_name}' is deprecated. {deprecation_reason}",
node,
)
)
else:
self.report_error(
GraphQLError(
f"Directive '@{directive_def.name}' argument"
f" '{arg_name}' is deprecated. {deprecation_reason}",
node,
)
)
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
context = self.context
input_object_def = get_named_type(context.get_parent_input_type())
if is_input_object_type(input_object_def):
input_field_def = cast(GraphQLInputObjectType, input_object_def).fields.get(
node.name.value
)
if input_field_def:
deprecation_reason = input_field_def.deprecation_reason
if deprecation_reason is not None:
field_name = node.name.value
input_object_name = input_object_def.name # type: ignore
self.report_error(
GraphQLError(
f"The input field {input_object_name}.{field_name}"
f" is deprecated. {deprecation_reason}",
node,
)
)
def enter_enum_value(self, node: EnumValueNode, *_args: Any) -> None:
context = self.context
enum_value_def = context.get_enum_value()
if enum_value_def:
deprecation_reason = enum_value_def.deprecation_reason
if deprecation_reason is not None: # pragma: no cover else
enum_type_def = get_named_type(context.get_input_type())
enum_type_name = enum_type_def.name # type: ignore
self.report_error(
GraphQLError(
f"The enum value '{enum_type_name}.{node.value}'"
f" is deprecated. {deprecation_reason}",
node,
)
)
@@ -0,0 +1,31 @@
from typing import Any
from ....error import GraphQLError
from ....language import FieldNode
from ....type import get_named_type, is_introspection_type
from .. import ValidationRule
__all__ = ["NoSchemaIntrospectionCustomRule"]
class NoSchemaIntrospectionCustomRule(ValidationRule):
"""Prohibit introspection queries
A GraphQL document is only valid if all fields selected are not fields that
return an introspection type.
Note: This rule is optional and is not part of the Validation section of the
GraphQL Specification. This rule effectively disables introspection, which
does not reflect best practices and should only be done if absolutely necessary.
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
type_ = get_named_type(self.context.get_type())
if type_ and is_introspection_type(type_):
self.report_error(
GraphQLError(
"GraphQL introspection has been disabled, but the requested query"
f" contained the field '{node.name.value}'.",
node,
)
)
@@ -0,0 +1,49 @@
from typing import Any, Union, cast
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DocumentNode,
ExecutableDefinitionNode,
SchemaDefinitionNode,
SchemaExtensionNode,
TypeDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationRule
__all__ = ["ExecutableDefinitionsRule"]
class ExecutableDefinitionsRule(ASTValidationRule):
"""Executable definitions
A GraphQL document is only valid for execution if all definitions are either
operation or fragment definitions.
See https://spec.graphql.org/draft/#sec-Executable-Definitions
"""
def enter_document(self, node: DocumentNode, *_args: Any) -> VisitorAction:
for definition in node.definitions:
if not isinstance(definition, ExecutableDefinitionNode):
def_name = (
"schema"
if isinstance(
definition, (SchemaDefinitionNode, SchemaExtensionNode)
)
else "'{}'".format(
cast(
Union[DirectiveDefinitionNode, TypeDefinitionNode],
definition,
).name.value
)
)
self.report_error(
GraphQLError(
f"The {def_name} definition is not executable.",
definition,
)
)
return SKIP
@@ -0,0 +1,136 @@
from collections import defaultdict
from functools import cmp_to_key
from typing import Any, Dict, List, Union, cast
from ...type import (
GraphQLAbstractType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLOutputType,
GraphQLSchema,
is_abstract_type,
is_interface_type,
is_object_type,
)
from ...error import GraphQLError
from ...language import FieldNode
from ...pyutils import did_you_mean, natural_comparison_key, suggestion_list
from . import ValidationRule
__all__ = ["FieldsOnCorrectTypeRule"]
class FieldsOnCorrectTypeRule(ValidationRule):
"""Fields on correct type
A GraphQL document is only valid if all fields selected are defined by the parent
type, or are an allowed meta field such as ``__typename``.
See https://spec.graphql.org/draft/#sec-Field-Selections
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
type_ = self.context.get_parent_type()
if not type_:
return
field_def = self.context.get_field_def()
if field_def:
return
# This field doesn't exist, lets look for suggestions.
schema = self.context.schema
field_name = node.name.value
# First determine if there are any suggested types to condition on.
suggestion = did_you_mean(
get_suggested_type_names(schema, type_, field_name),
"to use an inline fragment on",
)
# If there are no suggested types, then perhaps this was a typo?
if not suggestion:
suggestion = did_you_mean(get_suggested_field_names(type_, field_name))
# Report an error, including helpful suggestions.
self.report_error(
GraphQLError(
f"Cannot query field '{field_name}' on type '{type_}'." + suggestion,
node,
)
)
def get_suggested_type_names(
schema: GraphQLSchema, type_: GraphQLOutputType, field_name: str
) -> List[str]:
"""
Get a list of suggested type names.
Go through all of the implementations of type, as well as the interfaces
that they implement. If any of those types include the provided field,
suggest them, sorted by how often the type is referenced.
"""
if not is_abstract_type(type_):
# Must be an Object type, which does not have possible fields.
return []
type_ = cast(GraphQLAbstractType, type_)
# Use a dict instead of a set for stable sorting when usage counts are the same
suggested_types: Dict[Union[GraphQLObjectType, GraphQLInterfaceType], None] = {}
usage_count: Dict[str, int] = defaultdict(int)
for possible_type in schema.get_possible_types(type_):
if field_name not in possible_type.fields:
continue
# This object type defines this field.
suggested_types[possible_type] = None
usage_count[possible_type.name] = 1
for possible_interface in possible_type.interfaces:
if field_name not in possible_interface.fields:
continue
# This interface type defines this field.
suggested_types[possible_interface] = None
usage_count[possible_interface.name] += 1
def cmp(
type_a: Union[GraphQLObjectType, GraphQLInterfaceType],
type_b: Union[GraphQLObjectType, GraphQLInterfaceType],
) -> int: # pragma: no cover
# Suggest both interface and object types based on how common they are.
usage_count_diff = usage_count[type_b.name] - usage_count[type_a.name]
if usage_count_diff:
return usage_count_diff
# Suggest super types first followed by subtypes
if is_interface_type(type_a) and schema.is_sub_type(
cast(GraphQLInterfaceType, type_a), type_b
):
return -1
if is_interface_type(type_b) and schema.is_sub_type(
cast(GraphQLInterfaceType, type_b), type_a
):
return 1
name_a = natural_comparison_key(type_a.name)
name_b = natural_comparison_key(type_b.name)
if name_a > name_b:
return 1
if name_a < name_b:
return -1
return 0
return [type_.name for type_ in sorted(suggested_types, key=cmp_to_key(cmp))]
def get_suggested_field_names(type_: GraphQLOutputType, field_name: str) -> List[str]:
"""Get a list of suggested field names.
For the field name provided, determine if there are any similar field names that may
be the result of a typo.
"""
if is_object_type(type_) or is_interface_type(type_):
possible_field_names = list(type_.fields) # type: ignore
return suggestion_list(field_name, possible_field_names)
# Otherwise, must be a Union type, which does not define fields.
return []
@@ -0,0 +1,53 @@
from typing import Any
from ...error import GraphQLError
from ...language import (
FragmentDefinitionNode,
InlineFragmentNode,
print_ast,
)
from ...type import is_composite_type
from ...utilities import type_from_ast
from . import ValidationRule
__all__ = ["FragmentsOnCompositeTypesRule"]
class FragmentsOnCompositeTypesRule(ValidationRule):
"""Fragments on composite type
Fragments use a type condition to determine if they apply, since fragments can only
be spread into a composite type (object, interface, or union), the type condition
must also be a composite type.
See https://spec.graphql.org/draft/#sec-Fragments-On-Composite-Types
"""
def enter_inline_fragment(self, node: InlineFragmentNode, *_args: Any) -> None:
type_condition = node.type_condition
if type_condition:
type_ = type_from_ast(self.context.schema, type_condition)
if type_ and not is_composite_type(type_):
type_str = print_ast(type_condition)
self.report_error(
GraphQLError(
"Fragment cannot condition"
f" on non composite type '{type_str}'.",
type_condition,
)
)
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> None:
type_condition = node.type_condition
type_ = type_from_ast(self.context.schema, type_condition)
if type_ and not is_composite_type(type_):
type_str = print_ast(type_condition)
self.report_error(
GraphQLError(
f"Fragment '{node.name.value}' cannot condition"
f" on non composite type '{type_str}'.",
type_condition,
)
)
@@ -0,0 +1,98 @@
from typing import cast, Any, Dict, List, Union
from ...error import GraphQLError
from ...language import (
ArgumentNode,
DirectiveDefinitionNode,
DirectiveNode,
SKIP,
VisitorAction,
)
from ...pyutils import did_you_mean, suggestion_list
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["KnownArgumentNamesRule", "KnownArgumentNamesOnDirectivesRule"]
class KnownArgumentNamesOnDirectivesRule(ASTValidationRule):
"""Known argument names on directives
A GraphQL directive is only valid if all supplied arguments are defined.
For internal use only.
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
directive_args: Dict[str, List[str]] = {}
schema = context.schema
defined_directives = schema.directives if schema else specified_directives
for directive in cast(List, defined_directives):
directive_args[directive.name] = list(directive.args)
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
directive_args[def_.name.value] = [
arg.name.value for arg in def_.arguments or []
]
self.directive_args = directive_args
def enter_directive(
self, directive_node: DirectiveNode, *_args: Any
) -> VisitorAction:
directive_name = directive_node.name.value
known_args = self.directive_args.get(directive_name)
if directive_node.arguments and known_args is not None:
for arg_node in directive_node.arguments:
arg_name = arg_node.name.value
if arg_name not in known_args:
suggestions = suggestion_list(arg_name, known_args)
self.report_error(
GraphQLError(
f"Unknown argument '{arg_name}'"
f" on directive '@{directive_name}'."
+ did_you_mean(suggestions),
arg_node,
)
)
return SKIP
class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
"""Known argument names
A GraphQL field is only valid if all supplied arguments are defined by that field.
See https://spec.graphql.org/draft/#sec-Argument-Names
See https://spec.graphql.org/draft/#sec-Directives-Are-In-Valid-Locations
"""
context: ValidationContext
def __init__(self, context: ValidationContext):
super().__init__(context)
def enter_argument(self, arg_node: ArgumentNode, *args: Any) -> None:
context = self.context
arg_def = context.get_argument()
field_def = context.get_field_def()
parent_type = context.get_parent_type()
if not arg_def and field_def and parent_type:
arg_name = arg_node.name.value
field_name = args[3][-1].name.value
known_args_names = list(field_def.args)
suggestions = suggestion_list(arg_name, known_args_names)
context.report_error(
GraphQLError(
f"Unknown argument '{arg_name}'"
f" on field '{parent_type.name}.{field_name}'."
+ did_you_mean(suggestions),
arg_node,
)
)
@@ -0,0 +1,119 @@
from typing import cast, Any, Dict, List, Optional, Tuple, Union
from ...error import GraphQLError
from ...language import (
DirectiveLocation,
DirectiveDefinitionNode,
DirectiveNode,
Node,
OperationDefinitionNode,
)
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["KnownDirectivesRule"]
class KnownDirectivesRule(ASTValidationRule):
"""Known directives
A GraphQL document is only valid if all ``@directives`` are known by the schema and
legally positioned.
See https://spec.graphql.org/draft/#sec-Directives-Are-Defined
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
locations_map: Dict[str, Tuple[DirectiveLocation, ...]] = {}
schema = context.schema
defined_directives = (
schema.directives if schema else cast(List, specified_directives)
)
for directive in defined_directives:
locations_map[directive.name] = directive.locations
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
locations_map[def_.name.value] = tuple(
DirectiveLocation[name.value] for name in def_.locations
)
self.locations_map = locations_map
def enter_directive(
self,
node: DirectiveNode,
_key: Any,
_parent: Any,
_path: Any,
ancestors: List[Node],
) -> None:
name = node.name.value
locations = self.locations_map.get(name)
if locations:
candidate_location = get_directive_location_for_ast_path(ancestors)
if candidate_location and candidate_location not in locations:
self.report_error(
GraphQLError(
f"Directive '@{name}'"
f" may not be used on {candidate_location.value}.",
node,
)
)
else:
self.report_error(GraphQLError(f"Unknown directive '@{name}'.", node))
_operation_location = {
"query": DirectiveLocation.QUERY,
"mutation": DirectiveLocation.MUTATION,
"subscription": DirectiveLocation.SUBSCRIPTION,
}
_directive_location = {
"field": DirectiveLocation.FIELD,
"fragment_spread": DirectiveLocation.FRAGMENT_SPREAD,
"inline_fragment": DirectiveLocation.INLINE_FRAGMENT,
"fragment_definition": DirectiveLocation.FRAGMENT_DEFINITION,
"variable_definition": DirectiveLocation.VARIABLE_DEFINITION,
"schema_definition": DirectiveLocation.SCHEMA,
"schema_extension": DirectiveLocation.SCHEMA,
"scalar_type_definition": DirectiveLocation.SCALAR,
"scalar_type_extension": DirectiveLocation.SCALAR,
"object_type_definition": DirectiveLocation.OBJECT,
"object_type_extension": DirectiveLocation.OBJECT,
"field_definition": DirectiveLocation.FIELD_DEFINITION,
"interface_type_definition": DirectiveLocation.INTERFACE,
"interface_type_extension": DirectiveLocation.INTERFACE,
"union_type_definition": DirectiveLocation.UNION,
"union_type_extension": DirectiveLocation.UNION,
"enum_type_definition": DirectiveLocation.ENUM,
"enum_type_extension": DirectiveLocation.ENUM,
"enum_value_definition": DirectiveLocation.ENUM_VALUE,
"input_object_type_definition": DirectiveLocation.INPUT_OBJECT,
"input_object_type_extension": DirectiveLocation.INPUT_OBJECT,
}
def get_directive_location_for_ast_path(
ancestors: List[Node],
) -> Optional[DirectiveLocation]:
applied_to = ancestors[-1]
if not isinstance(applied_to, Node): # pragma: no cover
raise TypeError("Unexpected error in directive.")
kind = applied_to.kind
if kind == "operation_definition":
applied_to = cast(OperationDefinitionNode, applied_to)
return _operation_location[applied_to.operation.value]
elif kind == "input_value_definition":
parent_node = ancestors[-3]
return (
DirectiveLocation.INPUT_FIELD_DEFINITION
if parent_node.kind == "input_object_type_definition"
else DirectiveLocation.ARGUMENT_DEFINITION
)
else:
return _directive_location.get(kind)
@@ -0,0 +1,25 @@
from typing import Any
from ...error import GraphQLError
from ...language import FragmentSpreadNode
from . import ValidationRule
__all__ = ["KnownFragmentNamesRule"]
class KnownFragmentNamesRule(ValidationRule):
"""Known fragment names
A GraphQL document is only valid if all ``...Fragment`` fragment spreads refer to
fragments defined in the same document.
See https://spec.graphql.org/draft/#sec-Fragment-spread-target-defined
"""
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
fragment_name = node.name.value
fragment = self.context.get_fragment(fragment_name)
if not fragment:
self.report_error(
GraphQLError(f"Unknown fragment '{fragment_name}'.", node.name)
)
@@ -0,0 +1,90 @@
from typing import Any, Collection, List, Union, cast
from ...error import GraphQLError
from ...language import (
is_type_definition_node,
is_type_system_definition_node,
is_type_system_extension_node,
Node,
NamedTypeNode,
TypeDefinitionNode,
)
from ...type import introspection_types, specified_scalar_types
from ...pyutils import did_you_mean, suggestion_list
from . import ASTValidationRule, ValidationContext, SDLValidationContext
__all__ = ["KnownTypeNamesRule"]
class KnownTypeNamesRule(ASTValidationRule):
"""Known type names
A GraphQL document is only valid if referenced types (specifically variable
definitions and fragment conditions) are defined by the type schema.
See https://spec.graphql.org/draft/#sec-Fragment-Spread-Type-Existence
"""
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
schema = context.schema
self.existing_types_map = schema.type_map if schema else {}
defined_types = []
for def_ in context.document.definitions:
if is_type_definition_node(def_):
def_ = cast(TypeDefinitionNode, def_)
defined_types.append(def_.name.value)
self.defined_types = set(defined_types)
self.type_names = list(self.existing_types_map) + defined_types
def enter_named_type(
self,
node: NamedTypeNode,
_key: Any,
parent: Node,
_path: Any,
ancestors: List[Node],
) -> None:
type_name = node.name.value
if (
type_name not in self.existing_types_map
and type_name not in self.defined_types
):
try:
definition_node = ancestors[2]
except IndexError:
definition_node = parent
is_sdl = is_sdl_node(definition_node)
if is_sdl and type_name in standard_type_names:
return
suggested_types = suggestion_list(
type_name,
(
list(standard_type_names) + self.type_names
if is_sdl
else self.type_names
),
)
self.report_error(
GraphQLError(
f"Unknown type '{type_name}'." + did_you_mean(suggested_types),
node,
)
)
standard_type_names = set(specified_scalar_types).union(introspection_types)
def is_sdl_node(value: Union[Node, Collection[Node], None]) -> bool:
return (
value is not None
and not isinstance(value, list)
and (
is_type_system_definition_node(cast(Node, value))
or is_type_system_extension_node(cast(Node, value))
)
)
@@ -0,0 +1,37 @@
from typing import Any
from ...error import GraphQLError
from ...language import DocumentNode, OperationDefinitionNode
from . import ASTValidationContext, ASTValidationRule
__all__ = ["LoneAnonymousOperationRule"]
class LoneAnonymousOperationRule(ASTValidationRule):
"""Lone anonymous operation
A GraphQL document is only valid if when it contains an anonymous operation
(the query short-hand) that it contains only that one operation definition.
See https://spec.graphql.org/draft/#sec-Lone-Anonymous-Operation
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.operation_count = 0
def enter_document(self, node: DocumentNode, *_args: Any) -> None:
self.operation_count = sum(
isinstance(definition, OperationDefinitionNode)
for definition in node.definitions
)
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
if not node.name and self.operation_count > 1:
self.report_error(
GraphQLError(
"This anonymous operation must be the only defined operation.", node
)
)
@@ -0,0 +1,39 @@
from typing import Any
from ...error import GraphQLError
from ...language import SchemaDefinitionNode
from . import SDLValidationRule, SDLValidationContext
__all__ = ["LoneSchemaDefinitionRule"]
class LoneSchemaDefinitionRule(SDLValidationRule):
"""Lone Schema definition
A GraphQL document is only valid if it contains only one schema definition.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
old_schema = context.schema
self.already_defined = old_schema and (
old_schema.ast_node
or old_schema.query_type
or old_schema.mutation_type
or old_schema.subscription_type
)
self.schema_definitions_count = 0
def enter_schema_definition(self, node: SchemaDefinitionNode, *_args: Any) -> None:
if self.already_defined:
self.report_error(
GraphQLError(
"Cannot define a new schema within a schema extension.", node
)
)
else:
if self.schema_definitions_count:
self.report_error(
GraphQLError("Must provide only one schema definition.", node)
)
self.schema_definitions_count += 1
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Set
from ...error import GraphQLError
from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP
from . import ASTValidationContext, ASTValidationRule
__all__ = ["NoFragmentCyclesRule"]
class NoFragmentCyclesRule(ASTValidationRule):
"""No fragment cycles
The graph of fragment spreads must not form any cycles including spreading itself.
Otherwise an operation could infinitely spread or infinitely execute on cycles in
the underlying data.
See https://spec.graphql.org/draft/#sec-Fragment-spreads-must-not-form-cycles
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
# Tracks already visited fragments to maintain O(N) and to ensure that
# cycles are not redundantly reported.
self.visited_frags: Set[str] = set()
# List of AST nodes used to produce meaningful errors
self.spread_path: List[FragmentSpreadNode] = []
# Position in the spread path
self.spread_path_index_by_name: Dict[str, int] = {}
@staticmethod
def enter_operation_definition(*_args: Any) -> VisitorAction:
return SKIP
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> VisitorAction:
self.detect_cycle_recursive(node)
return SKIP
def detect_cycle_recursive(self, fragment: FragmentDefinitionNode) -> None:
# This does a straight-forward DFS to find cycles.
# It does not terminate when a cycle was found but continues to explore
# the graph to find all possible cycles.
if fragment.name.value in self.visited_frags:
return
fragment_name = fragment.name.value
visited_frags = self.visited_frags
visited_frags.add(fragment_name)
spread_nodes = self.context.get_fragment_spreads(fragment.selection_set)
if not spread_nodes:
return
spread_path = self.spread_path
spread_path_index = self.spread_path_index_by_name
spread_path_index[fragment_name] = len(spread_path)
get_fragment = self.context.get_fragment
for spread_node in spread_nodes:
spread_name = spread_node.name.value
cycle_index = spread_path_index.get(spread_name)
spread_path.append(spread_node)
if cycle_index is None:
spread_fragment = get_fragment(spread_name)
if spread_fragment:
self.detect_cycle_recursive(spread_fragment)
else:
cycle_path = spread_path[cycle_index:]
via_path = ", ".join("'" + s.name.value + "'" for s in cycle_path[:-1])
self.report_error(
GraphQLError(
f"Cannot spread fragment '{spread_name}' within itself"
+ (f" via {via_path}." if via_path else "."),
cycle_path,
)
)
spread_path.pop()
del spread_path_index[fragment_name]
@@ -0,0 +1,50 @@
from typing import Any, Set
from ...error import GraphQLError
from ...language import OperationDefinitionNode, VariableDefinitionNode
from . import ValidationContext, ValidationRule
__all__ = ["NoUndefinedVariablesRule"]
class NoUndefinedVariablesRule(ValidationRule):
"""No undefined variables
A GraphQL operation is only valid if all variables encountered, both directly and
via fragment spreads, are defined by that operation.
See https://spec.graphql.org/draft/#sec-All-Variable-Uses-Defined
"""
def __init__(self, context: ValidationContext):
super().__init__(context)
self.defined_variable_names: Set[str] = set()
def enter_operation_definition(self, *_args: Any) -> None:
self.defined_variable_names.clear()
def leave_operation_definition(
self, operation: OperationDefinitionNode, *_args: Any
) -> None:
usages = self.context.get_recursive_variable_usages(operation)
defined_variables = self.defined_variable_names
for usage in usages:
node = usage.node
var_name = node.name.value
if var_name not in defined_variables:
self.report_error(
GraphQLError(
(
f"Variable '${var_name}' is not defined"
f" by operation '{operation.name.value}'."
if operation.name
else f"Variable '${var_name}' is not defined."
),
[node, operation],
)
)
def enter_variable_definition(
self, node: VariableDefinitionNode, *_args: Any
) -> None:
self.defined_variable_names.add(node.variable.name.value)
@@ -0,0 +1,53 @@
from typing import Any, List
from ...error import GraphQLError
from ...language import (
FragmentDefinitionNode,
OperationDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationContext, ASTValidationRule
__all__ = ["NoUnusedFragmentsRule"]
class NoUnusedFragmentsRule(ASTValidationRule):
"""No unused fragments
A GraphQL document is only valid if all fragment definitions are spread within
operations, or spread within other fragments spread within operations.
See https://spec.graphql.org/draft/#sec-Fragments-Must-Be-Used
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.operation_defs: List[OperationDefinitionNode] = []
self.fragment_defs: List[FragmentDefinitionNode] = []
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> VisitorAction:
self.operation_defs.append(node)
return SKIP
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> VisitorAction:
self.fragment_defs.append(node)
return SKIP
def leave_document(self, *_args: Any) -> None:
fragment_names_used = set()
get_fragments = self.context.get_recursively_referenced_fragments
for operation in self.operation_defs:
for fragment in get_fragments(operation):
fragment_names_used.add(fragment.name.value)
for fragment_def in self.fragment_defs:
frag_name = fragment_def.name.value
if frag_name not in fragment_names_used:
self.report_error(
GraphQLError(f"Fragment '{frag_name}' is never used.", fragment_def)
)

Some files were not shown because too many files have changed in this diff Show More