2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,157 @@
"""GraphQL Validation
The :mod:`graphql.validation` package fulfills the Validation phase of fulfilling a
GraphQL result.
"""
from .validate import validate
from .validation_context import (
ASTValidationContext,
SDLValidationContext,
ValidationContext,
)
from .rules import ValidationRule, ASTValidationRule, SDLValidationRule
# All validation rules in the GraphQL Specification.
from .specified_rules import specified_rules
# Spec Section: "Executable Definitions"
from .rules.executable_definitions import ExecutableDefinitionsRule
# Spec Section: "Field Selections on Objects, Interfaces, and Unions Types"
from .rules.fields_on_correct_type import FieldsOnCorrectTypeRule
# Spec Section: "Fragments on Composite Types"
from .rules.fragments_on_composite_types import FragmentsOnCompositeTypesRule
# Spec Section: "Argument Names"
from .rules.known_argument_names import KnownArgumentNamesRule
# Spec Section: "Directives Are Defined"
from .rules.known_directives import KnownDirectivesRule
# Spec Section: "Fragment spread target defined"
from .rules.known_fragment_names import KnownFragmentNamesRule
# Spec Section: "Fragment Spread Type Existence"
from .rules.known_type_names import KnownTypeNamesRule
# Spec Section: "Lone Anonymous Operation"
from .rules.lone_anonymous_operation import LoneAnonymousOperationRule
# Spec Section: "Fragments must not form cycles"
from .rules.no_fragment_cycles import NoFragmentCyclesRule
# Spec Section: "All Variable Used Defined"
from .rules.no_undefined_variables import NoUndefinedVariablesRule
# Spec Section: "Fragments must be used"
from .rules.no_unused_fragments import NoUnusedFragmentsRule
# Spec Section: "All Variables Used"
from .rules.no_unused_variables import NoUnusedVariablesRule
# Spec Section: "Field Selection Merging"
from .rules.overlapping_fields_can_be_merged import OverlappingFieldsCanBeMergedRule
# Spec Section: "Fragment spread is possible"
from .rules.possible_fragment_spreads import PossibleFragmentSpreadsRule
# Spec Section: "Argument Optionality"
from .rules.provided_required_arguments import ProvidedRequiredArgumentsRule
# Spec Section: "Leaf Field Selections"
from .rules.scalar_leafs import ScalarLeafsRule
# Spec Section: "Subscriptions with Single Root Field"
from .rules.single_field_subscriptions import SingleFieldSubscriptionsRule
# Spec Section: "Argument Uniqueness"
from .rules.unique_argument_names import UniqueArgumentNamesRule
# Spec Section: "Directives Are Unique Per Location"
from .rules.unique_directives_per_location import UniqueDirectivesPerLocationRule
# Spec Section: "Fragment Name Uniqueness"
from .rules.unique_fragment_names import UniqueFragmentNamesRule
# Spec Section: "Input Object Field Uniqueness"
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
# Spec Section: "Operation Name Uniqueness"
from .rules.unique_operation_names import UniqueOperationNamesRule
# Spec Section: "Variable Uniqueness"
from .rules.unique_variable_names import UniqueVariableNamesRule
# Spec Section: "Value Type Correctness"
from .rules.values_of_correct_type import ValuesOfCorrectTypeRule
# Spec Section: "Variables are Input Types"
from .rules.variables_are_input_types import VariablesAreInputTypesRule
# Spec Section: "All Variable Usages Are Allowed"
from .rules.variables_in_allowed_position import VariablesInAllowedPositionRule
# SDL-specific validation rules
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
from .rules.unique_operation_types import UniqueOperationTypesRule
from .rules.unique_type_names import UniqueTypeNamesRule
from .rules.unique_enum_value_names import UniqueEnumValueNamesRule
from .rules.unique_field_definition_names import UniqueFieldDefinitionNamesRule
from .rules.unique_argument_definition_names import UniqueArgumentDefinitionNamesRule
from .rules.unique_directive_names import UniqueDirectiveNamesRule
from .rules.possible_type_extensions import PossibleTypeExtensionsRule
# Optional rules not defined by the GraphQL Specification
from .rules.custom.no_deprecated import NoDeprecatedCustomRule
from .rules.custom.no_schema_introspection import NoSchemaIntrospectionCustomRule
__all__ = [
"validate",
"ASTValidationContext",
"ASTValidationRule",
"SDLValidationContext",
"SDLValidationRule",
"ValidationContext",
"ValidationRule",
"specified_rules",
"ExecutableDefinitionsRule",
"FieldsOnCorrectTypeRule",
"FragmentsOnCompositeTypesRule",
"KnownArgumentNamesRule",
"KnownDirectivesRule",
"KnownFragmentNamesRule",
"KnownTypeNamesRule",
"LoneAnonymousOperationRule",
"NoFragmentCyclesRule",
"NoUndefinedVariablesRule",
"NoUnusedFragmentsRule",
"NoUnusedVariablesRule",
"OverlappingFieldsCanBeMergedRule",
"PossibleFragmentSpreadsRule",
"ProvidedRequiredArgumentsRule",
"ScalarLeafsRule",
"SingleFieldSubscriptionsRule",
"UniqueArgumentNamesRule",
"UniqueDirectivesPerLocationRule",
"UniqueFragmentNamesRule",
"UniqueInputFieldNamesRule",
"UniqueOperationNamesRule",
"UniqueVariableNamesRule",
"ValuesOfCorrectTypeRule",
"VariablesAreInputTypesRule",
"VariablesInAllowedPositionRule",
"LoneSchemaDefinitionRule",
"UniqueOperationTypesRule",
"UniqueTypeNamesRule",
"UniqueEnumValueNamesRule",
"UniqueFieldDefinitionNamesRule",
"UniqueArgumentDefinitionNamesRule",
"UniqueDirectiveNamesRule",
"PossibleTypeExtensionsRule",
"NoDeprecatedCustomRule",
"NoSchemaIntrospectionCustomRule",
]
@@ -0,0 +1,42 @@
"""graphql.validation.rules package"""
from ...error import GraphQLError
from ...language.visitor import Visitor
from ..validation_context import (
ASTValidationContext,
SDLValidationContext,
ValidationContext,
)
__all__ = ["ASTValidationRule", "SDLValidationRule", "ValidationRule"]
class ASTValidationRule(Visitor):
"""Visitor for validation of an AST."""
context: ASTValidationContext
def __init__(self, context: ASTValidationContext):
super().__init__()
self.context = context
def report_error(self, error: GraphQLError) -> None:
self.context.report_error(error)
class SDLValidationRule(ASTValidationRule):
"""Visitor for validation of an SDL AST."""
context: SDLValidationContext
def __init__(self, context: SDLValidationContext) -> None:
super().__init__(context)
class ValidationRule(ASTValidationRule):
"""Visitor for validation using a GraphQL schema."""
context: ValidationContext
def __init__(self, context: ValidationContext) -> None:
super().__init__(context)
@@ -0,0 +1 @@
"""graphql.validation.rules.custom package"""
@@ -0,0 +1,101 @@
from typing import Any, cast
from ....error import GraphQLError
from ....language import ArgumentNode, EnumValueNode, FieldNode, ObjectFieldNode
from ....type import GraphQLInputObjectType, get_named_type, is_input_object_type
from .. import ValidationRule
__all__ = ["NoDeprecatedCustomRule"]
class NoDeprecatedCustomRule(ValidationRule):
"""No deprecated
A GraphQL document is only valid if all selected fields and all used enum values
have not been deprecated.
Note: This rule is optional and is not part of the Validation section of the GraphQL
Specification. The main purpose of this rule is detection of deprecated usages and
not necessarily to forbid their use when querying a service.
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
context = self.context
field_def = context.get_field_def()
if field_def:
deprecation_reason = field_def.deprecation_reason
if deprecation_reason is not None:
parent_type = context.get_parent_type()
parent_name = parent_type.name # type: ignore
self.report_error(
GraphQLError(
f"The field {parent_name}.{node.name.value}"
f" is deprecated. {deprecation_reason}",
node,
)
)
def enter_argument(self, node: ArgumentNode, *_args: Any) -> None:
context = self.context
arg_def = context.get_argument()
if arg_def:
deprecation_reason = arg_def.deprecation_reason
if deprecation_reason is not None:
directive_def = context.get_directive()
arg_name = node.name.value
if directive_def is None:
parent_type = context.get_parent_type()
parent_name = parent_type.name # type: ignore
field_def = context.get_field_def()
field_name = field_def.ast_node.name.value # type: ignore
self.report_error(
GraphQLError(
f"Field '{parent_name}.{field_name}' argument"
f" '{arg_name}' is deprecated. {deprecation_reason}",
node,
)
)
else:
self.report_error(
GraphQLError(
f"Directive '@{directive_def.name}' argument"
f" '{arg_name}' is deprecated. {deprecation_reason}",
node,
)
)
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
context = self.context
input_object_def = get_named_type(context.get_parent_input_type())
if is_input_object_type(input_object_def):
input_field_def = cast(GraphQLInputObjectType, input_object_def).fields.get(
node.name.value
)
if input_field_def:
deprecation_reason = input_field_def.deprecation_reason
if deprecation_reason is not None:
field_name = node.name.value
input_object_name = input_object_def.name # type: ignore
self.report_error(
GraphQLError(
f"The input field {input_object_name}.{field_name}"
f" is deprecated. {deprecation_reason}",
node,
)
)
def enter_enum_value(self, node: EnumValueNode, *_args: Any) -> None:
context = self.context
enum_value_def = context.get_enum_value()
if enum_value_def:
deprecation_reason = enum_value_def.deprecation_reason
if deprecation_reason is not None: # pragma: no cover else
enum_type_def = get_named_type(context.get_input_type())
enum_type_name = enum_type_def.name # type: ignore
self.report_error(
GraphQLError(
f"The enum value '{enum_type_name}.{node.value}'"
f" is deprecated. {deprecation_reason}",
node,
)
)
@@ -0,0 +1,31 @@
from typing import Any
from ....error import GraphQLError
from ....language import FieldNode
from ....type import get_named_type, is_introspection_type
from .. import ValidationRule
__all__ = ["NoSchemaIntrospectionCustomRule"]
class NoSchemaIntrospectionCustomRule(ValidationRule):
"""Prohibit introspection queries
A GraphQL document is only valid if all fields selected are not fields that
return an introspection type.
Note: This rule is optional and is not part of the Validation section of the
GraphQL Specification. This rule effectively disables introspection, which
does not reflect best practices and should only be done if absolutely necessary.
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
type_ = get_named_type(self.context.get_type())
if type_ and is_introspection_type(type_):
self.report_error(
GraphQLError(
"GraphQL introspection has been disabled, but the requested query"
f" contained the field '{node.name.value}'.",
node,
)
)
@@ -0,0 +1,49 @@
from typing import Any, Union, cast
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DocumentNode,
ExecutableDefinitionNode,
SchemaDefinitionNode,
SchemaExtensionNode,
TypeDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationRule
__all__ = ["ExecutableDefinitionsRule"]
class ExecutableDefinitionsRule(ASTValidationRule):
"""Executable definitions
A GraphQL document is only valid for execution if all definitions are either
operation or fragment definitions.
See https://spec.graphql.org/draft/#sec-Executable-Definitions
"""
def enter_document(self, node: DocumentNode, *_args: Any) -> VisitorAction:
for definition in node.definitions:
if not isinstance(definition, ExecutableDefinitionNode):
def_name = (
"schema"
if isinstance(
definition, (SchemaDefinitionNode, SchemaExtensionNode)
)
else "'{}'".format(
cast(
Union[DirectiveDefinitionNode, TypeDefinitionNode],
definition,
).name.value
)
)
self.report_error(
GraphQLError(
f"The {def_name} definition is not executable.",
definition,
)
)
return SKIP
@@ -0,0 +1,136 @@
from collections import defaultdict
from functools import cmp_to_key
from typing import Any, Dict, List, Union, cast
from ...type import (
GraphQLAbstractType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLOutputType,
GraphQLSchema,
is_abstract_type,
is_interface_type,
is_object_type,
)
from ...error import GraphQLError
from ...language import FieldNode
from ...pyutils import did_you_mean, natural_comparison_key, suggestion_list
from . import ValidationRule
__all__ = ["FieldsOnCorrectTypeRule"]
class FieldsOnCorrectTypeRule(ValidationRule):
"""Fields on correct type
A GraphQL document is only valid if all fields selected are defined by the parent
type, or are an allowed meta field such as ``__typename``.
See https://spec.graphql.org/draft/#sec-Field-Selections
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
type_ = self.context.get_parent_type()
if not type_:
return
field_def = self.context.get_field_def()
if field_def:
return
# This field doesn't exist, lets look for suggestions.
schema = self.context.schema
field_name = node.name.value
# First determine if there are any suggested types to condition on.
suggestion = did_you_mean(
get_suggested_type_names(schema, type_, field_name),
"to use an inline fragment on",
)
# If there are no suggested types, then perhaps this was a typo?
if not suggestion:
suggestion = did_you_mean(get_suggested_field_names(type_, field_name))
# Report an error, including helpful suggestions.
self.report_error(
GraphQLError(
f"Cannot query field '{field_name}' on type '{type_}'." + suggestion,
node,
)
)
def get_suggested_type_names(
schema: GraphQLSchema, type_: GraphQLOutputType, field_name: str
) -> List[str]:
"""
Get a list of suggested type names.
Go through all of the implementations of type, as well as the interfaces
that they implement. If any of those types include the provided field,
suggest them, sorted by how often the type is referenced.
"""
if not is_abstract_type(type_):
# Must be an Object type, which does not have possible fields.
return []
type_ = cast(GraphQLAbstractType, type_)
# Use a dict instead of a set for stable sorting when usage counts are the same
suggested_types: Dict[Union[GraphQLObjectType, GraphQLInterfaceType], None] = {}
usage_count: Dict[str, int] = defaultdict(int)
for possible_type in schema.get_possible_types(type_):
if field_name not in possible_type.fields:
continue
# This object type defines this field.
suggested_types[possible_type] = None
usage_count[possible_type.name] = 1
for possible_interface in possible_type.interfaces:
if field_name not in possible_interface.fields:
continue
# This interface type defines this field.
suggested_types[possible_interface] = None
usage_count[possible_interface.name] += 1
def cmp(
type_a: Union[GraphQLObjectType, GraphQLInterfaceType],
type_b: Union[GraphQLObjectType, GraphQLInterfaceType],
) -> int: # pragma: no cover
# Suggest both interface and object types based on how common they are.
usage_count_diff = usage_count[type_b.name] - usage_count[type_a.name]
if usage_count_diff:
return usage_count_diff
# Suggest super types first followed by subtypes
if is_interface_type(type_a) and schema.is_sub_type(
cast(GraphQLInterfaceType, type_a), type_b
):
return -1
if is_interface_type(type_b) and schema.is_sub_type(
cast(GraphQLInterfaceType, type_b), type_a
):
return 1
name_a = natural_comparison_key(type_a.name)
name_b = natural_comparison_key(type_b.name)
if name_a > name_b:
return 1
if name_a < name_b:
return -1
return 0
return [type_.name for type_ in sorted(suggested_types, key=cmp_to_key(cmp))]
def get_suggested_field_names(type_: GraphQLOutputType, field_name: str) -> List[str]:
"""Get a list of suggested field names.
For the field name provided, determine if there are any similar field names that may
be the result of a typo.
"""
if is_object_type(type_) or is_interface_type(type_):
possible_field_names = list(type_.fields) # type: ignore
return suggestion_list(field_name, possible_field_names)
# Otherwise, must be a Union type, which does not define fields.
return []
@@ -0,0 +1,53 @@
from typing import Any
from ...error import GraphQLError
from ...language import (
FragmentDefinitionNode,
InlineFragmentNode,
print_ast,
)
from ...type import is_composite_type
from ...utilities import type_from_ast
from . import ValidationRule
__all__ = ["FragmentsOnCompositeTypesRule"]
class FragmentsOnCompositeTypesRule(ValidationRule):
"""Fragments on composite type
Fragments use a type condition to determine if they apply, since fragments can only
be spread into a composite type (object, interface, or union), the type condition
must also be a composite type.
See https://spec.graphql.org/draft/#sec-Fragments-On-Composite-Types
"""
def enter_inline_fragment(self, node: InlineFragmentNode, *_args: Any) -> None:
type_condition = node.type_condition
if type_condition:
type_ = type_from_ast(self.context.schema, type_condition)
if type_ and not is_composite_type(type_):
type_str = print_ast(type_condition)
self.report_error(
GraphQLError(
"Fragment cannot condition"
f" on non composite type '{type_str}'.",
type_condition,
)
)
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> None:
type_condition = node.type_condition
type_ = type_from_ast(self.context.schema, type_condition)
if type_ and not is_composite_type(type_):
type_str = print_ast(type_condition)
self.report_error(
GraphQLError(
f"Fragment '{node.name.value}' cannot condition"
f" on non composite type '{type_str}'.",
type_condition,
)
)
@@ -0,0 +1,98 @@
from typing import cast, Any, Dict, List, Union
from ...error import GraphQLError
from ...language import (
ArgumentNode,
DirectiveDefinitionNode,
DirectiveNode,
SKIP,
VisitorAction,
)
from ...pyutils import did_you_mean, suggestion_list
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["KnownArgumentNamesRule", "KnownArgumentNamesOnDirectivesRule"]
class KnownArgumentNamesOnDirectivesRule(ASTValidationRule):
"""Known argument names on directives
A GraphQL directive is only valid if all supplied arguments are defined.
For internal use only.
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
directive_args: Dict[str, List[str]] = {}
schema = context.schema
defined_directives = schema.directives if schema else specified_directives
for directive in cast(List, defined_directives):
directive_args[directive.name] = list(directive.args)
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
directive_args[def_.name.value] = [
arg.name.value for arg in def_.arguments or []
]
self.directive_args = directive_args
def enter_directive(
self, directive_node: DirectiveNode, *_args: Any
) -> VisitorAction:
directive_name = directive_node.name.value
known_args = self.directive_args.get(directive_name)
if directive_node.arguments and known_args is not None:
for arg_node in directive_node.arguments:
arg_name = arg_node.name.value
if arg_name not in known_args:
suggestions = suggestion_list(arg_name, known_args)
self.report_error(
GraphQLError(
f"Unknown argument '{arg_name}'"
f" on directive '@{directive_name}'."
+ did_you_mean(suggestions),
arg_node,
)
)
return SKIP
class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
"""Known argument names
A GraphQL field is only valid if all supplied arguments are defined by that field.
See https://spec.graphql.org/draft/#sec-Argument-Names
See https://spec.graphql.org/draft/#sec-Directives-Are-In-Valid-Locations
"""
context: ValidationContext
def __init__(self, context: ValidationContext):
super().__init__(context)
def enter_argument(self, arg_node: ArgumentNode, *args: Any) -> None:
context = self.context
arg_def = context.get_argument()
field_def = context.get_field_def()
parent_type = context.get_parent_type()
if not arg_def and field_def and parent_type:
arg_name = arg_node.name.value
field_name = args[3][-1].name.value
known_args_names = list(field_def.args)
suggestions = suggestion_list(arg_name, known_args_names)
context.report_error(
GraphQLError(
f"Unknown argument '{arg_name}'"
f" on field '{parent_type.name}.{field_name}'."
+ did_you_mean(suggestions),
arg_node,
)
)
@@ -0,0 +1,119 @@
from typing import cast, Any, Dict, List, Optional, Tuple, Union
from ...error import GraphQLError
from ...language import (
DirectiveLocation,
DirectiveDefinitionNode,
DirectiveNode,
Node,
OperationDefinitionNode,
)
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["KnownDirectivesRule"]
class KnownDirectivesRule(ASTValidationRule):
"""Known directives
A GraphQL document is only valid if all ``@directives`` are known by the schema and
legally positioned.
See https://spec.graphql.org/draft/#sec-Directives-Are-Defined
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
locations_map: Dict[str, Tuple[DirectiveLocation, ...]] = {}
schema = context.schema
defined_directives = (
schema.directives if schema else cast(List, specified_directives)
)
for directive in defined_directives:
locations_map[directive.name] = directive.locations
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
locations_map[def_.name.value] = tuple(
DirectiveLocation[name.value] for name in def_.locations
)
self.locations_map = locations_map
def enter_directive(
self,
node: DirectiveNode,
_key: Any,
_parent: Any,
_path: Any,
ancestors: List[Node],
) -> None:
name = node.name.value
locations = self.locations_map.get(name)
if locations:
candidate_location = get_directive_location_for_ast_path(ancestors)
if candidate_location and candidate_location not in locations:
self.report_error(
GraphQLError(
f"Directive '@{name}'"
f" may not be used on {candidate_location.value}.",
node,
)
)
else:
self.report_error(GraphQLError(f"Unknown directive '@{name}'.", node))
_operation_location = {
"query": DirectiveLocation.QUERY,
"mutation": DirectiveLocation.MUTATION,
"subscription": DirectiveLocation.SUBSCRIPTION,
}
_directive_location = {
"field": DirectiveLocation.FIELD,
"fragment_spread": DirectiveLocation.FRAGMENT_SPREAD,
"inline_fragment": DirectiveLocation.INLINE_FRAGMENT,
"fragment_definition": DirectiveLocation.FRAGMENT_DEFINITION,
"variable_definition": DirectiveLocation.VARIABLE_DEFINITION,
"schema_definition": DirectiveLocation.SCHEMA,
"schema_extension": DirectiveLocation.SCHEMA,
"scalar_type_definition": DirectiveLocation.SCALAR,
"scalar_type_extension": DirectiveLocation.SCALAR,
"object_type_definition": DirectiveLocation.OBJECT,
"object_type_extension": DirectiveLocation.OBJECT,
"field_definition": DirectiveLocation.FIELD_DEFINITION,
"interface_type_definition": DirectiveLocation.INTERFACE,
"interface_type_extension": DirectiveLocation.INTERFACE,
"union_type_definition": DirectiveLocation.UNION,
"union_type_extension": DirectiveLocation.UNION,
"enum_type_definition": DirectiveLocation.ENUM,
"enum_type_extension": DirectiveLocation.ENUM,
"enum_value_definition": DirectiveLocation.ENUM_VALUE,
"input_object_type_definition": DirectiveLocation.INPUT_OBJECT,
"input_object_type_extension": DirectiveLocation.INPUT_OBJECT,
}
def get_directive_location_for_ast_path(
ancestors: List[Node],
) -> Optional[DirectiveLocation]:
applied_to = ancestors[-1]
if not isinstance(applied_to, Node): # pragma: no cover
raise TypeError("Unexpected error in directive.")
kind = applied_to.kind
if kind == "operation_definition":
applied_to = cast(OperationDefinitionNode, applied_to)
return _operation_location[applied_to.operation.value]
elif kind == "input_value_definition":
parent_node = ancestors[-3]
return (
DirectiveLocation.INPUT_FIELD_DEFINITION
if parent_node.kind == "input_object_type_definition"
else DirectiveLocation.ARGUMENT_DEFINITION
)
else:
return _directive_location.get(kind)
@@ -0,0 +1,25 @@
from typing import Any
from ...error import GraphQLError
from ...language import FragmentSpreadNode
from . import ValidationRule
__all__ = ["KnownFragmentNamesRule"]
class KnownFragmentNamesRule(ValidationRule):
"""Known fragment names
A GraphQL document is only valid if all ``...Fragment`` fragment spreads refer to
fragments defined in the same document.
See https://spec.graphql.org/draft/#sec-Fragment-spread-target-defined
"""
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
fragment_name = node.name.value
fragment = self.context.get_fragment(fragment_name)
if not fragment:
self.report_error(
GraphQLError(f"Unknown fragment '{fragment_name}'.", node.name)
)
@@ -0,0 +1,90 @@
from typing import Any, Collection, List, Union, cast
from ...error import GraphQLError
from ...language import (
is_type_definition_node,
is_type_system_definition_node,
is_type_system_extension_node,
Node,
NamedTypeNode,
TypeDefinitionNode,
)
from ...type import introspection_types, specified_scalar_types
from ...pyutils import did_you_mean, suggestion_list
from . import ASTValidationRule, ValidationContext, SDLValidationContext
__all__ = ["KnownTypeNamesRule"]
class KnownTypeNamesRule(ASTValidationRule):
"""Known type names
A GraphQL document is only valid if referenced types (specifically variable
definitions and fragment conditions) are defined by the type schema.
See https://spec.graphql.org/draft/#sec-Fragment-Spread-Type-Existence
"""
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
schema = context.schema
self.existing_types_map = schema.type_map if schema else {}
defined_types = []
for def_ in context.document.definitions:
if is_type_definition_node(def_):
def_ = cast(TypeDefinitionNode, def_)
defined_types.append(def_.name.value)
self.defined_types = set(defined_types)
self.type_names = list(self.existing_types_map) + defined_types
def enter_named_type(
self,
node: NamedTypeNode,
_key: Any,
parent: Node,
_path: Any,
ancestors: List[Node],
) -> None:
type_name = node.name.value
if (
type_name not in self.existing_types_map
and type_name not in self.defined_types
):
try:
definition_node = ancestors[2]
except IndexError:
definition_node = parent
is_sdl = is_sdl_node(definition_node)
if is_sdl and type_name in standard_type_names:
return
suggested_types = suggestion_list(
type_name,
(
list(standard_type_names) + self.type_names
if is_sdl
else self.type_names
),
)
self.report_error(
GraphQLError(
f"Unknown type '{type_name}'." + did_you_mean(suggested_types),
node,
)
)
standard_type_names = set(specified_scalar_types).union(introspection_types)
def is_sdl_node(value: Union[Node, Collection[Node], None]) -> bool:
return (
value is not None
and not isinstance(value, list)
and (
is_type_system_definition_node(cast(Node, value))
or is_type_system_extension_node(cast(Node, value))
)
)
@@ -0,0 +1,37 @@
from typing import Any
from ...error import GraphQLError
from ...language import DocumentNode, OperationDefinitionNode
from . import ASTValidationContext, ASTValidationRule
__all__ = ["LoneAnonymousOperationRule"]
class LoneAnonymousOperationRule(ASTValidationRule):
"""Lone anonymous operation
A GraphQL document is only valid if when it contains an anonymous operation
(the query short-hand) that it contains only that one operation definition.
See https://spec.graphql.org/draft/#sec-Lone-Anonymous-Operation
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.operation_count = 0
def enter_document(self, node: DocumentNode, *_args: Any) -> None:
self.operation_count = sum(
isinstance(definition, OperationDefinitionNode)
for definition in node.definitions
)
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
if not node.name and self.operation_count > 1:
self.report_error(
GraphQLError(
"This anonymous operation must be the only defined operation.", node
)
)
@@ -0,0 +1,39 @@
from typing import Any
from ...error import GraphQLError
from ...language import SchemaDefinitionNode
from . import SDLValidationRule, SDLValidationContext
__all__ = ["LoneSchemaDefinitionRule"]
class LoneSchemaDefinitionRule(SDLValidationRule):
"""Lone Schema definition
A GraphQL document is only valid if it contains only one schema definition.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
old_schema = context.schema
self.already_defined = old_schema and (
old_schema.ast_node
or old_schema.query_type
or old_schema.mutation_type
or old_schema.subscription_type
)
self.schema_definitions_count = 0
def enter_schema_definition(self, node: SchemaDefinitionNode, *_args: Any) -> None:
if self.already_defined:
self.report_error(
GraphQLError(
"Cannot define a new schema within a schema extension.", node
)
)
else:
if self.schema_definitions_count:
self.report_error(
GraphQLError("Must provide only one schema definition.", node)
)
self.schema_definitions_count += 1
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Set
from ...error import GraphQLError
from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP
from . import ASTValidationContext, ASTValidationRule
__all__ = ["NoFragmentCyclesRule"]
class NoFragmentCyclesRule(ASTValidationRule):
"""No fragment cycles
The graph of fragment spreads must not form any cycles including spreading itself.
Otherwise an operation could infinitely spread or infinitely execute on cycles in
the underlying data.
See https://spec.graphql.org/draft/#sec-Fragment-spreads-must-not-form-cycles
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
# Tracks already visited fragments to maintain O(N) and to ensure that
# cycles are not redundantly reported.
self.visited_frags: Set[str] = set()
# List of AST nodes used to produce meaningful errors
self.spread_path: List[FragmentSpreadNode] = []
# Position in the spread path
self.spread_path_index_by_name: Dict[str, int] = {}
@staticmethod
def enter_operation_definition(*_args: Any) -> VisitorAction:
return SKIP
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> VisitorAction:
self.detect_cycle_recursive(node)
return SKIP
def detect_cycle_recursive(self, fragment: FragmentDefinitionNode) -> None:
# This does a straight-forward DFS to find cycles.
# It does not terminate when a cycle was found but continues to explore
# the graph to find all possible cycles.
if fragment.name.value in self.visited_frags:
return
fragment_name = fragment.name.value
visited_frags = self.visited_frags
visited_frags.add(fragment_name)
spread_nodes = self.context.get_fragment_spreads(fragment.selection_set)
if not spread_nodes:
return
spread_path = self.spread_path
spread_path_index = self.spread_path_index_by_name
spread_path_index[fragment_name] = len(spread_path)
get_fragment = self.context.get_fragment
for spread_node in spread_nodes:
spread_name = spread_node.name.value
cycle_index = spread_path_index.get(spread_name)
spread_path.append(spread_node)
if cycle_index is None:
spread_fragment = get_fragment(spread_name)
if spread_fragment:
self.detect_cycle_recursive(spread_fragment)
else:
cycle_path = spread_path[cycle_index:]
via_path = ", ".join("'" + s.name.value + "'" for s in cycle_path[:-1])
self.report_error(
GraphQLError(
f"Cannot spread fragment '{spread_name}' within itself"
+ (f" via {via_path}." if via_path else "."),
cycle_path,
)
)
spread_path.pop()
del spread_path_index[fragment_name]
@@ -0,0 +1,50 @@
from typing import Any, Set
from ...error import GraphQLError
from ...language import OperationDefinitionNode, VariableDefinitionNode
from . import ValidationContext, ValidationRule
__all__ = ["NoUndefinedVariablesRule"]
class NoUndefinedVariablesRule(ValidationRule):
"""No undefined variables
A GraphQL operation is only valid if all variables encountered, both directly and
via fragment spreads, are defined by that operation.
See https://spec.graphql.org/draft/#sec-All-Variable-Uses-Defined
"""
def __init__(self, context: ValidationContext):
super().__init__(context)
self.defined_variable_names: Set[str] = set()
def enter_operation_definition(self, *_args: Any) -> None:
self.defined_variable_names.clear()
def leave_operation_definition(
self, operation: OperationDefinitionNode, *_args: Any
) -> None:
usages = self.context.get_recursive_variable_usages(operation)
defined_variables = self.defined_variable_names
for usage in usages:
node = usage.node
var_name = node.name.value
if var_name not in defined_variables:
self.report_error(
GraphQLError(
(
f"Variable '${var_name}' is not defined"
f" by operation '{operation.name.value}'."
if operation.name
else f"Variable '${var_name}' is not defined."
),
[node, operation],
)
)
def enter_variable_definition(
self, node: VariableDefinitionNode, *_args: Any
) -> None:
self.defined_variable_names.add(node.variable.name.value)
@@ -0,0 +1,53 @@
from typing import Any, List
from ...error import GraphQLError
from ...language import (
FragmentDefinitionNode,
OperationDefinitionNode,
VisitorAction,
SKIP,
)
from . import ASTValidationContext, ASTValidationRule
__all__ = ["NoUnusedFragmentsRule"]
class NoUnusedFragmentsRule(ASTValidationRule):
"""No unused fragments
A GraphQL document is only valid if all fragment definitions are spread within
operations, or spread within other fragments spread within operations.
See https://spec.graphql.org/draft/#sec-Fragments-Must-Be-Used
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.operation_defs: List[OperationDefinitionNode] = []
self.fragment_defs: List[FragmentDefinitionNode] = []
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> VisitorAction:
self.operation_defs.append(node)
return SKIP
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> VisitorAction:
self.fragment_defs.append(node)
return SKIP
def leave_document(self, *_args: Any) -> None:
fragment_names_used = set()
get_fragments = self.context.get_recursively_referenced_fragments
for operation in self.operation_defs:
for fragment in get_fragments(operation):
fragment_names_used.add(fragment.name.value)
for fragment_def in self.fragment_defs:
frag_name = fragment_def.name.value
if frag_name not in fragment_names_used:
self.report_error(
GraphQLError(f"Fragment '{frag_name}' is never used.", fragment_def)
)
@@ -0,0 +1,53 @@
from typing import Any, List, Set
from ...error import GraphQLError
from ...language import OperationDefinitionNode, VariableDefinitionNode
from . import ValidationContext, ValidationRule
__all__ = ["NoUnusedVariablesRule"]
class NoUnusedVariablesRule(ValidationRule):
"""No unused variables
A GraphQL operation is only valid if all variables defined by an operation are used,
either directly or within a spread fragment.
See https://spec.graphql.org/draft/#sec-All-Variables-Used
"""
def __init__(self, context: ValidationContext):
super().__init__(context)
self.variable_defs: List[VariableDefinitionNode] = []
def enter_operation_definition(self, *_args: Any) -> None:
self.variable_defs.clear()
def leave_operation_definition(
self, operation: OperationDefinitionNode, *_args: Any
) -> None:
variable_name_used: Set[str] = set()
usages = self.context.get_recursive_variable_usages(operation)
for usage in usages:
variable_name_used.add(usage.node.name.value)
for variable_def in self.variable_defs:
variable_name = variable_def.variable.name.value
if variable_name not in variable_name_used:
self.report_error(
GraphQLError(
(
f"Variable '${variable_name}' is never used"
f" in operation '{operation.name.value}'."
if operation.name
else f"Variable '${variable_name}' is never used."
),
variable_def,
)
)
def enter_variable_definition(
self, definition: VariableDefinitionNode, *_args: Any
) -> None:
self.variable_defs.append(definition)
@@ -0,0 +1,783 @@
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from ...error import GraphQLError
from ...language import (
DirectiveNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
InlineFragmentNode,
SelectionSetNode,
ValueNode,
print_ast,
)
from ...type import (
GraphQLCompositeType,
GraphQLField,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLOutputType,
get_named_type,
is_interface_type,
is_leaf_type,
is_list_type,
is_non_null_type,
is_object_type,
)
from ...utilities import type_from_ast
from ...utilities.sort_value_node import sort_value_node
from . import ValidationContext, ValidationRule
MYPY = False
__all__ = ["OverlappingFieldsCanBeMergedRule"]
def reason_message(reason: "ConflictReasonMessage") -> str:
if isinstance(reason, list):
return " and ".join(
f"subfields '{response_name}' conflict"
f" because {reason_message(sub_reason)}"
for response_name, sub_reason in reason
)
return reason
class OverlappingFieldsCanBeMergedRule(ValidationRule):
"""Overlapping fields can be merged
A selection set is only valid if all fields (including spreading any fragments)
either correspond to distinct response names or can be merged without ambiguity.
See https://spec.graphql.org/draft/#sec-Field-Selection-Merging
"""
def __init__(self, context: ValidationContext):
super().__init__(context)
# A memoization for when two fragments are compared "between" each other for
# conflicts. Two fragments may be compared many times, so memoizing this can
# dramatically improve the performance of this validator.
self.compared_fragment_pairs = PairSet()
# A cache for the "field map" and list of fragment names found in any given
# selection set. Selection sets may be asked for this information multiple
# times, so this improves the performance of this validator.
self.cached_fields_and_fragment_names: Dict = {}
def enter_selection_set(self, selection_set: SelectionSetNode, *_args: Any) -> None:
conflicts = find_conflicts_within_selection_set(
self.context,
self.cached_fields_and_fragment_names,
self.compared_fragment_pairs,
self.context.get_parent_type(),
selection_set,
)
for (reason_name, reason), fields1, fields2 in conflicts:
reason_msg = reason_message(reason)
self.report_error(
GraphQLError(
f"Fields '{reason_name}' conflict because {reason_msg}."
" Use different aliases on the fields to fetch both"
" if this was intentional.",
fields1 + fields2,
)
)
Conflict = Tuple["ConflictReason", List[FieldNode], List[FieldNode]]
# Field name and reason.
ConflictReason = Tuple[str, "ConflictReasonMessage"]
# Reason is a string, or a nested list of conflicts.
if MYPY: # recursive types not fully supported yet (/python/mypy/issues/731)
ConflictReasonMessage = Union[str, List]
else:
ConflictReasonMessage = Union[str, List[ConflictReason]]
# Tuple defining a field node in a context.
NodeAndDef = Tuple[GraphQLCompositeType, FieldNode, Optional[GraphQLField]]
# Dictionary of lists of those.
NodeAndDefCollection = Dict[str, List[NodeAndDef]]
# Algorithm:
#
# Conflicts occur when two fields exist in a query which will produce the same
# response name, but represent differing values, thus creating a conflict.
# The algorithm below finds all conflicts via making a series of comparisons
# between fields. In order to compare as few fields as possible, this makes
# a series of comparisons "within" sets of fields and "between" sets of fields.
#
# Given any selection set, a collection produces both a set of fields by
# also including all inline fragments, as well as a list of fragments
# referenced by fragment spreads.
#
# A) Each selection set represented in the document first compares "within" its
# collected set of fields, finding any conflicts between every pair of
# overlapping fields.
# Note: This is the#only time* that a the fields "within" a set are compared
# to each other. After this only fields "between" sets are compared.
#
# B) Also, if any fragment is referenced in a selection set, then a
# comparison is made "between" the original set of fields and the
# referenced fragment.
#
# C) Also, if multiple fragments are referenced, then comparisons
# are made "between" each referenced fragment.
#
# D) When comparing "between" a set of fields and a referenced fragment, first
# a comparison is made between each field in the original set of fields and
# each field in the the referenced set of fields.
#
# E) Also, if any fragment is referenced in the referenced selection set,
# then a comparison is made "between" the original set of fields and the
# referenced fragment (recursively referring to step D).
#
# F) When comparing "between" two fragments, first a comparison is made between
# each field in the first referenced set of fields and each field in the the
# second referenced set of fields.
#
# G) Also, any fragments referenced by the first must be compared to the
# second, and any fragments referenced by the second must be compared to the
# first (recursively referring to step F).
#
# H) When comparing two fields, if both have selection sets, then a comparison
# is made "between" both selection sets, first comparing the set of fields in
# the first selection set with the set of fields in the second.
#
# I) Also, if any fragment is referenced in either selection set, then a
# comparison is made "between" the other set of fields and the
# referenced fragment.
#
# J) Also, if two fragments are referenced in both selection sets, then a
# comparison is made "between" the two fragments.
def find_conflicts_within_selection_set(
context: ValidationContext,
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
parent_type: Optional[GraphQLNamedType],
selection_set: SelectionSetNode,
) -> List[Conflict]:
"""Find conflicts within selection set.
Find all conflicts found "within" a selection set, including those found via
spreading in fragments.
Called when visiting each SelectionSet in the GraphQL Document.
"""
conflicts: List[Conflict] = []
field_map, fragment_names = get_fields_and_fragment_names(
context, cached_fields_and_fragment_names, parent_type, selection_set
)
# (A) Find all conflicts "within" the fields of this selection set.
# Note: this is the *only place* `collect_conflicts_within` is called.
collect_conflicts_within(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
field_map,
)
if fragment_names:
# (B) Then collect conflicts between these fields and those represented by each
# spread fragment name found.
for i, fragment_name in enumerate(fragment_names):
collect_conflicts_between_fields_and_fragment(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
False,
field_map,
fragment_name,
)
# (C) Then compare this fragment with all other fragments found in this
# selection set to collect conflicts within fragments spread together.
# This compares each item in the list of fragment names to every other
# item in that same list (except for itself).
for other_fragment_name in fragment_names[i + 1 :]:
collect_conflicts_between_fragments(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
False,
fragment_name,
other_fragment_name,
)
return conflicts
def collect_conflicts_between_fields_and_fragment(
context: ValidationContext,
conflicts: List[Conflict],
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
are_mutually_exclusive: bool,
field_map: NodeAndDefCollection,
fragment_name: str,
) -> None:
"""Collect conflicts between fields and fragment.
Collect all conflicts found between a set of fields and a fragment reference
including via spreading in any nested fragments.
"""
fragment = context.get_fragment(fragment_name)
if not fragment:
return None
field_map2, referenced_fragment_names = get_referenced_fields_and_fragment_names(
context, cached_fields_and_fragment_names, fragment
)
# Do not compare a fragment's fieldMap to itself.
if field_map is field_map2:
return
# (D) First collect any conflicts between the provided collection of fields and the
# collection of fields represented by the given fragment.
collect_conflicts_between(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map,
field_map2,
)
# (E) Then collect any conflicts between the provided collection of fields and any
# fragment names found in the given fragment.
for referenced_fragment_name in referenced_fragment_names:
# Memoize so two fragments are not compared for conflicts more than once.
if compared_fragment_pairs.has(
referenced_fragment_name, fragment_name, are_mutually_exclusive
):
continue
compared_fragment_pairs.add(
referenced_fragment_name, fragment_name, are_mutually_exclusive
)
collect_conflicts_between_fields_and_fragment(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map,
referenced_fragment_name,
)
def collect_conflicts_between_fragments(
context: ValidationContext,
conflicts: List[Conflict],
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
are_mutually_exclusive: bool,
fragment_name1: str,
fragment_name2: str,
) -> None:
"""Collect conflicts between fragments.
Collect all conflicts found between two fragments, including via spreading in any
nested fragments.
"""
# No need to compare a fragment to itself.
if fragment_name1 == fragment_name2:
return
# Memoize so two fragments are not compared for conflicts more than once.
if compared_fragment_pairs.has(
fragment_name1, fragment_name2, are_mutually_exclusive
):
return
compared_fragment_pairs.add(fragment_name1, fragment_name2, are_mutually_exclusive)
fragment1 = context.get_fragment(fragment_name1)
fragment2 = context.get_fragment(fragment_name2)
if not fragment1 or not fragment2:
return None
field_map1, referenced_fragment_names1 = get_referenced_fields_and_fragment_names(
context, cached_fields_and_fragment_names, fragment1
)
field_map2, referenced_fragment_names2 = get_referenced_fields_and_fragment_names(
context, cached_fields_and_fragment_names, fragment2
)
# (F) First, collect all conflicts between these two collections of fields
# (not including any nested fragments)
collect_conflicts_between(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map1,
field_map2,
)
# (G) Then collect conflicts between the first fragment and any nested fragments
# spread in the second fragment.
for referenced_fragment_name2 in referenced_fragment_names2:
collect_conflicts_between_fragments(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
fragment_name1,
referenced_fragment_name2,
)
# (G) Then collect conflicts between the second fragment and any nested fragments
# spread in the first fragment.
for referenced_fragment_name1 in referenced_fragment_names1:
collect_conflicts_between_fragments(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
referenced_fragment_name1,
fragment_name2,
)
def find_conflicts_between_sub_selection_sets(
context: ValidationContext,
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
are_mutually_exclusive: bool,
parent_type1: Optional[GraphQLNamedType],
selection_set1: SelectionSetNode,
parent_type2: Optional[GraphQLNamedType],
selection_set2: SelectionSetNode,
) -> List[Conflict]:
"""Find conflicts between sub selection sets.
Find all conflicts found between two selection sets, including those found via
spreading in fragments. Called when determining if conflicts exist between the
sub-fields of two overlapping fields.
"""
conflicts: List[Conflict] = []
field_map1, fragment_names1 = get_fields_and_fragment_names(
context, cached_fields_and_fragment_names, parent_type1, selection_set1
)
field_map2, fragment_names2 = get_fields_and_fragment_names(
context, cached_fields_and_fragment_names, parent_type2, selection_set2
)
# (H) First, collect all conflicts between these two collections of field.
collect_conflicts_between(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map1,
field_map2,
)
# (I) Then collect conflicts between the first collection of fields and those
# referenced by each fragment name associated with the second.
if fragment_names2:
for fragment_name2 in fragment_names2:
collect_conflicts_between_fields_and_fragment(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map1,
fragment_name2,
)
# (I) Then collect conflicts between the second collection of fields and those
# referenced by each fragment name associated with the first.
if fragment_names1:
for fragment_name1 in fragment_names1:
collect_conflicts_between_fields_and_fragment(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
field_map2,
fragment_name1,
)
# (J) Also collect conflicts between any fragment names by the first and fragment
# names by the second. This compares each item in the first set of names to each
# item in the second set of names.
for fragment_name1 in fragment_names1:
for fragment_name2 in fragment_names2:
collect_conflicts_between_fragments(
context,
conflicts,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
fragment_name1,
fragment_name2,
)
return conflicts
def collect_conflicts_within(
context: ValidationContext,
conflicts: List[Conflict],
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
field_map: NodeAndDefCollection,
) -> None:
"""Collect all Conflicts "within" one collection of fields."""
# A field map is a keyed collection, where each key represents a response name and
# the value at that key is a list of all fields which provide that response name.
# For every response name, if there are multiple fields, they must be compared to
# find a potential conflict.
for response_name, fields in field_map.items():
# This compares every field in the list to every other field in this list
# (except to itself). If the list only has one item, nothing needs to be
# compared.
if len(fields) > 1:
for i, field in enumerate(fields):
for other_field in fields[i + 1 :]:
conflict = find_conflict(
context,
cached_fields_and_fragment_names,
compared_fragment_pairs,
# within one collection is never mutually exclusive
False,
response_name,
field,
other_field,
)
if conflict:
conflicts.append(conflict)
def collect_conflicts_between(
context: ValidationContext,
conflicts: List[Conflict],
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
parent_fields_are_mutually_exclusive: bool,
field_map1: NodeAndDefCollection,
field_map2: NodeAndDefCollection,
) -> None:
"""Collect all Conflicts between two collections of fields.
This is similar to, but different from the :func:`~.collect_conflicts_within`
function above. This check assumes that :func:`~.collect_conflicts_within` has
already been called on each provided collection of fields. This is true because
this validator traverses each individual selection set.
"""
# A field map is a keyed collection, where each key represents a response name and
# the value at that key is a list of all fields which provide that response name.
# For any response name which appears in both provided field maps, each field from
# the first field map must be compared to every field in the second field map to
# find potential conflicts.
for response_name, fields1 in field_map1.items():
fields2 = field_map2.get(response_name)
if fields2:
for field1 in fields1:
for field2 in fields2:
conflict = find_conflict(
context,
cached_fields_and_fragment_names,
compared_fragment_pairs,
parent_fields_are_mutually_exclusive,
response_name,
field1,
field2,
)
if conflict:
conflicts.append(conflict)
def find_conflict(
context: ValidationContext,
cached_fields_and_fragment_names: Dict,
compared_fragment_pairs: "PairSet",
parent_fields_are_mutually_exclusive: bool,
response_name: str,
field1: NodeAndDef,
field2: NodeAndDef,
) -> Optional[Conflict]:
"""Find conflict.
Determines if there is a conflict between two particular fields, including comparing
their sub-fields.
"""
parent_type1, node1, def1 = field1
parent_type2, node2, def2 = field2
# If it is known that two fields could not possibly apply at the same time, due to
# the parent types, then it is safe to permit them to diverge in aliased field or
# arguments used as they will not present any ambiguity by differing. It is known
# that two parent types could never overlap if they are different Object types.
# Interface or Union types might overlap - if not in the current state of the
# schema, then perhaps in some future version, thus may not safely diverge.
are_mutually_exclusive = parent_fields_are_mutually_exclusive or (
parent_type1 != parent_type2
and is_object_type(parent_type1)
and is_object_type(parent_type2)
)
# The return type for each field.
type1 = cast(Optional[GraphQLOutputType], def1 and def1.type)
type2 = cast(Optional[GraphQLOutputType], def2 and def2.type)
if not are_mutually_exclusive:
# Two aliases must refer to the same field.
name1 = node1.name.value
name2 = node2.name.value
if name1 != name2:
return (
(response_name, f"'{name1}' and '{name2}' are different fields"),
[node1],
[node2],
)
# Two field calls must have the same arguments.
if not same_arguments(node1, node2):
return (response_name, "they have differing arguments"), [node1], [node2]
if type1 and type2 and do_types_conflict(type1, type2):
return (
(response_name, f"they return conflicting types '{type1}' and '{type2}'"),
[node1],
[node2],
)
# Collect and compare sub-fields. Use the same "visited fragment names" list for
# both collections so fields in a fragment reference are never compared to
# themselves.
selection_set1 = node1.selection_set
selection_set2 = node2.selection_set
if selection_set1 and selection_set2:
conflicts = find_conflicts_between_sub_selection_sets(
context,
cached_fields_and_fragment_names,
compared_fragment_pairs,
are_mutually_exclusive,
get_named_type(type1),
selection_set1,
get_named_type(type2),
selection_set2,
)
return subfield_conflicts(conflicts, response_name, node1, node2)
return None # no conflict
def same_arguments(
node1: Union[FieldNode, DirectiveNode], node2: Union[FieldNode, DirectiveNode]
) -> bool:
args1 = node1.arguments
args2 = node2.arguments
if not args1:
return not args2
if not args2:
return False
if len(args1) != len(args2):
return False # pragma: no cover
values2 = {arg.name.value: arg.value for arg in args2}
for arg1 in args1:
value1 = arg1.value
value2 = values2.get(arg1.name.value)
if value2 is None or stringify_value(value1) != stringify_value(value2):
return False
return True
def stringify_value(value: ValueNode) -> str:
return print_ast(sort_value_node(value))
def do_types_conflict(type1: GraphQLOutputType, type2: GraphQLOutputType) -> bool:
"""Check whether two types conflict
Two types conflict if both types could not apply to a value simultaneously.
Composite types are ignored as their individual field types will be compared later
recursively. However List and Non-Null types must match.
"""
if is_list_type(type1):
return (
do_types_conflict(
cast(GraphQLList, type1).of_type, cast(GraphQLList, type2).of_type
)
if is_list_type(type2)
else True
)
if is_list_type(type2):
return True
if is_non_null_type(type1):
return (
do_types_conflict(
cast(GraphQLNonNull, type1).of_type, cast(GraphQLNonNull, type2).of_type
)
if is_non_null_type(type2)
else True
)
if is_non_null_type(type2):
return True
if is_leaf_type(type1) or is_leaf_type(type2):
return type1 is not type2
return False
def get_fields_and_fragment_names(
context: ValidationContext,
cached_fields_and_fragment_names: Dict,
parent_type: Optional[GraphQLNamedType],
selection_set: SelectionSetNode,
) -> Tuple[NodeAndDefCollection, List[str]]:
"""Get fields and referenced fragment names
Given a selection set, return the collection of fields (a mapping of response name
to field nodes and definitions) as well as a list of fragment names referenced via
fragment spreads.
"""
cached = cached_fields_and_fragment_names.get(selection_set)
if not cached:
node_and_defs: NodeAndDefCollection = {}
fragment_names: Dict[str, bool] = {}
collect_fields_and_fragment_names(
context, parent_type, selection_set, node_and_defs, fragment_names
)
cached = (node_and_defs, list(fragment_names))
cached_fields_and_fragment_names[selection_set] = cached
return cached
def get_referenced_fields_and_fragment_names(
context: ValidationContext,
cached_fields_and_fragment_names: Dict,
fragment: FragmentDefinitionNode,
) -> Tuple[NodeAndDefCollection, List[str]]:
"""Get referenced fields and nested fragment names
Given a reference to a fragment, return the represented collection of fields as well
as a list of nested fragment names referenced via fragment spreads.
"""
# Short-circuit building a type from the node if possible.
cached = cached_fields_and_fragment_names.get(fragment.selection_set)
if cached:
return cached
fragment_type = type_from_ast(context.schema, fragment.type_condition)
return get_fields_and_fragment_names(
context, cached_fields_and_fragment_names, fragment_type, fragment.selection_set
)
def collect_fields_and_fragment_names(
context: ValidationContext,
parent_type: Optional[GraphQLNamedType],
selection_set: SelectionSetNode,
node_and_defs: NodeAndDefCollection,
fragment_names: Dict[str, bool],
) -> None:
for selection in selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
field_def = (
parent_type.fields.get(field_name) # type: ignore
if is_object_type(parent_type) or is_interface_type(parent_type)
else None
)
response_name = selection.alias.value if selection.alias else field_name
if not node_and_defs.get(response_name):
node_and_defs[response_name] = []
node_and_defs[response_name].append(
cast(NodeAndDef, (parent_type, selection, field_def))
)
elif isinstance(selection, FragmentSpreadNode):
fragment_names[selection.name.value] = True
elif isinstance(selection, InlineFragmentNode): # pragma: no cover else
type_condition = selection.type_condition
inline_fragment_type = (
type_from_ast(context.schema, type_condition)
if type_condition
else parent_type
)
collect_fields_and_fragment_names(
context,
inline_fragment_type,
selection.selection_set,
node_and_defs,
fragment_names,
)
def subfield_conflicts(
conflicts: List[Conflict], response_name: str, node1: FieldNode, node2: FieldNode
) -> Optional[Conflict]:
"""Check whether there are conflicts between sub-fields.
Given a series of Conflicts which occurred between two sub-fields, generate a single
Conflict.
"""
if conflicts:
return (
(response_name, [conflict[0] for conflict in conflicts]),
list(chain([node1], *[conflict[1] for conflict in conflicts])),
list(chain([node2], *[conflict[2] for conflict in conflicts])),
)
return None # no conflict
class PairSet:
"""Pair set
A way to keep track of pairs of things when the ordering of the pair doesn't matter.
"""
__slots__ = ("_data",)
_data: Dict[str, Dict[str, bool]]
def __init__(self) -> None:
self._data = {}
def has(self, a: str, b: str, are_mutually_exclusive: bool) -> bool:
key1, key2 = (a, b) if a < b else (b, a)
map_ = self._data.get(key1)
if map_ is None:
return False
result = map_.get(key2)
if result is None:
return False
# are_mutually_exclusive being False is a superset of being True,
# hence if we want to know if this PairSet "has" these two with no exclusivity,
# we have to ensure it was added as such.
return True if are_mutually_exclusive else are_mutually_exclusive == result
def add(self, a: str, b: str, are_mutually_exclusive: bool) -> None:
key1, key2 = (a, b) if a < b else (b, a)
map_ = self._data.get(key1)
if map_ is None:
self._data[key1] = {key2: are_mutually_exclusive}
else:
map_[key2] = are_mutually_exclusive
@@ -0,0 +1,66 @@
from typing import cast, Any, Optional
from ...error import GraphQLError
from ...language import FragmentSpreadNode, InlineFragmentNode
from ...type import GraphQLCompositeType, is_composite_type
from ...utilities import do_types_overlap, type_from_ast
from . import ValidationRule
__all__ = ["PossibleFragmentSpreadsRule"]
class PossibleFragmentSpreadsRule(ValidationRule):
"""Possible fragment spread
A fragment spread is only valid if the type condition could ever possibly be true:
if there is a non-empty intersection of the possible parent types, and possible
types which pass the type condition.
"""
def enter_inline_fragment(self, node: InlineFragmentNode, *_args: Any) -> None:
context = self.context
frag_type = context.get_type()
parent_type = context.get_parent_type()
if (
is_composite_type(frag_type)
and is_composite_type(parent_type)
and not do_types_overlap(
context.schema,
cast(GraphQLCompositeType, frag_type),
cast(GraphQLCompositeType, parent_type),
)
):
context.report_error(
GraphQLError(
f"Fragment cannot be spread here as objects"
f" of type '{parent_type}' can never be of type '{frag_type}'.",
node,
)
)
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
context = self.context
frag_name = node.name.value
frag_type = self.get_fragment_type(frag_name)
parent_type = context.get_parent_type()
if (
frag_type
and parent_type
and not do_types_overlap(context.schema, frag_type, parent_type)
):
context.report_error(
GraphQLError(
f"Fragment '{frag_name}' cannot be spread here as objects"
f" of type '{parent_type}' can never be of type '{frag_type}'.",
node,
)
)
def get_fragment_type(self, name: str) -> Optional[GraphQLCompositeType]:
context = self.context
frag = context.get_fragment(name)
if frag:
type_ = type_from_ast(context.schema, frag.type_condition)
if is_composite_type(type_):
return cast(GraphQLCompositeType, type_)
return None
@@ -0,0 +1,109 @@
import re
from functools import partial
from typing import Any, Optional
from ...error import GraphQLError
from ...language import TypeDefinitionNode, TypeExtensionNode
from ...pyutils import did_you_mean, inspect, suggestion_list
from ...type import (
is_enum_type,
is_input_object_type,
is_interface_type,
is_object_type,
is_scalar_type,
is_union_type,
)
from . import SDLValidationContext, SDLValidationRule
__all__ = ["PossibleTypeExtensionsRule"]
class PossibleTypeExtensionsRule(SDLValidationRule):
"""Possible type extension
A type extension is only valid if the type is defined and has the same kind.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
self.schema = context.schema
self.defined_types = {
def_.name.value: def_
for def_ in context.document.definitions
if isinstance(def_, TypeDefinitionNode)
}
def check_extension(self, node: TypeExtensionNode, *_args: Any) -> None:
schema = self.schema
type_name = node.name.value
def_node = self.defined_types.get(type_name)
existing_type = schema.get_type(type_name) if schema else None
expected_kind: Optional[str]
if def_node:
expected_kind = def_kind_to_ext_kind(def_node.kind)
elif existing_type:
expected_kind = type_to_ext_kind(existing_type)
else:
expected_kind = None
if expected_kind:
if expected_kind != node.kind:
kind_str = extension_kind_to_type_name(node.kind)
self.report_error(
GraphQLError(
f"Cannot extend non-{kind_str} type '{type_name}'.",
[def_node, node] if def_node else node,
)
)
else:
all_type_names = list(self.defined_types)
if self.schema:
all_type_names.extend(self.schema.type_map)
suggested_types = suggestion_list(type_name, all_type_names)
self.report_error(
GraphQLError(
f"Cannot extend type '{type_name}' because it is not defined."
+ did_you_mean(suggested_types),
node.name,
)
)
enter_scalar_type_extension = enter_object_type_extension = check_extension
enter_interface_type_extension = enter_union_type_extension = check_extension
enter_enum_type_extension = enter_input_object_type_extension = check_extension
def_kind_to_ext_kind = partial(re.compile("(?<=_type_)definition$").sub, "extension")
def type_to_ext_kind(type_: Any) -> str:
if is_scalar_type(type_):
return "scalar_type_extension"
if is_object_type(type_):
return "object_type_extension"
if is_interface_type(type_):
return "interface_type_extension"
if is_union_type(type_):
return "union_type_extension"
if is_enum_type(type_):
return "enum_type_extension"
if is_input_object_type(type_):
return "input_object_type_extension"
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.")
_type_names_for_extension_kinds = {
"scalar_type_extension": "scalar",
"object_type_extension": "object",
"interface_type_extension": "interface",
"union_type_extension": "union",
"enum_type_extension": "enum",
"input_object_type_extension": "input object",
}
def extension_kind_to_type_name(kind: str) -> str:
return _type_names_for_extension_kinds.get(kind, "unknown type")
@@ -0,0 +1,119 @@
from typing import cast, Any, Dict, List, Union
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DirectiveNode,
FieldNode,
InputValueDefinitionNode,
NonNullTypeNode,
TypeNode,
VisitorAction,
SKIP,
print_ast,
)
from ...type import GraphQLArgument, is_required_argument, is_type, specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["ProvidedRequiredArgumentsRule", "ProvidedRequiredArgumentsOnDirectivesRule"]
class ProvidedRequiredArgumentsOnDirectivesRule(ASTValidationRule):
"""Provided required arguments on directives
A directive is only valid if all required (non-null without a default value)
arguments have been provided.
For internal use only.
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
required_args_map: Dict[
str, Dict[str, Union[GraphQLArgument, InputValueDefinitionNode]]
] = {}
schema = context.schema
defined_directives = schema.directives if schema else specified_directives
for directive in cast(List, defined_directives):
required_args_map[directive.name] = {
name: arg
for name, arg in directive.args.items()
if is_required_argument(arg)
}
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
required_args_map[def_.name.value] = {
arg.name.value: arg
for arg in filter(is_required_argument_node, def_.arguments or ())
}
self.required_args_map = required_args_map
def leave_directive(self, directive_node: DirectiveNode, *_args: Any) -> None:
# Validate on leave to allow for deeper errors to appear first.
directive_name = directive_node.name.value
required_args = self.required_args_map.get(directive_name)
if required_args:
arg_nodes = directive_node.arguments or ()
arg_node_set = {arg.name.value for arg in arg_nodes}
for arg_name in required_args:
if arg_name not in arg_node_set:
arg_type = required_args[arg_name].type
arg_type_str = (
str(arg_type)
if is_type(arg_type)
else print_ast(cast(TypeNode, arg_type))
)
self.report_error(
GraphQLError(
f"Directive '@{directive_name}' argument '{arg_name}'"
f" of type '{arg_type_str}' is required,"
" but it was not provided.",
directive_node,
)
)
class ProvidedRequiredArgumentsRule(ProvidedRequiredArgumentsOnDirectivesRule):
"""Provided required arguments
A field or directive is only valid if all required (non-null without a default
value) field arguments have been provided.
"""
context: ValidationContext
def __init__(self, context: ValidationContext):
super().__init__(context)
def leave_field(self, field_node: FieldNode, *_args: Any) -> VisitorAction:
# Validate on leave to allow for deeper errors to appear first.
field_def = self.context.get_field_def()
if not field_def:
return SKIP
arg_nodes = field_node.arguments or ()
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
for arg_name, arg_def in field_def.args.items():
arg_node = arg_node_map.get(arg_name)
if not arg_node and is_required_argument(arg_def):
self.report_error(
GraphQLError(
f"Field '{field_node.name.value}' argument '{arg_name}'"
f" of type '{arg_def.type}' is required,"
" but it was not provided.",
field_node,
)
)
return None
def is_required_argument_node(arg: InputValueDefinitionNode) -> bool:
return isinstance(arg.type, NonNullTypeNode) and arg.default_value is None
@@ -0,0 +1,41 @@
from typing import Any
from ...error import GraphQLError
from ...language import FieldNode
from ...type import get_named_type, is_leaf_type
from . import ValidationRule
__all__ = ["ScalarLeafsRule"]
class ScalarLeafsRule(ValidationRule):
"""Scalar leafs
A GraphQL document is valid only if all leaf fields (fields without sub selections)
are of scalar or enum types.
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
type_ = self.context.get_type()
if type_:
selection_set = node.selection_set
if is_leaf_type(get_named_type(type_)):
if selection_set:
field_name = node.name.value
self.report_error(
GraphQLError(
f"Field '{field_name}' must not have a selection"
f" since type '{type_}' has no subfields.",
selection_set,
)
)
elif not selection_set:
field_name = node.name.value
self.report_error(
GraphQLError(
f"Field '{field_name}' of type '{type_}'"
" must have a selection of subfields."
f" Did you mean '{field_name} {{ ... }}'?",
node,
)
)
@@ -0,0 +1,85 @@
from typing import Any, Dict, cast
from ...error import GraphQLError
from ...execution.collect_fields import collect_fields
from ...language import (
FieldNode,
FragmentDefinitionNode,
OperationDefinitionNode,
OperationType,
)
from . import ValidationRule
__all__ = ["SingleFieldSubscriptionsRule"]
class SingleFieldSubscriptionsRule(ValidationRule):
"""Subscriptions must only include a single non-introspection field.
A GraphQL subscription is valid only if it contains a single root field and
that root field is not an introspection field.
See https://spec.graphql.org/draft/#sec-Single-root-field
"""
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
if node.operation != OperationType.SUBSCRIPTION:
return
schema = self.context.schema
subscription_type = schema.subscription_type
if subscription_type:
operation_name = node.name.value if node.name else None
variable_values: Dict[str, Any] = {}
document = self.context.document
fragments: Dict[str, FragmentDefinitionNode] = {
definition.name.value: definition
for definition in document.definitions
if isinstance(definition, FragmentDefinitionNode)
}
fields = collect_fields(
schema,
fragments,
variable_values,
subscription_type,
node.selection_set,
)
if len(fields) > 1:
field_selection_lists = list(fields.values())
extra_field_selection_lists = field_selection_lists[1:]
extra_field_selection = [
field
for fields in extra_field_selection_lists
for field in (
fields
if isinstance(fields, list)
else [cast(FieldNode, fields)]
)
]
self.report_error(
GraphQLError(
(
"Anonymous Subscription"
if operation_name is None
else f"Subscription '{operation_name}'"
)
+ " must select only one top level field.",
extra_field_selection,
)
)
for field_nodes in fields.values():
field = field_nodes[0]
field_name = field.name.value
if field_name.startswith("__"):
self.report_error(
GraphQLError(
(
"Anonymous Subscription"
if operation_name is None
else f"Subscription '{operation_name}'"
)
+ " must not select an introspection top level field.",
field_nodes,
)
)
@@ -0,0 +1,83 @@
from operator import attrgetter
from typing import Any, Collection
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
NameNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
VisitorAction,
SKIP,
)
from ...pyutils import group_by
from . import SDLValidationRule
__all__ = ["UniqueArgumentDefinitionNamesRule"]
class UniqueArgumentDefinitionNamesRule(SDLValidationRule):
"""Unique argument definition names
A GraphQL Object or Interface type is only valid if all its fields have uniquely
named arguments.
A GraphQL Directive is only valid if all its arguments are uniquely named.
See https://spec.graphql.org/draft/#sec-Argument-Uniqueness
"""
def enter_directive_definition(
self, node: DirectiveDefinitionNode, *_args: Any
) -> VisitorAction:
return self.check_arg_uniqueness(f"@{node.name.value}", node.arguments)
def enter_interface_type_definition(
self, node: InterfaceTypeDefinitionNode, *_args: Any
) -> VisitorAction:
return self.check_arg_uniqueness_per_field(node.name, node.fields)
def enter_interface_type_extension(
self, node: InterfaceTypeExtensionNode, *_args: Any
) -> VisitorAction:
return self.check_arg_uniqueness_per_field(node.name, node.fields)
def enter_object_type_definition(
self, node: ObjectTypeDefinitionNode, *_args: Any
) -> VisitorAction:
return self.check_arg_uniqueness_per_field(node.name, node.fields)
def enter_object_type_extension(
self, node: ObjectTypeExtensionNode, *_args: Any
) -> VisitorAction:
return self.check_arg_uniqueness_per_field(node.name, node.fields)
def check_arg_uniqueness_per_field(
self,
name: NameNode,
fields: Collection[FieldDefinitionNode],
) -> VisitorAction:
type_name = name.value
for field_def in fields:
field_name = field_def.name.value
argument_nodes = field_def.arguments or ()
self.check_arg_uniqueness(f"{type_name}.{field_name}", argument_nodes)
return SKIP
def check_arg_uniqueness(
self, parent_name: str, argument_nodes: Collection[InputValueDefinitionNode]
) -> VisitorAction:
seen_args = group_by(argument_nodes, attrgetter("name.value"))
for arg_name, arg_nodes in seen_args.items():
if len(arg_nodes) > 1:
self.report_error(
GraphQLError(
f"Argument '{parent_name}({arg_name}:)'"
" can only be defined once.",
[node.name for node in arg_nodes],
)
)
return SKIP
@@ -0,0 +1,37 @@
from operator import attrgetter
from typing import Any, Collection
from ...error import GraphQLError
from ...language import ArgumentNode, DirectiveNode, FieldNode
from ...pyutils import group_by
from . import ASTValidationRule
__all__ = ["UniqueArgumentNamesRule"]
class UniqueArgumentNamesRule(ASTValidationRule):
"""Unique argument names
A GraphQL field or directive is only valid if all supplied arguments are uniquely
named.
See https://spec.graphql.org/draft/#sec-Argument-Names
"""
def enter_field(self, node: FieldNode, *_args: Any) -> None:
self.check_arg_uniqueness(node.arguments)
def enter_directive(self, node: DirectiveNode, *args: Any) -> None:
self.check_arg_uniqueness(node.arguments)
def check_arg_uniqueness(self, argument_nodes: Collection[ArgumentNode]) -> None:
seen_args = group_by(argument_nodes, attrgetter("name.value"))
for arg_name, arg_nodes in seen_args.items():
if len(arg_nodes) > 1:
self.report_error(
GraphQLError(
f"There can be only one argument named '{arg_name}'.",
[node.name for node in arg_nodes],
)
)
@@ -0,0 +1,46 @@
from typing import Any, Dict
from ...error import GraphQLError
from ...language import DirectiveDefinitionNode, NameNode, VisitorAction, SKIP
from . import SDLValidationContext, SDLValidationRule
__all__ = ["UniqueDirectiveNamesRule"]
class UniqueDirectiveNamesRule(SDLValidationRule):
"""Unique directive names
A GraphQL document is only valid if all defined directives have unique names.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
self.known_directive_names: Dict[str, NameNode] = {}
self.schema = context.schema
def enter_directive_definition(
self, node: DirectiveDefinitionNode, *_args: Any
) -> VisitorAction:
directive_name = node.name.value
if self.schema and self.schema.get_directive(directive_name):
self.report_error(
GraphQLError(
f"Directive '@{directive_name}' already exists in the schema."
" It cannot be redefined.",
node.name,
)
)
else:
if directive_name in self.known_directive_names:
self.report_error(
GraphQLError(
f"There can be only one directive named '@{directive_name}'.",
[self.known_directive_names[directive_name], node.name],
)
)
else:
self.known_directive_names[directive_name] = node.name
return SKIP
return None
@@ -0,0 +1,85 @@
from collections import defaultdict
from typing import Any, Dict, List, Union, cast
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DirectiveNode,
Node,
SchemaDefinitionNode,
SchemaExtensionNode,
TypeDefinitionNode,
TypeExtensionNode,
is_type_definition_node,
is_type_extension_node,
)
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["UniqueDirectivesPerLocationRule"]
class UniqueDirectivesPerLocationRule(ASTValidationRule):
"""Unique directive names per location
A GraphQL document is only valid if all non-repeatable directives at a given
location are uniquely named.
See https://spec.graphql.org/draft/#sec-Directives-Are-Unique-Per-Location
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
unique_directive_map: Dict[str, bool] = {}
schema = context.schema
defined_directives = (
schema.directives if schema else cast(List, specified_directives)
)
for directive in defined_directives:
unique_directive_map[directive.name] = not directive.is_repeatable
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
unique_directive_map[def_.name.value] = not def_.repeatable
self.unique_directive_map = unique_directive_map
self.schema_directives: Dict[str, DirectiveNode] = {}
self.type_directives_map: Dict[str, Dict[str, DirectiveNode]] = defaultdict(
dict
)
# Many different AST nodes may contain directives. Rather than listing them all,
# just listen for entering any node, and check to see if it defines any directives.
def enter(self, node: Node, *_args: Any) -> None:
directives = getattr(node, "directives", None)
if not directives:
return
directives = cast(List[DirectiveNode], directives)
if isinstance(node, (SchemaDefinitionNode, SchemaExtensionNode)):
seen_directives = self.schema_directives
elif is_type_definition_node(node) or is_type_extension_node(node):
node = cast(Union[TypeDefinitionNode, TypeExtensionNode], node)
type_name = node.name.value
seen_directives = self.type_directives_map[type_name]
else:
seen_directives = {}
for directive in directives:
directive_name = directive.name.value
if self.unique_directive_map.get(directive_name):
if directive_name in seen_directives:
self.report_error(
GraphQLError(
f"The directive '@{directive_name}'"
" can only be used once at this location.",
[seen_directives[directive_name], directive],
)
)
else:
seen_directives[directive_name] = directive
@@ -0,0 +1,61 @@
from collections import defaultdict
from typing import cast, Any, Dict
from ...error import GraphQLError
from ...language import NameNode, EnumTypeDefinitionNode, VisitorAction, SKIP
from ...type import is_enum_type, GraphQLEnumType
from . import SDLValidationContext, SDLValidationRule
__all__ = ["UniqueEnumValueNamesRule"]
class UniqueEnumValueNamesRule(SDLValidationRule):
"""Unique enum value names
A GraphQL enum type is only valid if all its values are uniquely named.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
schema = context.schema
self.existing_type_map = schema.type_map if schema else {}
self.known_value_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict)
def check_value_uniqueness(
self, node: EnumTypeDefinitionNode, *_args: Any
) -> VisitorAction:
existing_type_map = self.existing_type_map
type_name = node.name.value
value_names = self.known_value_names[type_name]
for value_def in node.values or []:
value_name = value_def.name.value
existing_type = existing_type_map.get(type_name)
if (
is_enum_type(existing_type)
and value_name in cast(GraphQLEnumType, existing_type).values
):
self.report_error(
GraphQLError(
f"Enum value '{type_name}.{value_name}'"
" already exists in the schema."
" It cannot also be defined in this type extension.",
value_def.name,
)
)
elif value_name in value_names:
self.report_error(
GraphQLError(
f"Enum value '{type_name}.{value_name}'"
" can only be defined once.",
[value_names[value_name], value_def.name],
)
)
else:
value_names[value_name] = value_def.name
return SKIP
enter_enum_type_definition = check_value_uniqueness
enter_enum_type_extension = check_value_uniqueness
@@ -0,0 +1,67 @@
from collections import defaultdict
from typing import Any, Dict
from ...error import GraphQLError
from ...language import NameNode, ObjectTypeDefinitionNode, VisitorAction, SKIP
from ...type import is_object_type, is_interface_type, is_input_object_type
from . import SDLValidationContext, SDLValidationRule
__all__ = ["UniqueFieldDefinitionNamesRule"]
class UniqueFieldDefinitionNamesRule(SDLValidationRule):
"""Unique field definition names
A GraphQL complex type is only valid if all its fields are uniquely named.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
schema = context.schema
self.existing_type_map = schema.type_map if schema else {}
self.known_field_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict)
def check_field_uniqueness(
self, node: ObjectTypeDefinitionNode, *_args: Any
) -> VisitorAction:
existing_type_map = self.existing_type_map
type_name = node.name.value
field_names = self.known_field_names[type_name]
for field_def in node.fields or []:
field_name = field_def.name.value
if has_field(existing_type_map.get(type_name), field_name):
self.report_error(
GraphQLError(
f"Field '{type_name}.{field_name}'"
" already exists in the schema."
" It cannot also be defined in this type extension.",
field_def.name,
)
)
elif field_name in field_names:
self.report_error(
GraphQLError(
f"Field '{type_name}.{field_name}'"
" can only be defined once.",
[field_names[field_name], field_def.name],
)
)
else:
field_names[field_name] = field_def.name
return SKIP
enter_input_object_type_definition = check_field_uniqueness
enter_input_object_type_extension = check_field_uniqueness
enter_interface_type_definition = check_field_uniqueness
enter_interface_type_extension = check_field_uniqueness
enter_object_type_definition = check_field_uniqueness
enter_object_type_extension = check_field_uniqueness
def has_field(type_: Any, field_name: str) -> bool:
if is_object_type(type_) or is_interface_type(type_) or is_input_object_type(type_):
return field_name in type_.fields
return False
@@ -0,0 +1,40 @@
from typing import Any, Dict
from ...error import GraphQLError
from ...language import NameNode, FragmentDefinitionNode, VisitorAction, SKIP
from . import ASTValidationContext, ASTValidationRule
__all__ = ["UniqueFragmentNamesRule"]
class UniqueFragmentNamesRule(ASTValidationRule):
"""Unique fragment names
A GraphQL document is only valid if all defined fragments have unique names.
See https://spec.graphql.org/draft/#sec-Fragment-Name-Uniqueness
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.known_fragment_names: Dict[str, NameNode] = {}
@staticmethod
def enter_operation_definition(*_args: Any) -> VisitorAction:
return SKIP
def enter_fragment_definition(
self, node: FragmentDefinitionNode, *_args: Any
) -> VisitorAction:
known_fragment_names = self.known_fragment_names
fragment_name = node.name.value
if fragment_name in known_fragment_names:
self.report_error(
GraphQLError(
f"There can be only one fragment named '{fragment_name}'.",
[known_fragment_names[fragment_name], node.name],
)
)
else:
known_fragment_names[fragment_name] = node.name
return SKIP
@@ -0,0 +1,42 @@
from typing import Any, Dict, List
from ...error import GraphQLError
from ...language import NameNode, ObjectFieldNode
from . import ASTValidationContext, ASTValidationRule
__all__ = ["UniqueInputFieldNamesRule"]
class UniqueInputFieldNamesRule(ASTValidationRule):
"""Unique input field names
A GraphQL input object value is only valid if all supplied fields are uniquely
named.
See https://spec.graphql.org/draft/#sec-Input-Object-Field-Uniqueness
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.known_names_stack: List[Dict[str, NameNode]] = []
self.known_names: Dict[str, NameNode] = {}
def enter_object_value(self, *_args: Any) -> None:
self.known_names_stack.append(self.known_names)
self.known_names = {}
def leave_object_value(self, *_args: Any) -> None:
self.known_names = self.known_names_stack.pop()
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
known_names = self.known_names
field_name = node.name.value
if field_name in known_names:
self.report_error(
GraphQLError(
f"There can be only one input field named '{field_name}'.",
[known_names[field_name], node.name],
)
)
else:
known_names[field_name] = node.name
@@ -0,0 +1,42 @@
from typing import Any, Dict
from ...error import GraphQLError
from ...language import NameNode, OperationDefinitionNode, VisitorAction, SKIP
from . import ASTValidationContext, ASTValidationRule
__all__ = ["UniqueOperationNamesRule"]
class UniqueOperationNamesRule(ASTValidationRule):
"""Unique operation names
A GraphQL document is only valid if all defined operations have unique names.
See https://spec.graphql.org/draft/#sec-Operation-Name-Uniqueness
"""
def __init__(self, context: ASTValidationContext):
super().__init__(context)
self.known_operation_names: Dict[str, NameNode] = {}
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> VisitorAction:
operation_name = node.name
if operation_name:
known_operation_names = self.known_operation_names
if operation_name.value in known_operation_names:
self.report_error(
GraphQLError(
"There can be only one operation"
f" named '{operation_name.value}'.",
[known_operation_names[operation_name.value], operation_name],
)
)
else:
known_operation_names[operation_name.value] = operation_name
return SKIP
@staticmethod
def enter_fragment_definition(*_args: Any) -> VisitorAction:
return SKIP
@@ -0,0 +1,69 @@
from typing import Any, Dict, Optional, Union
from ...error import GraphQLError
from ...language import (
OperationTypeDefinitionNode,
OperationType,
SchemaDefinitionNode,
SchemaExtensionNode,
VisitorAction,
SKIP,
)
from ...type import GraphQLObjectType
from . import SDLValidationContext, SDLValidationRule
__all__ = ["UniqueOperationTypesRule"]
class UniqueOperationTypesRule(SDLValidationRule):
"""Unique operation types
A GraphQL document is only valid if it has only one type per operation.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
schema = context.schema
self.defined_operation_types: Dict[
OperationType, OperationTypeDefinitionNode
] = {}
self.existing_operation_types: Dict[
OperationType, Optional[GraphQLObjectType]
] = (
{
OperationType.QUERY: schema.query_type,
OperationType.MUTATION: schema.mutation_type,
OperationType.SUBSCRIPTION: schema.subscription_type,
}
if schema
else {}
)
self.schema = schema
def check_operation_types(
self, node: Union[SchemaDefinitionNode, SchemaExtensionNode], *_args: Any
) -> VisitorAction:
for operation_type in node.operation_types or []:
operation = operation_type.operation
already_defined_operation_type = self.defined_operation_types.get(operation)
if self.existing_operation_types.get(operation):
self.report_error(
GraphQLError(
f"Type for {operation.value} already defined in the schema."
" It cannot be redefined.",
operation_type,
)
)
elif already_defined_operation_type:
self.report_error(
GraphQLError(
f"There can be only one {operation.value} type in schema.",
[already_defined_operation_type, operation_type],
)
)
else:
self.defined_operation_types[operation] = operation_type
return SKIP
enter_schema_definition = enter_schema_extension = check_operation_types
@@ -0,0 +1,48 @@
from typing import Any, Dict
from ...error import GraphQLError
from ...language import NameNode, TypeDefinitionNode, VisitorAction, SKIP
from . import SDLValidationContext, SDLValidationRule
__all__ = ["UniqueTypeNamesRule"]
class UniqueTypeNamesRule(SDLValidationRule):
"""Unique type names
A GraphQL document is only valid if all defined types have unique names.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
self.known_type_names: Dict[str, NameNode] = {}
self.schema = context.schema
def check_type_name(self, node: TypeDefinitionNode, *_args: Any) -> VisitorAction:
type_name = node.name.value
if self.schema and self.schema.get_type(type_name):
self.report_error(
GraphQLError(
f"Type '{type_name}' already exists in the schema."
" It cannot also be defined in this type definition.",
node.name,
)
)
else:
if type_name in self.known_type_names:
self.report_error(
GraphQLError(
f"There can be only one type named '{type_name}'.",
[self.known_type_names[type_name], node.name],
)
)
else:
self.known_type_names[type_name] = node.name
return SKIP
return None
enter_scalar_type_definition = enter_object_type_definition = check_type_name
enter_interface_type_definition = enter_union_type_definition = check_type_name
enter_enum_type_definition = enter_input_object_type_definition = check_type_name
@@ -0,0 +1,34 @@
from operator import attrgetter
from typing import Any
from ...error import GraphQLError
from ...language import OperationDefinitionNode
from ...pyutils import group_by
from . import ASTValidationRule
__all__ = ["UniqueVariableNamesRule"]
class UniqueVariableNamesRule(ASTValidationRule):
"""Unique variable names
A GraphQL operation is only valid if all its variables are uniquely named.
"""
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
variable_definitions = node.variable_definitions
seen_variable_definitions = group_by(
variable_definitions, attrgetter("variable.name.value")
)
for variable_name, variable_nodes in seen_variable_definitions.items():
if len(variable_nodes) > 1:
self.report_error(
GraphQLError(
f"There can be only one variable named '${variable_name}'.",
[node.variable.name for node in variable_nodes],
)
)
@@ -0,0 +1,163 @@
from typing import cast, Any
from ...error import GraphQLError
from ...language import (
BooleanValueNode,
EnumValueNode,
FloatValueNode,
IntValueNode,
NullValueNode,
ListValueNode,
ObjectFieldNode,
ObjectValueNode,
StringValueNode,
ValueNode,
VisitorAction,
SKIP,
print_ast,
)
from ...pyutils import did_you_mean, suggestion_list, Undefined
from ...type import (
GraphQLInputObjectType,
GraphQLScalarType,
get_named_type,
get_nullable_type,
is_input_object_type,
is_leaf_type,
is_list_type,
is_non_null_type,
is_required_input_field,
)
from . import ValidationRule
__all__ = ["ValuesOfCorrectTypeRule"]
class ValuesOfCorrectTypeRule(ValidationRule):
"""Value literals of correct type
A GraphQL document is only valid if all value literals are of the type expected at
their position.
See https://spec.graphql.org/draft/#sec-Values-of-Correct-Type
"""
def enter_list_value(self, node: ListValueNode, *_args: Any) -> VisitorAction:
# Note: TypeInfo will traverse into a list's item type, so look to the parent
# input type to check if it is a list.
type_ = get_nullable_type(self.context.get_parent_input_type()) # type: ignore
if not is_list_type(type_):
self.is_valid_value_node(node)
return SKIP # Don't traverse further.
return None
def enter_object_value(self, node: ObjectValueNode, *_args: Any) -> VisitorAction:
type_ = get_named_type(self.context.get_input_type())
if not is_input_object_type(type_):
self.is_valid_value_node(node)
return SKIP # Don't traverse further.
type_ = cast(GraphQLInputObjectType, type_)
# Ensure every required field exists.
field_node_map = {field.name.value: field for field in node.fields}
for field_name, field_def in type_.fields.items():
field_node = field_node_map.get(field_name)
if not field_node and is_required_input_field(field_def):
field_type = field_def.type
self.report_error(
GraphQLError(
f"Field '{type_.name}.{field_name}' of required type"
f" '{field_type}' was not provided.",
node,
)
)
return None
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
parent_type = get_named_type(self.context.get_parent_input_type())
field_type = self.context.get_input_type()
if not field_type and is_input_object_type(parent_type):
parent_type = cast(GraphQLInputObjectType, parent_type)
suggestions = suggestion_list(node.name.value, list(parent_type.fields))
self.report_error(
GraphQLError(
f"Field '{node.name.value}'"
f" is not defined by type '{parent_type.name}'."
+ did_you_mean(suggestions),
node,
)
)
def enter_null_value(self, node: NullValueNode, *_args: Any) -> None:
type_ = self.context.get_input_type()
if is_non_null_type(type_):
self.report_error(
GraphQLError(
f"Expected value of type '{type_}', found {print_ast(node)}.", node
)
)
def enter_enum_value(self, node: EnumValueNode, *_args: Any) -> None:
self.is_valid_value_node(node)
def enter_int_value(self, node: IntValueNode, *_args: Any) -> None:
self.is_valid_value_node(node)
def enter_float_value(self, node: FloatValueNode, *_args: Any) -> None:
self.is_valid_value_node(node)
def enter_string_value(self, node: StringValueNode, *_args: Any) -> None:
self.is_valid_value_node(node)
def enter_boolean_value(self, node: BooleanValueNode, *_args: Any) -> None:
self.is_valid_value_node(node)
def is_valid_value_node(self, node: ValueNode) -> None:
"""Check whether this is a valid value node.
Any value literal may be a valid representation of a Scalar, depending on that
scalar type.
"""
# Report any error at the full type expected by the location.
location_type = self.context.get_input_type()
if not location_type:
return
type_ = get_named_type(location_type)
if not is_leaf_type(type_):
self.report_error(
GraphQLError(
f"Expected value of type '{location_type}',"
f" found {print_ast(node)}.",
node,
)
)
return
# Scalars determine if a literal value is valid via `parse_literal()` which may
# throw or return an invalid value to indicate failure.
type_ = cast(GraphQLScalarType, type_)
try:
parse_result = type_.parse_literal(node)
if parse_result is Undefined:
self.report_error(
GraphQLError(
f"Expected value of type '{location_type}',"
f" found {print_ast(node)}.",
node,
)
)
except GraphQLError as error:
self.report_error(error)
except Exception as error:
self.report_error(
GraphQLError(
f"Expected value of type '{location_type}',"
f" found {print_ast(node)}; {error}",
node,
# Ensure a reference to the original error is maintained.
original_error=error,
)
)
return
@@ -0,0 +1,36 @@
from typing import Any
from ...error import GraphQLError
from ...language import VariableDefinitionNode, print_ast
from ...type import is_input_type
from ...utilities import type_from_ast
from . import ValidationRule
__all__ = ["VariablesAreInputTypesRule"]
class VariablesAreInputTypesRule(ValidationRule):
"""Variables are input types
A GraphQL operation is only valid if all the variables it defines are of input types
(scalar, enum, or input object).
See https://spec.graphql.org/draft/#sec-Variables-Are-Input-Types
"""
def enter_variable_definition(
self, node: VariableDefinitionNode, *_args: Any
) -> None:
type_ = type_from_ast(self.context.schema, node.type)
# If the variable type is not an input type, return an error.
if type_ is not None and not is_input_type(type_):
variable_name = node.variable.name.value
type_name = print_ast(node.type)
self.report_error(
GraphQLError(
f"Variable '${variable_name}'"
f" cannot be non-input type '{type_name}'.",
node.type,
)
)
@@ -0,0 +1,93 @@
from typing import Any, Dict, Optional, cast
from ...error import GraphQLError
from ...language import (
NullValueNode,
OperationDefinitionNode,
ValueNode,
VariableDefinitionNode,
)
from ...pyutils import Undefined
from ...type import GraphQLNonNull, GraphQLSchema, GraphQLType, is_non_null_type
from ...utilities import type_from_ast, is_type_sub_type_of
from . import ValidationContext, ValidationRule
__all__ = ["VariablesInAllowedPositionRule"]
class VariablesInAllowedPositionRule(ValidationRule):
"""Variables in allowed position
Variable usages must be compatible with the arguments they are passed to.
See https://spec.graphql.org/draft/#sec-All-Variable-Usages-are-Allowed
"""
def __init__(self, context: ValidationContext):
super().__init__(context)
self.var_def_map: Dict[str, Any] = {}
def enter_operation_definition(self, *_args: Any) -> None:
self.var_def_map.clear()
def leave_operation_definition(
self, operation: OperationDefinitionNode, *_args: Any
) -> None:
var_def_map = self.var_def_map
usages = self.context.get_recursive_variable_usages(operation)
for usage in usages:
node, type_ = usage.node, usage.type
default_value = usage.default_value
var_name = node.name.value
var_def = var_def_map.get(var_name)
if var_def and type_:
# A var type is allowed if it is the same or more strict (e.g. is a
# subtype of) than the expected type. It can be more strict if the
# variable type is non-null when the expected type is nullable. If both
# are list types, the variable item type can be more strict than the
# expected item type (contravariant).
schema = self.context.schema
var_type = type_from_ast(schema, var_def.type)
if var_type and not allowed_variable_usage(
schema, var_type, var_def.default_value, type_, default_value
):
self.report_error(
GraphQLError(
f"Variable '${var_name}' of type '{var_type}' used"
f" in position expecting type '{type_}'.",
[var_def, node],
)
)
def enter_variable_definition(
self, node: VariableDefinitionNode, *_args: Any
) -> None:
self.var_def_map[node.variable.name.value] = node
def allowed_variable_usage(
schema: GraphQLSchema,
var_type: GraphQLType,
var_default_value: Optional[ValueNode],
location_type: GraphQLType,
location_default_value: Any,
) -> bool:
"""Check for allowed variable usage.
Returns True if the variable is allowed in the location it was found, which includes
considering if default values exist for either the variable or the location at which
it is located.
"""
if is_non_null_type(location_type) and not is_non_null_type(var_type):
has_non_null_variable_default_value = (
var_default_value is not None
and not isinstance(var_default_value, NullValueNode)
)
has_location_default_value = location_default_value is not Undefined
if not has_non_null_variable_default_value and not has_location_default_value:
return False
location_type = cast(GraphQLNonNull, location_type)
nullable_location_type = location_type.of_type
return is_type_sub_type_of(schema, var_type, nullable_location_type)
return is_type_sub_type_of(schema, var_type, location_type)
@@ -0,0 +1,157 @@
from typing import Tuple, Type
from .rules import ASTValidationRule
# Spec Section: "Executable Definitions"
from .rules.executable_definitions import ExecutableDefinitionsRule
# Spec Section: "Operation Name Uniqueness"
from .rules.unique_operation_names import UniqueOperationNamesRule
# Spec Section: "Lone Anonymous Operation"
from .rules.lone_anonymous_operation import LoneAnonymousOperationRule
# Spec Section: "Subscriptions with Single Root Field"
from .rules.single_field_subscriptions import SingleFieldSubscriptionsRule
# Spec Section: "Fragment Spread Type Existence"
from .rules.known_type_names import KnownTypeNamesRule
# Spec Section: "Fragments on Composite Types"
from .rules.fragments_on_composite_types import FragmentsOnCompositeTypesRule
# Spec Section: "Variables are Input Types"
from .rules.variables_are_input_types import VariablesAreInputTypesRule
# Spec Section: "Leaf Field Selections"
from .rules.scalar_leafs import ScalarLeafsRule
# Spec Section: "Field Selections on Objects, Interfaces, and Unions Types"
from .rules.fields_on_correct_type import FieldsOnCorrectTypeRule
# Spec Section: "Fragment Name Uniqueness"
from .rules.unique_fragment_names import UniqueFragmentNamesRule
# Spec Section: "Fragment spread target defined"
from .rules.known_fragment_names import KnownFragmentNamesRule
# Spec Section: "Fragments must be used"
from .rules.no_unused_fragments import NoUnusedFragmentsRule
# Spec Section: "Fragment spread is possible"
from .rules.possible_fragment_spreads import PossibleFragmentSpreadsRule
# Spec Section: "Fragments must not form cycles"
from .rules.no_fragment_cycles import NoFragmentCyclesRule
# Spec Section: "Variable Uniqueness"
from .rules.unique_variable_names import UniqueVariableNamesRule
# Spec Section: "All Variable Used Defined"
from .rules.no_undefined_variables import NoUndefinedVariablesRule
# Spec Section: "All Variables Used"
from .rules.no_unused_variables import NoUnusedVariablesRule
# Spec Section: "Directives Are Defined"
from .rules.known_directives import KnownDirectivesRule
# Spec Section: "Directives Are Unique Per Location"
from .rules.unique_directives_per_location import UniqueDirectivesPerLocationRule
# Spec Section: "Argument Names"
from .rules.known_argument_names import KnownArgumentNamesRule
from .rules.known_argument_names import KnownArgumentNamesOnDirectivesRule
# Spec Section: "Argument Uniqueness"
from .rules.unique_argument_names import UniqueArgumentNamesRule
# Spec Section: "Value Type Correctness"
from .rules.values_of_correct_type import ValuesOfCorrectTypeRule
# Spec Section: "Argument Optionality"
from .rules.provided_required_arguments import ProvidedRequiredArgumentsRule
from .rules.provided_required_arguments import ProvidedRequiredArgumentsOnDirectivesRule
# Spec Section: "All Variable Usages Are Allowed"
from .rules.variables_in_allowed_position import VariablesInAllowedPositionRule
# Spec Section: "Field Selection Merging"
from .rules.overlapping_fields_can_be_merged import OverlappingFieldsCanBeMergedRule
# Spec Section: "Input Object Field Uniqueness"
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
# Schema definition language:
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
from .rules.unique_operation_types import UniqueOperationTypesRule
from .rules.unique_type_names import UniqueTypeNamesRule
from .rules.unique_enum_value_names import UniqueEnumValueNamesRule
from .rules.unique_field_definition_names import UniqueFieldDefinitionNamesRule
from .rules.unique_argument_definition_names import UniqueArgumentDefinitionNamesRule
from .rules.unique_directive_names import UniqueDirectiveNamesRule
from .rules.possible_type_extensions import PossibleTypeExtensionsRule
__all__ = ["specified_rules", "specified_sdl_rules"]
# This list includes all validation rules defined by the GraphQL spec.
#
# The order of the rules in this list has been adjusted to lead to the
# most clear output when encountering multiple validation errors.
specified_rules: Tuple[Type[ASTValidationRule], ...] = (
ExecutableDefinitionsRule,
UniqueOperationNamesRule,
LoneAnonymousOperationRule,
SingleFieldSubscriptionsRule,
KnownTypeNamesRule,
FragmentsOnCompositeTypesRule,
VariablesAreInputTypesRule,
ScalarLeafsRule,
FieldsOnCorrectTypeRule,
UniqueFragmentNamesRule,
KnownFragmentNamesRule,
NoUnusedFragmentsRule,
PossibleFragmentSpreadsRule,
NoFragmentCyclesRule,
UniqueVariableNamesRule,
NoUndefinedVariablesRule,
NoUnusedVariablesRule,
KnownDirectivesRule,
UniqueDirectivesPerLocationRule,
KnownArgumentNamesRule,
UniqueArgumentNamesRule,
ValuesOfCorrectTypeRule,
ProvidedRequiredArgumentsRule,
VariablesInAllowedPositionRule,
OverlappingFieldsCanBeMergedRule,
UniqueInputFieldNamesRule,
)
"""A tuple with all validation rules defined by the GraphQL specification.
The order of the rules in this tuple has been adjusted to lead to the
most clear output when encountering multiple validation errors.
"""
specified_sdl_rules: Tuple[Type[ASTValidationRule], ...] = (
LoneSchemaDefinitionRule,
UniqueOperationTypesRule,
UniqueTypeNamesRule,
UniqueEnumValueNamesRule,
UniqueFieldDefinitionNamesRule,
UniqueArgumentDefinitionNamesRule,
UniqueDirectiveNamesRule,
KnownTypeNamesRule,
KnownDirectivesRule,
UniqueDirectivesPerLocationRule,
PossibleTypeExtensionsRule,
KnownArgumentNamesOnDirectivesRule,
UniqueArgumentNamesRule,
UniqueInputFieldNamesRule,
ProvidedRequiredArgumentsOnDirectivesRule,
)
"""This tuple includes all rules for validating SDL.
For internal use only.
"""
@@ -0,0 +1,133 @@
from typing import Collection, List, Optional, Type
from ..error import GraphQLError
from ..language import DocumentNode, ParallelVisitor, visit
from ..pyutils import inspect, is_collection
from ..type import GraphQLSchema, assert_valid_schema
from ..utilities import TypeInfo, TypeInfoVisitor
from .rules import ASTValidationRule
from .specified_rules import specified_rules, specified_sdl_rules
from .validation_context import SDLValidationContext, ValidationContext
__all__ = ["assert_valid_sdl", "assert_valid_sdl_extension", "validate", "validate_sdl"]
class ValidationAbortedError(RuntimeError):
"""Error when a validation has been aborted (error limit reached)."""
def validate(
schema: GraphQLSchema,
document_ast: DocumentNode,
rules: Optional[Collection[Type[ASTValidationRule]]] = None,
max_errors: Optional[int] = None,
type_info: Optional[TypeInfo] = None,
) -> List[GraphQLError]:
"""Implements the "Validation" section of the spec.
Validation runs synchronously, returning a list of encountered errors, or an empty
list if no errors were encountered and the document is valid.
A list of specific validation rules may be provided. If not provided, the default
list of rules defined by the GraphQL specification will be used.
Each validation rule is a ValidationRule object which is a visitor object that holds
a ValidationContext (see the language/visitor API). Visitor methods are expected to
return GraphQLErrors, or lists of GraphQLErrors when invalid.
Validate will stop validation after a ``max_errors`` limit has been reached.
Attackers can send pathologically invalid queries to induce a DoS attack,
so by default ``max_errors`` set to 100 errors.
Providing a custom TypeInfo instance is deprecated and will be removed in v3.3.
"""
if not document_ast or not isinstance(document_ast, DocumentNode):
raise TypeError("Must provide document.")
# If the schema used for validation is invalid, throw an error.
assert_valid_schema(schema)
if max_errors is None:
max_errors = 100
elif not isinstance(max_errors, int):
raise TypeError("The maximum number of errors must be passed as an int.")
if type_info is None:
type_info = TypeInfo(schema)
elif not isinstance(type_info, TypeInfo):
raise TypeError(f"Not a TypeInfo object: {inspect(type_info)}.")
if rules is None:
rules = specified_rules
elif not is_collection(rules) or not all(
isinstance(rule, type) and issubclass(rule, ASTValidationRule) for rule in rules
):
raise TypeError(
"Rules must be specified as a collection of ASTValidationRule subclasses."
)
errors: List[GraphQLError] = []
def on_error(error: GraphQLError) -> None:
if len(errors) >= max_errors:
errors.append(
GraphQLError(
"Too many validation errors, error limit reached."
" Validation aborted."
)
)
raise ValidationAbortedError
errors.append(error)
context = ValidationContext(schema, document_ast, type_info, on_error)
# This uses a specialized visitor which runs multiple visitors in parallel,
# while maintaining the visitor skip and break API.
visitors = [rule(context) for rule in rules]
# Visit the whole document with each instance of all provided rules.
try:
visit(document_ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors)))
except ValidationAbortedError:
pass
return errors
def validate_sdl(
document_ast: DocumentNode,
schema_to_extend: Optional[GraphQLSchema] = None,
rules: Optional[Collection[Type[ASTValidationRule]]] = None,
) -> List[GraphQLError]:
"""Validate an SDL document.
For internal use only.
"""
errors: List[GraphQLError] = []
context = SDLValidationContext(document_ast, schema_to_extend, errors.append)
if rules is None:
rules = specified_sdl_rules
visitors = [rule(context) for rule in rules]
visit(document_ast, ParallelVisitor(visitors))
return errors
def assert_valid_sdl(document_ast: DocumentNode) -> None:
"""Assert document is valid SDL.
Utility function which asserts a SDL document is valid by throwing an error if it
is invalid.
"""
errors = validate_sdl(document_ast)
if errors:
raise TypeError("\n\n".join(error.message for error in errors))
def assert_valid_sdl_extension(
document_ast: DocumentNode, schema: GraphQLSchema
) -> None:
"""Assert document is a valid SDL extension.
Utility function which asserts a SDL document is valid by throwing an error if it
is invalid.
"""
errors = validate_sdl(document_ast, schema)
if errors:
raise TypeError("\n\n".join(error.message for error in errors))
@@ -0,0 +1,250 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast
from ..error import GraphQLError
from ..language import (
DocumentNode,
FragmentDefinitionNode,
FragmentSpreadNode,
OperationDefinitionNode,
SelectionSetNode,
VariableNode,
Visitor,
VisitorAction,
visit,
)
from ..type import (
GraphQLArgument,
GraphQLCompositeType,
GraphQLDirective,
GraphQLEnumValue,
GraphQLField,
GraphQLInputType,
GraphQLOutputType,
GraphQLSchema,
)
from ..utilities import TypeInfo, TypeInfoVisitor
__all__ = [
"ASTValidationContext",
"SDLValidationContext",
"ValidationContext",
"VariableUsage",
"VariableUsageVisitor",
]
NodeWithSelectionSet = Union[OperationDefinitionNode, FragmentDefinitionNode]
class VariableUsage(NamedTuple):
node: VariableNode
type: Optional[GraphQLInputType]
default_value: Any
class VariableUsageVisitor(Visitor):
"""Visitor adding all variable usages to a given list."""
usages: List[VariableUsage]
def __init__(self, type_info: TypeInfo):
super().__init__()
self.usages = []
self._append_usage = self.usages.append
self._type_info = type_info
def enter_variable_definition(self, *_args: Any) -> VisitorAction:
return self.SKIP
def enter_variable(self, node: VariableNode, *_args: Any) -> VisitorAction:
type_info = self._type_info
usage = VariableUsage(
node, type_info.get_input_type(), type_info.get_default_value()
)
self._append_usage(usage)
return None
class ASTValidationContext:
"""Utility class providing a context for validation of an AST.
An instance of this class is passed as the context attribute to all Validators,
allowing access to commonly useful contextual information from within a validation
rule.
"""
document: DocumentNode
_fragments: Optional[Dict[str, FragmentDefinitionNode]]
_fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]]
_recursively_referenced_fragments: Dict[
OperationDefinitionNode, List[FragmentDefinitionNode]
]
def __init__(
self, ast: DocumentNode, on_error: Callable[[GraphQLError], None]
) -> None:
self.document = ast
self.on_error = on_error # type: ignore
self._fragments = None
self._fragment_spreads = {}
self._recursively_referenced_fragments = {}
def on_error(self, error: GraphQLError) -> None:
pass
def report_error(self, error: GraphQLError) -> None:
self.on_error(error)
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
fragments = self._fragments
if fragments is None:
fragments = {
statement.name.value: statement
for statement in self.document.definitions
if isinstance(statement, FragmentDefinitionNode)
}
self._fragments = fragments
return fragments.get(name)
def get_fragment_spreads(self, node: SelectionSetNode) -> List[FragmentSpreadNode]:
spreads = self._fragment_spreads.get(node)
if spreads is None:
spreads = []
append_spread = spreads.append
sets_to_visit = [node]
append_set = sets_to_visit.append
pop_set = sets_to_visit.pop
while sets_to_visit:
visited_set = pop_set()
for selection in visited_set.selections:
if isinstance(selection, FragmentSpreadNode):
append_spread(selection)
else:
set_to_visit = cast(
NodeWithSelectionSet, selection
).selection_set
if set_to_visit:
append_set(set_to_visit)
self._fragment_spreads[node] = spreads
return spreads
def get_recursively_referenced_fragments(
self, operation: OperationDefinitionNode
) -> List[FragmentDefinitionNode]:
fragments = self._recursively_referenced_fragments.get(operation)
if fragments is None:
fragments = []
append_fragment = fragments.append
collected_names: Set[str] = set()
add_name = collected_names.add
nodes_to_visit = [operation.selection_set]
append_node = nodes_to_visit.append
pop_node = nodes_to_visit.pop
get_fragment = self.get_fragment
get_fragment_spreads = self.get_fragment_spreads
while nodes_to_visit:
visited_node = pop_node()
for spread in get_fragment_spreads(visited_node):
frag_name = spread.name.value
if frag_name not in collected_names:
add_name(frag_name)
fragment = get_fragment(frag_name)
if fragment:
append_fragment(fragment)
append_node(fragment.selection_set)
self._recursively_referenced_fragments[operation] = fragments
return fragments
class SDLValidationContext(ASTValidationContext):
"""Utility class providing a context for validation of an SDL AST.
An instance of this class is passed as the context attribute to all Validators,
allowing access to commonly useful contextual information from within a validation
rule.
"""
schema: Optional[GraphQLSchema]
def __init__(
self,
ast: DocumentNode,
schema: Optional[GraphQLSchema],
on_error: Callable[[GraphQLError], None],
) -> None:
super().__init__(ast, on_error)
self.schema = schema
class ValidationContext(ASTValidationContext):
"""Utility class providing a context for validation using a GraphQL schema.
An instance of this class is passed as the context attribute to all Validators,
allowing access to commonly useful contextual information from within a validation
rule.
"""
schema: GraphQLSchema
_type_info: TypeInfo
_variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]]
_recursive_variable_usages: Dict[OperationDefinitionNode, List[VariableUsage]]
def __init__(
self,
schema: GraphQLSchema,
ast: DocumentNode,
type_info: TypeInfo,
on_error: Callable[[GraphQLError], None],
) -> None:
super().__init__(ast, on_error)
self.schema = schema
self._type_info = type_info
self._variable_usages = {}
self._recursive_variable_usages = {}
def get_variable_usages(self, node: NodeWithSelectionSet) -> List[VariableUsage]:
usages = self._variable_usages.get(node)
if usages is None:
usage_visitor = VariableUsageVisitor(self._type_info)
visit(node, TypeInfoVisitor(self._type_info, usage_visitor))
usages = usage_visitor.usages
self._variable_usages[node] = usages
return usages
def get_recursive_variable_usages(
self, operation: OperationDefinitionNode
) -> List[VariableUsage]:
usages = self._recursive_variable_usages.get(operation)
if usages is None:
get_variable_usages = self.get_variable_usages
usages = get_variable_usages(operation)
for fragment in self.get_recursively_referenced_fragments(operation):
usages.extend(get_variable_usages(fragment))
self._recursive_variable_usages[operation] = usages
return usages
def get_type(self) -> Optional[GraphQLOutputType]:
return self._type_info.get_type()
def get_parent_type(self) -> Optional[GraphQLCompositeType]:
return self._type_info.get_parent_type()
def get_input_type(self) -> Optional[GraphQLInputType]:
return self._type_info.get_input_type()
def get_parent_input_type(self) -> Optional[GraphQLInputType]:
return self._type_info.get_parent_input_type()
def get_field_def(self) -> Optional[GraphQLField]:
return self._type_info.get_field_def()
def get_directive(self) -> Optional[GraphQLDirective]:
return self._type_info.get_directive()
def get_argument(self) -> Optional[GraphQLArgument]:
return self._type_info.get_argument()
def get_enum_value(self) -> Optional[GraphQLEnumValue]:
return self._type_info.get_enum_value()