2025-12-01
This commit is contained in:
@@ -0,0 +1,157 @@
|
||||
"""GraphQL Validation
|
||||
|
||||
The :mod:`graphql.validation` package fulfills the Validation phase of fulfilling a
|
||||
GraphQL result.
|
||||
"""
|
||||
|
||||
from .validate import validate
|
||||
|
||||
from .validation_context import (
|
||||
ASTValidationContext,
|
||||
SDLValidationContext,
|
||||
ValidationContext,
|
||||
)
|
||||
|
||||
from .rules import ValidationRule, ASTValidationRule, SDLValidationRule
|
||||
|
||||
# All validation rules in the GraphQL Specification.
|
||||
from .specified_rules import specified_rules
|
||||
|
||||
# Spec Section: "Executable Definitions"
|
||||
from .rules.executable_definitions import ExecutableDefinitionsRule
|
||||
|
||||
# Spec Section: "Field Selections on Objects, Interfaces, and Unions Types"
|
||||
from .rules.fields_on_correct_type import FieldsOnCorrectTypeRule
|
||||
|
||||
# Spec Section: "Fragments on Composite Types"
|
||||
from .rules.fragments_on_composite_types import FragmentsOnCompositeTypesRule
|
||||
|
||||
# Spec Section: "Argument Names"
|
||||
from .rules.known_argument_names import KnownArgumentNamesRule
|
||||
|
||||
# Spec Section: "Directives Are Defined"
|
||||
from .rules.known_directives import KnownDirectivesRule
|
||||
|
||||
# Spec Section: "Fragment spread target defined"
|
||||
from .rules.known_fragment_names import KnownFragmentNamesRule
|
||||
|
||||
# Spec Section: "Fragment Spread Type Existence"
|
||||
from .rules.known_type_names import KnownTypeNamesRule
|
||||
|
||||
# Spec Section: "Lone Anonymous Operation"
|
||||
from .rules.lone_anonymous_operation import LoneAnonymousOperationRule
|
||||
|
||||
# Spec Section: "Fragments must not form cycles"
|
||||
from .rules.no_fragment_cycles import NoFragmentCyclesRule
|
||||
|
||||
# Spec Section: "All Variable Used Defined"
|
||||
from .rules.no_undefined_variables import NoUndefinedVariablesRule
|
||||
|
||||
# Spec Section: "Fragments must be used"
|
||||
from .rules.no_unused_fragments import NoUnusedFragmentsRule
|
||||
|
||||
# Spec Section: "All Variables Used"
|
||||
from .rules.no_unused_variables import NoUnusedVariablesRule
|
||||
|
||||
# Spec Section: "Field Selection Merging"
|
||||
from .rules.overlapping_fields_can_be_merged import OverlappingFieldsCanBeMergedRule
|
||||
|
||||
# Spec Section: "Fragment spread is possible"
|
||||
from .rules.possible_fragment_spreads import PossibleFragmentSpreadsRule
|
||||
|
||||
# Spec Section: "Argument Optionality"
|
||||
from .rules.provided_required_arguments import ProvidedRequiredArgumentsRule
|
||||
|
||||
# Spec Section: "Leaf Field Selections"
|
||||
from .rules.scalar_leafs import ScalarLeafsRule
|
||||
|
||||
# Spec Section: "Subscriptions with Single Root Field"
|
||||
from .rules.single_field_subscriptions import SingleFieldSubscriptionsRule
|
||||
|
||||
# Spec Section: "Argument Uniqueness"
|
||||
from .rules.unique_argument_names import UniqueArgumentNamesRule
|
||||
|
||||
# Spec Section: "Directives Are Unique Per Location"
|
||||
from .rules.unique_directives_per_location import UniqueDirectivesPerLocationRule
|
||||
|
||||
# Spec Section: "Fragment Name Uniqueness"
|
||||
from .rules.unique_fragment_names import UniqueFragmentNamesRule
|
||||
|
||||
# Spec Section: "Input Object Field Uniqueness"
|
||||
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
|
||||
|
||||
# Spec Section: "Operation Name Uniqueness"
|
||||
from .rules.unique_operation_names import UniqueOperationNamesRule
|
||||
|
||||
# Spec Section: "Variable Uniqueness"
|
||||
from .rules.unique_variable_names import UniqueVariableNamesRule
|
||||
|
||||
# Spec Section: "Value Type Correctness"
|
||||
from .rules.values_of_correct_type import ValuesOfCorrectTypeRule
|
||||
|
||||
# Spec Section: "Variables are Input Types"
|
||||
from .rules.variables_are_input_types import VariablesAreInputTypesRule
|
||||
|
||||
# Spec Section: "All Variable Usages Are Allowed"
|
||||
from .rules.variables_in_allowed_position import VariablesInAllowedPositionRule
|
||||
|
||||
# SDL-specific validation rules
|
||||
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
|
||||
from .rules.unique_operation_types import UniqueOperationTypesRule
|
||||
from .rules.unique_type_names import UniqueTypeNamesRule
|
||||
from .rules.unique_enum_value_names import UniqueEnumValueNamesRule
|
||||
from .rules.unique_field_definition_names import UniqueFieldDefinitionNamesRule
|
||||
from .rules.unique_argument_definition_names import UniqueArgumentDefinitionNamesRule
|
||||
from .rules.unique_directive_names import UniqueDirectiveNamesRule
|
||||
from .rules.possible_type_extensions import PossibleTypeExtensionsRule
|
||||
|
||||
# Optional rules not defined by the GraphQL Specification
|
||||
from .rules.custom.no_deprecated import NoDeprecatedCustomRule
|
||||
from .rules.custom.no_schema_introspection import NoSchemaIntrospectionCustomRule
|
||||
|
||||
__all__ = [
|
||||
"validate",
|
||||
"ASTValidationContext",
|
||||
"ASTValidationRule",
|
||||
"SDLValidationContext",
|
||||
"SDLValidationRule",
|
||||
"ValidationContext",
|
||||
"ValidationRule",
|
||||
"specified_rules",
|
||||
"ExecutableDefinitionsRule",
|
||||
"FieldsOnCorrectTypeRule",
|
||||
"FragmentsOnCompositeTypesRule",
|
||||
"KnownArgumentNamesRule",
|
||||
"KnownDirectivesRule",
|
||||
"KnownFragmentNamesRule",
|
||||
"KnownTypeNamesRule",
|
||||
"LoneAnonymousOperationRule",
|
||||
"NoFragmentCyclesRule",
|
||||
"NoUndefinedVariablesRule",
|
||||
"NoUnusedFragmentsRule",
|
||||
"NoUnusedVariablesRule",
|
||||
"OverlappingFieldsCanBeMergedRule",
|
||||
"PossibleFragmentSpreadsRule",
|
||||
"ProvidedRequiredArgumentsRule",
|
||||
"ScalarLeafsRule",
|
||||
"SingleFieldSubscriptionsRule",
|
||||
"UniqueArgumentNamesRule",
|
||||
"UniqueDirectivesPerLocationRule",
|
||||
"UniqueFragmentNamesRule",
|
||||
"UniqueInputFieldNamesRule",
|
||||
"UniqueOperationNamesRule",
|
||||
"UniqueVariableNamesRule",
|
||||
"ValuesOfCorrectTypeRule",
|
||||
"VariablesAreInputTypesRule",
|
||||
"VariablesInAllowedPositionRule",
|
||||
"LoneSchemaDefinitionRule",
|
||||
"UniqueOperationTypesRule",
|
||||
"UniqueTypeNamesRule",
|
||||
"UniqueEnumValueNamesRule",
|
||||
"UniqueFieldDefinitionNamesRule",
|
||||
"UniqueArgumentDefinitionNamesRule",
|
||||
"UniqueDirectiveNamesRule",
|
||||
"PossibleTypeExtensionsRule",
|
||||
"NoDeprecatedCustomRule",
|
||||
"NoSchemaIntrospectionCustomRule",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""graphql.validation.rules package"""
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language.visitor import Visitor
|
||||
from ..validation_context import (
|
||||
ASTValidationContext,
|
||||
SDLValidationContext,
|
||||
ValidationContext,
|
||||
)
|
||||
|
||||
__all__ = ["ASTValidationRule", "SDLValidationRule", "ValidationRule"]
|
||||
|
||||
|
||||
class ASTValidationRule(Visitor):
|
||||
"""Visitor for validation of an AST."""
|
||||
|
||||
context: ASTValidationContext
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
|
||||
def report_error(self, error: GraphQLError) -> None:
|
||||
self.context.report_error(error)
|
||||
|
||||
|
||||
class SDLValidationRule(ASTValidationRule):
|
||||
"""Visitor for validation of an SDL AST."""
|
||||
|
||||
context: SDLValidationContext
|
||||
|
||||
def __init__(self, context: SDLValidationContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
|
||||
class ValidationRule(ASTValidationRule):
|
||||
"""Visitor for validation using a GraphQL schema."""
|
||||
|
||||
context: ValidationContext
|
||||
|
||||
def __init__(self, context: ValidationContext) -> None:
|
||||
super().__init__(context)
|
||||
@@ -0,0 +1 @@
|
||||
"""graphql.validation.rules.custom package"""
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from ....error import GraphQLError
|
||||
from ....language import ArgumentNode, EnumValueNode, FieldNode, ObjectFieldNode
|
||||
from ....type import GraphQLInputObjectType, get_named_type, is_input_object_type
|
||||
from .. import ValidationRule
|
||||
|
||||
__all__ = ["NoDeprecatedCustomRule"]
|
||||
|
||||
|
||||
class NoDeprecatedCustomRule(ValidationRule):
|
||||
"""No deprecated
|
||||
|
||||
A GraphQL document is only valid if all selected fields and all used enum values
|
||||
have not been deprecated.
|
||||
|
||||
Note: This rule is optional and is not part of the Validation section of the GraphQL
|
||||
Specification. The main purpose of this rule is detection of deprecated usages and
|
||||
not necessarily to forbid their use when querying a service.
|
||||
"""
|
||||
|
||||
def enter_field(self, node: FieldNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
field_def = context.get_field_def()
|
||||
if field_def:
|
||||
deprecation_reason = field_def.deprecation_reason
|
||||
if deprecation_reason is not None:
|
||||
parent_type = context.get_parent_type()
|
||||
parent_name = parent_type.name # type: ignore
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"The field {parent_name}.{node.name.value}"
|
||||
f" is deprecated. {deprecation_reason}",
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_argument(self, node: ArgumentNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
arg_def = context.get_argument()
|
||||
if arg_def:
|
||||
deprecation_reason = arg_def.deprecation_reason
|
||||
if deprecation_reason is not None:
|
||||
directive_def = context.get_directive()
|
||||
arg_name = node.name.value
|
||||
if directive_def is None:
|
||||
parent_type = context.get_parent_type()
|
||||
parent_name = parent_type.name # type: ignore
|
||||
field_def = context.get_field_def()
|
||||
field_name = field_def.ast_node.name.value # type: ignore
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{parent_name}.{field_name}' argument"
|
||||
f" '{arg_name}' is deprecated. {deprecation_reason}",
|
||||
node,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Directive '@{directive_def.name}' argument"
|
||||
f" '{arg_name}' is deprecated. {deprecation_reason}",
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
input_object_def = get_named_type(context.get_parent_input_type())
|
||||
if is_input_object_type(input_object_def):
|
||||
input_field_def = cast(GraphQLInputObjectType, input_object_def).fields.get(
|
||||
node.name.value
|
||||
)
|
||||
if input_field_def:
|
||||
deprecation_reason = input_field_def.deprecation_reason
|
||||
if deprecation_reason is not None:
|
||||
field_name = node.name.value
|
||||
input_object_name = input_object_def.name # type: ignore
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"The input field {input_object_name}.{field_name}"
|
||||
f" is deprecated. {deprecation_reason}",
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_enum_value(self, node: EnumValueNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
enum_value_def = context.get_enum_value()
|
||||
if enum_value_def:
|
||||
deprecation_reason = enum_value_def.deprecation_reason
|
||||
if deprecation_reason is not None: # pragma: no cover else
|
||||
enum_type_def = get_named_type(context.get_input_type())
|
||||
enum_type_name = enum_type_def.name # type: ignore
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"The enum value '{enum_type_name}.{node.value}'"
|
||||
f" is deprecated. {deprecation_reason}",
|
||||
node,
|
||||
)
|
||||
)
|
||||
+31
@@ -0,0 +1,31 @@
|
||||
from typing import Any
|
||||
|
||||
from ....error import GraphQLError
|
||||
from ....language import FieldNode
|
||||
from ....type import get_named_type, is_introspection_type
|
||||
from .. import ValidationRule
|
||||
|
||||
__all__ = ["NoSchemaIntrospectionCustomRule"]
|
||||
|
||||
|
||||
class NoSchemaIntrospectionCustomRule(ValidationRule):
|
||||
"""Prohibit introspection queries
|
||||
|
||||
A GraphQL document is only valid if all fields selected are not fields that
|
||||
return an introspection type.
|
||||
|
||||
Note: This rule is optional and is not part of the Validation section of the
|
||||
GraphQL Specification. This rule effectively disables introspection, which
|
||||
does not reflect best practices and should only be done if absolutely necessary.
|
||||
"""
|
||||
|
||||
def enter_field(self, node: FieldNode, *_args: Any) -> None:
|
||||
type_ = get_named_type(self.context.get_type())
|
||||
if type_ and is_introspection_type(type_):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
"GraphQL introspection has been disabled, but the requested query"
|
||||
f" contained the field '{node.name.value}'.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveDefinitionNode,
|
||||
DocumentNode,
|
||||
ExecutableDefinitionNode,
|
||||
SchemaDefinitionNode,
|
||||
SchemaExtensionNode,
|
||||
TypeDefinitionNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
)
|
||||
from . import ASTValidationRule
|
||||
|
||||
__all__ = ["ExecutableDefinitionsRule"]
|
||||
|
||||
|
||||
class ExecutableDefinitionsRule(ASTValidationRule):
|
||||
"""Executable definitions
|
||||
|
||||
A GraphQL document is only valid for execution if all definitions are either
|
||||
operation or fragment definitions.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Executable-Definitions
|
||||
"""
|
||||
|
||||
def enter_document(self, node: DocumentNode, *_args: Any) -> VisitorAction:
|
||||
for definition in node.definitions:
|
||||
if not isinstance(definition, ExecutableDefinitionNode):
|
||||
def_name = (
|
||||
"schema"
|
||||
if isinstance(
|
||||
definition, (SchemaDefinitionNode, SchemaExtensionNode)
|
||||
)
|
||||
else "'{}'".format(
|
||||
cast(
|
||||
Union[DirectiveDefinitionNode, TypeDefinitionNode],
|
||||
definition,
|
||||
).name.value
|
||||
)
|
||||
)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"The {def_name} definition is not executable.",
|
||||
definition,
|
||||
)
|
||||
)
|
||||
return SKIP
|
||||
+136
@@ -0,0 +1,136 @@
|
||||
from collections import defaultdict
|
||||
from functools import cmp_to_key
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
|
||||
from ...type import (
|
||||
GraphQLAbstractType,
|
||||
GraphQLInterfaceType,
|
||||
GraphQLObjectType,
|
||||
GraphQLOutputType,
|
||||
GraphQLSchema,
|
||||
is_abstract_type,
|
||||
is_interface_type,
|
||||
is_object_type,
|
||||
)
|
||||
from ...error import GraphQLError
|
||||
from ...language import FieldNode
|
||||
from ...pyutils import did_you_mean, natural_comparison_key, suggestion_list
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["FieldsOnCorrectTypeRule"]
|
||||
|
||||
|
||||
class FieldsOnCorrectTypeRule(ValidationRule):
|
||||
"""Fields on correct type
|
||||
|
||||
A GraphQL document is only valid if all fields selected are defined by the parent
|
||||
type, or are an allowed meta field such as ``__typename``.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Field-Selections
|
||||
"""
|
||||
|
||||
def enter_field(self, node: FieldNode, *_args: Any) -> None:
|
||||
type_ = self.context.get_parent_type()
|
||||
if not type_:
|
||||
return
|
||||
field_def = self.context.get_field_def()
|
||||
if field_def:
|
||||
return
|
||||
# This field doesn't exist, lets look for suggestions.
|
||||
schema = self.context.schema
|
||||
field_name = node.name.value
|
||||
|
||||
# First determine if there are any suggested types to condition on.
|
||||
suggestion = did_you_mean(
|
||||
get_suggested_type_names(schema, type_, field_name),
|
||||
"to use an inline fragment on",
|
||||
)
|
||||
|
||||
# If there are no suggested types, then perhaps this was a typo?
|
||||
if not suggestion:
|
||||
suggestion = did_you_mean(get_suggested_field_names(type_, field_name))
|
||||
|
||||
# Report an error, including helpful suggestions.
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Cannot query field '{field_name}' on type '{type_}'." + suggestion,
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_suggested_type_names(
|
||||
schema: GraphQLSchema, type_: GraphQLOutputType, field_name: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get a list of suggested type names.
|
||||
|
||||
Go through all of the implementations of type, as well as the interfaces
|
||||
that they implement. If any of those types include the provided field,
|
||||
suggest them, sorted by how often the type is referenced.
|
||||
"""
|
||||
if not is_abstract_type(type_):
|
||||
# Must be an Object type, which does not have possible fields.
|
||||
return []
|
||||
|
||||
type_ = cast(GraphQLAbstractType, type_)
|
||||
# Use a dict instead of a set for stable sorting when usage counts are the same
|
||||
suggested_types: Dict[Union[GraphQLObjectType, GraphQLInterfaceType], None] = {}
|
||||
usage_count: Dict[str, int] = defaultdict(int)
|
||||
for possible_type in schema.get_possible_types(type_):
|
||||
if field_name not in possible_type.fields:
|
||||
continue
|
||||
|
||||
# This object type defines this field.
|
||||
suggested_types[possible_type] = None
|
||||
usage_count[possible_type.name] = 1
|
||||
|
||||
for possible_interface in possible_type.interfaces:
|
||||
if field_name not in possible_interface.fields:
|
||||
continue
|
||||
|
||||
# This interface type defines this field.
|
||||
suggested_types[possible_interface] = None
|
||||
usage_count[possible_interface.name] += 1
|
||||
|
||||
def cmp(
|
||||
type_a: Union[GraphQLObjectType, GraphQLInterfaceType],
|
||||
type_b: Union[GraphQLObjectType, GraphQLInterfaceType],
|
||||
) -> int: # pragma: no cover
|
||||
# Suggest both interface and object types based on how common they are.
|
||||
usage_count_diff = usage_count[type_b.name] - usage_count[type_a.name]
|
||||
if usage_count_diff:
|
||||
return usage_count_diff
|
||||
|
||||
# Suggest super types first followed by subtypes
|
||||
if is_interface_type(type_a) and schema.is_sub_type(
|
||||
cast(GraphQLInterfaceType, type_a), type_b
|
||||
):
|
||||
return -1
|
||||
if is_interface_type(type_b) and schema.is_sub_type(
|
||||
cast(GraphQLInterfaceType, type_b), type_a
|
||||
):
|
||||
return 1
|
||||
|
||||
name_a = natural_comparison_key(type_a.name)
|
||||
name_b = natural_comparison_key(type_b.name)
|
||||
if name_a > name_b:
|
||||
return 1
|
||||
if name_a < name_b:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
return [type_.name for type_ in sorted(suggested_types, key=cmp_to_key(cmp))]
|
||||
|
||||
|
||||
def get_suggested_field_names(type_: GraphQLOutputType, field_name: str) -> List[str]:
|
||||
"""Get a list of suggested field names.
|
||||
|
||||
For the field name provided, determine if there are any similar field names that may
|
||||
be the result of a typo.
|
||||
"""
|
||||
if is_object_type(type_) or is_interface_type(type_):
|
||||
possible_field_names = list(type_.fields) # type: ignore
|
||||
return suggestion_list(field_name, possible_field_names)
|
||||
# Otherwise, must be a Union type, which does not define fields.
|
||||
return []
|
||||
+53
@@ -0,0 +1,53 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
FragmentDefinitionNode,
|
||||
InlineFragmentNode,
|
||||
print_ast,
|
||||
)
|
||||
from ...type import is_composite_type
|
||||
from ...utilities import type_from_ast
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["FragmentsOnCompositeTypesRule"]
|
||||
|
||||
|
||||
class FragmentsOnCompositeTypesRule(ValidationRule):
|
||||
"""Fragments on composite type
|
||||
|
||||
Fragments use a type condition to determine if they apply, since fragments can only
|
||||
be spread into a composite type (object, interface, or union), the type condition
|
||||
must also be a composite type.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragments-On-Composite-Types
|
||||
"""
|
||||
|
||||
def enter_inline_fragment(self, node: InlineFragmentNode, *_args: Any) -> None:
|
||||
type_condition = node.type_condition
|
||||
if type_condition:
|
||||
type_ = type_from_ast(self.context.schema, type_condition)
|
||||
if type_ and not is_composite_type(type_):
|
||||
type_str = print_ast(type_condition)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
"Fragment cannot condition"
|
||||
f" on non composite type '{type_str}'.",
|
||||
type_condition,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_fragment_definition(
|
||||
self, node: FragmentDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
type_condition = node.type_condition
|
||||
type_ = type_from_ast(self.context.schema, type_condition)
|
||||
if type_ and not is_composite_type(type_):
|
||||
type_str = print_ast(type_condition)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Fragment '{node.name.value}' cannot condition"
|
||||
f" on non composite type '{type_str}'.",
|
||||
type_condition,
|
||||
)
|
||||
)
|
||||
+98
@@ -0,0 +1,98 @@
|
||||
from typing import cast, Any, Dict, List, Union
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
ArgumentNode,
|
||||
DirectiveDefinitionNode,
|
||||
DirectiveNode,
|
||||
SKIP,
|
||||
VisitorAction,
|
||||
)
|
||||
from ...pyutils import did_you_mean, suggestion_list
|
||||
from ...type import specified_directives
|
||||
from . import ASTValidationRule, SDLValidationContext, ValidationContext
|
||||
|
||||
__all__ = ["KnownArgumentNamesRule", "KnownArgumentNamesOnDirectivesRule"]
|
||||
|
||||
|
||||
class KnownArgumentNamesOnDirectivesRule(ASTValidationRule):
|
||||
"""Known argument names on directives
|
||||
|
||||
A GraphQL directive is only valid if all supplied arguments are defined.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
|
||||
context: Union[ValidationContext, SDLValidationContext]
|
||||
|
||||
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
|
||||
super().__init__(context)
|
||||
directive_args: Dict[str, List[str]] = {}
|
||||
|
||||
schema = context.schema
|
||||
defined_directives = schema.directives if schema else specified_directives
|
||||
for directive in cast(List, defined_directives):
|
||||
directive_args[directive.name] = list(directive.args)
|
||||
|
||||
ast_definitions = context.document.definitions
|
||||
for def_ in ast_definitions:
|
||||
if isinstance(def_, DirectiveDefinitionNode):
|
||||
directive_args[def_.name.value] = [
|
||||
arg.name.value for arg in def_.arguments or []
|
||||
]
|
||||
|
||||
self.directive_args = directive_args
|
||||
|
||||
def enter_directive(
|
||||
self, directive_node: DirectiveNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
directive_name = directive_node.name.value
|
||||
known_args = self.directive_args.get(directive_name)
|
||||
if directive_node.arguments and known_args is not None:
|
||||
for arg_node in directive_node.arguments:
|
||||
arg_name = arg_node.name.value
|
||||
if arg_name not in known_args:
|
||||
suggestions = suggestion_list(arg_name, known_args)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Unknown argument '{arg_name}'"
|
||||
f" on directive '@{directive_name}'."
|
||||
+ did_you_mean(suggestions),
|
||||
arg_node,
|
||||
)
|
||||
)
|
||||
return SKIP
|
||||
|
||||
|
||||
class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
|
||||
"""Known argument names
|
||||
|
||||
A GraphQL field is only valid if all supplied arguments are defined by that field.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Argument-Names
|
||||
See https://spec.graphql.org/draft/#sec-Directives-Are-In-Valid-Locations
|
||||
"""
|
||||
|
||||
context: ValidationContext
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
|
||||
def enter_argument(self, arg_node: ArgumentNode, *args: Any) -> None:
|
||||
context = self.context
|
||||
arg_def = context.get_argument()
|
||||
field_def = context.get_field_def()
|
||||
parent_type = context.get_parent_type()
|
||||
if not arg_def and field_def and parent_type:
|
||||
arg_name = arg_node.name.value
|
||||
field_name = args[3][-1].name.value
|
||||
known_args_names = list(field_def.args)
|
||||
suggestions = suggestion_list(arg_name, known_args_names)
|
||||
context.report_error(
|
||||
GraphQLError(
|
||||
f"Unknown argument '{arg_name}'"
|
||||
f" on field '{parent_type.name}.{field_name}'."
|
||||
+ did_you_mean(suggestions),
|
||||
arg_node,
|
||||
)
|
||||
)
|
||||
+119
@@ -0,0 +1,119 @@
|
||||
from typing import cast, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveLocation,
|
||||
DirectiveDefinitionNode,
|
||||
DirectiveNode,
|
||||
Node,
|
||||
OperationDefinitionNode,
|
||||
)
|
||||
from ...type import specified_directives
|
||||
from . import ASTValidationRule, SDLValidationContext, ValidationContext
|
||||
|
||||
__all__ = ["KnownDirectivesRule"]
|
||||
|
||||
|
||||
class KnownDirectivesRule(ASTValidationRule):
|
||||
"""Known directives
|
||||
|
||||
A GraphQL document is only valid if all ``@directives`` are known by the schema and
|
||||
legally positioned.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Directives-Are-Defined
|
||||
"""
|
||||
|
||||
context: Union[ValidationContext, SDLValidationContext]
|
||||
|
||||
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
|
||||
super().__init__(context)
|
||||
locations_map: Dict[str, Tuple[DirectiveLocation, ...]] = {}
|
||||
|
||||
schema = context.schema
|
||||
defined_directives = (
|
||||
schema.directives if schema else cast(List, specified_directives)
|
||||
)
|
||||
for directive in defined_directives:
|
||||
locations_map[directive.name] = directive.locations
|
||||
ast_definitions = context.document.definitions
|
||||
for def_ in ast_definitions:
|
||||
if isinstance(def_, DirectiveDefinitionNode):
|
||||
locations_map[def_.name.value] = tuple(
|
||||
DirectiveLocation[name.value] for name in def_.locations
|
||||
)
|
||||
self.locations_map = locations_map
|
||||
|
||||
def enter_directive(
|
||||
self,
|
||||
node: DirectiveNode,
|
||||
_key: Any,
|
||||
_parent: Any,
|
||||
_path: Any,
|
||||
ancestors: List[Node],
|
||||
) -> None:
|
||||
name = node.name.value
|
||||
locations = self.locations_map.get(name)
|
||||
if locations:
|
||||
candidate_location = get_directive_location_for_ast_path(ancestors)
|
||||
if candidate_location and candidate_location not in locations:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Directive '@{name}'"
|
||||
f" may not be used on {candidate_location.value}.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.report_error(GraphQLError(f"Unknown directive '@{name}'.", node))
|
||||
|
||||
|
||||
_operation_location = {
|
||||
"query": DirectiveLocation.QUERY,
|
||||
"mutation": DirectiveLocation.MUTATION,
|
||||
"subscription": DirectiveLocation.SUBSCRIPTION,
|
||||
}
|
||||
|
||||
_directive_location = {
|
||||
"field": DirectiveLocation.FIELD,
|
||||
"fragment_spread": DirectiveLocation.FRAGMENT_SPREAD,
|
||||
"inline_fragment": DirectiveLocation.INLINE_FRAGMENT,
|
||||
"fragment_definition": DirectiveLocation.FRAGMENT_DEFINITION,
|
||||
"variable_definition": DirectiveLocation.VARIABLE_DEFINITION,
|
||||
"schema_definition": DirectiveLocation.SCHEMA,
|
||||
"schema_extension": DirectiveLocation.SCHEMA,
|
||||
"scalar_type_definition": DirectiveLocation.SCALAR,
|
||||
"scalar_type_extension": DirectiveLocation.SCALAR,
|
||||
"object_type_definition": DirectiveLocation.OBJECT,
|
||||
"object_type_extension": DirectiveLocation.OBJECT,
|
||||
"field_definition": DirectiveLocation.FIELD_DEFINITION,
|
||||
"interface_type_definition": DirectiveLocation.INTERFACE,
|
||||
"interface_type_extension": DirectiveLocation.INTERFACE,
|
||||
"union_type_definition": DirectiveLocation.UNION,
|
||||
"union_type_extension": DirectiveLocation.UNION,
|
||||
"enum_type_definition": DirectiveLocation.ENUM,
|
||||
"enum_type_extension": DirectiveLocation.ENUM,
|
||||
"enum_value_definition": DirectiveLocation.ENUM_VALUE,
|
||||
"input_object_type_definition": DirectiveLocation.INPUT_OBJECT,
|
||||
"input_object_type_extension": DirectiveLocation.INPUT_OBJECT,
|
||||
}
|
||||
|
||||
|
||||
def get_directive_location_for_ast_path(
|
||||
ancestors: List[Node],
|
||||
) -> Optional[DirectiveLocation]:
|
||||
applied_to = ancestors[-1]
|
||||
if not isinstance(applied_to, Node): # pragma: no cover
|
||||
raise TypeError("Unexpected error in directive.")
|
||||
kind = applied_to.kind
|
||||
if kind == "operation_definition":
|
||||
applied_to = cast(OperationDefinitionNode, applied_to)
|
||||
return _operation_location[applied_to.operation.value]
|
||||
elif kind == "input_value_definition":
|
||||
parent_node = ancestors[-3]
|
||||
return (
|
||||
DirectiveLocation.INPUT_FIELD_DEFINITION
|
||||
if parent_node.kind == "input_object_type_definition"
|
||||
else DirectiveLocation.ARGUMENT_DEFINITION
|
||||
)
|
||||
else:
|
||||
return _directive_location.get(kind)
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import FragmentSpreadNode
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["KnownFragmentNamesRule"]
|
||||
|
||||
|
||||
class KnownFragmentNamesRule(ValidationRule):
|
||||
"""Known fragment names
|
||||
|
||||
A GraphQL document is only valid if all ``...Fragment`` fragment spreads refer to
|
||||
fragments defined in the same document.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragment-spread-target-defined
|
||||
"""
|
||||
|
||||
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
|
||||
fragment_name = node.name.value
|
||||
fragment = self.context.get_fragment(fragment_name)
|
||||
if not fragment:
|
||||
self.report_error(
|
||||
GraphQLError(f"Unknown fragment '{fragment_name}'.", node.name)
|
||||
)
|
||||
+90
@@ -0,0 +1,90 @@
|
||||
from typing import Any, Collection, List, Union, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
is_type_definition_node,
|
||||
is_type_system_definition_node,
|
||||
is_type_system_extension_node,
|
||||
Node,
|
||||
NamedTypeNode,
|
||||
TypeDefinitionNode,
|
||||
)
|
||||
from ...type import introspection_types, specified_scalar_types
|
||||
from ...pyutils import did_you_mean, suggestion_list
|
||||
from . import ASTValidationRule, ValidationContext, SDLValidationContext
|
||||
|
||||
__all__ = ["KnownTypeNamesRule"]
|
||||
|
||||
|
||||
class KnownTypeNamesRule(ASTValidationRule):
|
||||
"""Known type names
|
||||
|
||||
A GraphQL document is only valid if referenced types (specifically variable
|
||||
definitions and fragment conditions) are defined by the type schema.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragment-Spread-Type-Existence
|
||||
"""
|
||||
|
||||
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
|
||||
super().__init__(context)
|
||||
schema = context.schema
|
||||
self.existing_types_map = schema.type_map if schema else {}
|
||||
|
||||
defined_types = []
|
||||
for def_ in context.document.definitions:
|
||||
if is_type_definition_node(def_):
|
||||
def_ = cast(TypeDefinitionNode, def_)
|
||||
defined_types.append(def_.name.value)
|
||||
self.defined_types = set(defined_types)
|
||||
|
||||
self.type_names = list(self.existing_types_map) + defined_types
|
||||
|
||||
def enter_named_type(
|
||||
self,
|
||||
node: NamedTypeNode,
|
||||
_key: Any,
|
||||
parent: Node,
|
||||
_path: Any,
|
||||
ancestors: List[Node],
|
||||
) -> None:
|
||||
type_name = node.name.value
|
||||
if (
|
||||
type_name not in self.existing_types_map
|
||||
and type_name not in self.defined_types
|
||||
):
|
||||
try:
|
||||
definition_node = ancestors[2]
|
||||
except IndexError:
|
||||
definition_node = parent
|
||||
is_sdl = is_sdl_node(definition_node)
|
||||
if is_sdl and type_name in standard_type_names:
|
||||
return
|
||||
|
||||
suggested_types = suggestion_list(
|
||||
type_name,
|
||||
(
|
||||
list(standard_type_names) + self.type_names
|
||||
if is_sdl
|
||||
else self.type_names
|
||||
),
|
||||
)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Unknown type '{type_name}'." + did_you_mean(suggested_types),
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
standard_type_names = set(specified_scalar_types).union(introspection_types)
|
||||
|
||||
|
||||
def is_sdl_node(value: Union[Node, Collection[Node], None]) -> bool:
|
||||
return (
|
||||
value is not None
|
||||
and not isinstance(value, list)
|
||||
and (
|
||||
is_type_system_definition_node(cast(Node, value))
|
||||
or is_type_system_extension_node(cast(Node, value))
|
||||
)
|
||||
)
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import DocumentNode, OperationDefinitionNode
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["LoneAnonymousOperationRule"]
|
||||
|
||||
|
||||
class LoneAnonymousOperationRule(ASTValidationRule):
|
||||
"""Lone anonymous operation
|
||||
|
||||
A GraphQL document is only valid if when it contains an anonymous operation
|
||||
(the query short-hand) that it contains only that one operation definition.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Lone-Anonymous-Operation
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
self.operation_count = 0
|
||||
|
||||
def enter_document(self, node: DocumentNode, *_args: Any) -> None:
|
||||
self.operation_count = sum(
|
||||
isinstance(definition, OperationDefinitionNode)
|
||||
for definition in node.definitions
|
||||
)
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
if not node.name and self.operation_count > 1:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
"This anonymous operation must be the only defined operation.", node
|
||||
)
|
||||
)
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import SchemaDefinitionNode
|
||||
from . import SDLValidationRule, SDLValidationContext
|
||||
|
||||
__all__ = ["LoneSchemaDefinitionRule"]
|
||||
|
||||
|
||||
class LoneSchemaDefinitionRule(SDLValidationRule):
|
||||
"""Lone Schema definition
|
||||
|
||||
A GraphQL document is only valid if it contains only one schema definition.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
old_schema = context.schema
|
||||
self.already_defined = old_schema and (
|
||||
old_schema.ast_node
|
||||
or old_schema.query_type
|
||||
or old_schema.mutation_type
|
||||
or old_schema.subscription_type
|
||||
)
|
||||
self.schema_definitions_count = 0
|
||||
|
||||
def enter_schema_definition(self, node: SchemaDefinitionNode, *_args: Any) -> None:
|
||||
if self.already_defined:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
"Cannot define a new schema within a schema extension.", node
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.schema_definitions_count:
|
||||
self.report_error(
|
||||
GraphQLError("Must provide only one schema definition.", node)
|
||||
)
|
||||
self.schema_definitions_count += 1
|
||||
+81
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["NoFragmentCyclesRule"]
|
||||
|
||||
|
||||
class NoFragmentCyclesRule(ASTValidationRule):
|
||||
"""No fragment cycles
|
||||
|
||||
The graph of fragment spreads must not form any cycles including spreading itself.
|
||||
Otherwise an operation could infinitely spread or infinitely execute on cycles in
|
||||
the underlying data.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragment-spreads-must-not-form-cycles
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
# Tracks already visited fragments to maintain O(N) and to ensure that
|
||||
# cycles are not redundantly reported.
|
||||
self.visited_frags: Set[str] = set()
|
||||
# List of AST nodes used to produce meaningful errors
|
||||
self.spread_path: List[FragmentSpreadNode] = []
|
||||
# Position in the spread path
|
||||
self.spread_path_index_by_name: Dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def enter_operation_definition(*_args: Any) -> VisitorAction:
|
||||
return SKIP
|
||||
|
||||
def enter_fragment_definition(
|
||||
self, node: FragmentDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
self.detect_cycle_recursive(node)
|
||||
return SKIP
|
||||
|
||||
def detect_cycle_recursive(self, fragment: FragmentDefinitionNode) -> None:
|
||||
# This does a straight-forward DFS to find cycles.
|
||||
# It does not terminate when a cycle was found but continues to explore
|
||||
# the graph to find all possible cycles.
|
||||
if fragment.name.value in self.visited_frags:
|
||||
return
|
||||
|
||||
fragment_name = fragment.name.value
|
||||
visited_frags = self.visited_frags
|
||||
visited_frags.add(fragment_name)
|
||||
|
||||
spread_nodes = self.context.get_fragment_spreads(fragment.selection_set)
|
||||
if not spread_nodes:
|
||||
return
|
||||
|
||||
spread_path = self.spread_path
|
||||
spread_path_index = self.spread_path_index_by_name
|
||||
spread_path_index[fragment_name] = len(spread_path)
|
||||
get_fragment = self.context.get_fragment
|
||||
|
||||
for spread_node in spread_nodes:
|
||||
spread_name = spread_node.name.value
|
||||
cycle_index = spread_path_index.get(spread_name)
|
||||
|
||||
spread_path.append(spread_node)
|
||||
if cycle_index is None:
|
||||
spread_fragment = get_fragment(spread_name)
|
||||
if spread_fragment:
|
||||
self.detect_cycle_recursive(spread_fragment)
|
||||
else:
|
||||
cycle_path = spread_path[cycle_index:]
|
||||
via_path = ", ".join("'" + s.name.value + "'" for s in cycle_path[:-1])
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Cannot spread fragment '{spread_name}' within itself"
|
||||
+ (f" via {via_path}." if via_path else "."),
|
||||
cycle_path,
|
||||
)
|
||||
)
|
||||
spread_path.pop()
|
||||
|
||||
del spread_path_index[fragment_name]
|
||||
+50
@@ -0,0 +1,50 @@
|
||||
from typing import Any, Set
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import OperationDefinitionNode, VariableDefinitionNode
|
||||
from . import ValidationContext, ValidationRule
|
||||
|
||||
__all__ = ["NoUndefinedVariablesRule"]
|
||||
|
||||
|
||||
class NoUndefinedVariablesRule(ValidationRule):
|
||||
"""No undefined variables
|
||||
|
||||
A GraphQL operation is only valid if all variables encountered, both directly and
|
||||
via fragment spreads, are defined by that operation.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-All-Variable-Uses-Defined
|
||||
"""
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
self.defined_variable_names: Set[str] = set()
|
||||
|
||||
def enter_operation_definition(self, *_args: Any) -> None:
|
||||
self.defined_variable_names.clear()
|
||||
|
||||
def leave_operation_definition(
|
||||
self, operation: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
usages = self.context.get_recursive_variable_usages(operation)
|
||||
defined_variables = self.defined_variable_names
|
||||
for usage in usages:
|
||||
node = usage.node
|
||||
var_name = node.name.value
|
||||
if var_name not in defined_variables:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
(
|
||||
f"Variable '${var_name}' is not defined"
|
||||
f" by operation '{operation.name.value}'."
|
||||
if operation.name
|
||||
else f"Variable '${var_name}' is not defined."
|
||||
),
|
||||
[node, operation],
|
||||
)
|
||||
)
|
||||
|
||||
def enter_variable_definition(
|
||||
self, node: VariableDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
self.defined_variable_names.add(node.variable.name.value)
|
||||
+53
@@ -0,0 +1,53 @@
|
||||
from typing import Any, List
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
FragmentDefinitionNode,
|
||||
OperationDefinitionNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
)
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["NoUnusedFragmentsRule"]
|
||||
|
||||
|
||||
class NoUnusedFragmentsRule(ASTValidationRule):
|
||||
"""No unused fragments
|
||||
|
||||
A GraphQL document is only valid if all fragment definitions are spread within
|
||||
operations, or spread within other fragments spread within operations.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragments-Must-Be-Used
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
self.operation_defs: List[OperationDefinitionNode] = []
|
||||
self.fragment_defs: List[FragmentDefinitionNode] = []
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
self.operation_defs.append(node)
|
||||
return SKIP
|
||||
|
||||
def enter_fragment_definition(
|
||||
self, node: FragmentDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
self.fragment_defs.append(node)
|
||||
return SKIP
|
||||
|
||||
def leave_document(self, *_args: Any) -> None:
|
||||
fragment_names_used = set()
|
||||
get_fragments = self.context.get_recursively_referenced_fragments
|
||||
for operation in self.operation_defs:
|
||||
for fragment in get_fragments(operation):
|
||||
fragment_names_used.add(fragment.name.value)
|
||||
|
||||
for fragment_def in self.fragment_defs:
|
||||
frag_name = fragment_def.name.value
|
||||
if frag_name not in fragment_names_used:
|
||||
self.report_error(
|
||||
GraphQLError(f"Fragment '{frag_name}' is never used.", fragment_def)
|
||||
)
|
||||
+53
@@ -0,0 +1,53 @@
|
||||
from typing import Any, List, Set
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import OperationDefinitionNode, VariableDefinitionNode
|
||||
from . import ValidationContext, ValidationRule
|
||||
|
||||
__all__ = ["NoUnusedVariablesRule"]
|
||||
|
||||
|
||||
class NoUnusedVariablesRule(ValidationRule):
|
||||
"""No unused variables
|
||||
|
||||
A GraphQL operation is only valid if all variables defined by an operation are used,
|
||||
either directly or within a spread fragment.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-All-Variables-Used
|
||||
"""
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
self.variable_defs: List[VariableDefinitionNode] = []
|
||||
|
||||
def enter_operation_definition(self, *_args: Any) -> None:
|
||||
self.variable_defs.clear()
|
||||
|
||||
def leave_operation_definition(
|
||||
self, operation: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
variable_name_used: Set[str] = set()
|
||||
usages = self.context.get_recursive_variable_usages(operation)
|
||||
|
||||
for usage in usages:
|
||||
variable_name_used.add(usage.node.name.value)
|
||||
|
||||
for variable_def in self.variable_defs:
|
||||
variable_name = variable_def.variable.name.value
|
||||
if variable_name not in variable_name_used:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
(
|
||||
f"Variable '${variable_name}' is never used"
|
||||
f" in operation '{operation.name.value}'."
|
||||
if operation.name
|
||||
else f"Variable '${variable_name}' is never used."
|
||||
),
|
||||
variable_def,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_variable_definition(
|
||||
self, definition: VariableDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
self.variable_defs.append(definition)
|
||||
+783
@@ -0,0 +1,783 @@
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveNode,
|
||||
FieldNode,
|
||||
FragmentDefinitionNode,
|
||||
FragmentSpreadNode,
|
||||
InlineFragmentNode,
|
||||
SelectionSetNode,
|
||||
ValueNode,
|
||||
print_ast,
|
||||
)
|
||||
from ...type import (
|
||||
GraphQLCompositeType,
|
||||
GraphQLField,
|
||||
GraphQLList,
|
||||
GraphQLNamedType,
|
||||
GraphQLNonNull,
|
||||
GraphQLOutputType,
|
||||
get_named_type,
|
||||
is_interface_type,
|
||||
is_leaf_type,
|
||||
is_list_type,
|
||||
is_non_null_type,
|
||||
is_object_type,
|
||||
)
|
||||
from ...utilities import type_from_ast
|
||||
from ...utilities.sort_value_node import sort_value_node
|
||||
from . import ValidationContext, ValidationRule
|
||||
|
||||
MYPY = False
|
||||
|
||||
__all__ = ["OverlappingFieldsCanBeMergedRule"]
|
||||
|
||||
|
||||
def reason_message(reason: "ConflictReasonMessage") -> str:
|
||||
if isinstance(reason, list):
|
||||
return " and ".join(
|
||||
f"subfields '{response_name}' conflict"
|
||||
f" because {reason_message(sub_reason)}"
|
||||
for response_name, sub_reason in reason
|
||||
)
|
||||
return reason
|
||||
|
||||
|
||||
class OverlappingFieldsCanBeMergedRule(ValidationRule):
|
||||
"""Overlapping fields can be merged
|
||||
|
||||
A selection set is only valid if all fields (including spreading any fragments)
|
||||
either correspond to distinct response names or can be merged without ambiguity.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Field-Selection-Merging
|
||||
"""
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
# A memoization for when two fragments are compared "between" each other for
|
||||
# conflicts. Two fragments may be compared many times, so memoizing this can
|
||||
# dramatically improve the performance of this validator.
|
||||
self.compared_fragment_pairs = PairSet()
|
||||
|
||||
# A cache for the "field map" and list of fragment names found in any given
|
||||
# selection set. Selection sets may be asked for this information multiple
|
||||
# times, so this improves the performance of this validator.
|
||||
self.cached_fields_and_fragment_names: Dict = {}
|
||||
|
||||
def enter_selection_set(self, selection_set: SelectionSetNode, *_args: Any) -> None:
|
||||
conflicts = find_conflicts_within_selection_set(
|
||||
self.context,
|
||||
self.cached_fields_and_fragment_names,
|
||||
self.compared_fragment_pairs,
|
||||
self.context.get_parent_type(),
|
||||
selection_set,
|
||||
)
|
||||
for (reason_name, reason), fields1, fields2 in conflicts:
|
||||
reason_msg = reason_message(reason)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Fields '{reason_name}' conflict because {reason_msg}."
|
||||
" Use different aliases on the fields to fetch both"
|
||||
" if this was intentional.",
|
||||
fields1 + fields2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Conflict = Tuple["ConflictReason", List[FieldNode], List[FieldNode]]
|
||||
# Field name and reason.
|
||||
ConflictReason = Tuple[str, "ConflictReasonMessage"]
|
||||
# Reason is a string, or a nested list of conflicts.
|
||||
if MYPY: # recursive types not fully supported yet (/python/mypy/issues/731)
|
||||
ConflictReasonMessage = Union[str, List]
|
||||
else:
|
||||
ConflictReasonMessage = Union[str, List[ConflictReason]]
|
||||
# Tuple defining a field node in a context.
|
||||
NodeAndDef = Tuple[GraphQLCompositeType, FieldNode, Optional[GraphQLField]]
|
||||
# Dictionary of lists of those.
|
||||
NodeAndDefCollection = Dict[str, List[NodeAndDef]]
|
||||
|
||||
|
||||
# Algorithm:
|
||||
#
|
||||
# Conflicts occur when two fields exist in a query which will produce the same
|
||||
# response name, but represent differing values, thus creating a conflict.
|
||||
# The algorithm below finds all conflicts via making a series of comparisons
|
||||
# between fields. In order to compare as few fields as possible, this makes
|
||||
# a series of comparisons "within" sets of fields and "between" sets of fields.
|
||||
#
|
||||
# Given any selection set, a collection produces both a set of fields by
|
||||
# also including all inline fragments, as well as a list of fragments
|
||||
# referenced by fragment spreads.
|
||||
#
|
||||
# A) Each selection set represented in the document first compares "within" its
|
||||
# collected set of fields, finding any conflicts between every pair of
|
||||
# overlapping fields.
|
||||
# Note: This is the#only time* that a the fields "within" a set are compared
|
||||
# to each other. After this only fields "between" sets are compared.
|
||||
#
|
||||
# B) Also, if any fragment is referenced in a selection set, then a
|
||||
# comparison is made "between" the original set of fields and the
|
||||
# referenced fragment.
|
||||
#
|
||||
# C) Also, if multiple fragments are referenced, then comparisons
|
||||
# are made "between" each referenced fragment.
|
||||
#
|
||||
# D) When comparing "between" a set of fields and a referenced fragment, first
|
||||
# a comparison is made between each field in the original set of fields and
|
||||
# each field in the the referenced set of fields.
|
||||
#
|
||||
# E) Also, if any fragment is referenced in the referenced selection set,
|
||||
# then a comparison is made "between" the original set of fields and the
|
||||
# referenced fragment (recursively referring to step D).
|
||||
#
|
||||
# F) When comparing "between" two fragments, first a comparison is made between
|
||||
# each field in the first referenced set of fields and each field in the the
|
||||
# second referenced set of fields.
|
||||
#
|
||||
# G) Also, any fragments referenced by the first must be compared to the
|
||||
# second, and any fragments referenced by the second must be compared to the
|
||||
# first (recursively referring to step F).
|
||||
#
|
||||
# H) When comparing two fields, if both have selection sets, then a comparison
|
||||
# is made "between" both selection sets, first comparing the set of fields in
|
||||
# the first selection set with the set of fields in the second.
|
||||
#
|
||||
# I) Also, if any fragment is referenced in either selection set, then a
|
||||
# comparison is made "between" the other set of fields and the
|
||||
# referenced fragment.
|
||||
#
|
||||
# J) Also, if two fragments are referenced in both selection sets, then a
|
||||
# comparison is made "between" the two fragments.
|
||||
|
||||
|
||||
def find_conflicts_within_selection_set(
|
||||
context: ValidationContext,
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
parent_type: Optional[GraphQLNamedType],
|
||||
selection_set: SelectionSetNode,
|
||||
) -> List[Conflict]:
|
||||
"""Find conflicts within selection set.
|
||||
|
||||
Find all conflicts found "within" a selection set, including those found via
|
||||
spreading in fragments.
|
||||
|
||||
Called when visiting each SelectionSet in the GraphQL Document.
|
||||
"""
|
||||
conflicts: List[Conflict] = []
|
||||
|
||||
field_map, fragment_names = get_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, parent_type, selection_set
|
||||
)
|
||||
|
||||
# (A) Find all conflicts "within" the fields of this selection set.
|
||||
# Note: this is the *only place* `collect_conflicts_within` is called.
|
||||
collect_conflicts_within(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
field_map,
|
||||
)
|
||||
|
||||
if fragment_names:
|
||||
# (B) Then collect conflicts between these fields and those represented by each
|
||||
# spread fragment name found.
|
||||
for i, fragment_name in enumerate(fragment_names):
|
||||
collect_conflicts_between_fields_and_fragment(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
False,
|
||||
field_map,
|
||||
fragment_name,
|
||||
)
|
||||
# (C) Then compare this fragment with all other fragments found in this
|
||||
# selection set to collect conflicts within fragments spread together.
|
||||
# This compares each item in the list of fragment names to every other
|
||||
# item in that same list (except for itself).
|
||||
for other_fragment_name in fragment_names[i + 1 :]:
|
||||
collect_conflicts_between_fragments(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
False,
|
||||
fragment_name,
|
||||
other_fragment_name,
|
||||
)
|
||||
|
||||
return conflicts
|
||||
|
||||
|
||||
def collect_conflicts_between_fields_and_fragment(
|
||||
context: ValidationContext,
|
||||
conflicts: List[Conflict],
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
are_mutually_exclusive: bool,
|
||||
field_map: NodeAndDefCollection,
|
||||
fragment_name: str,
|
||||
) -> None:
|
||||
"""Collect conflicts between fields and fragment.
|
||||
|
||||
Collect all conflicts found between a set of fields and a fragment reference
|
||||
including via spreading in any nested fragments.
|
||||
"""
|
||||
fragment = context.get_fragment(fragment_name)
|
||||
if not fragment:
|
||||
return None
|
||||
|
||||
field_map2, referenced_fragment_names = get_referenced_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, fragment
|
||||
)
|
||||
|
||||
# Do not compare a fragment's fieldMap to itself.
|
||||
if field_map is field_map2:
|
||||
return
|
||||
|
||||
# (D) First collect any conflicts between the provided collection of fields and the
|
||||
# collection of fields represented by the given fragment.
|
||||
collect_conflicts_between(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map,
|
||||
field_map2,
|
||||
)
|
||||
|
||||
# (E) Then collect any conflicts between the provided collection of fields and any
|
||||
# fragment names found in the given fragment.
|
||||
for referenced_fragment_name in referenced_fragment_names:
|
||||
# Memoize so two fragments are not compared for conflicts more than once.
|
||||
if compared_fragment_pairs.has(
|
||||
referenced_fragment_name, fragment_name, are_mutually_exclusive
|
||||
):
|
||||
continue
|
||||
compared_fragment_pairs.add(
|
||||
referenced_fragment_name, fragment_name, are_mutually_exclusive
|
||||
)
|
||||
|
||||
collect_conflicts_between_fields_and_fragment(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map,
|
||||
referenced_fragment_name,
|
||||
)
|
||||
|
||||
|
||||
def collect_conflicts_between_fragments(
|
||||
context: ValidationContext,
|
||||
conflicts: List[Conflict],
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
are_mutually_exclusive: bool,
|
||||
fragment_name1: str,
|
||||
fragment_name2: str,
|
||||
) -> None:
|
||||
"""Collect conflicts between fragments.
|
||||
|
||||
Collect all conflicts found between two fragments, including via spreading in any
|
||||
nested fragments.
|
||||
"""
|
||||
# No need to compare a fragment to itself.
|
||||
if fragment_name1 == fragment_name2:
|
||||
return
|
||||
|
||||
# Memoize so two fragments are not compared for conflicts more than once.
|
||||
if compared_fragment_pairs.has(
|
||||
fragment_name1, fragment_name2, are_mutually_exclusive
|
||||
):
|
||||
return
|
||||
compared_fragment_pairs.add(fragment_name1, fragment_name2, are_mutually_exclusive)
|
||||
|
||||
fragment1 = context.get_fragment(fragment_name1)
|
||||
fragment2 = context.get_fragment(fragment_name2)
|
||||
if not fragment1 or not fragment2:
|
||||
return None
|
||||
|
||||
field_map1, referenced_fragment_names1 = get_referenced_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, fragment1
|
||||
)
|
||||
|
||||
field_map2, referenced_fragment_names2 = get_referenced_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, fragment2
|
||||
)
|
||||
|
||||
# (F) First, collect all conflicts between these two collections of fields
|
||||
# (not including any nested fragments)
|
||||
collect_conflicts_between(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map1,
|
||||
field_map2,
|
||||
)
|
||||
|
||||
# (G) Then collect conflicts between the first fragment and any nested fragments
|
||||
# spread in the second fragment.
|
||||
for referenced_fragment_name2 in referenced_fragment_names2:
|
||||
collect_conflicts_between_fragments(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
fragment_name1,
|
||||
referenced_fragment_name2,
|
||||
)
|
||||
|
||||
# (G) Then collect conflicts between the second fragment and any nested fragments
|
||||
# spread in the first fragment.
|
||||
for referenced_fragment_name1 in referenced_fragment_names1:
|
||||
collect_conflicts_between_fragments(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
referenced_fragment_name1,
|
||||
fragment_name2,
|
||||
)
|
||||
|
||||
|
||||
def find_conflicts_between_sub_selection_sets(
|
||||
context: ValidationContext,
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
are_mutually_exclusive: bool,
|
||||
parent_type1: Optional[GraphQLNamedType],
|
||||
selection_set1: SelectionSetNode,
|
||||
parent_type2: Optional[GraphQLNamedType],
|
||||
selection_set2: SelectionSetNode,
|
||||
) -> List[Conflict]:
|
||||
"""Find conflicts between sub selection sets.
|
||||
|
||||
Find all conflicts found between two selection sets, including those found via
|
||||
spreading in fragments. Called when determining if conflicts exist between the
|
||||
sub-fields of two overlapping fields.
|
||||
"""
|
||||
conflicts: List[Conflict] = []
|
||||
|
||||
field_map1, fragment_names1 = get_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, parent_type1, selection_set1
|
||||
)
|
||||
field_map2, fragment_names2 = get_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, parent_type2, selection_set2
|
||||
)
|
||||
|
||||
# (H) First, collect all conflicts between these two collections of field.
|
||||
collect_conflicts_between(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map1,
|
||||
field_map2,
|
||||
)
|
||||
|
||||
# (I) Then collect conflicts between the first collection of fields and those
|
||||
# referenced by each fragment name associated with the second.
|
||||
if fragment_names2:
|
||||
for fragment_name2 in fragment_names2:
|
||||
collect_conflicts_between_fields_and_fragment(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map1,
|
||||
fragment_name2,
|
||||
)
|
||||
|
||||
# (I) Then collect conflicts between the second collection of fields and those
|
||||
# referenced by each fragment name associated with the first.
|
||||
if fragment_names1:
|
||||
for fragment_name1 in fragment_names1:
|
||||
collect_conflicts_between_fields_and_fragment(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
field_map2,
|
||||
fragment_name1,
|
||||
)
|
||||
|
||||
# (J) Also collect conflicts between any fragment names by the first and fragment
|
||||
# names by the second. This compares each item in the first set of names to each
|
||||
# item in the second set of names.
|
||||
for fragment_name1 in fragment_names1:
|
||||
for fragment_name2 in fragment_names2:
|
||||
collect_conflicts_between_fragments(
|
||||
context,
|
||||
conflicts,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
fragment_name1,
|
||||
fragment_name2,
|
||||
)
|
||||
|
||||
return conflicts
|
||||
|
||||
|
||||
def collect_conflicts_within(
|
||||
context: ValidationContext,
|
||||
conflicts: List[Conflict],
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
field_map: NodeAndDefCollection,
|
||||
) -> None:
|
||||
"""Collect all Conflicts "within" one collection of fields."""
|
||||
# A field map is a keyed collection, where each key represents a response name and
|
||||
# the value at that key is a list of all fields which provide that response name.
|
||||
# For every response name, if there are multiple fields, they must be compared to
|
||||
# find a potential conflict.
|
||||
for response_name, fields in field_map.items():
|
||||
# This compares every field in the list to every other field in this list
|
||||
# (except to itself). If the list only has one item, nothing needs to be
|
||||
# compared.
|
||||
if len(fields) > 1:
|
||||
for i, field in enumerate(fields):
|
||||
for other_field in fields[i + 1 :]:
|
||||
conflict = find_conflict(
|
||||
context,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
# within one collection is never mutually exclusive
|
||||
False,
|
||||
response_name,
|
||||
field,
|
||||
other_field,
|
||||
)
|
||||
if conflict:
|
||||
conflicts.append(conflict)
|
||||
|
||||
|
||||
def collect_conflicts_between(
|
||||
context: ValidationContext,
|
||||
conflicts: List[Conflict],
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
parent_fields_are_mutually_exclusive: bool,
|
||||
field_map1: NodeAndDefCollection,
|
||||
field_map2: NodeAndDefCollection,
|
||||
) -> None:
|
||||
"""Collect all Conflicts between two collections of fields.
|
||||
|
||||
This is similar to, but different from the :func:`~.collect_conflicts_within`
|
||||
function above. This check assumes that :func:`~.collect_conflicts_within` has
|
||||
already been called on each provided collection of fields. This is true because
|
||||
this validator traverses each individual selection set.
|
||||
"""
|
||||
# A field map is a keyed collection, where each key represents a response name and
|
||||
# the value at that key is a list of all fields which provide that response name.
|
||||
# For any response name which appears in both provided field maps, each field from
|
||||
# the first field map must be compared to every field in the second field map to
|
||||
# find potential conflicts.
|
||||
for response_name, fields1 in field_map1.items():
|
||||
fields2 = field_map2.get(response_name)
|
||||
if fields2:
|
||||
for field1 in fields1:
|
||||
for field2 in fields2:
|
||||
conflict = find_conflict(
|
||||
context,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
parent_fields_are_mutually_exclusive,
|
||||
response_name,
|
||||
field1,
|
||||
field2,
|
||||
)
|
||||
if conflict:
|
||||
conflicts.append(conflict)
|
||||
|
||||
|
||||
def find_conflict(
|
||||
context: ValidationContext,
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
compared_fragment_pairs: "PairSet",
|
||||
parent_fields_are_mutually_exclusive: bool,
|
||||
response_name: str,
|
||||
field1: NodeAndDef,
|
||||
field2: NodeAndDef,
|
||||
) -> Optional[Conflict]:
|
||||
"""Find conflict.
|
||||
|
||||
Determines if there is a conflict between two particular fields, including comparing
|
||||
their sub-fields.
|
||||
"""
|
||||
parent_type1, node1, def1 = field1
|
||||
parent_type2, node2, def2 = field2
|
||||
|
||||
# If it is known that two fields could not possibly apply at the same time, due to
|
||||
# the parent types, then it is safe to permit them to diverge in aliased field or
|
||||
# arguments used as they will not present any ambiguity by differing. It is known
|
||||
# that two parent types could never overlap if they are different Object types.
|
||||
# Interface or Union types might overlap - if not in the current state of the
|
||||
# schema, then perhaps in some future version, thus may not safely diverge.
|
||||
are_mutually_exclusive = parent_fields_are_mutually_exclusive or (
|
||||
parent_type1 != parent_type2
|
||||
and is_object_type(parent_type1)
|
||||
and is_object_type(parent_type2)
|
||||
)
|
||||
|
||||
# The return type for each field.
|
||||
type1 = cast(Optional[GraphQLOutputType], def1 and def1.type)
|
||||
type2 = cast(Optional[GraphQLOutputType], def2 and def2.type)
|
||||
|
||||
if not are_mutually_exclusive:
|
||||
# Two aliases must refer to the same field.
|
||||
name1 = node1.name.value
|
||||
name2 = node2.name.value
|
||||
if name1 != name2:
|
||||
return (
|
||||
(response_name, f"'{name1}' and '{name2}' are different fields"),
|
||||
[node1],
|
||||
[node2],
|
||||
)
|
||||
|
||||
# Two field calls must have the same arguments.
|
||||
if not same_arguments(node1, node2):
|
||||
return (response_name, "they have differing arguments"), [node1], [node2]
|
||||
|
||||
if type1 and type2 and do_types_conflict(type1, type2):
|
||||
return (
|
||||
(response_name, f"they return conflicting types '{type1}' and '{type2}'"),
|
||||
[node1],
|
||||
[node2],
|
||||
)
|
||||
|
||||
# Collect and compare sub-fields. Use the same "visited fragment names" list for
|
||||
# both collections so fields in a fragment reference are never compared to
|
||||
# themselves.
|
||||
selection_set1 = node1.selection_set
|
||||
selection_set2 = node2.selection_set
|
||||
if selection_set1 and selection_set2:
|
||||
conflicts = find_conflicts_between_sub_selection_sets(
|
||||
context,
|
||||
cached_fields_and_fragment_names,
|
||||
compared_fragment_pairs,
|
||||
are_mutually_exclusive,
|
||||
get_named_type(type1),
|
||||
selection_set1,
|
||||
get_named_type(type2),
|
||||
selection_set2,
|
||||
)
|
||||
return subfield_conflicts(conflicts, response_name, node1, node2)
|
||||
|
||||
return None # no conflict
|
||||
|
||||
|
||||
def same_arguments(
|
||||
node1: Union[FieldNode, DirectiveNode], node2: Union[FieldNode, DirectiveNode]
|
||||
) -> bool:
|
||||
args1 = node1.arguments
|
||||
args2 = node2.arguments
|
||||
|
||||
if not args1:
|
||||
return not args2
|
||||
|
||||
if not args2:
|
||||
return False
|
||||
|
||||
if len(args1) != len(args2):
|
||||
return False # pragma: no cover
|
||||
|
||||
values2 = {arg.name.value: arg.value for arg in args2}
|
||||
|
||||
for arg1 in args1:
|
||||
value1 = arg1.value
|
||||
value2 = values2.get(arg1.name.value)
|
||||
if value2 is None or stringify_value(value1) != stringify_value(value2):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def stringify_value(value: ValueNode) -> str:
|
||||
return print_ast(sort_value_node(value))
|
||||
|
||||
|
||||
def do_types_conflict(type1: GraphQLOutputType, type2: GraphQLOutputType) -> bool:
|
||||
"""Check whether two types conflict
|
||||
|
||||
Two types conflict if both types could not apply to a value simultaneously.
|
||||
Composite types are ignored as their individual field types will be compared later
|
||||
recursively. However List and Non-Null types must match.
|
||||
"""
|
||||
if is_list_type(type1):
|
||||
return (
|
||||
do_types_conflict(
|
||||
cast(GraphQLList, type1).of_type, cast(GraphQLList, type2).of_type
|
||||
)
|
||||
if is_list_type(type2)
|
||||
else True
|
||||
)
|
||||
if is_list_type(type2):
|
||||
return True
|
||||
if is_non_null_type(type1):
|
||||
return (
|
||||
do_types_conflict(
|
||||
cast(GraphQLNonNull, type1).of_type, cast(GraphQLNonNull, type2).of_type
|
||||
)
|
||||
if is_non_null_type(type2)
|
||||
else True
|
||||
)
|
||||
if is_non_null_type(type2):
|
||||
return True
|
||||
if is_leaf_type(type1) or is_leaf_type(type2):
|
||||
return type1 is not type2
|
||||
return False
|
||||
|
||||
|
||||
def get_fields_and_fragment_names(
|
||||
context: ValidationContext,
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
parent_type: Optional[GraphQLNamedType],
|
||||
selection_set: SelectionSetNode,
|
||||
) -> Tuple[NodeAndDefCollection, List[str]]:
|
||||
"""Get fields and referenced fragment names
|
||||
|
||||
Given a selection set, return the collection of fields (a mapping of response name
|
||||
to field nodes and definitions) as well as a list of fragment names referenced via
|
||||
fragment spreads.
|
||||
"""
|
||||
cached = cached_fields_and_fragment_names.get(selection_set)
|
||||
if not cached:
|
||||
node_and_defs: NodeAndDefCollection = {}
|
||||
fragment_names: Dict[str, bool] = {}
|
||||
collect_fields_and_fragment_names(
|
||||
context, parent_type, selection_set, node_and_defs, fragment_names
|
||||
)
|
||||
cached = (node_and_defs, list(fragment_names))
|
||||
cached_fields_and_fragment_names[selection_set] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def get_referenced_fields_and_fragment_names(
|
||||
context: ValidationContext,
|
||||
cached_fields_and_fragment_names: Dict,
|
||||
fragment: FragmentDefinitionNode,
|
||||
) -> Tuple[NodeAndDefCollection, List[str]]:
|
||||
"""Get referenced fields and nested fragment names
|
||||
|
||||
Given a reference to a fragment, return the represented collection of fields as well
|
||||
as a list of nested fragment names referenced via fragment spreads.
|
||||
"""
|
||||
# Short-circuit building a type from the node if possible.
|
||||
cached = cached_fields_and_fragment_names.get(fragment.selection_set)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
fragment_type = type_from_ast(context.schema, fragment.type_condition)
|
||||
return get_fields_and_fragment_names(
|
||||
context, cached_fields_and_fragment_names, fragment_type, fragment.selection_set
|
||||
)
|
||||
|
||||
|
||||
def collect_fields_and_fragment_names(
|
||||
context: ValidationContext,
|
||||
parent_type: Optional[GraphQLNamedType],
|
||||
selection_set: SelectionSetNode,
|
||||
node_and_defs: NodeAndDefCollection,
|
||||
fragment_names: Dict[str, bool],
|
||||
) -> None:
|
||||
for selection in selection_set.selections:
|
||||
if isinstance(selection, FieldNode):
|
||||
field_name = selection.name.value
|
||||
field_def = (
|
||||
parent_type.fields.get(field_name) # type: ignore
|
||||
if is_object_type(parent_type) or is_interface_type(parent_type)
|
||||
else None
|
||||
)
|
||||
response_name = selection.alias.value if selection.alias else field_name
|
||||
if not node_and_defs.get(response_name):
|
||||
node_and_defs[response_name] = []
|
||||
node_and_defs[response_name].append(
|
||||
cast(NodeAndDef, (parent_type, selection, field_def))
|
||||
)
|
||||
elif isinstance(selection, FragmentSpreadNode):
|
||||
fragment_names[selection.name.value] = True
|
||||
elif isinstance(selection, InlineFragmentNode): # pragma: no cover else
|
||||
type_condition = selection.type_condition
|
||||
inline_fragment_type = (
|
||||
type_from_ast(context.schema, type_condition)
|
||||
if type_condition
|
||||
else parent_type
|
||||
)
|
||||
collect_fields_and_fragment_names(
|
||||
context,
|
||||
inline_fragment_type,
|
||||
selection.selection_set,
|
||||
node_and_defs,
|
||||
fragment_names,
|
||||
)
|
||||
|
||||
|
||||
def subfield_conflicts(
|
||||
conflicts: List[Conflict], response_name: str, node1: FieldNode, node2: FieldNode
|
||||
) -> Optional[Conflict]:
|
||||
"""Check whether there are conflicts between sub-fields.
|
||||
|
||||
Given a series of Conflicts which occurred between two sub-fields, generate a single
|
||||
Conflict.
|
||||
"""
|
||||
if conflicts:
|
||||
return (
|
||||
(response_name, [conflict[0] for conflict in conflicts]),
|
||||
list(chain([node1], *[conflict[1] for conflict in conflicts])),
|
||||
list(chain([node2], *[conflict[2] for conflict in conflicts])),
|
||||
)
|
||||
return None # no conflict
|
||||
|
||||
|
||||
class PairSet:
|
||||
"""Pair set
|
||||
|
||||
A way to keep track of pairs of things when the ordering of the pair doesn't matter.
|
||||
"""
|
||||
|
||||
__slots__ = ("_data",)
|
||||
|
||||
_data: Dict[str, Dict[str, bool]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data = {}
|
||||
|
||||
def has(self, a: str, b: str, are_mutually_exclusive: bool) -> bool:
|
||||
key1, key2 = (a, b) if a < b else (b, a)
|
||||
|
||||
map_ = self._data.get(key1)
|
||||
if map_ is None:
|
||||
return False
|
||||
result = map_.get(key2)
|
||||
if result is None:
|
||||
return False
|
||||
|
||||
# are_mutually_exclusive being False is a superset of being True,
|
||||
# hence if we want to know if this PairSet "has" these two with no exclusivity,
|
||||
# we have to ensure it was added as such.
|
||||
return True if are_mutually_exclusive else are_mutually_exclusive == result
|
||||
|
||||
def add(self, a: str, b: str, are_mutually_exclusive: bool) -> None:
|
||||
key1, key2 = (a, b) if a < b else (b, a)
|
||||
|
||||
map_ = self._data.get(key1)
|
||||
if map_ is None:
|
||||
self._data[key1] = {key2: are_mutually_exclusive}
|
||||
else:
|
||||
map_[key2] = are_mutually_exclusive
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
from typing import cast, Any, Optional
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import FragmentSpreadNode, InlineFragmentNode
|
||||
from ...type import GraphQLCompositeType, is_composite_type
|
||||
from ...utilities import do_types_overlap, type_from_ast
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["PossibleFragmentSpreadsRule"]
|
||||
|
||||
|
||||
class PossibleFragmentSpreadsRule(ValidationRule):
|
||||
"""Possible fragment spread
|
||||
|
||||
A fragment spread is only valid if the type condition could ever possibly be true:
|
||||
if there is a non-empty intersection of the possible parent types, and possible
|
||||
types which pass the type condition.
|
||||
"""
|
||||
|
||||
def enter_inline_fragment(self, node: InlineFragmentNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
frag_type = context.get_type()
|
||||
parent_type = context.get_parent_type()
|
||||
if (
|
||||
is_composite_type(frag_type)
|
||||
and is_composite_type(parent_type)
|
||||
and not do_types_overlap(
|
||||
context.schema,
|
||||
cast(GraphQLCompositeType, frag_type),
|
||||
cast(GraphQLCompositeType, parent_type),
|
||||
)
|
||||
):
|
||||
context.report_error(
|
||||
GraphQLError(
|
||||
f"Fragment cannot be spread here as objects"
|
||||
f" of type '{parent_type}' can never be of type '{frag_type}'.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
|
||||
context = self.context
|
||||
frag_name = node.name.value
|
||||
frag_type = self.get_fragment_type(frag_name)
|
||||
parent_type = context.get_parent_type()
|
||||
if (
|
||||
frag_type
|
||||
and parent_type
|
||||
and not do_types_overlap(context.schema, frag_type, parent_type)
|
||||
):
|
||||
context.report_error(
|
||||
GraphQLError(
|
||||
f"Fragment '{frag_name}' cannot be spread here as objects"
|
||||
f" of type '{parent_type}' can never be of type '{frag_type}'.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def get_fragment_type(self, name: str) -> Optional[GraphQLCompositeType]:
|
||||
context = self.context
|
||||
frag = context.get_fragment(name)
|
||||
if frag:
|
||||
type_ = type_from_ast(context.schema, frag.type_condition)
|
||||
if is_composite_type(type_):
|
||||
return cast(GraphQLCompositeType, type_)
|
||||
return None
|
||||
+109
@@ -0,0 +1,109 @@
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import TypeDefinitionNode, TypeExtensionNode
|
||||
from ...pyutils import did_you_mean, inspect, suggestion_list
|
||||
from ...type import (
|
||||
is_enum_type,
|
||||
is_input_object_type,
|
||||
is_interface_type,
|
||||
is_object_type,
|
||||
is_scalar_type,
|
||||
is_union_type,
|
||||
)
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["PossibleTypeExtensionsRule"]
|
||||
|
||||
|
||||
class PossibleTypeExtensionsRule(SDLValidationRule):
|
||||
"""Possible type extension
|
||||
|
||||
A type extension is only valid if the type is defined and has the same kind.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
self.schema = context.schema
|
||||
self.defined_types = {
|
||||
def_.name.value: def_
|
||||
for def_ in context.document.definitions
|
||||
if isinstance(def_, TypeDefinitionNode)
|
||||
}
|
||||
|
||||
def check_extension(self, node: TypeExtensionNode, *_args: Any) -> None:
|
||||
schema = self.schema
|
||||
type_name = node.name.value
|
||||
def_node = self.defined_types.get(type_name)
|
||||
existing_type = schema.get_type(type_name) if schema else None
|
||||
|
||||
expected_kind: Optional[str]
|
||||
if def_node:
|
||||
expected_kind = def_kind_to_ext_kind(def_node.kind)
|
||||
elif existing_type:
|
||||
expected_kind = type_to_ext_kind(existing_type)
|
||||
else:
|
||||
expected_kind = None
|
||||
|
||||
if expected_kind:
|
||||
if expected_kind != node.kind:
|
||||
kind_str = extension_kind_to_type_name(node.kind)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Cannot extend non-{kind_str} type '{type_name}'.",
|
||||
[def_node, node] if def_node else node,
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_type_names = list(self.defined_types)
|
||||
if self.schema:
|
||||
all_type_names.extend(self.schema.type_map)
|
||||
suggested_types = suggestion_list(type_name, all_type_names)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Cannot extend type '{type_name}' because it is not defined."
|
||||
+ did_you_mean(suggested_types),
|
||||
node.name,
|
||||
)
|
||||
)
|
||||
|
||||
enter_scalar_type_extension = enter_object_type_extension = check_extension
|
||||
enter_interface_type_extension = enter_union_type_extension = check_extension
|
||||
enter_enum_type_extension = enter_input_object_type_extension = check_extension
|
||||
|
||||
|
||||
def_kind_to_ext_kind = partial(re.compile("(?<=_type_)definition$").sub, "extension")
|
||||
|
||||
|
||||
def type_to_ext_kind(type_: Any) -> str:
|
||||
if is_scalar_type(type_):
|
||||
return "scalar_type_extension"
|
||||
if is_object_type(type_):
|
||||
return "object_type_extension"
|
||||
if is_interface_type(type_):
|
||||
return "interface_type_extension"
|
||||
if is_union_type(type_):
|
||||
return "union_type_extension"
|
||||
if is_enum_type(type_):
|
||||
return "enum_type_extension"
|
||||
if is_input_object_type(type_):
|
||||
return "input_object_type_extension"
|
||||
|
||||
# Not reachable. All possible types have been considered.
|
||||
raise TypeError(f"Unexpected type: {inspect(type_)}.")
|
||||
|
||||
|
||||
_type_names_for_extension_kinds = {
|
||||
"scalar_type_extension": "scalar",
|
||||
"object_type_extension": "object",
|
||||
"interface_type_extension": "interface",
|
||||
"union_type_extension": "union",
|
||||
"enum_type_extension": "enum",
|
||||
"input_object_type_extension": "input object",
|
||||
}
|
||||
|
||||
|
||||
def extension_kind_to_type_name(kind: str) -> str:
|
||||
return _type_names_for_extension_kinds.get(kind, "unknown type")
|
||||
+119
@@ -0,0 +1,119 @@
|
||||
from typing import cast, Any, Dict, List, Union
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveDefinitionNode,
|
||||
DirectiveNode,
|
||||
FieldNode,
|
||||
InputValueDefinitionNode,
|
||||
NonNullTypeNode,
|
||||
TypeNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
print_ast,
|
||||
)
|
||||
from ...type import GraphQLArgument, is_required_argument, is_type, specified_directives
|
||||
from . import ASTValidationRule, SDLValidationContext, ValidationContext
|
||||
|
||||
__all__ = ["ProvidedRequiredArgumentsRule", "ProvidedRequiredArgumentsOnDirectivesRule"]
|
||||
|
||||
|
||||
class ProvidedRequiredArgumentsOnDirectivesRule(ASTValidationRule):
|
||||
"""Provided required arguments on directives
|
||||
|
||||
A directive is only valid if all required (non-null without a default value)
|
||||
arguments have been provided.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
|
||||
context: Union[ValidationContext, SDLValidationContext]
|
||||
|
||||
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
|
||||
super().__init__(context)
|
||||
required_args_map: Dict[
|
||||
str, Dict[str, Union[GraphQLArgument, InputValueDefinitionNode]]
|
||||
] = {}
|
||||
|
||||
schema = context.schema
|
||||
defined_directives = schema.directives if schema else specified_directives
|
||||
for directive in cast(List, defined_directives):
|
||||
required_args_map[directive.name] = {
|
||||
name: arg
|
||||
for name, arg in directive.args.items()
|
||||
if is_required_argument(arg)
|
||||
}
|
||||
|
||||
ast_definitions = context.document.definitions
|
||||
for def_ in ast_definitions:
|
||||
if isinstance(def_, DirectiveDefinitionNode):
|
||||
required_args_map[def_.name.value] = {
|
||||
arg.name.value: arg
|
||||
for arg in filter(is_required_argument_node, def_.arguments or ())
|
||||
}
|
||||
|
||||
self.required_args_map = required_args_map
|
||||
|
||||
def leave_directive(self, directive_node: DirectiveNode, *_args: Any) -> None:
|
||||
# Validate on leave to allow for deeper errors to appear first.
|
||||
directive_name = directive_node.name.value
|
||||
required_args = self.required_args_map.get(directive_name)
|
||||
if required_args:
|
||||
|
||||
arg_nodes = directive_node.arguments or ()
|
||||
arg_node_set = {arg.name.value for arg in arg_nodes}
|
||||
for arg_name in required_args:
|
||||
if arg_name not in arg_node_set:
|
||||
arg_type = required_args[arg_name].type
|
||||
arg_type_str = (
|
||||
str(arg_type)
|
||||
if is_type(arg_type)
|
||||
else print_ast(cast(TypeNode, arg_type))
|
||||
)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Directive '@{directive_name}' argument '{arg_name}'"
|
||||
f" of type '{arg_type_str}' is required,"
|
||||
" but it was not provided.",
|
||||
directive_node,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ProvidedRequiredArgumentsRule(ProvidedRequiredArgumentsOnDirectivesRule):
|
||||
"""Provided required arguments
|
||||
|
||||
A field or directive is only valid if all required (non-null without a default
|
||||
value) field arguments have been provided.
|
||||
"""
|
||||
|
||||
context: ValidationContext
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
|
||||
def leave_field(self, field_node: FieldNode, *_args: Any) -> VisitorAction:
|
||||
# Validate on leave to allow for deeper errors to appear first.
|
||||
field_def = self.context.get_field_def()
|
||||
if not field_def:
|
||||
return SKIP
|
||||
arg_nodes = field_node.arguments or ()
|
||||
|
||||
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
|
||||
for arg_name, arg_def in field_def.args.items():
|
||||
arg_node = arg_node_map.get(arg_name)
|
||||
if not arg_node and is_required_argument(arg_def):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{field_node.name.value}' argument '{arg_name}'"
|
||||
f" of type '{arg_def.type}' is required,"
|
||||
" but it was not provided.",
|
||||
field_node,
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_required_argument_node(arg: InputValueDefinitionNode) -> bool:
|
||||
return isinstance(arg.type, NonNullTypeNode) and arg.default_value is None
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import FieldNode
|
||||
from ...type import get_named_type, is_leaf_type
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["ScalarLeafsRule"]
|
||||
|
||||
|
||||
class ScalarLeafsRule(ValidationRule):
|
||||
"""Scalar leafs
|
||||
|
||||
A GraphQL document is valid only if all leaf fields (fields without sub selections)
|
||||
are of scalar or enum types.
|
||||
"""
|
||||
|
||||
def enter_field(self, node: FieldNode, *_args: Any) -> None:
|
||||
type_ = self.context.get_type()
|
||||
if type_:
|
||||
selection_set = node.selection_set
|
||||
if is_leaf_type(get_named_type(type_)):
|
||||
if selection_set:
|
||||
field_name = node.name.value
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{field_name}' must not have a selection"
|
||||
f" since type '{type_}' has no subfields.",
|
||||
selection_set,
|
||||
)
|
||||
)
|
||||
elif not selection_set:
|
||||
field_name = node.name.value
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{field_name}' of type '{type_}'"
|
||||
" must have a selection of subfields."
|
||||
f" Did you mean '{field_name} {{ ... }}'?",
|
||||
node,
|
||||
)
|
||||
)
|
||||
+85
@@ -0,0 +1,85 @@
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...execution.collect_fields import collect_fields
|
||||
from ...language import (
|
||||
FieldNode,
|
||||
FragmentDefinitionNode,
|
||||
OperationDefinitionNode,
|
||||
OperationType,
|
||||
)
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["SingleFieldSubscriptionsRule"]
|
||||
|
||||
|
||||
class SingleFieldSubscriptionsRule(ValidationRule):
|
||||
"""Subscriptions must only include a single non-introspection field.
|
||||
|
||||
A GraphQL subscription is valid only if it contains a single root field and
|
||||
that root field is not an introspection field.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Single-root-field
|
||||
"""
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
if node.operation != OperationType.SUBSCRIPTION:
|
||||
return
|
||||
schema = self.context.schema
|
||||
subscription_type = schema.subscription_type
|
||||
if subscription_type:
|
||||
operation_name = node.name.value if node.name else None
|
||||
variable_values: Dict[str, Any] = {}
|
||||
document = self.context.document
|
||||
fragments: Dict[str, FragmentDefinitionNode] = {
|
||||
definition.name.value: definition
|
||||
for definition in document.definitions
|
||||
if isinstance(definition, FragmentDefinitionNode)
|
||||
}
|
||||
fields = collect_fields(
|
||||
schema,
|
||||
fragments,
|
||||
variable_values,
|
||||
subscription_type,
|
||||
node.selection_set,
|
||||
)
|
||||
if len(fields) > 1:
|
||||
field_selection_lists = list(fields.values())
|
||||
extra_field_selection_lists = field_selection_lists[1:]
|
||||
extra_field_selection = [
|
||||
field
|
||||
for fields in extra_field_selection_lists
|
||||
for field in (
|
||||
fields
|
||||
if isinstance(fields, list)
|
||||
else [cast(FieldNode, fields)]
|
||||
)
|
||||
]
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
(
|
||||
"Anonymous Subscription"
|
||||
if operation_name is None
|
||||
else f"Subscription '{operation_name}'"
|
||||
)
|
||||
+ " must select only one top level field.",
|
||||
extra_field_selection,
|
||||
)
|
||||
)
|
||||
for field_nodes in fields.values():
|
||||
field = field_nodes[0]
|
||||
field_name = field.name.value
|
||||
if field_name.startswith("__"):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
(
|
||||
"Anonymous Subscription"
|
||||
if operation_name is None
|
||||
else f"Subscription '{operation_name}'"
|
||||
)
|
||||
+ " must not select an introspection top level field.",
|
||||
field_nodes,
|
||||
)
|
||||
)
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
from operator import attrgetter
|
||||
from typing import Any, Collection
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveDefinitionNode,
|
||||
FieldDefinitionNode,
|
||||
InputValueDefinitionNode,
|
||||
InterfaceTypeDefinitionNode,
|
||||
InterfaceTypeExtensionNode,
|
||||
NameNode,
|
||||
ObjectTypeDefinitionNode,
|
||||
ObjectTypeExtensionNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
)
|
||||
from ...pyutils import group_by
|
||||
from . import SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueArgumentDefinitionNamesRule"]
|
||||
|
||||
|
||||
class UniqueArgumentDefinitionNamesRule(SDLValidationRule):
|
||||
"""Unique argument definition names
|
||||
|
||||
A GraphQL Object or Interface type is only valid if all its fields have uniquely
|
||||
named arguments.
|
||||
A GraphQL Directive is only valid if all its arguments are uniquely named.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Argument-Uniqueness
|
||||
"""
|
||||
|
||||
def enter_directive_definition(
|
||||
self, node: DirectiveDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
return self.check_arg_uniqueness(f"@{node.name.value}", node.arguments)
|
||||
|
||||
def enter_interface_type_definition(
|
||||
self, node: InterfaceTypeDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
return self.check_arg_uniqueness_per_field(node.name, node.fields)
|
||||
|
||||
def enter_interface_type_extension(
|
||||
self, node: InterfaceTypeExtensionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
return self.check_arg_uniqueness_per_field(node.name, node.fields)
|
||||
|
||||
def enter_object_type_definition(
|
||||
self, node: ObjectTypeDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
return self.check_arg_uniqueness_per_field(node.name, node.fields)
|
||||
|
||||
def enter_object_type_extension(
|
||||
self, node: ObjectTypeExtensionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
return self.check_arg_uniqueness_per_field(node.name, node.fields)
|
||||
|
||||
def check_arg_uniqueness_per_field(
|
||||
self,
|
||||
name: NameNode,
|
||||
fields: Collection[FieldDefinitionNode],
|
||||
) -> VisitorAction:
|
||||
type_name = name.value
|
||||
for field_def in fields:
|
||||
field_name = field_def.name.value
|
||||
argument_nodes = field_def.arguments or ()
|
||||
self.check_arg_uniqueness(f"{type_name}.{field_name}", argument_nodes)
|
||||
return SKIP
|
||||
|
||||
def check_arg_uniqueness(
|
||||
self, parent_name: str, argument_nodes: Collection[InputValueDefinitionNode]
|
||||
) -> VisitorAction:
|
||||
seen_args = group_by(argument_nodes, attrgetter("name.value"))
|
||||
for arg_name, arg_nodes in seen_args.items():
|
||||
if len(arg_nodes) > 1:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Argument '{parent_name}({arg_name}:)'"
|
||||
" can only be defined once.",
|
||||
[node.name for node in arg_nodes],
|
||||
)
|
||||
)
|
||||
return SKIP
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
from operator import attrgetter
|
||||
from typing import Any, Collection
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import ArgumentNode, DirectiveNode, FieldNode
|
||||
from ...pyutils import group_by
|
||||
from . import ASTValidationRule
|
||||
|
||||
__all__ = ["UniqueArgumentNamesRule"]
|
||||
|
||||
|
||||
class UniqueArgumentNamesRule(ASTValidationRule):
|
||||
"""Unique argument names
|
||||
|
||||
A GraphQL field or directive is only valid if all supplied arguments are uniquely
|
||||
named.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Argument-Names
|
||||
"""
|
||||
|
||||
def enter_field(self, node: FieldNode, *_args: Any) -> None:
|
||||
self.check_arg_uniqueness(node.arguments)
|
||||
|
||||
def enter_directive(self, node: DirectiveNode, *args: Any) -> None:
|
||||
self.check_arg_uniqueness(node.arguments)
|
||||
|
||||
def check_arg_uniqueness(self, argument_nodes: Collection[ArgumentNode]) -> None:
|
||||
seen_args = group_by(argument_nodes, attrgetter("name.value"))
|
||||
|
||||
for arg_name, arg_nodes in seen_args.items():
|
||||
if len(arg_nodes) > 1:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one argument named '{arg_name}'.",
|
||||
[node.name for node in arg_nodes],
|
||||
)
|
||||
)
|
||||
+46
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import DirectiveDefinitionNode, NameNode, VisitorAction, SKIP
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueDirectiveNamesRule"]
|
||||
|
||||
|
||||
class UniqueDirectiveNamesRule(SDLValidationRule):
|
||||
"""Unique directive names
|
||||
|
||||
A GraphQL document is only valid if all defined directives have unique names.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
self.known_directive_names: Dict[str, NameNode] = {}
|
||||
self.schema = context.schema
|
||||
|
||||
def enter_directive_definition(
|
||||
self, node: DirectiveDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
directive_name = node.name.value
|
||||
|
||||
if self.schema and self.schema.get_directive(directive_name):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Directive '@{directive_name}' already exists in the schema."
|
||||
" It cannot be redefined.",
|
||||
node.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if directive_name in self.known_directive_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one directive named '@{directive_name}'.",
|
||||
[self.known_directive_names[directive_name], node.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.known_directive_names[directive_name] = node.name
|
||||
return SKIP
|
||||
|
||||
return None
|
||||
+85
@@ -0,0 +1,85 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
DirectiveDefinitionNode,
|
||||
DirectiveNode,
|
||||
Node,
|
||||
SchemaDefinitionNode,
|
||||
SchemaExtensionNode,
|
||||
TypeDefinitionNode,
|
||||
TypeExtensionNode,
|
||||
is_type_definition_node,
|
||||
is_type_extension_node,
|
||||
)
|
||||
from ...type import specified_directives
|
||||
from . import ASTValidationRule, SDLValidationContext, ValidationContext
|
||||
|
||||
__all__ = ["UniqueDirectivesPerLocationRule"]
|
||||
|
||||
|
||||
class UniqueDirectivesPerLocationRule(ASTValidationRule):
|
||||
"""Unique directive names per location
|
||||
|
||||
A GraphQL document is only valid if all non-repeatable directives at a given
|
||||
location are uniquely named.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Directives-Are-Unique-Per-Location
|
||||
"""
|
||||
|
||||
context: Union[ValidationContext, SDLValidationContext]
|
||||
|
||||
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
|
||||
super().__init__(context)
|
||||
unique_directive_map: Dict[str, bool] = {}
|
||||
|
||||
schema = context.schema
|
||||
defined_directives = (
|
||||
schema.directives if schema else cast(List, specified_directives)
|
||||
)
|
||||
for directive in defined_directives:
|
||||
unique_directive_map[directive.name] = not directive.is_repeatable
|
||||
|
||||
ast_definitions = context.document.definitions
|
||||
for def_ in ast_definitions:
|
||||
if isinstance(def_, DirectiveDefinitionNode):
|
||||
unique_directive_map[def_.name.value] = not def_.repeatable
|
||||
self.unique_directive_map = unique_directive_map
|
||||
|
||||
self.schema_directives: Dict[str, DirectiveNode] = {}
|
||||
self.type_directives_map: Dict[str, Dict[str, DirectiveNode]] = defaultdict(
|
||||
dict
|
||||
)
|
||||
|
||||
# Many different AST nodes may contain directives. Rather than listing them all,
|
||||
# just listen for entering any node, and check to see if it defines any directives.
|
||||
def enter(self, node: Node, *_args: Any) -> None:
|
||||
directives = getattr(node, "directives", None)
|
||||
if not directives:
|
||||
return
|
||||
directives = cast(List[DirectiveNode], directives)
|
||||
|
||||
if isinstance(node, (SchemaDefinitionNode, SchemaExtensionNode)):
|
||||
seen_directives = self.schema_directives
|
||||
elif is_type_definition_node(node) or is_type_extension_node(node):
|
||||
node = cast(Union[TypeDefinitionNode, TypeExtensionNode], node)
|
||||
type_name = node.name.value
|
||||
seen_directives = self.type_directives_map[type_name]
|
||||
else:
|
||||
seen_directives = {}
|
||||
|
||||
for directive in directives:
|
||||
directive_name = directive.name.value
|
||||
|
||||
if self.unique_directive_map.get(directive_name):
|
||||
if directive_name in seen_directives:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"The directive '@{directive_name}'"
|
||||
" can only be used once at this location.",
|
||||
[seen_directives[directive_name], directive],
|
||||
)
|
||||
)
|
||||
else:
|
||||
seen_directives[directive_name] = directive
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
from collections import defaultdict
|
||||
from typing import cast, Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, EnumTypeDefinitionNode, VisitorAction, SKIP
|
||||
from ...type import is_enum_type, GraphQLEnumType
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueEnumValueNamesRule"]
|
||||
|
||||
|
||||
class UniqueEnumValueNamesRule(SDLValidationRule):
|
||||
"""Unique enum value names
|
||||
|
||||
A GraphQL enum type is only valid if all its values are uniquely named.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
schema = context.schema
|
||||
self.existing_type_map = schema.type_map if schema else {}
|
||||
self.known_value_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict)
|
||||
|
||||
def check_value_uniqueness(
|
||||
self, node: EnumTypeDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
existing_type_map = self.existing_type_map
|
||||
type_name = node.name.value
|
||||
value_names = self.known_value_names[type_name]
|
||||
|
||||
for value_def in node.values or []:
|
||||
value_name = value_def.name.value
|
||||
|
||||
existing_type = existing_type_map.get(type_name)
|
||||
if (
|
||||
is_enum_type(existing_type)
|
||||
and value_name in cast(GraphQLEnumType, existing_type).values
|
||||
):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Enum value '{type_name}.{value_name}'"
|
||||
" already exists in the schema."
|
||||
" It cannot also be defined in this type extension.",
|
||||
value_def.name,
|
||||
)
|
||||
)
|
||||
elif value_name in value_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Enum value '{type_name}.{value_name}'"
|
||||
" can only be defined once.",
|
||||
[value_names[value_name], value_def.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
value_names[value_name] = value_def.name
|
||||
|
||||
return SKIP
|
||||
|
||||
enter_enum_type_definition = check_value_uniqueness
|
||||
enter_enum_type_extension = check_value_uniqueness
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, ObjectTypeDefinitionNode, VisitorAction, SKIP
|
||||
from ...type import is_object_type, is_interface_type, is_input_object_type
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueFieldDefinitionNamesRule"]
|
||||
|
||||
|
||||
class UniqueFieldDefinitionNamesRule(SDLValidationRule):
|
||||
"""Unique field definition names
|
||||
|
||||
A GraphQL complex type is only valid if all its fields are uniquely named.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
schema = context.schema
|
||||
self.existing_type_map = schema.type_map if schema else {}
|
||||
self.known_field_names: Dict[str, Dict[str, NameNode]] = defaultdict(dict)
|
||||
|
||||
def check_field_uniqueness(
|
||||
self, node: ObjectTypeDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
existing_type_map = self.existing_type_map
|
||||
type_name = node.name.value
|
||||
field_names = self.known_field_names[type_name]
|
||||
|
||||
for field_def in node.fields or []:
|
||||
field_name = field_def.name.value
|
||||
|
||||
if has_field(existing_type_map.get(type_name), field_name):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{type_name}.{field_name}'"
|
||||
" already exists in the schema."
|
||||
" It cannot also be defined in this type extension.",
|
||||
field_def.name,
|
||||
)
|
||||
)
|
||||
elif field_name in field_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{type_name}.{field_name}'"
|
||||
" can only be defined once.",
|
||||
[field_names[field_name], field_def.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
field_names[field_name] = field_def.name
|
||||
|
||||
return SKIP
|
||||
|
||||
enter_input_object_type_definition = check_field_uniqueness
|
||||
enter_input_object_type_extension = check_field_uniqueness
|
||||
enter_interface_type_definition = check_field_uniqueness
|
||||
enter_interface_type_extension = check_field_uniqueness
|
||||
enter_object_type_definition = check_field_uniqueness
|
||||
enter_object_type_extension = check_field_uniqueness
|
||||
|
||||
|
||||
def has_field(type_: Any, field_name: str) -> bool:
|
||||
if is_object_type(type_) or is_interface_type(type_) or is_input_object_type(type_):
|
||||
return field_name in type_.fields
|
||||
return False
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, FragmentDefinitionNode, VisitorAction, SKIP
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["UniqueFragmentNamesRule"]
|
||||
|
||||
|
||||
class UniqueFragmentNamesRule(ASTValidationRule):
|
||||
"""Unique fragment names
|
||||
|
||||
A GraphQL document is only valid if all defined fragments have unique names.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Fragment-Name-Uniqueness
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
self.known_fragment_names: Dict[str, NameNode] = {}
|
||||
|
||||
@staticmethod
|
||||
def enter_operation_definition(*_args: Any) -> VisitorAction:
|
||||
return SKIP
|
||||
|
||||
def enter_fragment_definition(
|
||||
self, node: FragmentDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
known_fragment_names = self.known_fragment_names
|
||||
fragment_name = node.name.value
|
||||
if fragment_name in known_fragment_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one fragment named '{fragment_name}'.",
|
||||
[known_fragment_names[fragment_name], node.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
known_fragment_names[fragment_name] = node.name
|
||||
return SKIP
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, ObjectFieldNode
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["UniqueInputFieldNamesRule"]
|
||||
|
||||
|
||||
class UniqueInputFieldNamesRule(ASTValidationRule):
|
||||
"""Unique input field names
|
||||
|
||||
A GraphQL input object value is only valid if all supplied fields are uniquely
|
||||
named.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Input-Object-Field-Uniqueness
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
self.known_names_stack: List[Dict[str, NameNode]] = []
|
||||
self.known_names: Dict[str, NameNode] = {}
|
||||
|
||||
def enter_object_value(self, *_args: Any) -> None:
|
||||
self.known_names_stack.append(self.known_names)
|
||||
self.known_names = {}
|
||||
|
||||
def leave_object_value(self, *_args: Any) -> None:
|
||||
self.known_names = self.known_names_stack.pop()
|
||||
|
||||
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
|
||||
known_names = self.known_names
|
||||
field_name = node.name.value
|
||||
if field_name in known_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one input field named '{field_name}'.",
|
||||
[known_names[field_name], node.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
known_names[field_name] = node.name
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, OperationDefinitionNode, VisitorAction, SKIP
|
||||
from . import ASTValidationContext, ASTValidationRule
|
||||
|
||||
__all__ = ["UniqueOperationNamesRule"]
|
||||
|
||||
|
||||
class UniqueOperationNamesRule(ASTValidationRule):
|
||||
"""Unique operation names
|
||||
|
||||
A GraphQL document is only valid if all defined operations have unique names.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Operation-Name-Uniqueness
|
||||
"""
|
||||
|
||||
def __init__(self, context: ASTValidationContext):
|
||||
super().__init__(context)
|
||||
self.known_operation_names: Dict[str, NameNode] = {}
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> VisitorAction:
|
||||
operation_name = node.name
|
||||
if operation_name:
|
||||
known_operation_names = self.known_operation_names
|
||||
if operation_name.value in known_operation_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
"There can be only one operation"
|
||||
f" named '{operation_name.value}'.",
|
||||
[known_operation_names[operation_name.value], operation_name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
known_operation_names[operation_name.value] = operation_name
|
||||
return SKIP
|
||||
|
||||
@staticmethod
|
||||
def enter_fragment_definition(*_args: Any) -> VisitorAction:
|
||||
return SKIP
|
||||
+69
@@ -0,0 +1,69 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
OperationTypeDefinitionNode,
|
||||
OperationType,
|
||||
SchemaDefinitionNode,
|
||||
SchemaExtensionNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
)
|
||||
from ...type import GraphQLObjectType
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueOperationTypesRule"]
|
||||
|
||||
|
||||
class UniqueOperationTypesRule(SDLValidationRule):
|
||||
"""Unique operation types
|
||||
|
||||
A GraphQL document is only valid if it has only one type per operation.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
schema = context.schema
|
||||
self.defined_operation_types: Dict[
|
||||
OperationType, OperationTypeDefinitionNode
|
||||
] = {}
|
||||
self.existing_operation_types: Dict[
|
||||
OperationType, Optional[GraphQLObjectType]
|
||||
] = (
|
||||
{
|
||||
OperationType.QUERY: schema.query_type,
|
||||
OperationType.MUTATION: schema.mutation_type,
|
||||
OperationType.SUBSCRIPTION: schema.subscription_type,
|
||||
}
|
||||
if schema
|
||||
else {}
|
||||
)
|
||||
self.schema = schema
|
||||
|
||||
def check_operation_types(
|
||||
self, node: Union[SchemaDefinitionNode, SchemaExtensionNode], *_args: Any
|
||||
) -> VisitorAction:
|
||||
for operation_type in node.operation_types or []:
|
||||
operation = operation_type.operation
|
||||
already_defined_operation_type = self.defined_operation_types.get(operation)
|
||||
|
||||
if self.existing_operation_types.get(operation):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Type for {operation.value} already defined in the schema."
|
||||
" It cannot be redefined.",
|
||||
operation_type,
|
||||
)
|
||||
)
|
||||
elif already_defined_operation_type:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one {operation.value} type in schema.",
|
||||
[already_defined_operation_type, operation_type],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.defined_operation_types[operation] = operation_type
|
||||
return SKIP
|
||||
|
||||
enter_schema_definition = enter_schema_extension = check_operation_types
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import NameNode, TypeDefinitionNode, VisitorAction, SKIP
|
||||
from . import SDLValidationContext, SDLValidationRule
|
||||
|
||||
__all__ = ["UniqueTypeNamesRule"]
|
||||
|
||||
|
||||
class UniqueTypeNamesRule(SDLValidationRule):
|
||||
"""Unique type names
|
||||
|
||||
A GraphQL document is only valid if all defined types have unique names.
|
||||
"""
|
||||
|
||||
def __init__(self, context: SDLValidationContext):
|
||||
super().__init__(context)
|
||||
self.known_type_names: Dict[str, NameNode] = {}
|
||||
self.schema = context.schema
|
||||
|
||||
def check_type_name(self, node: TypeDefinitionNode, *_args: Any) -> VisitorAction:
|
||||
type_name = node.name.value
|
||||
|
||||
if self.schema and self.schema.get_type(type_name):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Type '{type_name}' already exists in the schema."
|
||||
" It cannot also be defined in this type definition.",
|
||||
node.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if type_name in self.known_type_names:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one type named '{type_name}'.",
|
||||
[self.known_type_names[type_name], node.name],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.known_type_names[type_name] = node.name
|
||||
return SKIP
|
||||
|
||||
return None
|
||||
|
||||
enter_scalar_type_definition = enter_object_type_definition = check_type_name
|
||||
enter_interface_type_definition = enter_union_type_definition = check_type_name
|
||||
enter_enum_type_definition = enter_input_object_type_definition = check_type_name
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
from operator import attrgetter
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import OperationDefinitionNode
|
||||
from ...pyutils import group_by
|
||||
from . import ASTValidationRule
|
||||
|
||||
__all__ = ["UniqueVariableNamesRule"]
|
||||
|
||||
|
||||
class UniqueVariableNamesRule(ASTValidationRule):
|
||||
"""Unique variable names
|
||||
|
||||
A GraphQL operation is only valid if all its variables are uniquely named.
|
||||
"""
|
||||
|
||||
def enter_operation_definition(
|
||||
self, node: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
variable_definitions = node.variable_definitions
|
||||
|
||||
seen_variable_definitions = group_by(
|
||||
variable_definitions, attrgetter("variable.name.value")
|
||||
)
|
||||
|
||||
for variable_name, variable_nodes in seen_variable_definitions.items():
|
||||
if len(variable_nodes) > 1:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"There can be only one variable named '${variable_name}'.",
|
||||
[node.variable.name for node in variable_nodes],
|
||||
)
|
||||
)
|
||||
+163
@@ -0,0 +1,163 @@
|
||||
from typing import cast, Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
BooleanValueNode,
|
||||
EnumValueNode,
|
||||
FloatValueNode,
|
||||
IntValueNode,
|
||||
NullValueNode,
|
||||
ListValueNode,
|
||||
ObjectFieldNode,
|
||||
ObjectValueNode,
|
||||
StringValueNode,
|
||||
ValueNode,
|
||||
VisitorAction,
|
||||
SKIP,
|
||||
print_ast,
|
||||
)
|
||||
from ...pyutils import did_you_mean, suggestion_list, Undefined
|
||||
from ...type import (
|
||||
GraphQLInputObjectType,
|
||||
GraphQLScalarType,
|
||||
get_named_type,
|
||||
get_nullable_type,
|
||||
is_input_object_type,
|
||||
is_leaf_type,
|
||||
is_list_type,
|
||||
is_non_null_type,
|
||||
is_required_input_field,
|
||||
)
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["ValuesOfCorrectTypeRule"]
|
||||
|
||||
|
||||
class ValuesOfCorrectTypeRule(ValidationRule):
|
||||
"""Value literals of correct type
|
||||
|
||||
A GraphQL document is only valid if all value literals are of the type expected at
|
||||
their position.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Values-of-Correct-Type
|
||||
"""
|
||||
|
||||
def enter_list_value(self, node: ListValueNode, *_args: Any) -> VisitorAction:
|
||||
# Note: TypeInfo will traverse into a list's item type, so look to the parent
|
||||
# input type to check if it is a list.
|
||||
type_ = get_nullable_type(self.context.get_parent_input_type()) # type: ignore
|
||||
if not is_list_type(type_):
|
||||
self.is_valid_value_node(node)
|
||||
return SKIP # Don't traverse further.
|
||||
return None
|
||||
|
||||
def enter_object_value(self, node: ObjectValueNode, *_args: Any) -> VisitorAction:
|
||||
type_ = get_named_type(self.context.get_input_type())
|
||||
if not is_input_object_type(type_):
|
||||
self.is_valid_value_node(node)
|
||||
return SKIP # Don't traverse further.
|
||||
type_ = cast(GraphQLInputObjectType, type_)
|
||||
# Ensure every required field exists.
|
||||
field_node_map = {field.name.value: field for field in node.fields}
|
||||
for field_name, field_def in type_.fields.items():
|
||||
field_node = field_node_map.get(field_name)
|
||||
if not field_node and is_required_input_field(field_def):
|
||||
field_type = field_def.type
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{type_.name}.{field_name}' of required type"
|
||||
f" '{field_type}' was not provided.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
|
||||
parent_type = get_named_type(self.context.get_parent_input_type())
|
||||
field_type = self.context.get_input_type()
|
||||
if not field_type and is_input_object_type(parent_type):
|
||||
parent_type = cast(GraphQLInputObjectType, parent_type)
|
||||
suggestions = suggestion_list(node.name.value, list(parent_type.fields))
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Field '{node.name.value}'"
|
||||
f" is not defined by type '{parent_type.name}'."
|
||||
+ did_you_mean(suggestions),
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
def enter_null_value(self, node: NullValueNode, *_args: Any) -> None:
|
||||
type_ = self.context.get_input_type()
|
||||
if is_non_null_type(type_):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Expected value of type '{type_}', found {print_ast(node)}.", node
|
||||
)
|
||||
)
|
||||
|
||||
def enter_enum_value(self, node: EnumValueNode, *_args: Any) -> None:
|
||||
self.is_valid_value_node(node)
|
||||
|
||||
def enter_int_value(self, node: IntValueNode, *_args: Any) -> None:
|
||||
self.is_valid_value_node(node)
|
||||
|
||||
def enter_float_value(self, node: FloatValueNode, *_args: Any) -> None:
|
||||
self.is_valid_value_node(node)
|
||||
|
||||
def enter_string_value(self, node: StringValueNode, *_args: Any) -> None:
|
||||
self.is_valid_value_node(node)
|
||||
|
||||
def enter_boolean_value(self, node: BooleanValueNode, *_args: Any) -> None:
|
||||
self.is_valid_value_node(node)
|
||||
|
||||
def is_valid_value_node(self, node: ValueNode) -> None:
|
||||
"""Check whether this is a valid value node.
|
||||
|
||||
Any value literal may be a valid representation of a Scalar, depending on that
|
||||
scalar type.
|
||||
"""
|
||||
# Report any error at the full type expected by the location.
|
||||
location_type = self.context.get_input_type()
|
||||
if not location_type:
|
||||
return
|
||||
|
||||
type_ = get_named_type(location_type)
|
||||
|
||||
if not is_leaf_type(type_):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Expected value of type '{location_type}',"
|
||||
f" found {print_ast(node)}.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Scalars determine if a literal value is valid via `parse_literal()` which may
|
||||
# throw or return an invalid value to indicate failure.
|
||||
type_ = cast(GraphQLScalarType, type_)
|
||||
try:
|
||||
parse_result = type_.parse_literal(node)
|
||||
if parse_result is Undefined:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Expected value of type '{location_type}',"
|
||||
f" found {print_ast(node)}.",
|
||||
node,
|
||||
)
|
||||
)
|
||||
except GraphQLError as error:
|
||||
self.report_error(error)
|
||||
except Exception as error:
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Expected value of type '{location_type}',"
|
||||
f" found {print_ast(node)}; {error}",
|
||||
node,
|
||||
# Ensure a reference to the original error is maintained.
|
||||
original_error=error,
|
||||
)
|
||||
)
|
||||
|
||||
return
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import VariableDefinitionNode, print_ast
|
||||
from ...type import is_input_type
|
||||
from ...utilities import type_from_ast
|
||||
from . import ValidationRule
|
||||
|
||||
__all__ = ["VariablesAreInputTypesRule"]
|
||||
|
||||
|
||||
class VariablesAreInputTypesRule(ValidationRule):
|
||||
"""Variables are input types
|
||||
|
||||
A GraphQL operation is only valid if all the variables it defines are of input types
|
||||
(scalar, enum, or input object).
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-Variables-Are-Input-Types
|
||||
"""
|
||||
|
||||
def enter_variable_definition(
|
||||
self, node: VariableDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
type_ = type_from_ast(self.context.schema, node.type)
|
||||
|
||||
# If the variable type is not an input type, return an error.
|
||||
if type_ is not None and not is_input_type(type_):
|
||||
variable_name = node.variable.name.value
|
||||
type_name = print_ast(node.type)
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Variable '${variable_name}'"
|
||||
f" cannot be non-input type '{type_name}'.",
|
||||
node.type,
|
||||
)
|
||||
)
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from ...error import GraphQLError
|
||||
from ...language import (
|
||||
NullValueNode,
|
||||
OperationDefinitionNode,
|
||||
ValueNode,
|
||||
VariableDefinitionNode,
|
||||
)
|
||||
from ...pyutils import Undefined
|
||||
from ...type import GraphQLNonNull, GraphQLSchema, GraphQLType, is_non_null_type
|
||||
from ...utilities import type_from_ast, is_type_sub_type_of
|
||||
from . import ValidationContext, ValidationRule
|
||||
|
||||
__all__ = ["VariablesInAllowedPositionRule"]
|
||||
|
||||
|
||||
class VariablesInAllowedPositionRule(ValidationRule):
|
||||
"""Variables in allowed position
|
||||
|
||||
Variable usages must be compatible with the arguments they are passed to.
|
||||
|
||||
See https://spec.graphql.org/draft/#sec-All-Variable-Usages-are-Allowed
|
||||
"""
|
||||
|
||||
def __init__(self, context: ValidationContext):
|
||||
super().__init__(context)
|
||||
self.var_def_map: Dict[str, Any] = {}
|
||||
|
||||
def enter_operation_definition(self, *_args: Any) -> None:
|
||||
self.var_def_map.clear()
|
||||
|
||||
def leave_operation_definition(
|
||||
self, operation: OperationDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
var_def_map = self.var_def_map
|
||||
usages = self.context.get_recursive_variable_usages(operation)
|
||||
|
||||
for usage in usages:
|
||||
node, type_ = usage.node, usage.type
|
||||
default_value = usage.default_value
|
||||
var_name = node.name.value
|
||||
var_def = var_def_map.get(var_name)
|
||||
if var_def and type_:
|
||||
# A var type is allowed if it is the same or more strict (e.g. is a
|
||||
# subtype of) than the expected type. It can be more strict if the
|
||||
# variable type is non-null when the expected type is nullable. If both
|
||||
# are list types, the variable item type can be more strict than the
|
||||
# expected item type (contravariant).
|
||||
schema = self.context.schema
|
||||
var_type = type_from_ast(schema, var_def.type)
|
||||
if var_type and not allowed_variable_usage(
|
||||
schema, var_type, var_def.default_value, type_, default_value
|
||||
):
|
||||
self.report_error(
|
||||
GraphQLError(
|
||||
f"Variable '${var_name}' of type '{var_type}' used"
|
||||
f" in position expecting type '{type_}'.",
|
||||
[var_def, node],
|
||||
)
|
||||
)
|
||||
|
||||
def enter_variable_definition(
|
||||
self, node: VariableDefinitionNode, *_args: Any
|
||||
) -> None:
|
||||
self.var_def_map[node.variable.name.value] = node
|
||||
|
||||
|
||||
def allowed_variable_usage(
|
||||
schema: GraphQLSchema,
|
||||
var_type: GraphQLType,
|
||||
var_default_value: Optional[ValueNode],
|
||||
location_type: GraphQLType,
|
||||
location_default_value: Any,
|
||||
) -> bool:
|
||||
"""Check for allowed variable usage.
|
||||
|
||||
Returns True if the variable is allowed in the location it was found, which includes
|
||||
considering if default values exist for either the variable or the location at which
|
||||
it is located.
|
||||
"""
|
||||
if is_non_null_type(location_type) and not is_non_null_type(var_type):
|
||||
has_non_null_variable_default_value = (
|
||||
var_default_value is not None
|
||||
and not isinstance(var_default_value, NullValueNode)
|
||||
)
|
||||
has_location_default_value = location_default_value is not Undefined
|
||||
if not has_non_null_variable_default_value and not has_location_default_value:
|
||||
return False
|
||||
location_type = cast(GraphQLNonNull, location_type)
|
||||
nullable_location_type = location_type.of_type
|
||||
return is_type_sub_type_of(schema, var_type, nullable_location_type)
|
||||
return is_type_sub_type_of(schema, var_type, location_type)
|
||||
@@ -0,0 +1,157 @@
|
||||
from typing import Tuple, Type
|
||||
|
||||
from .rules import ASTValidationRule
|
||||
|
||||
# Spec Section: "Executable Definitions"
|
||||
from .rules.executable_definitions import ExecutableDefinitionsRule
|
||||
|
||||
# Spec Section: "Operation Name Uniqueness"
|
||||
from .rules.unique_operation_names import UniqueOperationNamesRule
|
||||
|
||||
# Spec Section: "Lone Anonymous Operation"
|
||||
from .rules.lone_anonymous_operation import LoneAnonymousOperationRule
|
||||
|
||||
# Spec Section: "Subscriptions with Single Root Field"
|
||||
from .rules.single_field_subscriptions import SingleFieldSubscriptionsRule
|
||||
|
||||
# Spec Section: "Fragment Spread Type Existence"
|
||||
from .rules.known_type_names import KnownTypeNamesRule
|
||||
|
||||
# Spec Section: "Fragments on Composite Types"
|
||||
from .rules.fragments_on_composite_types import FragmentsOnCompositeTypesRule
|
||||
|
||||
# Spec Section: "Variables are Input Types"
|
||||
from .rules.variables_are_input_types import VariablesAreInputTypesRule
|
||||
|
||||
# Spec Section: "Leaf Field Selections"
|
||||
from .rules.scalar_leafs import ScalarLeafsRule
|
||||
|
||||
# Spec Section: "Field Selections on Objects, Interfaces, and Unions Types"
|
||||
from .rules.fields_on_correct_type import FieldsOnCorrectTypeRule
|
||||
|
||||
# Spec Section: "Fragment Name Uniqueness"
|
||||
from .rules.unique_fragment_names import UniqueFragmentNamesRule
|
||||
|
||||
# Spec Section: "Fragment spread target defined"
|
||||
from .rules.known_fragment_names import KnownFragmentNamesRule
|
||||
|
||||
# Spec Section: "Fragments must be used"
|
||||
from .rules.no_unused_fragments import NoUnusedFragmentsRule
|
||||
|
||||
# Spec Section: "Fragment spread is possible"
|
||||
from .rules.possible_fragment_spreads import PossibleFragmentSpreadsRule
|
||||
|
||||
# Spec Section: "Fragments must not form cycles"
|
||||
from .rules.no_fragment_cycles import NoFragmentCyclesRule
|
||||
|
||||
# Spec Section: "Variable Uniqueness"
|
||||
from .rules.unique_variable_names import UniqueVariableNamesRule
|
||||
|
||||
# Spec Section: "All Variable Used Defined"
|
||||
from .rules.no_undefined_variables import NoUndefinedVariablesRule
|
||||
|
||||
# Spec Section: "All Variables Used"
|
||||
from .rules.no_unused_variables import NoUnusedVariablesRule
|
||||
|
||||
# Spec Section: "Directives Are Defined"
|
||||
from .rules.known_directives import KnownDirectivesRule
|
||||
|
||||
# Spec Section: "Directives Are Unique Per Location"
|
||||
from .rules.unique_directives_per_location import UniqueDirectivesPerLocationRule
|
||||
|
||||
# Spec Section: "Argument Names"
|
||||
from .rules.known_argument_names import KnownArgumentNamesRule
|
||||
from .rules.known_argument_names import KnownArgumentNamesOnDirectivesRule
|
||||
|
||||
# Spec Section: "Argument Uniqueness"
|
||||
from .rules.unique_argument_names import UniqueArgumentNamesRule
|
||||
|
||||
# Spec Section: "Value Type Correctness"
|
||||
from .rules.values_of_correct_type import ValuesOfCorrectTypeRule
|
||||
|
||||
# Spec Section: "Argument Optionality"
|
||||
from .rules.provided_required_arguments import ProvidedRequiredArgumentsRule
|
||||
from .rules.provided_required_arguments import ProvidedRequiredArgumentsOnDirectivesRule
|
||||
|
||||
# Spec Section: "All Variable Usages Are Allowed"
|
||||
from .rules.variables_in_allowed_position import VariablesInAllowedPositionRule
|
||||
|
||||
# Spec Section: "Field Selection Merging"
|
||||
from .rules.overlapping_fields_can_be_merged import OverlappingFieldsCanBeMergedRule
|
||||
|
||||
# Spec Section: "Input Object Field Uniqueness"
|
||||
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
|
||||
|
||||
# Schema definition language:
|
||||
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
|
||||
from .rules.unique_operation_types import UniqueOperationTypesRule
|
||||
from .rules.unique_type_names import UniqueTypeNamesRule
|
||||
from .rules.unique_enum_value_names import UniqueEnumValueNamesRule
|
||||
from .rules.unique_field_definition_names import UniqueFieldDefinitionNamesRule
|
||||
from .rules.unique_argument_definition_names import UniqueArgumentDefinitionNamesRule
|
||||
from .rules.unique_directive_names import UniqueDirectiveNamesRule
|
||||
from .rules.possible_type_extensions import PossibleTypeExtensionsRule
|
||||
|
||||
__all__ = ["specified_rules", "specified_sdl_rules"]
|
||||
|
||||
|
||||
# This list includes all validation rules defined by the GraphQL spec.
|
||||
#
|
||||
# The order of the rules in this list has been adjusted to lead to the
|
||||
# most clear output when encountering multiple validation errors.
|
||||
|
||||
specified_rules: Tuple[Type[ASTValidationRule], ...] = (
|
||||
ExecutableDefinitionsRule,
|
||||
UniqueOperationNamesRule,
|
||||
LoneAnonymousOperationRule,
|
||||
SingleFieldSubscriptionsRule,
|
||||
KnownTypeNamesRule,
|
||||
FragmentsOnCompositeTypesRule,
|
||||
VariablesAreInputTypesRule,
|
||||
ScalarLeafsRule,
|
||||
FieldsOnCorrectTypeRule,
|
||||
UniqueFragmentNamesRule,
|
||||
KnownFragmentNamesRule,
|
||||
NoUnusedFragmentsRule,
|
||||
PossibleFragmentSpreadsRule,
|
||||
NoFragmentCyclesRule,
|
||||
UniqueVariableNamesRule,
|
||||
NoUndefinedVariablesRule,
|
||||
NoUnusedVariablesRule,
|
||||
KnownDirectivesRule,
|
||||
UniqueDirectivesPerLocationRule,
|
||||
KnownArgumentNamesRule,
|
||||
UniqueArgumentNamesRule,
|
||||
ValuesOfCorrectTypeRule,
|
||||
ProvidedRequiredArgumentsRule,
|
||||
VariablesInAllowedPositionRule,
|
||||
OverlappingFieldsCanBeMergedRule,
|
||||
UniqueInputFieldNamesRule,
|
||||
)
|
||||
"""A tuple with all validation rules defined by the GraphQL specification.
|
||||
|
||||
The order of the rules in this tuple has been adjusted to lead to the
|
||||
most clear output when encountering multiple validation errors.
|
||||
"""
|
||||
|
||||
specified_sdl_rules: Tuple[Type[ASTValidationRule], ...] = (
|
||||
LoneSchemaDefinitionRule,
|
||||
UniqueOperationTypesRule,
|
||||
UniqueTypeNamesRule,
|
||||
UniqueEnumValueNamesRule,
|
||||
UniqueFieldDefinitionNamesRule,
|
||||
UniqueArgumentDefinitionNamesRule,
|
||||
UniqueDirectiveNamesRule,
|
||||
KnownTypeNamesRule,
|
||||
KnownDirectivesRule,
|
||||
UniqueDirectivesPerLocationRule,
|
||||
PossibleTypeExtensionsRule,
|
||||
KnownArgumentNamesOnDirectivesRule,
|
||||
UniqueArgumentNamesRule,
|
||||
UniqueInputFieldNamesRule,
|
||||
ProvidedRequiredArgumentsOnDirectivesRule,
|
||||
)
|
||||
"""This tuple includes all rules for validating SDL.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
@@ -0,0 +1,133 @@
|
||||
from typing import Collection, List, Optional, Type
|
||||
|
||||
from ..error import GraphQLError
|
||||
from ..language import DocumentNode, ParallelVisitor, visit
|
||||
from ..pyutils import inspect, is_collection
|
||||
from ..type import GraphQLSchema, assert_valid_schema
|
||||
from ..utilities import TypeInfo, TypeInfoVisitor
|
||||
from .rules import ASTValidationRule
|
||||
from .specified_rules import specified_rules, specified_sdl_rules
|
||||
from .validation_context import SDLValidationContext, ValidationContext
|
||||
|
||||
__all__ = ["assert_valid_sdl", "assert_valid_sdl_extension", "validate", "validate_sdl"]
|
||||
|
||||
|
||||
class ValidationAbortedError(RuntimeError):
|
||||
"""Error when a validation has been aborted (error limit reached)."""
|
||||
|
||||
|
||||
def validate(
|
||||
schema: GraphQLSchema,
|
||||
document_ast: DocumentNode,
|
||||
rules: Optional[Collection[Type[ASTValidationRule]]] = None,
|
||||
max_errors: Optional[int] = None,
|
||||
type_info: Optional[TypeInfo] = None,
|
||||
) -> List[GraphQLError]:
|
||||
"""Implements the "Validation" section of the spec.
|
||||
|
||||
Validation runs synchronously, returning a list of encountered errors, or an empty
|
||||
list if no errors were encountered and the document is valid.
|
||||
|
||||
A list of specific validation rules may be provided. If not provided, the default
|
||||
list of rules defined by the GraphQL specification will be used.
|
||||
|
||||
Each validation rule is a ValidationRule object which is a visitor object that holds
|
||||
a ValidationContext (see the language/visitor API). Visitor methods are expected to
|
||||
return GraphQLErrors, or lists of GraphQLErrors when invalid.
|
||||
|
||||
Validate will stop validation after a ``max_errors`` limit has been reached.
|
||||
Attackers can send pathologically invalid queries to induce a DoS attack,
|
||||
so by default ``max_errors`` set to 100 errors.
|
||||
|
||||
Providing a custom TypeInfo instance is deprecated and will be removed in v3.3.
|
||||
"""
|
||||
if not document_ast or not isinstance(document_ast, DocumentNode):
|
||||
raise TypeError("Must provide document.")
|
||||
# If the schema used for validation is invalid, throw an error.
|
||||
assert_valid_schema(schema)
|
||||
if max_errors is None:
|
||||
max_errors = 100
|
||||
elif not isinstance(max_errors, int):
|
||||
raise TypeError("The maximum number of errors must be passed as an int.")
|
||||
if type_info is None:
|
||||
type_info = TypeInfo(schema)
|
||||
elif not isinstance(type_info, TypeInfo):
|
||||
raise TypeError(f"Not a TypeInfo object: {inspect(type_info)}.")
|
||||
if rules is None:
|
||||
rules = specified_rules
|
||||
elif not is_collection(rules) or not all(
|
||||
isinstance(rule, type) and issubclass(rule, ASTValidationRule) for rule in rules
|
||||
):
|
||||
raise TypeError(
|
||||
"Rules must be specified as a collection of ASTValidationRule subclasses."
|
||||
)
|
||||
|
||||
errors: List[GraphQLError] = []
|
||||
|
||||
def on_error(error: GraphQLError) -> None:
|
||||
if len(errors) >= max_errors:
|
||||
errors.append(
|
||||
GraphQLError(
|
||||
"Too many validation errors, error limit reached."
|
||||
" Validation aborted."
|
||||
)
|
||||
)
|
||||
raise ValidationAbortedError
|
||||
errors.append(error)
|
||||
|
||||
context = ValidationContext(schema, document_ast, type_info, on_error)
|
||||
|
||||
# This uses a specialized visitor which runs multiple visitors in parallel,
|
||||
# while maintaining the visitor skip and break API.
|
||||
visitors = [rule(context) for rule in rules]
|
||||
|
||||
# Visit the whole document with each instance of all provided rules.
|
||||
try:
|
||||
visit(document_ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors)))
|
||||
except ValidationAbortedError:
|
||||
pass
|
||||
return errors
|
||||
|
||||
|
||||
def validate_sdl(
|
||||
document_ast: DocumentNode,
|
||||
schema_to_extend: Optional[GraphQLSchema] = None,
|
||||
rules: Optional[Collection[Type[ASTValidationRule]]] = None,
|
||||
) -> List[GraphQLError]:
|
||||
"""Validate an SDL document.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
errors: List[GraphQLError] = []
|
||||
context = SDLValidationContext(document_ast, schema_to_extend, errors.append)
|
||||
if rules is None:
|
||||
rules = specified_sdl_rules
|
||||
visitors = [rule(context) for rule in rules]
|
||||
visit(document_ast, ParallelVisitor(visitors))
|
||||
return errors
|
||||
|
||||
|
||||
def assert_valid_sdl(document_ast: DocumentNode) -> None:
|
||||
"""Assert document is valid SDL.
|
||||
|
||||
Utility function which asserts a SDL document is valid by throwing an error if it
|
||||
is invalid.
|
||||
"""
|
||||
|
||||
errors = validate_sdl(document_ast)
|
||||
if errors:
|
||||
raise TypeError("\n\n".join(error.message for error in errors))
|
||||
|
||||
|
||||
def assert_valid_sdl_extension(
|
||||
document_ast: DocumentNode, schema: GraphQLSchema
|
||||
) -> None:
|
||||
"""Assert document is a valid SDL extension.
|
||||
|
||||
Utility function which asserts a SDL document is valid by throwing an error if it
|
||||
is invalid.
|
||||
"""
|
||||
|
||||
errors = validate_sdl(document_ast, schema)
|
||||
if errors:
|
||||
raise TypeError("\n\n".join(error.message for error in errors))
|
||||
@@ -0,0 +1,250 @@
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast
|
||||
|
||||
from ..error import GraphQLError
|
||||
from ..language import (
|
||||
DocumentNode,
|
||||
FragmentDefinitionNode,
|
||||
FragmentSpreadNode,
|
||||
OperationDefinitionNode,
|
||||
SelectionSetNode,
|
||||
VariableNode,
|
||||
Visitor,
|
||||
VisitorAction,
|
||||
visit,
|
||||
)
|
||||
from ..type import (
|
||||
GraphQLArgument,
|
||||
GraphQLCompositeType,
|
||||
GraphQLDirective,
|
||||
GraphQLEnumValue,
|
||||
GraphQLField,
|
||||
GraphQLInputType,
|
||||
GraphQLOutputType,
|
||||
GraphQLSchema,
|
||||
)
|
||||
from ..utilities import TypeInfo, TypeInfoVisitor
|
||||
|
||||
__all__ = [
|
||||
"ASTValidationContext",
|
||||
"SDLValidationContext",
|
||||
"ValidationContext",
|
||||
"VariableUsage",
|
||||
"VariableUsageVisitor",
|
||||
]
|
||||
|
||||
NodeWithSelectionSet = Union[OperationDefinitionNode, FragmentDefinitionNode]
|
||||
|
||||
|
||||
class VariableUsage(NamedTuple):
|
||||
node: VariableNode
|
||||
type: Optional[GraphQLInputType]
|
||||
default_value: Any
|
||||
|
||||
|
||||
class VariableUsageVisitor(Visitor):
|
||||
"""Visitor adding all variable usages to a given list."""
|
||||
|
||||
usages: List[VariableUsage]
|
||||
|
||||
def __init__(self, type_info: TypeInfo):
|
||||
super().__init__()
|
||||
self.usages = []
|
||||
self._append_usage = self.usages.append
|
||||
self._type_info = type_info
|
||||
|
||||
def enter_variable_definition(self, *_args: Any) -> VisitorAction:
|
||||
return self.SKIP
|
||||
|
||||
def enter_variable(self, node: VariableNode, *_args: Any) -> VisitorAction:
|
||||
type_info = self._type_info
|
||||
usage = VariableUsage(
|
||||
node, type_info.get_input_type(), type_info.get_default_value()
|
||||
)
|
||||
self._append_usage(usage)
|
||||
return None
|
||||
|
||||
|
||||
class ASTValidationContext:
|
||||
"""Utility class providing a context for validation of an AST.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
document: DocumentNode
|
||||
|
||||
_fragments: Optional[Dict[str, FragmentDefinitionNode]]
|
||||
_fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]]
|
||||
_recursively_referenced_fragments: Dict[
|
||||
OperationDefinitionNode, List[FragmentDefinitionNode]
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, ast: DocumentNode, on_error: Callable[[GraphQLError], None]
|
||||
) -> None:
|
||||
self.document = ast
|
||||
self.on_error = on_error # type: ignore
|
||||
self._fragments = None
|
||||
self._fragment_spreads = {}
|
||||
self._recursively_referenced_fragments = {}
|
||||
|
||||
def on_error(self, error: GraphQLError) -> None:
|
||||
pass
|
||||
|
||||
def report_error(self, error: GraphQLError) -> None:
|
||||
self.on_error(error)
|
||||
|
||||
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
|
||||
fragments = self._fragments
|
||||
if fragments is None:
|
||||
fragments = {
|
||||
statement.name.value: statement
|
||||
for statement in self.document.definitions
|
||||
if isinstance(statement, FragmentDefinitionNode)
|
||||
}
|
||||
|
||||
self._fragments = fragments
|
||||
return fragments.get(name)
|
||||
|
||||
def get_fragment_spreads(self, node: SelectionSetNode) -> List[FragmentSpreadNode]:
|
||||
spreads = self._fragment_spreads.get(node)
|
||||
if spreads is None:
|
||||
spreads = []
|
||||
append_spread = spreads.append
|
||||
sets_to_visit = [node]
|
||||
append_set = sets_to_visit.append
|
||||
pop_set = sets_to_visit.pop
|
||||
while sets_to_visit:
|
||||
visited_set = pop_set()
|
||||
for selection in visited_set.selections:
|
||||
if isinstance(selection, FragmentSpreadNode):
|
||||
append_spread(selection)
|
||||
else:
|
||||
set_to_visit = cast(
|
||||
NodeWithSelectionSet, selection
|
||||
).selection_set
|
||||
if set_to_visit:
|
||||
append_set(set_to_visit)
|
||||
self._fragment_spreads[node] = spreads
|
||||
return spreads
|
||||
|
||||
def get_recursively_referenced_fragments(
|
||||
self, operation: OperationDefinitionNode
|
||||
) -> List[FragmentDefinitionNode]:
|
||||
fragments = self._recursively_referenced_fragments.get(operation)
|
||||
if fragments is None:
|
||||
fragments = []
|
||||
append_fragment = fragments.append
|
||||
collected_names: Set[str] = set()
|
||||
add_name = collected_names.add
|
||||
nodes_to_visit = [operation.selection_set]
|
||||
append_node = nodes_to_visit.append
|
||||
pop_node = nodes_to_visit.pop
|
||||
get_fragment = self.get_fragment
|
||||
get_fragment_spreads = self.get_fragment_spreads
|
||||
while nodes_to_visit:
|
||||
visited_node = pop_node()
|
||||
for spread in get_fragment_spreads(visited_node):
|
||||
frag_name = spread.name.value
|
||||
if frag_name not in collected_names:
|
||||
add_name(frag_name)
|
||||
fragment = get_fragment(frag_name)
|
||||
if fragment:
|
||||
append_fragment(fragment)
|
||||
append_node(fragment.selection_set)
|
||||
self._recursively_referenced_fragments[operation] = fragments
|
||||
return fragments
|
||||
|
||||
|
||||
class SDLValidationContext(ASTValidationContext):
|
||||
"""Utility class providing a context for validation of an SDL AST.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
schema: Optional[GraphQLSchema]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ast: DocumentNode,
|
||||
schema: Optional[GraphQLSchema],
|
||||
on_error: Callable[[GraphQLError], None],
|
||||
) -> None:
|
||||
super().__init__(ast, on_error)
|
||||
self.schema = schema
|
||||
|
||||
|
||||
class ValidationContext(ASTValidationContext):
|
||||
"""Utility class providing a context for validation using a GraphQL schema.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
schema: GraphQLSchema
|
||||
|
||||
_type_info: TypeInfo
|
||||
_variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]]
|
||||
_recursive_variable_usages: Dict[OperationDefinitionNode, List[VariableUsage]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: GraphQLSchema,
|
||||
ast: DocumentNode,
|
||||
type_info: TypeInfo,
|
||||
on_error: Callable[[GraphQLError], None],
|
||||
) -> None:
|
||||
super().__init__(ast, on_error)
|
||||
self.schema = schema
|
||||
self._type_info = type_info
|
||||
self._variable_usages = {}
|
||||
self._recursive_variable_usages = {}
|
||||
|
||||
def get_variable_usages(self, node: NodeWithSelectionSet) -> List[VariableUsage]:
|
||||
usages = self._variable_usages.get(node)
|
||||
if usages is None:
|
||||
usage_visitor = VariableUsageVisitor(self._type_info)
|
||||
visit(node, TypeInfoVisitor(self._type_info, usage_visitor))
|
||||
usages = usage_visitor.usages
|
||||
self._variable_usages[node] = usages
|
||||
return usages
|
||||
|
||||
def get_recursive_variable_usages(
|
||||
self, operation: OperationDefinitionNode
|
||||
) -> List[VariableUsage]:
|
||||
usages = self._recursive_variable_usages.get(operation)
|
||||
if usages is None:
|
||||
get_variable_usages = self.get_variable_usages
|
||||
usages = get_variable_usages(operation)
|
||||
for fragment in self.get_recursively_referenced_fragments(operation):
|
||||
usages.extend(get_variable_usages(fragment))
|
||||
self._recursive_variable_usages[operation] = usages
|
||||
return usages
|
||||
|
||||
def get_type(self) -> Optional[GraphQLOutputType]:
|
||||
return self._type_info.get_type()
|
||||
|
||||
def get_parent_type(self) -> Optional[GraphQLCompositeType]:
|
||||
return self._type_info.get_parent_type()
|
||||
|
||||
def get_input_type(self) -> Optional[GraphQLInputType]:
|
||||
return self._type_info.get_input_type()
|
||||
|
||||
def get_parent_input_type(self) -> Optional[GraphQLInputType]:
|
||||
return self._type_info.get_parent_input_type()
|
||||
|
||||
def get_field_def(self) -> Optional[GraphQLField]:
|
||||
return self._type_info.get_field_def()
|
||||
|
||||
def get_directive(self) -> Optional[GraphQLDirective]:
|
||||
return self._type_info.get_directive()
|
||||
|
||||
def get_argument(self) -> Optional[GraphQLArgument]:
|
||||
return self._type_info.get_argument()
|
||||
|
||||
def get_enum_value(self) -> Optional[GraphQLEnumValue]:
|
||||
return self._type_info.get_enum_value()
|
||||
Reference in New Issue
Block a user