2025-12-01

This commit is contained in:
2026-03-17 14:58:51 -06:00
parent 183e865f8b
commit 4b82b57113
6846 changed files with 954887 additions and 162606 deletions
@@ -0,0 +1,208 @@
"""GraphQL Language
The :mod:`graphql.language` package is responsible for parsing and operating on the
GraphQL language.
"""
from .source import Source
from .location import get_location, SourceLocation, FormattedSourceLocation
from .print_location import print_location, print_source_location
from .token_kind import TokenKind
from .lexer import Lexer
from .parser import parse, parse_type, parse_value, parse_const_value
from .printer import print_ast
from .visitor import (
visit,
Visitor,
ParallelVisitor,
VisitorAction,
VisitorKeyMap,
BREAK,
SKIP,
REMOVE,
IDLE,
)
from .ast import (
Location,
Token,
Node,
# Each kind of AST node
NameNode,
DocumentNode,
DefinitionNode,
ExecutableDefinitionNode,
OperationDefinitionNode,
OperationType,
VariableDefinitionNode,
VariableNode,
SelectionSetNode,
SelectionNode,
FieldNode,
ArgumentNode,
ConstArgumentNode,
FragmentSpreadNode,
InlineFragmentNode,
FragmentDefinitionNode,
ValueNode,
ConstValueNode,
IntValueNode,
FloatValueNode,
StringValueNode,
BooleanValueNode,
NullValueNode,
EnumValueNode,
ListValueNode,
ConstListValueNode,
ObjectValueNode,
ConstObjectValueNode,
ObjectFieldNode,
ConstObjectFieldNode,
DirectiveNode,
ConstDirectiveNode,
TypeNode,
NamedTypeNode,
ListTypeNode,
NonNullTypeNode,
TypeSystemDefinitionNode,
SchemaDefinitionNode,
OperationTypeDefinitionNode,
TypeDefinitionNode,
ScalarTypeDefinitionNode,
ObjectTypeDefinitionNode,
FieldDefinitionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
UnionTypeDefinitionNode,
EnumTypeDefinitionNode,
EnumValueDefinitionNode,
InputObjectTypeDefinitionNode,
DirectiveDefinitionNode,
TypeSystemExtensionNode,
SchemaExtensionNode,
TypeExtensionNode,
ScalarTypeExtensionNode,
ObjectTypeExtensionNode,
InterfaceTypeExtensionNode,
UnionTypeExtensionNode,
EnumTypeExtensionNode,
InputObjectTypeExtensionNode,
)
from .predicates import (
is_definition_node,
is_executable_definition_node,
is_selection_node,
is_value_node,
is_const_value_node,
is_type_node,
is_type_system_definition_node,
is_type_definition_node,
is_type_system_extension_node,
is_type_extension_node,
)
from .directive_locations import DirectiveLocation
__all__ = [
"get_location",
"SourceLocation",
"FormattedSourceLocation",
"print_location",
"print_source_location",
"TokenKind",
"Lexer",
"parse",
"parse_value",
"parse_const_value",
"parse_type",
"print_ast",
"Source",
"visit",
"Visitor",
"ParallelVisitor",
"VisitorAction",
"VisitorKeyMap",
"BREAK",
"SKIP",
"REMOVE",
"IDLE",
"Location",
"Token",
"DirectiveLocation",
"Node",
"NameNode",
"DocumentNode",
"DefinitionNode",
"ExecutableDefinitionNode",
"OperationDefinitionNode",
"OperationType",
"VariableDefinitionNode",
"VariableNode",
"SelectionSetNode",
"SelectionNode",
"FieldNode",
"ArgumentNode",
"ConstArgumentNode",
"FragmentSpreadNode",
"InlineFragmentNode",
"FragmentDefinitionNode",
"ValueNode",
"ConstValueNode",
"IntValueNode",
"FloatValueNode",
"StringValueNode",
"BooleanValueNode",
"NullValueNode",
"EnumValueNode",
"ListValueNode",
"ConstListValueNode",
"ObjectValueNode",
"ConstObjectValueNode",
"ObjectFieldNode",
"ConstObjectFieldNode",
"DirectiveNode",
"ConstDirectiveNode",
"TypeNode",
"NamedTypeNode",
"ListTypeNode",
"NonNullTypeNode",
"TypeSystemDefinitionNode",
"SchemaDefinitionNode",
"OperationTypeDefinitionNode",
"TypeDefinitionNode",
"ScalarTypeDefinitionNode",
"ObjectTypeDefinitionNode",
"FieldDefinitionNode",
"InputValueDefinitionNode",
"InterfaceTypeDefinitionNode",
"UnionTypeDefinitionNode",
"EnumTypeDefinitionNode",
"EnumValueDefinitionNode",
"InputObjectTypeDefinitionNode",
"DirectiveDefinitionNode",
"TypeSystemExtensionNode",
"SchemaExtensionNode",
"TypeExtensionNode",
"ScalarTypeExtensionNode",
"ObjectTypeExtensionNode",
"InterfaceTypeExtensionNode",
"UnionTypeExtensionNode",
"EnumTypeExtensionNode",
"InputObjectTypeExtensionNode",
"is_definition_node",
"is_executable_definition_node",
"is_selection_node",
"is_value_node",
"is_const_value_node",
"is_type_node",
"is_type_system_definition_node",
"is_type_definition_node",
"is_type_system_extension_node",
"is_type_extension_node",
]
@@ -0,0 +1,807 @@
from copy import copy, deepcopy
from enum import Enum
from typing import Any, Dict, List, Tuple, Optional, Union
from .source import Source
from .token_kind import TokenKind
from ..pyutils import camel_to_snake
__all__ = [
"Location",
"Token",
"Node",
"NameNode",
"DocumentNode",
"DefinitionNode",
"ExecutableDefinitionNode",
"OperationDefinitionNode",
"VariableDefinitionNode",
"SelectionSetNode",
"SelectionNode",
"FieldNode",
"ArgumentNode",
"ConstArgumentNode",
"FragmentSpreadNode",
"InlineFragmentNode",
"FragmentDefinitionNode",
"ValueNode",
"ConstValueNode",
"VariableNode",
"IntValueNode",
"FloatValueNode",
"StringValueNode",
"BooleanValueNode",
"NullValueNode",
"EnumValueNode",
"ListValueNode",
"ConstListValueNode",
"ObjectValueNode",
"ConstObjectValueNode",
"ObjectFieldNode",
"ConstObjectFieldNode",
"DirectiveNode",
"ConstDirectiveNode",
"TypeNode",
"NamedTypeNode",
"ListTypeNode",
"NonNullTypeNode",
"TypeSystemDefinitionNode",
"SchemaDefinitionNode",
"OperationType",
"OperationTypeDefinitionNode",
"TypeDefinitionNode",
"ScalarTypeDefinitionNode",
"ObjectTypeDefinitionNode",
"FieldDefinitionNode",
"InputValueDefinitionNode",
"InterfaceTypeDefinitionNode",
"UnionTypeDefinitionNode",
"EnumTypeDefinitionNode",
"EnumValueDefinitionNode",
"InputObjectTypeDefinitionNode",
"DirectiveDefinitionNode",
"SchemaExtensionNode",
"TypeExtensionNode",
"TypeSystemExtensionNode",
"ScalarTypeExtensionNode",
"ObjectTypeExtensionNode",
"InterfaceTypeExtensionNode",
"UnionTypeExtensionNode",
"EnumTypeExtensionNode",
"InputObjectTypeExtensionNode",
"QUERY_DOCUMENT_KEYS",
]
class Token:
"""AST Token
Represents a range of characters represented by a lexical token within a Source.
"""
__slots__ = "kind", "start", "end", "line", "column", "prev", "next", "value"
kind: TokenKind # the kind of token
start: int # the character offset at which this Node begins
end: int # the character offset at which this Node ends
line: int # the 1-indexed line number on which this Token appears
column: int # the 1-indexed column number at which this Token begins
# for non-punctuation tokens, represents the interpreted value of the token:
value: Optional[str]
# Tokens exist as nodes in a double-linked-list amongst all tokens including
# ignored tokens. <SOF> is always the first node and <EOF> the last.
prev: Optional["Token"]
next: Optional["Token"]
def __init__(
self,
kind: TokenKind,
start: int,
end: int,
line: int,
column: int,
value: Optional[str] = None,
) -> None:
self.kind = kind
self.start, self.end = start, end
self.line, self.column = line, column
self.value = value
self.prev = self.next = None
def __str__(self) -> str:
return self.desc
def __repr__(self) -> str:
"""Print a simplified form when appearing in repr() or inspect()."""
return f"<Token {self.desc} {self.line}:{self.column}>"
def __inspect__(self) -> str:
return repr(self)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Token):
return (
self.kind == other.kind
and self.start == other.start
and self.end == other.end
and self.line == other.line
and self.column == other.column
and self.value == other.value
)
elif isinstance(other, str):
return other == self.desc
return False
def __hash__(self) -> int:
return hash(
(self.kind, self.start, self.end, self.line, self.column, self.value)
)
def __copy__(self) -> "Token":
"""Create a shallow copy of the token"""
token = self.__class__(
self.kind,
self.start,
self.end,
self.line,
self.column,
self.value,
)
token.prev = self.prev
return token
def __deepcopy__(self, memo: Dict) -> "Token":
"""Allow only shallow copies to avoid recursion."""
return copy(self)
def __getstate__(self) -> Dict[str, Any]:
"""Remove the links when pickling.
Keeping the links would make pickling a schema too expensive.
"""
return {
key: getattr(self, key)
for key in self.__slots__
if key not in {"prev", "next"}
}
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Reset the links when un-pickling."""
for key, value in state.items():
setattr(self, key, value)
self.prev = self.next = None
@property
def desc(self) -> str:
"""A helper property to describe a token as a string for debugging"""
kind, value = self.kind.value, self.value
return f"{kind} {value!r}" if value else kind
class Location:
"""AST Location
Contains a range of UTF-8 character offsets and token references that identify the
region of the source from which the AST derived.
"""
__slots__ = (
"start",
"end",
"start_token",
"end_token",
"source",
)
start: int # character offset at which this Node begins
end: int # character offset at which this Node ends
start_token: Token # Token at which this Node begins
end_token: Token # Token at which this Node ends.
source: Source # Source document the AST represents
def __init__(self, start_token: Token, end_token: Token, source: Source) -> None:
self.start = start_token.start
self.end = end_token.end
self.start_token = start_token
self.end_token = end_token
self.source = source
def __str__(self) -> str:
return f"{self.start}:{self.end}"
def __repr__(self) -> str:
"""Print a simplified form when appearing in repr() or inspect()."""
return f"<Location {self.start}:{self.end}>"
def __inspect__(self) -> str:
return repr(self)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Location):
return self.start == other.start and self.end == other.end
elif isinstance(other, (list, tuple)) and len(other) == 2:
return self.start == other[0] and self.end == other[1]
return False
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self) -> int:
return hash((self.start, self.end))
class OperationType(Enum):
QUERY = "query"
MUTATION = "mutation"
SUBSCRIPTION = "subscription"
# Default map from node kinds to their node attributes (internal)
QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
"name": (),
"document": ("definitions",),
"operation_definition": (
"name",
"variable_definitions",
"directives",
"selection_set",
),
"variable_definition": ("variable", "type", "default_value", "directives"),
"variable": ("name",),
"selection_set": ("selections",),
"field": ("alias", "name", "arguments", "directives", "selection_set"),
"argument": ("name", "value"),
"fragment_spread": ("name", "directives"),
"inline_fragment": ("type_condition", "directives", "selection_set"),
"fragment_definition": (
# Note: fragment variable definitions are deprecated and will be removed in v3.3
"name",
"variable_definitions",
"type_condition",
"directives",
"selection_set",
),
"list_value": ("values",),
"object_value": ("fields",),
"object_field": ("name", "value"),
"directive": ("name", "arguments"),
"named_type": ("name",),
"list_type": ("type",),
"non_null_type": ("type",),
"schema_definition": ("description", "directives", "operation_types"),
"operation_type_definition": ("type",),
"scalar_type_definition": ("description", "name", "directives"),
"object_type_definition": (
"description",
"name",
"interfaces",
"directives",
"fields",
),
"field_definition": ("description", "name", "arguments", "type", "directives"),
"input_value_definition": (
"description",
"name",
"type",
"default_value",
"directives",
),
"interface_type_definition": (
"description",
"name",
"interfaces",
"directives",
"fields",
),
"union_type_definition": ("description", "name", "directives", "types"),
"enum_type_definition": ("description", "name", "directives", "values"),
"enum_value_definition": ("description", "name", "directives"),
"input_object_type_definition": ("description", "name", "directives", "fields"),
"directive_definition": ("description", "name", "arguments", "locations"),
"schema_extension": ("directives", "operation_types"),
"scalar_type_extension": ("name", "directives"),
"object_type_extension": ("name", "interfaces", "directives", "fields"),
"interface_type_extension": ("name", "interfaces", "directives", "fields"),
"union_type_extension": ("name", "directives", "types"),
"enum_type_extension": ("name", "directives", "values"),
"input_object_type_extension": ("name", "directives", "fields"),
}
# Base AST Node
class Node:
"""AST nodes"""
# allow custom attributes and weak references (not used internally)
__slots__ = "__dict__", "__weakref__", "loc", "_hash"
loc: Optional[Location]
kind: str = "ast" # the kind of the node as a snake_case string
keys: Tuple[str, ...] = ("loc",) # the names of the attributes of this node
def __init__(self, **kwargs: Any) -> None:
"""Initialize the node with the given keyword arguments."""
for key in self.keys:
value = kwargs.get(key)
if isinstance(value, list):
value = tuple(value)
setattr(self, key, value)
def __repr__(self) -> str:
"""Get a simple representation of the node."""
name, loc = self.__class__.__name__, getattr(self, "loc", None)
return f"{name} at {loc}" if loc else name
def __eq__(self, other: Any) -> bool:
"""Test whether two nodes are equal (recursively)."""
return (
isinstance(other, Node)
and self.__class__ == other.__class__
and all(getattr(self, key) == getattr(other, key) for key in self.keys)
)
def __hash__(self) -> int:
"""Get a cached hash value for the node."""
# Caching the hash values improves the performance of AST validators
hashed = getattr(self, "_hash", None)
if hashed is None:
self._hash = id(self) # avoid recursion
hashed = hash(tuple(getattr(self, key) for key in self.keys))
self._hash = hashed
return hashed
def __setattr__(self, key: str, value: Any) -> None:
# reset cashed hash value if attributes are changed
if hasattr(self, "_hash") and key in self.keys:
del self._hash
super().__setattr__(key, value)
def __copy__(self) -> "Node":
"""Create a shallow copy of the node."""
return self.__class__(**{key: getattr(self, key) for key in self.keys})
def __deepcopy__(self, memo: Dict) -> "Node":
"""Create a deep copy of the node"""
# noinspection PyArgumentList
return self.__class__(
**{key: deepcopy(getattr(self, key), memo) for key in self.keys}
)
def __init_subclass__(cls) -> None:
super().__init_subclass__()
name = cls.__name__
try:
name = name.removeprefix("Const").removesuffix("Node")
except AttributeError: # pragma: no cover (Python < 3.9)
if name.startswith("Const"):
name = name[5:]
if name.endswith("Node"):
name = name[:-4]
cls.kind = camel_to_snake(name)
keys: List[str] = []
for base in cls.__bases__:
# noinspection PyUnresolvedReferences
keys.extend(base.keys) # type: ignore
keys.extend(cls.__slots__)
cls.keys = tuple(keys)
def to_dict(self, locations: bool = False) -> Dict:
from ..utilities import ast_to_dict
return ast_to_dict(self, locations)
# Name
class NameNode(Node):
__slots__ = ("value",)
value: str
# Document
class DocumentNode(Node):
__slots__ = ("definitions",)
definitions: Tuple["DefinitionNode", ...]
class DefinitionNode(Node):
__slots__ = ()
class ExecutableDefinitionNode(DefinitionNode):
__slots__ = "name", "directives", "variable_definitions", "selection_set"
name: Optional[NameNode]
directives: Tuple["DirectiveNode", ...]
variable_definitions: Tuple["VariableDefinitionNode", ...]
selection_set: "SelectionSetNode"
class OperationDefinitionNode(ExecutableDefinitionNode):
__slots__ = ("operation",)
operation: OperationType
class VariableDefinitionNode(Node):
__slots__ = "variable", "type", "default_value", "directives"
variable: "VariableNode"
type: "TypeNode"
default_value: Optional["ConstValueNode"]
directives: Tuple["ConstDirectiveNode", ...]
class SelectionSetNode(Node):
__slots__ = ("selections",)
selections: Tuple["SelectionNode", ...]
class SelectionNode(Node):
__slots__ = ("directives",)
directives: Tuple["DirectiveNode", ...]
class FieldNode(SelectionNode):
__slots__ = "alias", "name", "arguments", "selection_set"
alias: Optional[NameNode]
name: NameNode
arguments: Tuple["ArgumentNode", ...]
selection_set: Optional[SelectionSetNode]
class ArgumentNode(Node):
__slots__ = "name", "value"
name: NameNode
value: "ValueNode"
class ConstArgumentNode(ArgumentNode):
value: "ConstValueNode"
# Fragments
class FragmentSpreadNode(SelectionNode):
__slots__ = ("name",)
name: NameNode
class InlineFragmentNode(SelectionNode):
__slots__ = "type_condition", "selection_set"
type_condition: "NamedTypeNode"
selection_set: SelectionSetNode
class FragmentDefinitionNode(ExecutableDefinitionNode):
__slots__ = ("type_condition",)
name: NameNode
type_condition: "NamedTypeNode"
# Values
class ValueNode(Node):
__slots__ = ()
class VariableNode(ValueNode):
__slots__ = ("name",)
name: NameNode
class IntValueNode(ValueNode):
__slots__ = ("value",)
value: str
class FloatValueNode(ValueNode):
__slots__ = ("value",)
value: str
class StringValueNode(ValueNode):
__slots__ = "value", "block"
value: str
block: Optional[bool]
class BooleanValueNode(ValueNode):
__slots__ = ("value",)
value: bool
class NullValueNode(ValueNode):
__slots__ = ()
class EnumValueNode(ValueNode):
__slots__ = ("value",)
value: str
class ListValueNode(ValueNode):
__slots__ = ("values",)
values: Tuple[ValueNode, ...]
class ConstListValueNode(ListValueNode):
values: Tuple["ConstValueNode", ...]
class ObjectValueNode(ValueNode):
__slots__ = ("fields",)
fields: Tuple["ObjectFieldNode", ...]
class ConstObjectValueNode(ObjectValueNode):
fields: Tuple["ConstObjectFieldNode", ...]
class ObjectFieldNode(Node):
__slots__ = "name", "value"
name: NameNode
value: ValueNode
class ConstObjectFieldNode(ObjectFieldNode):
value: "ConstValueNode"
ConstValueNode = Union[
IntValueNode,
FloatValueNode,
StringValueNode,
BooleanValueNode,
NullValueNode,
EnumValueNode,
ConstListValueNode,
ConstObjectValueNode,
]
# Directives
class DirectiveNode(Node):
__slots__ = "name", "arguments"
name: NameNode
arguments: Tuple[ArgumentNode, ...]
class ConstDirectiveNode(DirectiveNode):
arguments: Tuple[ConstArgumentNode, ...]
# Type Reference
class TypeNode(Node):
__slots__ = ()
class NamedTypeNode(TypeNode):
__slots__ = ("name",)
name: NameNode
class ListTypeNode(TypeNode):
__slots__ = ("type",)
type: TypeNode
class NonNullTypeNode(TypeNode):
__slots__ = ("type",)
type: Union[NamedTypeNode, ListTypeNode]
# Type System Definition
class TypeSystemDefinitionNode(DefinitionNode):
__slots__ = ()
class SchemaDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "directives", "operation_types"
description: Optional[StringValueNode]
directives: Tuple[ConstDirectiveNode, ...]
operation_types: Tuple["OperationTypeDefinitionNode", ...]
class OperationTypeDefinitionNode(Node):
__slots__ = "operation", "type"
operation: OperationType
type: NamedTypeNode
# Type Definition
class TypeDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "name", "directives"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[DirectiveNode, ...]
class ScalarTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ()
directives: Tuple[ConstDirectiveNode, ...]
class ObjectTypeDefinitionNode(TypeDefinitionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
directives: Tuple[ConstDirectiveNode, ...]
fields: Tuple["FieldDefinitionNode", ...]
class FieldDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives", "arguments", "type"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
arguments: Tuple["InputValueDefinitionNode", ...]
type: TypeNode
class InputValueDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives", "type", "default_value"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
type: TypeNode
default_value: Optional[ConstValueNode]
class InterfaceTypeDefinitionNode(TypeDefinitionNode):
__slots__ = "fields", "interfaces"
fields: Tuple["FieldDefinitionNode", ...]
directives: Tuple[ConstDirectiveNode, ...]
interfaces: Tuple[NamedTypeNode, ...]
class UnionTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("types",)
directives: Tuple[ConstDirectiveNode, ...]
types: Tuple[NamedTypeNode, ...]
class EnumTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("values",)
directives: Tuple[ConstDirectiveNode, ...]
values: Tuple["EnumValueDefinitionNode", ...]
class EnumValueDefinitionNode(DefinitionNode):
__slots__ = "description", "name", "directives"
description: Optional[StringValueNode]
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
class InputObjectTypeDefinitionNode(TypeDefinitionNode):
__slots__ = ("fields",)
directives: Tuple[ConstDirectiveNode, ...]
fields: Tuple[InputValueDefinitionNode, ...]
# Directive Definitions
class DirectiveDefinitionNode(TypeSystemDefinitionNode):
__slots__ = "description", "name", "arguments", "repeatable", "locations"
description: Optional[StringValueNode]
name: NameNode
arguments: Tuple[InputValueDefinitionNode, ...]
repeatable: bool
locations: Tuple[NameNode, ...]
# Type System Extensions
class SchemaExtensionNode(Node):
__slots__ = "directives", "operation_types"
directives: Tuple[ConstDirectiveNode, ...]
operation_types: Tuple[OperationTypeDefinitionNode, ...]
# Type Extensions
class TypeExtensionNode(TypeSystemDefinitionNode):
__slots__ = "name", "directives"
name: NameNode
directives: Tuple[ConstDirectiveNode, ...]
TypeSystemExtensionNode = Union[SchemaExtensionNode, TypeExtensionNode]
class ScalarTypeExtensionNode(TypeExtensionNode):
__slots__ = ()
class ObjectTypeExtensionNode(TypeExtensionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
fields: Tuple[FieldDefinitionNode, ...]
class InterfaceTypeExtensionNode(TypeExtensionNode):
__slots__ = "interfaces", "fields"
interfaces: Tuple[NamedTypeNode, ...]
fields: Tuple[FieldDefinitionNode, ...]
class UnionTypeExtensionNode(TypeExtensionNode):
__slots__ = ("types",)
types: Tuple[NamedTypeNode, ...]
class EnumTypeExtensionNode(TypeExtensionNode):
__slots__ = ("values",)
values: Tuple[EnumValueDefinitionNode, ...]
class InputObjectTypeExtensionNode(TypeExtensionNode):
__slots__ = ("fields",)
fields: Tuple[InputValueDefinitionNode, ...]
@@ -0,0 +1,155 @@
from typing import Collection, List
from sys import maxsize
__all__ = [
"dedent_block_string_lines",
"is_printable_as_block_string",
"print_block_string",
]
def dedent_block_string_lines(lines: Collection[str]) -> List[str]:
"""Produce the value of a block string from its parsed raw value.
This function works similar to CoffeeScript's block string,
Python's docstring trim or Ruby's strip_heredoc.
It implements the GraphQL spec's BlockStringValue() static algorithm.
Note that this is very similar to Python's inspect.cleandoc() function.
The difference is that the latter also expands tabs to spaces and
removes whitespace at the beginning of the first line. Python also has
textwrap.dedent() which uses a completely different algorithm.
For internal use only.
"""
common_indent = maxsize
first_non_empty_line = None
last_non_empty_line = -1
for i, line in enumerate(lines):
indent = leading_white_space(line)
if indent == len(line):
continue # skip empty lines
if first_non_empty_line is None:
first_non_empty_line = i
last_non_empty_line = i
if i and indent < common_indent:
common_indent = indent
if first_non_empty_line is None:
first_non_empty_line = 0
return [ # Remove common indentation from all lines but first.
line[common_indent:] if i else line for i, line in enumerate(lines)
][ # Remove leading and trailing blank lines.
first_non_empty_line : last_non_empty_line + 1
]
def leading_white_space(s: str) -> int:
i = 0
for c in s:
if c not in " \t":
return i
i += 1
return i
def is_printable_as_block_string(value: str) -> bool:
"""Check whether the given string is printable as a block string.
For internal use only.
"""
if not isinstance(value, str):
value = str(value) # resolve lazy string proxy object
if not value:
return True # emtpy string is printable
is_empty_line = True
has_indent = False
has_common_indent = True
seen_non_empty_line = False
for c in value:
if c == "\n":
if is_empty_line and not seen_non_empty_line:
return False # has leading new line
seen_non_empty_line = True
is_empty_line = True
has_indent = False
elif c in " \t":
has_indent = has_indent or is_empty_line
elif c <= "\x0f":
return False
else:
has_common_indent = has_common_indent and has_indent
is_empty_line = False
if is_empty_line:
return False # has trailing empty lines
if has_common_indent and seen_non_empty_line:
return False # has internal indent
return True
def print_block_string(value: str, minimize: bool = False) -> str:
"""Print a block string in the indented block form.
Prints a block string in the indented block form by adding a leading and
trailing blank line. However, if a block string starts with whitespace and
is a single-line, adding a leading blank line would strip that whitespace.
For internal use only.
"""
if not isinstance(value, str):
value = str(value) # resolve lazy string proxy object
escaped_value = value.replace('"""', '\\"""')
# Expand a block string's raw value into independent lines.
lines = escaped_value.splitlines() or [""]
num_lines = len(lines)
is_single_line = num_lines == 1
# If common indentation is found,
# we can fix some of those cases by adding a leading new line.
force_leading_new_line = num_lines > 1 and all(
not line or line[0] in " \t" for line in lines[1:]
)
# Trailing triple quotes just looks confusing but doesn't force trailing new line.
has_trailing_triple_quotes = escaped_value.endswith('\\"""')
# Trailing quote (single or double) or slash forces trailing new line
has_trailing_quote = value.endswith('"') and not has_trailing_triple_quotes
has_trailing_slash = value.endswith("\\")
force_trailing_new_line = has_trailing_quote or has_trailing_slash
print_as_multiple_lines = not minimize and (
# add leading and trailing new lines only if it improves readability
not is_single_line
or len(value) > 70
or force_trailing_new_line
or force_leading_new_line
or has_trailing_triple_quotes
)
# Format a multi-line block quote to account for leading space.
skip_leading_new_line = is_single_line and value and value[0] in " \t"
before = (
"\n"
if print_as_multiple_lines
and not skip_leading_new_line
or force_leading_new_line
else ""
)
after = "\n" if print_as_multiple_lines or force_trailing_new_line else ""
return f'"""{before}{escaped_value}{after}"""'
@@ -0,0 +1,68 @@
__all__ = ["is_digit", "is_letter", "is_name_start", "is_name_continue"]
try:
"string".isascii()
except AttributeError: # Python < 3.7
def is_digit(char: str) -> bool:
"""Check whether char is a digit
For internal use by the lexer only.
"""
return "0" <= char <= "9"
def is_letter(char: str) -> bool:
"""Check whether char is a plain ASCII letter
For internal use by the lexer only.
"""
return "a" <= char <= "z" or "A" <= char <= "Z"
def is_name_start(char: str) -> bool:
"""Check whether char is allowed at the beginning of a GraphQL name
For internal use by the lexer only.
"""
return "a" <= char <= "z" or "A" <= char <= "Z" or char == "_"
def is_name_continue(char: str) -> bool:
"""Check whether char is allowed in the continuation of a GraphQL name
For internal use by the lexer only.
"""
return (
"a" <= char <= "z"
or "A" <= char <= "Z"
or "0" <= char <= "9"
or char == "_"
)
else:
def is_digit(char: str) -> bool:
"""Check whether char is a digit
For internal use by the lexer only.
"""
return char.isascii() and char.isdigit()
def is_letter(char: str) -> bool:
"""Check whether char is a plain ASCII letter
For internal use by the lexer only.
"""
return char.isascii() and char.isalpha()
def is_name_start(char: str) -> bool:
"""Check whether char is allowed at the beginning of a GraphQL name
For internal use by the lexer only.
"""
return char.isascii() and (char.isalpha() or char == "_")
def is_name_continue(char: str) -> bool:
"""Check whether char is allowed in the continuation of a GraphQL name
For internal use by the lexer only.
"""
return char.isascii() and (char.isalnum() or char == "_")
@@ -0,0 +1,30 @@
from enum import Enum
__all__ = ["DirectiveLocation"]
class DirectiveLocation(Enum):
"""The enum type representing the directive location values."""
# Request Definitions
QUERY = "query"
MUTATION = "mutation"
SUBSCRIPTION = "subscription"
FIELD = "field"
FRAGMENT_DEFINITION = "fragment definition"
FRAGMENT_SPREAD = "fragment spread"
VARIABLE_DEFINITION = "variable definition"
INLINE_FRAGMENT = "inline fragment"
# Type System Definitions
SCHEMA = "schema"
SCALAR = "scalar"
OBJECT = "object"
FIELD_DEFINITION = "field definition"
ARGUMENT_DEFINITION = "argument definition"
INTERFACE = "interface"
UNION = "union"
ENUM = "enum"
ENUM_VALUE = "enum value"
INPUT_OBJECT = "input object"
INPUT_FIELD_DEFINITION = "input field definition"
@@ -0,0 +1,574 @@
from typing import List, NamedTuple, Optional
from ..error import GraphQLSyntaxError
from .ast import Token
from .block_string import dedent_block_string_lines
from .character_classes import is_digit, is_name_start, is_name_continue
from .source import Source
from .token_kind import TokenKind
__all__ = ["Lexer", "is_punctuator_token_kind"]
class EscapeSequence(NamedTuple):
"""The string value and lexed size of an escape sequence."""
value: str
size: int
class Lexer:
"""GraphQL Lexer
A Lexer is a stateful stream generator in that every time it is advanced, it returns
the next token in the Source. Assuming the source lexes, the final Token emitted by
the lexer will be of kind EOF, after which the lexer will repeatedly return the same
EOF token whenever called.
"""
def __init__(self, source: Source):
"""Given a Source object, initialize a Lexer for that source."""
self.source = source
self.token = self.last_token = Token(TokenKind.SOF, 0, 0, 0, 0)
self.line, self.line_start = 1, 0
def advance(self) -> Token:
"""Advance the token stream to the next non-ignored token."""
self.last_token = self.token
token = self.token = self.lookahead()
return token
def lookahead(self) -> Token:
"""Look ahead and return the next non-ignored token, but do not change state."""
token = self.token
if token.kind != TokenKind.EOF:
while True:
if token.next:
token = token.next
else:
# Read the next token and form a link in the token linked-list.
next_token = self.read_next_token(token.end)
token.next = next_token
next_token.prev = token
token = next_token
if token.kind != TokenKind.COMMENT:
break
return token
def print_code_point_at(self, location: int) -> str:
"""Print the code point at the given location.
Prints the code point (or end of file reference) at a given location in a
source for use in error messages.
Printable ASCII is printed quoted, while other points are printed in Unicode
code point form (ie. U+1234).
"""
body = self.source.body
if location >= len(body):
return TokenKind.EOF.value
char = body[location]
# Printable ASCII
if "\x20" <= char <= "\x7E":
return "'\"'" if char == '"' else f"'{char}'"
# Unicode code point
point = ord(
body[location : location + 2]
.encode("utf-16", "surrogatepass")
.decode("utf-16")
if is_supplementary_code_point(body, location)
else char
)
return f"U+{point:04X}"
def create_token(
self, kind: TokenKind, start: int, end: int, value: Optional[str] = None
) -> Token:
"""Create a token with line and column location information."""
line = self.line
col = 1 + start - self.line_start
return Token(kind, start, end, line, col, value)
def read_next_token(self, start: int) -> Token:
"""Get the next token from the source starting at the given position.
This skips over whitespace until it finds the next lexable token, then lexes
punctuators immediately or calls the appropriate helper function for more
complicated tokens.
"""
body = self.source.body
body_length = len(body)
position = start
while position < body_length:
char = body[position] # SourceCharacter
if char in " \t,\ufeff":
position += 1
continue
elif char == "\n":
position += 1
self.line += 1
self.line_start = position
continue
elif char == "\r":
if body[position + 1 : position + 2] == "\n":
position += 2
else:
position += 1
self.line += 1
self.line_start = position
continue
if char == "#":
return self.read_comment(position)
if char == '"':
if body[position + 1 : position + 3] == '""':
return self.read_block_string(position)
return self.read_string(position)
kind = _KIND_FOR_PUNCT.get(char)
if kind:
return self.create_token(kind, position, position + 1)
if is_digit(char) or char == "-":
return self.read_number(position, char)
if is_name_start(char):
return self.read_name(position)
if char == ".":
if body[position + 1 : position + 3] == "..":
return self.create_token(TokenKind.SPREAD, position, position + 3)
message = (
"Unexpected single quote character ('),"
' did you mean to use a double quote (")?'
if char == "'"
else (
f"Unexpected character: {self.print_code_point_at(position)}."
if is_unicode_scalar_value(char)
or is_supplementary_code_point(body, position)
else f"Invalid character: {self.print_code_point_at(position)}."
)
)
raise GraphQLSyntaxError(self.source, position, message)
return self.create_token(TokenKind.EOF, body_length, body_length)
def read_comment(self, start: int) -> Token:
"""Read a comment token from the source file."""
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length:
char = body[position]
if char in "\r\n":
break
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
break # pragma: no cover
return self.create_token(
TokenKind.COMMENT,
start,
position,
body[start + 1 : position],
)
def read_number(self, start: int, first_char: str) -> Token:
"""Reads a number token from the source file.
This can be either a FloatValue or an IntValue,
depending on whether a FractionalPart or ExponentPart is encountered.
"""
body = self.source.body
position = start
char = first_char
is_float = False
if char == "-":
position += 1
char = body[position : position + 1]
if char == "0":
position += 1
char = body[position : position + 1]
if is_digit(char):
raise GraphQLSyntaxError(
self.source,
position,
"Invalid number, unexpected digit after 0:"
f" {self.print_code_point_at(position)}.",
)
else:
position = self.read_digits(position, char)
char = body[position : position + 1]
if char == ".":
is_float = True
position += 1
char = body[position : position + 1]
position = self.read_digits(position, char)
char = body[position : position + 1]
if char and char in "Ee":
is_float = True
position += 1
char = body[position : position + 1]
if char and char in "+-":
position += 1
char = body[position : position + 1]
position = self.read_digits(position, char)
char = body[position : position + 1]
# Numbers cannot be followed by . or NameStart
if char and (char == "." or is_name_start(char)):
raise GraphQLSyntaxError(
self.source,
position,
"Invalid number, expected digit but got:"
f" {self.print_code_point_at(position)}.",
)
return self.create_token(
TokenKind.FLOAT if is_float else TokenKind.INT,
start,
position,
body[start:position],
)
def read_digits(self, start: int, first_char: str) -> int:
"""Return the new position in the source after reading one or more digits."""
if not is_digit(first_char):
raise GraphQLSyntaxError(
self.source,
start,
"Invalid number, expected digit but got:"
f" {self.print_code_point_at(start)}.",
)
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length and is_digit(body[position]):
position += 1
return position
def read_string(self, start: int) -> Token:
"""Read a single-quote string token from the source file."""
body = self.source.body
body_length = len(body)
position = start + 1
chunk_start = position
value: List[str] = []
append = value.append
while position < body_length:
char = body[position]
if char == '"':
append(body[chunk_start:position])
return self.create_token(
TokenKind.STRING,
start,
position + 1,
"".join(value),
)
if char == "\\":
append(body[chunk_start:position])
escape = (
(
self.read_escaped_unicode_variable_width(position)
if body[position + 2 : position + 3] == "{"
else self.read_escaped_unicode_fixed_width(position)
)
if body[position + 1 : position + 2] == "u"
else self.read_escaped_character(position)
)
append(escape.value)
position += escape.size
chunk_start = position
continue
if char in "\r\n":
break
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
raise GraphQLSyntaxError(
self.source,
position,
"Invalid character within String:"
f" {self.print_code_point_at(position)}.",
)
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
def read_escaped_unicode_variable_width(self, position: int) -> EscapeSequence:
body = self.source.body
point = 0
size = 3
max_size = min(12, len(body) - position)
# Cannot be larger than 12 chars (\u{00000000}).
while size < max_size:
char = body[position + size]
size += 1
if char == "}":
# Must be at least 5 chars (\u{0}) and encode a Unicode scalar value.
if size < 5 or not (
0 <= point <= 0xD7FF or 0xE000 <= point <= 0x10FFFF
):
break
return EscapeSequence(chr(point), size)
# Append this hex digit to the code point.
point = (point << 4) | read_hex_digit(char)
if point < 0:
break
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid Unicode escape sequence: '{body[position: position + size]}'.",
)
def read_escaped_unicode_fixed_width(self, position: int) -> EscapeSequence:
body = self.source.body
code = read_16_bit_hex_code(body, position + 2)
if 0 <= code <= 0xD7FF or 0xE000 <= code <= 0x10FFFF:
return EscapeSequence(chr(code), 6)
# GraphQL allows JSON-style surrogate pair escape sequences, but only when
# a valid pair is formed.
if 0xD800 <= code <= 0xDBFF:
if body[position + 6 : position + 8] == "\\u":
trailing_code = read_16_bit_hex_code(body, position + 8)
if 0xDC00 <= trailing_code <= 0xDFFF:
return EscapeSequence(
(chr(code) + chr(trailing_code))
.encode("utf-16", "surrogatepass")
.decode("utf-16"),
12,
)
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid Unicode escape sequence: '{body[position: position + 6]}'.",
)
def read_escaped_character(self, position: int) -> EscapeSequence:
body = self.source.body
value = _ESCAPED_CHARS.get(body[position + 1])
if value:
return EscapeSequence(value, 2)
raise GraphQLSyntaxError(
self.source,
position,
f"Invalid character escape sequence: '{body[position: position + 2]}'.",
)
def read_block_string(self, start: int) -> Token:
"""Read a block string token from the source file."""
body = self.source.body
body_length = len(body)
line_start = self.line_start
position = start + 3
chunk_start = position
current_line = ""
block_lines = []
while position < body_length:
char = body[position]
if char == '"' and body[position + 1 : position + 3] == '""':
current_line += body[chunk_start:position]
block_lines.append(current_line)
token = self.create_token(
TokenKind.BLOCK_STRING,
start,
position + 3,
# return a string of the lines joined with new lines
"\n".join(dedent_block_string_lines(block_lines)),
)
self.line += len(block_lines) - 1
self.line_start = line_start
return token
if char == "\\" and body[position + 1 : position + 4] == '"""':
current_line += body[chunk_start:position]
chunk_start = position + 1 # skip only slash
position += 4
continue
if char in "\r\n":
current_line += body[chunk_start:position]
block_lines.append(current_line)
if char == "\r" and body[position + 1 : position + 2] == "\n":
position += 2
else:
position += 1
current_line = ""
chunk_start = line_start = position
continue
if is_unicode_scalar_value(char):
position += 1
elif is_supplementary_code_point(body, position):
position += 2
else:
raise GraphQLSyntaxError(
self.source,
position,
"Invalid character within String:"
f" {self.print_code_point_at(position)}.",
)
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
def read_name(self, start: int) -> Token:
"""Read an alphanumeric + underscore name from the source."""
body = self.source.body
body_length = len(body)
position = start + 1
while position < body_length:
char = body[position]
if not is_name_continue(char):
break
position += 1
return self.create_token(TokenKind.NAME, start, position, body[start:position])
_punctuator_token_kinds = frozenset(
[
TokenKind.BANG,
TokenKind.DOLLAR,
TokenKind.AMP,
TokenKind.PAREN_L,
TokenKind.PAREN_R,
TokenKind.SPREAD,
TokenKind.COLON,
TokenKind.EQUALS,
TokenKind.AT,
TokenKind.BRACKET_L,
TokenKind.BRACKET_R,
TokenKind.BRACE_L,
TokenKind.PIPE,
TokenKind.BRACE_R,
]
)
def is_punctuator_token_kind(kind: TokenKind) -> bool:
"""Check whether the given token kind corresponds to a punctuator.
For internal use only.
"""
return kind in _punctuator_token_kinds
_KIND_FOR_PUNCT = {
"!": TokenKind.BANG,
"$": TokenKind.DOLLAR,
"&": TokenKind.AMP,
"(": TokenKind.PAREN_L,
")": TokenKind.PAREN_R,
":": TokenKind.COLON,
"=": TokenKind.EQUALS,
"@": TokenKind.AT,
"[": TokenKind.BRACKET_L,
"]": TokenKind.BRACKET_R,
"{": TokenKind.BRACE_L,
"}": TokenKind.BRACE_R,
"|": TokenKind.PIPE,
}
_ESCAPED_CHARS = {
'"': '"',
"/": "/",
"\\": "\\",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
}
def read_16_bit_hex_code(body: str, position: int) -> int:
"""Read a 16bit hexadecimal string and return its positive integer value (0-65535).
Reads four hexadecimal characters and returns the positive integer that 16bit
hexadecimal string represents. For example, "000f" will return 15, and "dead"
will return 57005.
Returns a negative number if any char was not a valid hexadecimal digit.
"""
# read_hex_digit() returns -1 on error. ORing a negative value with any other
# value always produces a negative value.
return (
read_hex_digit(body[position]) << 12
| read_hex_digit(body[position + 1]) << 8
| read_hex_digit(body[position + 2]) << 4
| read_hex_digit(body[position + 3])
)
def read_hex_digit(char: str) -> int:
"""Read a hexadecimal character and returns its positive integer value (0-15).
'0' becomes 0, '9' becomes 9
'A' becomes 10, 'F' becomes 15
'a' becomes 10, 'f' becomes 15
Returns -1 if the provided character code was not a valid hexadecimal digit.
"""
if "0" <= char <= "9":
return ord(char) - 48
elif "A" <= char <= "F":
return ord(char) - 55
elif "a" <= char <= "f":
return ord(char) - 87
return -1
def is_unicode_scalar_value(char: str) -> bool:
"""Check whether this is a Unicode scalar value.
A Unicode scalar value is any Unicode code point except surrogate code
points. In other words, the inclusive ranges of values 0x0000 to 0xD7FF and
0xE000 to 0x10FFFF.
"""
return "\x00" <= char <= "\ud7ff" or "\ue000" <= char <= "\U0010ffff"
def is_supplementary_code_point(body: str, location: int) -> bool:
"""
Check whether the current location is a supplementary code point.
The GraphQL specification defines source text as a sequence of unicode scalar
values (which Unicode defines to exclude surrogate code points).
"""
try:
return (
"\ud800" <= body[location] <= "\udbff"
and "\udc00" <= body[location + 1] <= "\udfff"
)
except IndexError:
return False
@@ -0,0 +1,46 @@
from typing import Any, NamedTuple, TYPE_CHECKING
try:
from typing import TypedDict
except ImportError: # Python < 3.8
from typing_extensions import TypedDict
if TYPE_CHECKING:
from .source import Source # noqa: F401
__all__ = ["get_location", "SourceLocation", "FormattedSourceLocation"]
class FormattedSourceLocation(TypedDict):
"""Formatted source location"""
line: int
column: int
class SourceLocation(NamedTuple):
"""Represents a location in a Source."""
line: int
column: int
@property
def formatted(self) -> FormattedSourceLocation:
return dict(line=self.line, column=self.column)
def __eq__(self, other: Any) -> bool:
if isinstance(other, dict):
return self.formatted == other
return tuple(self) == other
def __ne__(self, other: Any) -> bool:
return not self == other
def get_location(source: "Source", position: int) -> SourceLocation:
"""Get the line and column for a character position in the source.
Takes a Source and a UTF-8 character offset, and returns the corresponding line and
column as a SourceLocation.
"""
return source.get_location(position)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,86 @@
from .ast import (
Node,
DefinitionNode,
ExecutableDefinitionNode,
ListValueNode,
ObjectValueNode,
SchemaExtensionNode,
SelectionNode,
TypeDefinitionNode,
TypeExtensionNode,
TypeNode,
TypeSystemDefinitionNode,
ValueNode,
VariableNode,
)
__all__ = [
"is_definition_node",
"is_executable_definition_node",
"is_selection_node",
"is_value_node",
"is_const_value_node",
"is_type_node",
"is_type_system_definition_node",
"is_type_definition_node",
"is_type_system_extension_node",
"is_type_extension_node",
]
def is_definition_node(node: Node) -> bool:
"""Check whether the given node represents a definition."""
return isinstance(node, DefinitionNode)
def is_executable_definition_node(node: Node) -> bool:
"""Check whether the given node represents an executable definition."""
return isinstance(node, ExecutableDefinitionNode)
def is_selection_node(node: Node) -> bool:
"""Check whether the given node represents a selection."""
return isinstance(node, SelectionNode)
def is_value_node(node: Node) -> bool:
"""Check whether the given node represents a value."""
return isinstance(node, ValueNode)
def is_const_value_node(node: Node) -> bool:
"""Check whether the given node represents a constant value."""
return is_value_node(node) and (
any(is_const_value_node(value) for value in node.values)
if isinstance(node, ListValueNode)
else (
any(is_const_value_node(field.value) for field in node.fields)
if isinstance(node, ObjectValueNode)
else not isinstance(node, VariableNode)
)
)
def is_type_node(node: Node) -> bool:
"""Check whether the given node represents a type."""
return isinstance(node, TypeNode)
def is_type_system_definition_node(node: Node) -> bool:
"""Check whether the given node represents a type system definition."""
return isinstance(node, TypeSystemDefinitionNode)
def is_type_definition_node(node: Node) -> bool:
"""Check whether the given node represents a type definition."""
return isinstance(node, TypeDefinitionNode)
def is_type_system_extension_node(node: Node) -> bool:
"""Check whether the given node represents a type system extension."""
return isinstance(node, (SchemaExtensionNode, TypeExtensionNode))
def is_type_extension_node(node: Node) -> bool:
"""Check whether the given node represents a type extension."""
return isinstance(node, TypeExtensionNode)
@@ -0,0 +1,79 @@
import re
from typing import Optional, Tuple, cast
from .ast import Location
from .location import SourceLocation, get_location
from .source import Source
__all__ = ["print_location", "print_source_location"]
def print_location(location: Location) -> str:
"""Render a helpful description of the location in the GraphQL Source document."""
return print_source_location(
location.source, get_location(location.source, location.start)
)
_re_newline = re.compile(r"\r\n|[\n\r]")
def print_source_location(source: Source, source_location: SourceLocation) -> str:
"""Render a helpful description of the location in the GraphQL Source document."""
first_line_column_offset = source.location_offset.column - 1
body = "".rjust(first_line_column_offset) + source.body
line_index = source_location.line - 1
line_offset = source.location_offset.line - 1
line_num = source_location.line + line_offset
column_offset = first_line_column_offset if source_location.line == 1 else 0
column_num = source_location.column + column_offset
location_str = f"{source.name}:{line_num}:{column_num}\n"
lines = _re_newline.split(body) # works a bit different from splitlines()
location_line = lines[line_index]
# Special case for minified documents
if len(location_line) > 120:
sub_line_index, sub_line_column_num = divmod(column_num, 80)
sub_lines = [
location_line[i : i + 80] for i in range(0, len(location_line), 80)
]
return location_str + print_prefixed_lines(
(f"{line_num} |", sub_lines[0]),
*[("|", sub_line) for sub_line in sub_lines[1 : sub_line_index + 1]],
("|", "^".rjust(sub_line_column_num)),
(
"|",
(
sub_lines[sub_line_index + 1]
if sub_line_index < len(sub_lines) - 1
else None
),
),
)
return location_str + print_prefixed_lines(
(f"{line_num - 1} |", lines[line_index - 1] if line_index > 0 else None),
(f"{line_num} |", location_line),
("|", "^".rjust(column_num)),
(
f"{line_num + 1} |",
lines[line_index + 1] if line_index < len(lines) - 1 else None,
),
)
def print_prefixed_lines(*lines: Tuple[str, Optional[str]]) -> str:
"""Print lines specified like this: ("prefix", "string")"""
existing_lines = [
cast(Tuple[str, str], line) for line in lines if line[1] is not None
]
pad_len = max(len(line[0]) for line in existing_lines)
return "\n".join(
prefix.rjust(pad_len) + (" " + line if line else "")
for prefix, line in existing_lines
)
@@ -0,0 +1,83 @@
__all__ = ["print_string"]
def print_string(s: str) -> str:
"""Print a string as a GraphQL StringValue literal.
Replaces control characters and excluded characters (" U+0022 and \\ U+005C)
with escape sequences.
"""
if not isinstance(s, str):
s = str(s)
return f'"{s.translate(escape_sequences)}"'
escape_sequences = {
0x00: "\\u0000",
0x01: "\\u0001",
0x02: "\\u0002",
0x03: "\\u0003",
0x04: "\\u0004",
0x05: "\\u0005",
0x06: "\\u0006",
0x07: "\\u0007",
0x08: "\\b",
0x09: "\\t",
0x0A: "\\n",
0x0B: "\\u000B",
0x0C: "\\f",
0x0D: "\\r",
0x0E: "\\u000E",
0x0F: "\\u000F",
0x10: "\\u0010",
0x11: "\\u0011",
0x12: "\\u0012",
0x13: "\\u0013",
0x14: "\\u0014",
0x15: "\\u0015",
0x16: "\\u0016",
0x17: "\\u0017",
0x18: "\\u0018",
0x19: "\\u0019",
0x1A: "\\u001A",
0x1B: "\\u001B",
0x1C: "\\u001C",
0x1D: "\\u001D",
0x1E: "\\u001E",
0x1F: "\\u001F",
0x22: '\\"',
0x5C: "\\\\",
0x7F: "\\u007F",
0x80: "\\u0080",
0x81: "\\u0081",
0x82: "\\u0082",
0x83: "\\u0083",
0x84: "\\u0084",
0x85: "\\u0085",
0x86: "\\u0086",
0x87: "\\u0087",
0x88: "\\u0088",
0x89: "\\u0089",
0x8A: "\\u008A",
0x8B: "\\u008B",
0x8C: "\\u008C",
0x8D: "\\u008D",
0x8E: "\\u008E",
0x8F: "\\u008F",
0x90: "\\u0090",
0x91: "\\u0091",
0x92: "\\u0092",
0x93: "\\u0093",
0x94: "\\u0094",
0x95: "\\u0095",
0x96: "\\u0096",
0x97: "\\u0097",
0x98: "\\u0098",
0x99: "\\u0099",
0x9A: "\\u009A",
0x9B: "\\u009B",
0x9C: "\\u009C",
0x9D: "\\u009D",
0x9E: "\\u009E",
0x9F: "\\u009F",
}
@@ -0,0 +1,428 @@
from typing import Any, Collection, Optional
from ..language.ast import Node, OperationType
from .block_string import print_block_string
from .print_string import print_string
from .visitor import visit, Visitor
__all__ = ["print_ast"]
MAX_LINE_LENGTH = 80
Strings = Collection[str]
class PrintedNode:
"""A union type for all nodes that have been processed by the printer."""
alias: str
arguments: Strings
block: bool
default_value: str
definitions: Strings
description: str
directives: str
fields: Strings
interfaces: Strings
locations: Strings
name: str
operation: OperationType
operation_types: Strings
repeatable: bool
selection_set: str
selections: Strings
type: str
type_condition: str
types: Strings
value: str
values: Strings
variable: str
variable_definitions: Strings
def print_ast(ast: Node) -> str:
"""Convert an AST into a string.
The conversion is done using a set of reasonable formatting rules.
"""
return visit(ast, PrintAstVisitor())
class PrintAstVisitor(Visitor):
@staticmethod
def leave_name(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_variable(node: PrintedNode, *_args: Any) -> str:
return f"${node.name}"
# Document
@staticmethod
def leave_document(node: PrintedNode, *_args: Any) -> str:
return join(node.definitions, "\n\n")
@staticmethod
def leave_operation_definition(node: PrintedNode, *_args: Any) -> str:
var_defs = wrap("(", join(node.variable_definitions, ", "), ")")
prefix = join(
(
node.operation.value,
join((node.name, var_defs)),
join(node.directives, " "),
),
" ",
)
# Anonymous queries with no directives or variable definitions can use the
# query short form.
return ("" if prefix == "query" else prefix + " ") + node.selection_set
@staticmethod
def leave_variable_definition(node: PrintedNode, *_args: Any) -> str:
return (
f"{node.variable}: {node.type}"
f"{wrap(' = ', node.default_value)}"
f"{wrap(' ', join(node.directives, ' '))}"
)
@staticmethod
def leave_selection_set(node: PrintedNode, *_args: Any) -> str:
return block(node.selections)
@staticmethod
def leave_field(node: PrintedNode, *_args: Any) -> str:
prefix = wrap("", node.alias, ": ") + node.name
args_line = prefix + wrap("(", join(node.arguments, ", "), ")")
if len(args_line) > MAX_LINE_LENGTH:
args_line = prefix + wrap("(\n", indent(join(node.arguments, "\n")), "\n)")
return join((args_line, join(node.directives, " "), node.selection_set), " ")
@staticmethod
def leave_argument(node: PrintedNode, *_args: Any) -> str:
return f"{node.name}: {node.value}"
# Fragments
@staticmethod
def leave_fragment_spread(node: PrintedNode, *_args: Any) -> str:
return f"...{node.name}{wrap(' ', join(node.directives, ' '))}"
@staticmethod
def leave_inline_fragment(node: PrintedNode, *_args: Any) -> str:
return join(
(
"...",
wrap("on ", node.type_condition),
join(node.directives, " "),
node.selection_set,
),
" ",
)
@staticmethod
def leave_fragment_definition(node: PrintedNode, *_args: Any) -> str:
# Note: fragment variable definitions are deprecated and will be removed in v3.3
return (
f"fragment {node.name}"
f"{wrap('(', join(node.variable_definitions, ', '), ')')}"
f" on {node.type_condition}"
f" {wrap('', join(node.directives, ' '), ' ')}"
f"{node.selection_set}"
)
# Value
@staticmethod
def leave_int_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_float_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_string_value(node: PrintedNode, *_args: Any) -> str:
if node.block:
return print_block_string(node.value)
return print_string(node.value)
@staticmethod
def leave_boolean_value(node: PrintedNode, *_args: Any) -> str:
return "true" if node.value else "false"
@staticmethod
def leave_null_value(_node: PrintedNode, *_args: Any) -> str:
return "null"
@staticmethod
def leave_enum_value(node: PrintedNode, *_args: Any) -> str:
return node.value
@staticmethod
def leave_list_value(node: PrintedNode, *_args: Any) -> str:
return f"[{join(node.values, ', ')}]"
@staticmethod
def leave_object_value(node: PrintedNode, *_args: Any) -> str:
return f"{{{join(node.fields, ', ')}}}"
@staticmethod
def leave_object_field(node: PrintedNode, *_args: Any) -> str:
return f"{node.name}: {node.value}"
# Directive
@staticmethod
def leave_directive(node: PrintedNode, *_args: Any) -> str:
return f"@{node.name}{wrap('(', join(node.arguments, ', '), ')')}"
# Type
@staticmethod
def leave_named_type(node: PrintedNode, *_args: Any) -> str:
return node.name
@staticmethod
def leave_list_type(node: PrintedNode, *_args: Any) -> str:
return f"[{node.type}]"
@staticmethod
def leave_non_null_type(node: PrintedNode, *_args: Any) -> str:
return f"{node.type}!"
# Type System Definitions
@staticmethod
def leave_schema_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"schema",
join(node.directives, " "),
block(node.operation_types),
),
" ",
)
@staticmethod
def leave_operation_type_definition(node: PrintedNode, *_args: Any) -> str:
return f"{node.operation.value}: {node.type}"
@staticmethod
def leave_scalar_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"scalar",
node.name,
join(node.directives, " "),
),
" ",
)
@staticmethod
def leave_object_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"type",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_field_definition(node: PrintedNode, *_args: Any) -> str:
args = node.arguments
args = (
wrap("(\n", indent(join(args, "\n")), "\n)")
if has_multiline_items(args)
else wrap("(", join(args, ", "), ")")
)
directives = wrap(" ", join(node.directives, " "))
return (
wrap("", node.description, "\n")
+ f"{node.name}{args}: {node.type}{directives}"
)
@staticmethod
def leave_input_value_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
f"{node.name}: {node.type}",
wrap("= ", node.default_value),
join(node.directives, " "),
),
" ",
)
@staticmethod
def leave_interface_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"interface",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_union_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(
"union",
node.name,
join(node.directives, " "),
wrap("= ", join(node.types, " | ")),
),
" ",
)
@staticmethod
def leave_enum_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
("enum", node.name, join(node.directives, " "), block(node.values)), " "
)
@staticmethod
def leave_enum_value_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
(node.name, join(node.directives, " ")), " "
)
@staticmethod
def leave_input_object_type_definition(node: PrintedNode, *_args: Any) -> str:
return wrap("", node.description, "\n") + join(
("input", node.name, join(node.directives, " "), block(node.fields)), " "
)
@staticmethod
def leave_directive_definition(node: PrintedNode, *_args: Any) -> str:
args = node.arguments
args = (
wrap("(\n", indent(join(args, "\n")), "\n)")
if has_multiline_items(args)
else wrap("(", join(args, ", "), ")")
)
repeatable = " repeatable" if node.repeatable else ""
locations = join(node.locations, " | ")
return (
wrap("", node.description, "\n")
+ f"directive @{node.name}{args}{repeatable} on {locations}"
)
@staticmethod
def leave_schema_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend schema", join(node.directives, " "), block(node.operation_types)),
" ",
)
@staticmethod
def leave_scalar_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(("extend scalar", node.name, join(node.directives, " ")), " ")
@staticmethod
def leave_object_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend type",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_interface_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend interface",
node.name,
wrap("implements ", join(node.interfaces, " & ")),
join(node.directives, " "),
block(node.fields),
),
" ",
)
@staticmethod
def leave_union_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
(
"extend union",
node.name,
join(node.directives, " "),
wrap("= ", join(node.types, " | ")),
),
" ",
)
@staticmethod
def leave_enum_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend enum", node.name, join(node.directives, " "), block(node.values)),
" ",
)
@staticmethod
def leave_input_object_type_extension(node: PrintedNode, *_args: Any) -> str:
return join(
("extend input", node.name, join(node.directives, " "), block(node.fields)),
" ",
)
def join(strings: Optional[Strings], separator: str = "") -> str:
"""Join strings in a given collection.
Return an empty string if it is None or empty, otherwise join all items together
separated by separator if provided.
"""
return separator.join(s for s in strings if s) if strings else ""
def block(strings: Optional[Strings]) -> str:
"""Return strings inside a block.
Given a collection of strings, return a string with each item on its own line,
wrapped in an indented "{ }" block.
"""
return wrap("{\n", indent(join(strings, "\n")), "\n}")
def wrap(start: str, string: Optional[str], end: str = "") -> str:
"""Wrap string inside other strings at start and end.
If the string is not None or empty, then wrap with start and end, otherwise return
an empty string.
"""
return f"{start}{string}{end}" if string else ""
def indent(string: str) -> str:
"""Indent string with two spaces.
If the string is not None or empty, add two spaces at the beginning of every line
inside the string.
"""
return wrap(" ", string.replace("\n", "\n "))
def is_multiline(string: str) -> bool:
"""Check whether a string consists of multiple lines."""
return "\n" in string
def has_multiline_items(strings: Optional[Strings]) -> bool:
"""Check whether one of the items in the list has multiple lines."""
return any(is_multiline(item) for item in strings) if strings else False
@@ -0,0 +1,70 @@
from typing import Any
from .location import SourceLocation
__all__ = ["Source", "is_source"]
class Source:
"""A representation of source input to GraphQL."""
# allow custom attributes and weak references (not used internally)
__slots__ = "__weakref__", "__dict__", "body", "name", "location_offset"
def __init__(
self,
body: str,
name: str = "GraphQL request",
location_offset: SourceLocation = SourceLocation(1, 1),
) -> None:
"""Initialize source input.
The ``name`` and ``location_offset`` parameters are optional, but they are
useful for clients who store GraphQL documents in source files. For example,
if the GraphQL input starts at line 40 in a file named ``Foo.graphql``, it might
be useful for ``name`` to be ``"Foo.graphql"`` and location to be ``(40, 0)``.
The ``line`` and ``column`` attributes in ``location_offset`` are 1-indexed.
"""
self.body = body
self.name = name
if not isinstance(location_offset, SourceLocation):
location_offset = SourceLocation._make(location_offset)
if location_offset.line <= 0:
raise ValueError(
"line in location_offset is 1-indexed and must be positive."
)
if location_offset.column <= 0:
raise ValueError(
"column in location_offset is 1-indexed and must be positive."
)
self.location_offset = location_offset
def get_location(self, position: int) -> SourceLocation:
lines = self.body[:position].splitlines()
if lines:
line = len(lines)
column = len(lines[-1]) + 1
else:
line = 1
column = 1
return SourceLocation(line, column)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name!r}>"
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Source) and other.body == self.body) or (
isinstance(other, str) and other == self.body
)
def __ne__(self, other: Any) -> bool:
return not self == other
def is_source(source: Any) -> bool:
"""Test if the given value is a Source object.
For internal use only.
"""
return isinstance(source, Source)
@@ -0,0 +1,30 @@
from enum import Enum
__all__ = ["TokenKind"]
class TokenKind(Enum):
"""The different kinds of tokens that the lexer emits"""
SOF = "<SOF>"
EOF = "<EOF>"
BANG = "!"
DOLLAR = "$"
AMP = "&"
PAREN_L = "("
PAREN_R = ")"
SPREAD = "..."
COLON = ":"
EQUALS = "="
AT = "@"
BRACKET_L = "["
BRACKET_R = "]"
BRACE_L = "{"
PIPE = "|"
BRACE_R = "}"
NAME = "Name"
INT = "Int"
FLOAT = "Float"
STRING = "String"
BLOCK_STRING = "BlockString"
COMMENT = "Comment"
@@ -0,0 +1,375 @@
from copy import copy
from enum import Enum
from typing import (
Any,
Callable,
Collection,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
from ..pyutils import inspect, snake_to_camel
from . import ast
from .ast import QUERY_DOCUMENT_KEYS, Node
__all__ = [
"Visitor",
"ParallelVisitor",
"VisitorAction",
"visit",
"BREAK",
"SKIP",
"REMOVE",
"IDLE",
]
class VisitorActionEnum(Enum):
"""Special return values for the visitor methods.
You can also use the values of this enum directly.
"""
BREAK = True
SKIP = False
REMOVE = Ellipsis
VisitorAction = Optional[VisitorActionEnum]
# Note that in GraphQL.js these are defined differently:
# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
BREAK = VisitorActionEnum.BREAK
SKIP = VisitorActionEnum.SKIP
REMOVE = VisitorActionEnum.REMOVE
IDLE = None
VisitorKeyMap = Dict[str, Tuple[str, ...]]
class EnterLeaveVisitor(NamedTuple):
"""Visitor with functions for entering and leaving."""
enter: Optional[Callable[..., Optional[VisitorAction]]]
leave: Optional[Callable[..., Optional[VisitorAction]]]
class Visitor:
"""Visitor that walks through an AST.
Visitors can define two generic methods "enter" and "leave". The former will be
called when a node is entered in the traversal, the latter is called after visiting
the node and its child nodes. These methods have the following signature::
def enter(self, node, key, parent, path, ancestors):
# The return value has the following meaning:
# IDLE (None): no action
# SKIP: skip visiting this node
# BREAK: stop visiting altogether
# REMOVE: delete this node
# any other value: replace this node with the returned value
return
def leave(self, node, key, parent, path, ancestors):
# The return value has the following meaning:
# IDLE (None) or SKIP: no action
# BREAK: stop visiting altogether
# REMOVE: delete this node
# any other value: replace this node with the returned value
return
The parameters have the following meaning:
:arg node: The current node being visiting.
:arg key: The index or key to this node from the parent node or Array.
:arg parent: the parent immediately above this node, which may be an Array.
:arg path: The key path to get to this node from the root node.
:arg ancestors: All nodes and Arrays visited before reaching parent
of this node. These correspond to array indices in ``path``.
Note: ancestors includes arrays which contain the parent of visited node.
You can also define node kind specific methods by suffixing them with an underscore
followed by the kind of the node to be visited. For instance, to visit ``field``
nodes, you would defined the methods ``enter_field()`` and/or ``leave_field()``,
with the same signature as above. If no kind specific method has been defined
for a given node, the generic method is called.
"""
# Provide special return values as attributes
BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE
enter_leave_map: Dict[str, EnterLeaveVisitor]
def __init_subclass__(cls) -> None:
"""Verify that all defined handlers are valid."""
super().__init_subclass__()
for attr, val in cls.__dict__.items():
if attr.startswith("_"):
continue
attr_kind = attr.split("_", 1)
if len(attr_kind) < 2:
kind: Optional[str] = None
else:
attr, kind = attr_kind
if attr in ("enter", "leave") and kind:
name = snake_to_camel(kind) + "Node"
node_cls = getattr(ast, name, None)
if (
not node_cls
or not isinstance(node_cls, type)
or not issubclass(node_cls, Node)
):
raise TypeError(f"Invalid AST node kind: {kind}.")
def __init__(self) -> None:
self.enter_leave_map = {}
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
try:
return self.enter_leave_map[kind]
except KeyError:
enter_fn = getattr(self, f"enter_{kind}", None)
if not enter_fn:
enter_fn = getattr(self, "enter", None)
leave_fn = getattr(self, f"leave_{kind}", None)
if not leave_fn:
leave_fn = getattr(self, "leave", None)
enter_leave = EnterLeaveVisitor(enter_fn, leave_fn)
self.enter_leave_map[kind] = enter_leave
return enter_leave
def get_visit_fn(
self, kind: str, is_leaving: bool = False
) -> Optional[Callable[..., Optional[VisitorAction]]]:
"""Get the visit function for the given node kind and direction.
.. deprecated:: 3.2
Please use ``get_enter_leave_for_kind`` instead. Will be removed in v3.3.
"""
enter_leave = self.get_enter_leave_for_kind(kind)
return enter_leave.leave if is_leaving else enter_leave.enter
class Stack(NamedTuple):
"""A stack for the visit function."""
in_array: bool
idx: int
keys: Tuple[Node, ...]
edits: List[Tuple[Union[int, str], Node]]
prev: Any # 'Stack' (python/mypy/issues/731)
def visit(
root: Node, visitor: Visitor, visitor_keys: Optional[VisitorKeyMap] = None
) -> Any:
"""Visit each node in an AST.
:func:`~.visit` will walk through an AST using a depth-first traversal, calling the
visitor's enter methods at each node in the traversal, and calling the leave methods
after visiting that node and all of its child nodes.
By returning different values from the enter and leave methods, the behavior of the
visitor can be altered, including skipping over a sub-tree of the AST (by returning
False), editing the AST by returning a value or None to remove the value, or to stop
the whole traversal by returning :data:`~.BREAK`.
When using :func:`~.visit` to edit an AST, the original AST will not be modified,
and a new version of the AST with the changes applied will be returned from the
visit function.
To customize the node attributes to be used for traversal, you can provide a
dictionary visitor_keys mapping node kinds to node attributes.
"""
if not isinstance(root, Node):
raise TypeError(f"Not an AST Node: {inspect(root)}.")
if not isinstance(visitor, Visitor):
raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.")
if visitor_keys is None:
visitor_keys = QUERY_DOCUMENT_KEYS
stack: Any = None
in_array = False
keys: Tuple[Node, ...] = (root,)
idx = -1
edits: List[Any] = []
node: Any = root
key: Any = None
parent: Any = None
path: List[Any] = []
path_append = path.append
path_pop = path.pop
ancestors: List[Any] = []
ancestors_append = ancestors.append
ancestors_pop = ancestors.pop
while True:
idx += 1
is_leaving = idx == len(keys)
is_edited = is_leaving and edits
if is_leaving:
key = path[-1] if ancestors else None
node = parent
parent = ancestors_pop() if ancestors else None
if is_edited:
if in_array:
node = list(node)
edit_offset = 0
for edit_key, edit_value in edits:
array_key = edit_key - edit_offset
if edit_value is REMOVE or edit_value is Ellipsis:
node.pop(array_key)
edit_offset += 1
else:
node[array_key] = edit_value
node = tuple(node)
else:
node = copy(node)
for edit_key, edit_value in edits:
setattr(node, edit_key, edit_value)
idx = stack.idx
keys = stack.keys
edits = stack.edits
in_array = stack.in_array
stack = stack.prev
elif parent:
if in_array:
key = idx
node = parent[key]
else:
key = keys[idx]
node = getattr(parent, key, None)
if node is None:
continue
path_append(key)
if isinstance(node, tuple):
result = None
else:
if not isinstance(node, Node):
raise TypeError(f"Invalid AST Node: {inspect(node)}.")
enter_leave = visitor.get_enter_leave_for_kind(node.kind)
visit_fn = enter_leave.leave if is_leaving else enter_leave.enter
if visit_fn:
result = visit_fn(node, key, parent, path, ancestors)
if result is BREAK or result is True:
break
if result is SKIP or result is False:
if not is_leaving:
path_pop()
continue
elif result is not None:
edits.append((key, result))
if not is_leaving:
if isinstance(result, Node):
node = result
else:
path_pop()
continue
else:
result = None
if result is None and is_edited:
edits.append((key, node))
if is_leaving:
if path:
path_pop()
else:
stack = Stack(in_array, idx, keys, edits, stack)
in_array = isinstance(node, tuple)
keys = node if in_array else visitor_keys.get(node.kind, ()) # type: ignore
idx = -1
edits = []
if parent:
ancestors_append(parent)
parent = node
if not stack:
break
if edits:
return edits[-1][1]
return root
class ParallelVisitor(Visitor):
"""A Visitor which delegates to many visitors to run in parallel.
Each visitor will be visited for each node before moving on.
If a prior visitor edits a node, no following visitors will see that node.
"""
def __init__(self, visitors: Collection[Visitor]):
"""Create a new visitor from the given list of parallel visitors."""
super().__init__()
self.visitors = visitors
self.skipping: List[Any] = [None] * len(visitors)
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
try:
return self.enter_leave_map[kind]
except KeyError:
has_visitor = False
enter_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
leave_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
for visitor in self.visitors:
enter, leave = visitor.get_enter_leave_for_kind(kind)
if not has_visitor and (enter or leave):
has_visitor = True
enter_list.append(enter)
leave_list.append(leave)
if has_visitor:
def enter(node: Node, *args: Any) -> Optional[VisitorAction]:
skipping = self.skipping
for i, fn in enumerate(enter_list):
if not skipping[i]:
if fn:
result = fn(node, *args)
if result is SKIP or result is False:
skipping[i] = node
elif result is BREAK or result is True:
skipping[i] = BREAK
elif result is not None:
return result
return None
def leave(node: Node, *args: Any) -> Optional[VisitorAction]:
skipping = self.skipping
for i, fn in enumerate(leave_list):
if not skipping[i]:
if fn:
result = fn(node, *args)
if result is BREAK or result is True:
skipping[i] = BREAK
elif (
result is not None
and result is not SKIP
and result is not False
):
return result
elif skipping[i] is node:
skipping[i] = None
return None
else:
enter = leave = None
enter_leave = EnterLeaveVisitor(enter, leave)
self.enter_leave_map[kind] = enter_leave
return enter_leave