251 lines
8.4 KiB
Python
251 lines
8.4 KiB
Python
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast
|
|
|
|
from ..error import GraphQLError
|
|
from ..language import (
|
|
DocumentNode,
|
|
FragmentDefinitionNode,
|
|
FragmentSpreadNode,
|
|
OperationDefinitionNode,
|
|
SelectionSetNode,
|
|
VariableNode,
|
|
Visitor,
|
|
VisitorAction,
|
|
visit,
|
|
)
|
|
from ..type import (
|
|
GraphQLArgument,
|
|
GraphQLCompositeType,
|
|
GraphQLDirective,
|
|
GraphQLEnumValue,
|
|
GraphQLField,
|
|
GraphQLInputType,
|
|
GraphQLOutputType,
|
|
GraphQLSchema,
|
|
)
|
|
from ..utilities import TypeInfo, TypeInfoVisitor
|
|
|
|
__all__ = [
|
|
"ASTValidationContext",
|
|
"SDLValidationContext",
|
|
"ValidationContext",
|
|
"VariableUsage",
|
|
"VariableUsageVisitor",
|
|
]
|
|
|
|
NodeWithSelectionSet = Union[OperationDefinitionNode, FragmentDefinitionNode]
|
|
|
|
|
|
class VariableUsage(NamedTuple):
|
|
node: VariableNode
|
|
type: Optional[GraphQLInputType]
|
|
default_value: Any
|
|
|
|
|
|
class VariableUsageVisitor(Visitor):
|
|
"""Visitor adding all variable usages to a given list."""
|
|
|
|
usages: List[VariableUsage]
|
|
|
|
def __init__(self, type_info: TypeInfo):
|
|
super().__init__()
|
|
self.usages = []
|
|
self._append_usage = self.usages.append
|
|
self._type_info = type_info
|
|
|
|
def enter_variable_definition(self, *_args: Any) -> VisitorAction:
|
|
return self.SKIP
|
|
|
|
def enter_variable(self, node: VariableNode, *_args: Any) -> VisitorAction:
|
|
type_info = self._type_info
|
|
usage = VariableUsage(
|
|
node, type_info.get_input_type(), type_info.get_default_value()
|
|
)
|
|
self._append_usage(usage)
|
|
return None
|
|
|
|
|
|
class ASTValidationContext:
|
|
"""Utility class providing a context for validation of an AST.
|
|
|
|
An instance of this class is passed as the context attribute to all Validators,
|
|
allowing access to commonly useful contextual information from within a validation
|
|
rule.
|
|
"""
|
|
|
|
document: DocumentNode
|
|
|
|
_fragments: Optional[Dict[str, FragmentDefinitionNode]]
|
|
_fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]]
|
|
_recursively_referenced_fragments: Dict[
|
|
OperationDefinitionNode, List[FragmentDefinitionNode]
|
|
]
|
|
|
|
def __init__(
|
|
self, ast: DocumentNode, on_error: Callable[[GraphQLError], None]
|
|
) -> None:
|
|
self.document = ast
|
|
self.on_error = on_error # type: ignore
|
|
self._fragments = None
|
|
self._fragment_spreads = {}
|
|
self._recursively_referenced_fragments = {}
|
|
|
|
def on_error(self, error: GraphQLError) -> None:
|
|
pass
|
|
|
|
def report_error(self, error: GraphQLError) -> None:
|
|
self.on_error(error)
|
|
|
|
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
|
|
fragments = self._fragments
|
|
if fragments is None:
|
|
fragments = {
|
|
statement.name.value: statement
|
|
for statement in self.document.definitions
|
|
if isinstance(statement, FragmentDefinitionNode)
|
|
}
|
|
|
|
self._fragments = fragments
|
|
return fragments.get(name)
|
|
|
|
def get_fragment_spreads(self, node: SelectionSetNode) -> List[FragmentSpreadNode]:
|
|
spreads = self._fragment_spreads.get(node)
|
|
if spreads is None:
|
|
spreads = []
|
|
append_spread = spreads.append
|
|
sets_to_visit = [node]
|
|
append_set = sets_to_visit.append
|
|
pop_set = sets_to_visit.pop
|
|
while sets_to_visit:
|
|
visited_set = pop_set()
|
|
for selection in visited_set.selections:
|
|
if isinstance(selection, FragmentSpreadNode):
|
|
append_spread(selection)
|
|
else:
|
|
set_to_visit = cast(
|
|
NodeWithSelectionSet, selection
|
|
).selection_set
|
|
if set_to_visit:
|
|
append_set(set_to_visit)
|
|
self._fragment_spreads[node] = spreads
|
|
return spreads
|
|
|
|
def get_recursively_referenced_fragments(
|
|
self, operation: OperationDefinitionNode
|
|
) -> List[FragmentDefinitionNode]:
|
|
fragments = self._recursively_referenced_fragments.get(operation)
|
|
if fragments is None:
|
|
fragments = []
|
|
append_fragment = fragments.append
|
|
collected_names: Set[str] = set()
|
|
add_name = collected_names.add
|
|
nodes_to_visit = [operation.selection_set]
|
|
append_node = nodes_to_visit.append
|
|
pop_node = nodes_to_visit.pop
|
|
get_fragment = self.get_fragment
|
|
get_fragment_spreads = self.get_fragment_spreads
|
|
while nodes_to_visit:
|
|
visited_node = pop_node()
|
|
for spread in get_fragment_spreads(visited_node):
|
|
frag_name = spread.name.value
|
|
if frag_name not in collected_names:
|
|
add_name(frag_name)
|
|
fragment = get_fragment(frag_name)
|
|
if fragment:
|
|
append_fragment(fragment)
|
|
append_node(fragment.selection_set)
|
|
self._recursively_referenced_fragments[operation] = fragments
|
|
return fragments
|
|
|
|
|
|
class SDLValidationContext(ASTValidationContext):
|
|
"""Utility class providing a context for validation of an SDL AST.
|
|
|
|
An instance of this class is passed as the context attribute to all Validators,
|
|
allowing access to commonly useful contextual information from within a validation
|
|
rule.
|
|
"""
|
|
|
|
schema: Optional[GraphQLSchema]
|
|
|
|
def __init__(
|
|
self,
|
|
ast: DocumentNode,
|
|
schema: Optional[GraphQLSchema],
|
|
on_error: Callable[[GraphQLError], None],
|
|
) -> None:
|
|
super().__init__(ast, on_error)
|
|
self.schema = schema
|
|
|
|
|
|
class ValidationContext(ASTValidationContext):
|
|
"""Utility class providing a context for validation using a GraphQL schema.
|
|
|
|
An instance of this class is passed as the context attribute to all Validators,
|
|
allowing access to commonly useful contextual information from within a validation
|
|
rule.
|
|
"""
|
|
|
|
schema: GraphQLSchema
|
|
|
|
_type_info: TypeInfo
|
|
_variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]]
|
|
_recursive_variable_usages: Dict[OperationDefinitionNode, List[VariableUsage]]
|
|
|
|
def __init__(
|
|
self,
|
|
schema: GraphQLSchema,
|
|
ast: DocumentNode,
|
|
type_info: TypeInfo,
|
|
on_error: Callable[[GraphQLError], None],
|
|
) -> None:
|
|
super().__init__(ast, on_error)
|
|
self.schema = schema
|
|
self._type_info = type_info
|
|
self._variable_usages = {}
|
|
self._recursive_variable_usages = {}
|
|
|
|
def get_variable_usages(self, node: NodeWithSelectionSet) -> List[VariableUsage]:
|
|
usages = self._variable_usages.get(node)
|
|
if usages is None:
|
|
usage_visitor = VariableUsageVisitor(self._type_info)
|
|
visit(node, TypeInfoVisitor(self._type_info, usage_visitor))
|
|
usages = usage_visitor.usages
|
|
self._variable_usages[node] = usages
|
|
return usages
|
|
|
|
def get_recursive_variable_usages(
|
|
self, operation: OperationDefinitionNode
|
|
) -> List[VariableUsage]:
|
|
usages = self._recursive_variable_usages.get(operation)
|
|
if usages is None:
|
|
get_variable_usages = self.get_variable_usages
|
|
usages = get_variable_usages(operation)
|
|
for fragment in self.get_recursively_referenced_fragments(operation):
|
|
usages.extend(get_variable_usages(fragment))
|
|
self._recursive_variable_usages[operation] = usages
|
|
return usages
|
|
|
|
def get_type(self) -> Optional[GraphQLOutputType]:
|
|
return self._type_info.get_type()
|
|
|
|
def get_parent_type(self) -> Optional[GraphQLCompositeType]:
|
|
return self._type_info.get_parent_type()
|
|
|
|
def get_input_type(self) -> Optional[GraphQLInputType]:
|
|
return self._type_info.get_input_type()
|
|
|
|
def get_parent_input_type(self) -> Optional[GraphQLInputType]:
|
|
return self._type_info.get_parent_input_type()
|
|
|
|
def get_field_def(self) -> Optional[GraphQLField]:
|
|
return self._type_info.get_field_def()
|
|
|
|
def get_directive(self) -> Optional[GraphQLDirective]:
|
|
return self._type_info.get_directive()
|
|
|
|
def get_argument(self) -> Optional[GraphQLArgument]:
|
|
return self._type_info.get_argument()
|
|
|
|
def get_enum_value(self) -> Optional[GraphQLEnumValue]:
|
|
return self._type_info.get_enum_value()
|