2025-12-01
This commit is contained in:
@@ -0,0 +1,250 @@
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast
|
||||
|
||||
from ..error import GraphQLError
|
||||
from ..language import (
|
||||
DocumentNode,
|
||||
FragmentDefinitionNode,
|
||||
FragmentSpreadNode,
|
||||
OperationDefinitionNode,
|
||||
SelectionSetNode,
|
||||
VariableNode,
|
||||
Visitor,
|
||||
VisitorAction,
|
||||
visit,
|
||||
)
|
||||
from ..type import (
|
||||
GraphQLArgument,
|
||||
GraphQLCompositeType,
|
||||
GraphQLDirective,
|
||||
GraphQLEnumValue,
|
||||
GraphQLField,
|
||||
GraphQLInputType,
|
||||
GraphQLOutputType,
|
||||
GraphQLSchema,
|
||||
)
|
||||
from ..utilities import TypeInfo, TypeInfoVisitor
|
||||
|
||||
__all__ = [
|
||||
"ASTValidationContext",
|
||||
"SDLValidationContext",
|
||||
"ValidationContext",
|
||||
"VariableUsage",
|
||||
"VariableUsageVisitor",
|
||||
]
|
||||
|
||||
NodeWithSelectionSet = Union[OperationDefinitionNode, FragmentDefinitionNode]
|
||||
|
||||
|
||||
class VariableUsage(NamedTuple):
|
||||
node: VariableNode
|
||||
type: Optional[GraphQLInputType]
|
||||
default_value: Any
|
||||
|
||||
|
||||
class VariableUsageVisitor(Visitor):
|
||||
"""Visitor adding all variable usages to a given list."""
|
||||
|
||||
usages: List[VariableUsage]
|
||||
|
||||
def __init__(self, type_info: TypeInfo):
|
||||
super().__init__()
|
||||
self.usages = []
|
||||
self._append_usage = self.usages.append
|
||||
self._type_info = type_info
|
||||
|
||||
def enter_variable_definition(self, *_args: Any) -> VisitorAction:
|
||||
return self.SKIP
|
||||
|
||||
def enter_variable(self, node: VariableNode, *_args: Any) -> VisitorAction:
|
||||
type_info = self._type_info
|
||||
usage = VariableUsage(
|
||||
node, type_info.get_input_type(), type_info.get_default_value()
|
||||
)
|
||||
self._append_usage(usage)
|
||||
return None
|
||||
|
||||
|
||||
class ASTValidationContext:
|
||||
"""Utility class providing a context for validation of an AST.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
document: DocumentNode
|
||||
|
||||
_fragments: Optional[Dict[str, FragmentDefinitionNode]]
|
||||
_fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]]
|
||||
_recursively_referenced_fragments: Dict[
|
||||
OperationDefinitionNode, List[FragmentDefinitionNode]
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, ast: DocumentNode, on_error: Callable[[GraphQLError], None]
|
||||
) -> None:
|
||||
self.document = ast
|
||||
self.on_error = on_error # type: ignore
|
||||
self._fragments = None
|
||||
self._fragment_spreads = {}
|
||||
self._recursively_referenced_fragments = {}
|
||||
|
||||
def on_error(self, error: GraphQLError) -> None:
|
||||
pass
|
||||
|
||||
def report_error(self, error: GraphQLError) -> None:
|
||||
self.on_error(error)
|
||||
|
||||
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
|
||||
fragments = self._fragments
|
||||
if fragments is None:
|
||||
fragments = {
|
||||
statement.name.value: statement
|
||||
for statement in self.document.definitions
|
||||
if isinstance(statement, FragmentDefinitionNode)
|
||||
}
|
||||
|
||||
self._fragments = fragments
|
||||
return fragments.get(name)
|
||||
|
||||
def get_fragment_spreads(self, node: SelectionSetNode) -> List[FragmentSpreadNode]:
|
||||
spreads = self._fragment_spreads.get(node)
|
||||
if spreads is None:
|
||||
spreads = []
|
||||
append_spread = spreads.append
|
||||
sets_to_visit = [node]
|
||||
append_set = sets_to_visit.append
|
||||
pop_set = sets_to_visit.pop
|
||||
while sets_to_visit:
|
||||
visited_set = pop_set()
|
||||
for selection in visited_set.selections:
|
||||
if isinstance(selection, FragmentSpreadNode):
|
||||
append_spread(selection)
|
||||
else:
|
||||
set_to_visit = cast(
|
||||
NodeWithSelectionSet, selection
|
||||
).selection_set
|
||||
if set_to_visit:
|
||||
append_set(set_to_visit)
|
||||
self._fragment_spreads[node] = spreads
|
||||
return spreads
|
||||
|
||||
def get_recursively_referenced_fragments(
|
||||
self, operation: OperationDefinitionNode
|
||||
) -> List[FragmentDefinitionNode]:
|
||||
fragments = self._recursively_referenced_fragments.get(operation)
|
||||
if fragments is None:
|
||||
fragments = []
|
||||
append_fragment = fragments.append
|
||||
collected_names: Set[str] = set()
|
||||
add_name = collected_names.add
|
||||
nodes_to_visit = [operation.selection_set]
|
||||
append_node = nodes_to_visit.append
|
||||
pop_node = nodes_to_visit.pop
|
||||
get_fragment = self.get_fragment
|
||||
get_fragment_spreads = self.get_fragment_spreads
|
||||
while nodes_to_visit:
|
||||
visited_node = pop_node()
|
||||
for spread in get_fragment_spreads(visited_node):
|
||||
frag_name = spread.name.value
|
||||
if frag_name not in collected_names:
|
||||
add_name(frag_name)
|
||||
fragment = get_fragment(frag_name)
|
||||
if fragment:
|
||||
append_fragment(fragment)
|
||||
append_node(fragment.selection_set)
|
||||
self._recursively_referenced_fragments[operation] = fragments
|
||||
return fragments
|
||||
|
||||
|
||||
class SDLValidationContext(ASTValidationContext):
|
||||
"""Utility class providing a context for validation of an SDL AST.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
schema: Optional[GraphQLSchema]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ast: DocumentNode,
|
||||
schema: Optional[GraphQLSchema],
|
||||
on_error: Callable[[GraphQLError], None],
|
||||
) -> None:
|
||||
super().__init__(ast, on_error)
|
||||
self.schema = schema
|
||||
|
||||
|
||||
class ValidationContext(ASTValidationContext):
|
||||
"""Utility class providing a context for validation using a GraphQL schema.
|
||||
|
||||
An instance of this class is passed as the context attribute to all Validators,
|
||||
allowing access to commonly useful contextual information from within a validation
|
||||
rule.
|
||||
"""
|
||||
|
||||
schema: GraphQLSchema
|
||||
|
||||
_type_info: TypeInfo
|
||||
_variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]]
|
||||
_recursive_variable_usages: Dict[OperationDefinitionNode, List[VariableUsage]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: GraphQLSchema,
|
||||
ast: DocumentNode,
|
||||
type_info: TypeInfo,
|
||||
on_error: Callable[[GraphQLError], None],
|
||||
) -> None:
|
||||
super().__init__(ast, on_error)
|
||||
self.schema = schema
|
||||
self._type_info = type_info
|
||||
self._variable_usages = {}
|
||||
self._recursive_variable_usages = {}
|
||||
|
||||
def get_variable_usages(self, node: NodeWithSelectionSet) -> List[VariableUsage]:
|
||||
usages = self._variable_usages.get(node)
|
||||
if usages is None:
|
||||
usage_visitor = VariableUsageVisitor(self._type_info)
|
||||
visit(node, TypeInfoVisitor(self._type_info, usage_visitor))
|
||||
usages = usage_visitor.usages
|
||||
self._variable_usages[node] = usages
|
||||
return usages
|
||||
|
||||
def get_recursive_variable_usages(
|
||||
self, operation: OperationDefinitionNode
|
||||
) -> List[VariableUsage]:
|
||||
usages = self._recursive_variable_usages.get(operation)
|
||||
if usages is None:
|
||||
get_variable_usages = self.get_variable_usages
|
||||
usages = get_variable_usages(operation)
|
||||
for fragment in self.get_recursively_referenced_fragments(operation):
|
||||
usages.extend(get_variable_usages(fragment))
|
||||
self._recursive_variable_usages[operation] = usages
|
||||
return usages
|
||||
|
||||
def get_type(self) -> Optional[GraphQLOutputType]:
|
||||
return self._type_info.get_type()
|
||||
|
||||
def get_parent_type(self) -> Optional[GraphQLCompositeType]:
|
||||
return self._type_info.get_parent_type()
|
||||
|
||||
def get_input_type(self) -> Optional[GraphQLInputType]:
|
||||
return self._type_info.get_input_type()
|
||||
|
||||
def get_parent_input_type(self) -> Optional[GraphQLInputType]:
|
||||
return self._type_info.get_parent_input_type()
|
||||
|
||||
def get_field_def(self) -> Optional[GraphQLField]:
|
||||
return self._type_info.get_field_def()
|
||||
|
||||
def get_directive(self) -> Optional[GraphQLDirective]:
|
||||
return self._type_info.get_directive()
|
||||
|
||||
def get_argument(self) -> Optional[GraphQLArgument]:
|
||||
return self._type_info.get_argument()
|
||||
|
||||
def get_enum_value(self) -> Optional[GraphQLEnumValue]:
|
||||
return self._type_info.get_enum_value()
|
||||
Reference in New Issue
Block a user