2025-12-01
This commit is contained in:
@@ -0,0 +1,208 @@
|
||||
"""GraphQL Language
|
||||
|
||||
The :mod:`graphql.language` package is responsible for parsing and operating on the
|
||||
GraphQL language.
|
||||
"""
|
||||
|
||||
from .source import Source
|
||||
|
||||
from .location import get_location, SourceLocation, FormattedSourceLocation
|
||||
|
||||
from .print_location import print_location, print_source_location
|
||||
|
||||
from .token_kind import TokenKind
|
||||
|
||||
from .lexer import Lexer
|
||||
|
||||
from .parser import parse, parse_type, parse_value, parse_const_value
|
||||
|
||||
from .printer import print_ast
|
||||
|
||||
from .visitor import (
|
||||
visit,
|
||||
Visitor,
|
||||
ParallelVisitor,
|
||||
VisitorAction,
|
||||
VisitorKeyMap,
|
||||
BREAK,
|
||||
SKIP,
|
||||
REMOVE,
|
||||
IDLE,
|
||||
)
|
||||
|
||||
from .ast import (
|
||||
Location,
|
||||
Token,
|
||||
Node,
|
||||
# Each kind of AST node
|
||||
NameNode,
|
||||
DocumentNode,
|
||||
DefinitionNode,
|
||||
ExecutableDefinitionNode,
|
||||
OperationDefinitionNode,
|
||||
OperationType,
|
||||
VariableDefinitionNode,
|
||||
VariableNode,
|
||||
SelectionSetNode,
|
||||
SelectionNode,
|
||||
FieldNode,
|
||||
ArgumentNode,
|
||||
ConstArgumentNode,
|
||||
FragmentSpreadNode,
|
||||
InlineFragmentNode,
|
||||
FragmentDefinitionNode,
|
||||
ValueNode,
|
||||
ConstValueNode,
|
||||
IntValueNode,
|
||||
FloatValueNode,
|
||||
StringValueNode,
|
||||
BooleanValueNode,
|
||||
NullValueNode,
|
||||
EnumValueNode,
|
||||
ListValueNode,
|
||||
ConstListValueNode,
|
||||
ObjectValueNode,
|
||||
ConstObjectValueNode,
|
||||
ObjectFieldNode,
|
||||
ConstObjectFieldNode,
|
||||
DirectiveNode,
|
||||
ConstDirectiveNode,
|
||||
TypeNode,
|
||||
NamedTypeNode,
|
||||
ListTypeNode,
|
||||
NonNullTypeNode,
|
||||
TypeSystemDefinitionNode,
|
||||
SchemaDefinitionNode,
|
||||
OperationTypeDefinitionNode,
|
||||
TypeDefinitionNode,
|
||||
ScalarTypeDefinitionNode,
|
||||
ObjectTypeDefinitionNode,
|
||||
FieldDefinitionNode,
|
||||
InputValueDefinitionNode,
|
||||
InterfaceTypeDefinitionNode,
|
||||
UnionTypeDefinitionNode,
|
||||
EnumTypeDefinitionNode,
|
||||
EnumValueDefinitionNode,
|
||||
InputObjectTypeDefinitionNode,
|
||||
DirectiveDefinitionNode,
|
||||
TypeSystemExtensionNode,
|
||||
SchemaExtensionNode,
|
||||
TypeExtensionNode,
|
||||
ScalarTypeExtensionNode,
|
||||
ObjectTypeExtensionNode,
|
||||
InterfaceTypeExtensionNode,
|
||||
UnionTypeExtensionNode,
|
||||
EnumTypeExtensionNode,
|
||||
InputObjectTypeExtensionNode,
|
||||
)
|
||||
from .predicates import (
|
||||
is_definition_node,
|
||||
is_executable_definition_node,
|
||||
is_selection_node,
|
||||
is_value_node,
|
||||
is_const_value_node,
|
||||
is_type_node,
|
||||
is_type_system_definition_node,
|
||||
is_type_definition_node,
|
||||
is_type_system_extension_node,
|
||||
is_type_extension_node,
|
||||
)
|
||||
from .directive_locations import DirectiveLocation
|
||||
|
||||
__all__ = [
|
||||
"get_location",
|
||||
"SourceLocation",
|
||||
"FormattedSourceLocation",
|
||||
"print_location",
|
||||
"print_source_location",
|
||||
"TokenKind",
|
||||
"Lexer",
|
||||
"parse",
|
||||
"parse_value",
|
||||
"parse_const_value",
|
||||
"parse_type",
|
||||
"print_ast",
|
||||
"Source",
|
||||
"visit",
|
||||
"Visitor",
|
||||
"ParallelVisitor",
|
||||
"VisitorAction",
|
||||
"VisitorKeyMap",
|
||||
"BREAK",
|
||||
"SKIP",
|
||||
"REMOVE",
|
||||
"IDLE",
|
||||
"Location",
|
||||
"Token",
|
||||
"DirectiveLocation",
|
||||
"Node",
|
||||
"NameNode",
|
||||
"DocumentNode",
|
||||
"DefinitionNode",
|
||||
"ExecutableDefinitionNode",
|
||||
"OperationDefinitionNode",
|
||||
"OperationType",
|
||||
"VariableDefinitionNode",
|
||||
"VariableNode",
|
||||
"SelectionSetNode",
|
||||
"SelectionNode",
|
||||
"FieldNode",
|
||||
"ArgumentNode",
|
||||
"ConstArgumentNode",
|
||||
"FragmentSpreadNode",
|
||||
"InlineFragmentNode",
|
||||
"FragmentDefinitionNode",
|
||||
"ValueNode",
|
||||
"ConstValueNode",
|
||||
"IntValueNode",
|
||||
"FloatValueNode",
|
||||
"StringValueNode",
|
||||
"BooleanValueNode",
|
||||
"NullValueNode",
|
||||
"EnumValueNode",
|
||||
"ListValueNode",
|
||||
"ConstListValueNode",
|
||||
"ObjectValueNode",
|
||||
"ConstObjectValueNode",
|
||||
"ObjectFieldNode",
|
||||
"ConstObjectFieldNode",
|
||||
"DirectiveNode",
|
||||
"ConstDirectiveNode",
|
||||
"TypeNode",
|
||||
"NamedTypeNode",
|
||||
"ListTypeNode",
|
||||
"NonNullTypeNode",
|
||||
"TypeSystemDefinitionNode",
|
||||
"SchemaDefinitionNode",
|
||||
"OperationTypeDefinitionNode",
|
||||
"TypeDefinitionNode",
|
||||
"ScalarTypeDefinitionNode",
|
||||
"ObjectTypeDefinitionNode",
|
||||
"FieldDefinitionNode",
|
||||
"InputValueDefinitionNode",
|
||||
"InterfaceTypeDefinitionNode",
|
||||
"UnionTypeDefinitionNode",
|
||||
"EnumTypeDefinitionNode",
|
||||
"EnumValueDefinitionNode",
|
||||
"InputObjectTypeDefinitionNode",
|
||||
"DirectiveDefinitionNode",
|
||||
"TypeSystemExtensionNode",
|
||||
"SchemaExtensionNode",
|
||||
"TypeExtensionNode",
|
||||
"ScalarTypeExtensionNode",
|
||||
"ObjectTypeExtensionNode",
|
||||
"InterfaceTypeExtensionNode",
|
||||
"UnionTypeExtensionNode",
|
||||
"EnumTypeExtensionNode",
|
||||
"InputObjectTypeExtensionNode",
|
||||
"is_definition_node",
|
||||
"is_executable_definition_node",
|
||||
"is_selection_node",
|
||||
"is_value_node",
|
||||
"is_const_value_node",
|
||||
"is_type_node",
|
||||
"is_type_system_definition_node",
|
||||
"is_type_definition_node",
|
||||
"is_type_system_extension_node",
|
||||
"is_type_extension_node",
|
||||
]
|
||||
@@ -0,0 +1,807 @@
|
||||
from copy import copy, deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
from .source import Source
|
||||
from .token_kind import TokenKind
|
||||
from ..pyutils import camel_to_snake
|
||||
|
||||
__all__ = [
|
||||
"Location",
|
||||
"Token",
|
||||
"Node",
|
||||
"NameNode",
|
||||
"DocumentNode",
|
||||
"DefinitionNode",
|
||||
"ExecutableDefinitionNode",
|
||||
"OperationDefinitionNode",
|
||||
"VariableDefinitionNode",
|
||||
"SelectionSetNode",
|
||||
"SelectionNode",
|
||||
"FieldNode",
|
||||
"ArgumentNode",
|
||||
"ConstArgumentNode",
|
||||
"FragmentSpreadNode",
|
||||
"InlineFragmentNode",
|
||||
"FragmentDefinitionNode",
|
||||
"ValueNode",
|
||||
"ConstValueNode",
|
||||
"VariableNode",
|
||||
"IntValueNode",
|
||||
"FloatValueNode",
|
||||
"StringValueNode",
|
||||
"BooleanValueNode",
|
||||
"NullValueNode",
|
||||
"EnumValueNode",
|
||||
"ListValueNode",
|
||||
"ConstListValueNode",
|
||||
"ObjectValueNode",
|
||||
"ConstObjectValueNode",
|
||||
"ObjectFieldNode",
|
||||
"ConstObjectFieldNode",
|
||||
"DirectiveNode",
|
||||
"ConstDirectiveNode",
|
||||
"TypeNode",
|
||||
"NamedTypeNode",
|
||||
"ListTypeNode",
|
||||
"NonNullTypeNode",
|
||||
"TypeSystemDefinitionNode",
|
||||
"SchemaDefinitionNode",
|
||||
"OperationType",
|
||||
"OperationTypeDefinitionNode",
|
||||
"TypeDefinitionNode",
|
||||
"ScalarTypeDefinitionNode",
|
||||
"ObjectTypeDefinitionNode",
|
||||
"FieldDefinitionNode",
|
||||
"InputValueDefinitionNode",
|
||||
"InterfaceTypeDefinitionNode",
|
||||
"UnionTypeDefinitionNode",
|
||||
"EnumTypeDefinitionNode",
|
||||
"EnumValueDefinitionNode",
|
||||
"InputObjectTypeDefinitionNode",
|
||||
"DirectiveDefinitionNode",
|
||||
"SchemaExtensionNode",
|
||||
"TypeExtensionNode",
|
||||
"TypeSystemExtensionNode",
|
||||
"ScalarTypeExtensionNode",
|
||||
"ObjectTypeExtensionNode",
|
||||
"InterfaceTypeExtensionNode",
|
||||
"UnionTypeExtensionNode",
|
||||
"EnumTypeExtensionNode",
|
||||
"InputObjectTypeExtensionNode",
|
||||
"QUERY_DOCUMENT_KEYS",
|
||||
]
|
||||
|
||||
|
||||
class Token:
|
||||
"""AST Token
|
||||
|
||||
Represents a range of characters represented by a lexical token within a Source.
|
||||
"""
|
||||
|
||||
__slots__ = "kind", "start", "end", "line", "column", "prev", "next", "value"
|
||||
|
||||
kind: TokenKind # the kind of token
|
||||
start: int # the character offset at which this Node begins
|
||||
end: int # the character offset at which this Node ends
|
||||
line: int # the 1-indexed line number on which this Token appears
|
||||
column: int # the 1-indexed column number at which this Token begins
|
||||
# for non-punctuation tokens, represents the interpreted value of the token:
|
||||
value: Optional[str]
|
||||
# Tokens exist as nodes in a double-linked-list amongst all tokens including
|
||||
# ignored tokens. <SOF> is always the first node and <EOF> the last.
|
||||
prev: Optional["Token"]
|
||||
next: Optional["Token"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kind: TokenKind,
|
||||
start: int,
|
||||
end: int,
|
||||
line: int,
|
||||
column: int,
|
||||
value: Optional[str] = None,
|
||||
) -> None:
|
||||
self.kind = kind
|
||||
self.start, self.end = start, end
|
||||
self.line, self.column = line, column
|
||||
self.value = value
|
||||
self.prev = self.next = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.desc
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print a simplified form when appearing in repr() or inspect()."""
|
||||
return f"<Token {self.desc} {self.line}:{self.column}>"
|
||||
|
||||
def __inspect__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Token):
|
||||
return (
|
||||
self.kind == other.kind
|
||||
and self.start == other.start
|
||||
and self.end == other.end
|
||||
and self.line == other.line
|
||||
and self.column == other.column
|
||||
and self.value == other.value
|
||||
)
|
||||
elif isinstance(other, str):
|
||||
return other == self.desc
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(self.kind, self.start, self.end, self.line, self.column, self.value)
|
||||
)
|
||||
|
||||
def __copy__(self) -> "Token":
|
||||
"""Create a shallow copy of the token"""
|
||||
token = self.__class__(
|
||||
self.kind,
|
||||
self.start,
|
||||
self.end,
|
||||
self.line,
|
||||
self.column,
|
||||
self.value,
|
||||
)
|
||||
token.prev = self.prev
|
||||
return token
|
||||
|
||||
def __deepcopy__(self, memo: Dict) -> "Token":
|
||||
"""Allow only shallow copies to avoid recursion."""
|
||||
return copy(self)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
"""Remove the links when pickling.
|
||||
|
||||
Keeping the links would make pickling a schema too expensive.
|
||||
"""
|
||||
return {
|
||||
key: getattr(self, key)
|
||||
for key in self.__slots__
|
||||
if key not in {"prev", "next"}
|
||||
}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
"""Reset the links when un-pickling."""
|
||||
for key, value in state.items():
|
||||
setattr(self, key, value)
|
||||
self.prev = self.next = None
|
||||
|
||||
@property
|
||||
def desc(self) -> str:
|
||||
"""A helper property to describe a token as a string for debugging"""
|
||||
kind, value = self.kind.value, self.value
|
||||
return f"{kind} {value!r}" if value else kind
|
||||
|
||||
|
||||
class Location:
|
||||
"""AST Location
|
||||
|
||||
Contains a range of UTF-8 character offsets and token references that identify the
|
||||
region of the source from which the AST derived.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"start",
|
||||
"end",
|
||||
"start_token",
|
||||
"end_token",
|
||||
"source",
|
||||
)
|
||||
|
||||
start: int # character offset at which this Node begins
|
||||
end: int # character offset at which this Node ends
|
||||
start_token: Token # Token at which this Node begins
|
||||
end_token: Token # Token at which this Node ends.
|
||||
source: Source # Source document the AST represents
|
||||
|
||||
def __init__(self, start_token: Token, end_token: Token, source: Source) -> None:
|
||||
self.start = start_token.start
|
||||
self.end = end_token.end
|
||||
self.start_token = start_token
|
||||
self.end_token = end_token
|
||||
self.source = source
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.start}:{self.end}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Print a simplified form when appearing in repr() or inspect()."""
|
||||
return f"<Location {self.start}:{self.end}>"
|
||||
|
||||
def __inspect__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Location):
|
||||
return self.start == other.start and self.end == other.end
|
||||
elif isinstance(other, (list, tuple)) and len(other) == 2:
|
||||
return self.start == other[0] and self.end == other[1]
|
||||
return False
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.start, self.end))
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
|
||||
QUERY = "query"
|
||||
MUTATION = "mutation"
|
||||
SUBSCRIPTION = "subscription"
|
||||
|
||||
|
||||
# Default map from node kinds to their node attributes (internal)
|
||||
QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
|
||||
"name": (),
|
||||
"document": ("definitions",),
|
||||
"operation_definition": (
|
||||
"name",
|
||||
"variable_definitions",
|
||||
"directives",
|
||||
"selection_set",
|
||||
),
|
||||
"variable_definition": ("variable", "type", "default_value", "directives"),
|
||||
"variable": ("name",),
|
||||
"selection_set": ("selections",),
|
||||
"field": ("alias", "name", "arguments", "directives", "selection_set"),
|
||||
"argument": ("name", "value"),
|
||||
"fragment_spread": ("name", "directives"),
|
||||
"inline_fragment": ("type_condition", "directives", "selection_set"),
|
||||
"fragment_definition": (
|
||||
# Note: fragment variable definitions are deprecated and will be removed in v3.3
|
||||
"name",
|
||||
"variable_definitions",
|
||||
"type_condition",
|
||||
"directives",
|
||||
"selection_set",
|
||||
),
|
||||
"list_value": ("values",),
|
||||
"object_value": ("fields",),
|
||||
"object_field": ("name", "value"),
|
||||
"directive": ("name", "arguments"),
|
||||
"named_type": ("name",),
|
||||
"list_type": ("type",),
|
||||
"non_null_type": ("type",),
|
||||
"schema_definition": ("description", "directives", "operation_types"),
|
||||
"operation_type_definition": ("type",),
|
||||
"scalar_type_definition": ("description", "name", "directives"),
|
||||
"object_type_definition": (
|
||||
"description",
|
||||
"name",
|
||||
"interfaces",
|
||||
"directives",
|
||||
"fields",
|
||||
),
|
||||
"field_definition": ("description", "name", "arguments", "type", "directives"),
|
||||
"input_value_definition": (
|
||||
"description",
|
||||
"name",
|
||||
"type",
|
||||
"default_value",
|
||||
"directives",
|
||||
),
|
||||
"interface_type_definition": (
|
||||
"description",
|
||||
"name",
|
||||
"interfaces",
|
||||
"directives",
|
||||
"fields",
|
||||
),
|
||||
"union_type_definition": ("description", "name", "directives", "types"),
|
||||
"enum_type_definition": ("description", "name", "directives", "values"),
|
||||
"enum_value_definition": ("description", "name", "directives"),
|
||||
"input_object_type_definition": ("description", "name", "directives", "fields"),
|
||||
"directive_definition": ("description", "name", "arguments", "locations"),
|
||||
"schema_extension": ("directives", "operation_types"),
|
||||
"scalar_type_extension": ("name", "directives"),
|
||||
"object_type_extension": ("name", "interfaces", "directives", "fields"),
|
||||
"interface_type_extension": ("name", "interfaces", "directives", "fields"),
|
||||
"union_type_extension": ("name", "directives", "types"),
|
||||
"enum_type_extension": ("name", "directives", "values"),
|
||||
"input_object_type_extension": ("name", "directives", "fields"),
|
||||
}
|
||||
|
||||
|
||||
# Base AST Node
|
||||
|
||||
|
||||
class Node:
|
||||
"""AST nodes"""
|
||||
|
||||
# allow custom attributes and weak references (not used internally)
|
||||
__slots__ = "__dict__", "__weakref__", "loc", "_hash"
|
||||
|
||||
loc: Optional[Location]
|
||||
|
||||
kind: str = "ast" # the kind of the node as a snake_case string
|
||||
keys: Tuple[str, ...] = ("loc",) # the names of the attributes of this node
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the node with the given keyword arguments."""
|
||||
for key in self.keys:
|
||||
value = kwargs.get(key)
|
||||
if isinstance(value, list):
|
||||
value = tuple(value)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Get a simple representation of the node."""
|
||||
name, loc = self.__class__.__name__, getattr(self, "loc", None)
|
||||
return f"{name} at {loc}" if loc else name
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Test whether two nodes are equal (recursively)."""
|
||||
return (
|
||||
isinstance(other, Node)
|
||||
and self.__class__ == other.__class__
|
||||
and all(getattr(self, key) == getattr(other, key) for key in self.keys)
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Get a cached hash value for the node."""
|
||||
# Caching the hash values improves the performance of AST validators
|
||||
hashed = getattr(self, "_hash", None)
|
||||
if hashed is None:
|
||||
self._hash = id(self) # avoid recursion
|
||||
hashed = hash(tuple(getattr(self, key) for key in self.keys))
|
||||
self._hash = hashed
|
||||
return hashed
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
# reset cashed hash value if attributes are changed
|
||||
if hasattr(self, "_hash") and key in self.keys:
|
||||
del self._hash
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __copy__(self) -> "Node":
|
||||
"""Create a shallow copy of the node."""
|
||||
return self.__class__(**{key: getattr(self, key) for key in self.keys})
|
||||
|
||||
def __deepcopy__(self, memo: Dict) -> "Node":
|
||||
"""Create a deep copy of the node"""
|
||||
# noinspection PyArgumentList
|
||||
return self.__class__(
|
||||
**{key: deepcopy(getattr(self, key), memo) for key in self.keys}
|
||||
)
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
name = cls.__name__
|
||||
try:
|
||||
name = name.removeprefix("Const").removesuffix("Node")
|
||||
except AttributeError: # pragma: no cover (Python < 3.9)
|
||||
if name.startswith("Const"):
|
||||
name = name[5:]
|
||||
if name.endswith("Node"):
|
||||
name = name[:-4]
|
||||
cls.kind = camel_to_snake(name)
|
||||
keys: List[str] = []
|
||||
for base in cls.__bases__:
|
||||
# noinspection PyUnresolvedReferences
|
||||
keys.extend(base.keys) # type: ignore
|
||||
keys.extend(cls.__slots__)
|
||||
cls.keys = tuple(keys)
|
||||
|
||||
def to_dict(self, locations: bool = False) -> Dict:
|
||||
from ..utilities import ast_to_dict
|
||||
|
||||
return ast_to_dict(self, locations)
|
||||
|
||||
|
||||
# Name
|
||||
|
||||
|
||||
class NameNode(Node):
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
# Document
|
||||
|
||||
|
||||
class DocumentNode(Node):
|
||||
__slots__ = ("definitions",)
|
||||
|
||||
definitions: Tuple["DefinitionNode", ...]
|
||||
|
||||
|
||||
class DefinitionNode(Node):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class ExecutableDefinitionNode(DefinitionNode):
|
||||
__slots__ = "name", "directives", "variable_definitions", "selection_set"
|
||||
|
||||
name: Optional[NameNode]
|
||||
directives: Tuple["DirectiveNode", ...]
|
||||
variable_definitions: Tuple["VariableDefinitionNode", ...]
|
||||
selection_set: "SelectionSetNode"
|
||||
|
||||
|
||||
class OperationDefinitionNode(ExecutableDefinitionNode):
|
||||
__slots__ = ("operation",)
|
||||
|
||||
operation: OperationType
|
||||
|
||||
|
||||
class VariableDefinitionNode(Node):
|
||||
__slots__ = "variable", "type", "default_value", "directives"
|
||||
|
||||
variable: "VariableNode"
|
||||
type: "TypeNode"
|
||||
default_value: Optional["ConstValueNode"]
|
||||
directives: Tuple["ConstDirectiveNode", ...]
|
||||
|
||||
|
||||
class SelectionSetNode(Node):
|
||||
__slots__ = ("selections",)
|
||||
|
||||
selections: Tuple["SelectionNode", ...]
|
||||
|
||||
|
||||
class SelectionNode(Node):
|
||||
__slots__ = ("directives",)
|
||||
|
||||
directives: Tuple["DirectiveNode", ...]
|
||||
|
||||
|
||||
class FieldNode(SelectionNode):
|
||||
__slots__ = "alias", "name", "arguments", "selection_set"
|
||||
|
||||
alias: Optional[NameNode]
|
||||
name: NameNode
|
||||
arguments: Tuple["ArgumentNode", ...]
|
||||
selection_set: Optional[SelectionSetNode]
|
||||
|
||||
|
||||
class ArgumentNode(Node):
|
||||
__slots__ = "name", "value"
|
||||
|
||||
name: NameNode
|
||||
value: "ValueNode"
|
||||
|
||||
|
||||
class ConstArgumentNode(ArgumentNode):
|
||||
|
||||
value: "ConstValueNode"
|
||||
|
||||
|
||||
# Fragments
|
||||
|
||||
|
||||
class FragmentSpreadNode(SelectionNode):
|
||||
__slots__ = ("name",)
|
||||
|
||||
name: NameNode
|
||||
|
||||
|
||||
class InlineFragmentNode(SelectionNode):
|
||||
__slots__ = "type_condition", "selection_set"
|
||||
|
||||
type_condition: "NamedTypeNode"
|
||||
selection_set: SelectionSetNode
|
||||
|
||||
|
||||
class FragmentDefinitionNode(ExecutableDefinitionNode):
|
||||
__slots__ = ("type_condition",)
|
||||
|
||||
name: NameNode
|
||||
type_condition: "NamedTypeNode"
|
||||
|
||||
|
||||
# Values
|
||||
|
||||
|
||||
class ValueNode(Node):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class VariableNode(ValueNode):
|
||||
__slots__ = ("name",)
|
||||
|
||||
name: NameNode
|
||||
|
||||
|
||||
class IntValueNode(ValueNode):
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class FloatValueNode(ValueNode):
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class StringValueNode(ValueNode):
|
||||
__slots__ = "value", "block"
|
||||
|
||||
value: str
|
||||
block: Optional[bool]
|
||||
|
||||
|
||||
class BooleanValueNode(ValueNode):
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: bool
|
||||
|
||||
|
||||
class NullValueNode(ValueNode):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class EnumValueNode(ValueNode):
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: str
|
||||
|
||||
|
||||
class ListValueNode(ValueNode):
|
||||
__slots__ = ("values",)
|
||||
|
||||
values: Tuple[ValueNode, ...]
|
||||
|
||||
|
||||
class ConstListValueNode(ListValueNode):
|
||||
|
||||
values: Tuple["ConstValueNode", ...]
|
||||
|
||||
|
||||
class ObjectValueNode(ValueNode):
|
||||
__slots__ = ("fields",)
|
||||
|
||||
fields: Tuple["ObjectFieldNode", ...]
|
||||
|
||||
|
||||
class ConstObjectValueNode(ObjectValueNode):
|
||||
|
||||
fields: Tuple["ConstObjectFieldNode", ...]
|
||||
|
||||
|
||||
class ObjectFieldNode(Node):
|
||||
__slots__ = "name", "value"
|
||||
|
||||
name: NameNode
|
||||
value: ValueNode
|
||||
|
||||
|
||||
class ConstObjectFieldNode(ObjectFieldNode):
|
||||
|
||||
value: "ConstValueNode"
|
||||
|
||||
|
||||
ConstValueNode = Union[
|
||||
IntValueNode,
|
||||
FloatValueNode,
|
||||
StringValueNode,
|
||||
BooleanValueNode,
|
||||
NullValueNode,
|
||||
EnumValueNode,
|
||||
ConstListValueNode,
|
||||
ConstObjectValueNode,
|
||||
]
|
||||
|
||||
|
||||
# Directives
|
||||
|
||||
|
||||
class DirectiveNode(Node):
|
||||
__slots__ = "name", "arguments"
|
||||
|
||||
name: NameNode
|
||||
arguments: Tuple[ArgumentNode, ...]
|
||||
|
||||
|
||||
class ConstDirectiveNode(DirectiveNode):
|
||||
|
||||
arguments: Tuple[ConstArgumentNode, ...]
|
||||
|
||||
|
||||
# Type Reference
|
||||
|
||||
|
||||
class TypeNode(Node):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class NamedTypeNode(TypeNode):
|
||||
__slots__ = ("name",)
|
||||
|
||||
name: NameNode
|
||||
|
||||
|
||||
class ListTypeNode(TypeNode):
|
||||
__slots__ = ("type",)
|
||||
|
||||
type: TypeNode
|
||||
|
||||
|
||||
class NonNullTypeNode(TypeNode):
|
||||
__slots__ = ("type",)
|
||||
|
||||
type: Union[NamedTypeNode, ListTypeNode]
|
||||
|
||||
|
||||
# Type System Definition
|
||||
|
||||
|
||||
class TypeSystemDefinitionNode(DefinitionNode):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class SchemaDefinitionNode(TypeSystemDefinitionNode):
|
||||
__slots__ = "description", "directives", "operation_types"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
operation_types: Tuple["OperationTypeDefinitionNode", ...]
|
||||
|
||||
|
||||
class OperationTypeDefinitionNode(Node):
|
||||
__slots__ = "operation", "type"
|
||||
|
||||
operation: OperationType
|
||||
type: NamedTypeNode
|
||||
|
||||
|
||||
# Type Definition
|
||||
|
||||
|
||||
class TypeDefinitionNode(TypeSystemDefinitionNode):
|
||||
__slots__ = "description", "name", "directives"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
name: NameNode
|
||||
directives: Tuple[DirectiveNode, ...]
|
||||
|
||||
|
||||
class ScalarTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = ()
|
||||
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
|
||||
|
||||
class ObjectTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = "interfaces", "fields"
|
||||
|
||||
interfaces: Tuple[NamedTypeNode, ...]
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
fields: Tuple["FieldDefinitionNode", ...]
|
||||
|
||||
|
||||
class FieldDefinitionNode(DefinitionNode):
|
||||
__slots__ = "description", "name", "directives", "arguments", "type"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
name: NameNode
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
arguments: Tuple["InputValueDefinitionNode", ...]
|
||||
type: TypeNode
|
||||
|
||||
|
||||
class InputValueDefinitionNode(DefinitionNode):
|
||||
__slots__ = "description", "name", "directives", "type", "default_value"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
name: NameNode
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
type: TypeNode
|
||||
default_value: Optional[ConstValueNode]
|
||||
|
||||
|
||||
class InterfaceTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = "fields", "interfaces"
|
||||
|
||||
fields: Tuple["FieldDefinitionNode", ...]
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
interfaces: Tuple[NamedTypeNode, ...]
|
||||
|
||||
|
||||
class UnionTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = ("types",)
|
||||
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
types: Tuple[NamedTypeNode, ...]
|
||||
|
||||
|
||||
class EnumTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = ("values",)
|
||||
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
values: Tuple["EnumValueDefinitionNode", ...]
|
||||
|
||||
|
||||
class EnumValueDefinitionNode(DefinitionNode):
|
||||
__slots__ = "description", "name", "directives"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
name: NameNode
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
|
||||
|
||||
class InputObjectTypeDefinitionNode(TypeDefinitionNode):
|
||||
__slots__ = ("fields",)
|
||||
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
fields: Tuple[InputValueDefinitionNode, ...]
|
||||
|
||||
|
||||
# Directive Definitions
|
||||
|
||||
|
||||
class DirectiveDefinitionNode(TypeSystemDefinitionNode):
|
||||
__slots__ = "description", "name", "arguments", "repeatable", "locations"
|
||||
|
||||
description: Optional[StringValueNode]
|
||||
name: NameNode
|
||||
arguments: Tuple[InputValueDefinitionNode, ...]
|
||||
repeatable: bool
|
||||
locations: Tuple[NameNode, ...]
|
||||
|
||||
|
||||
# Type System Extensions
|
||||
|
||||
|
||||
class SchemaExtensionNode(Node):
|
||||
__slots__ = "directives", "operation_types"
|
||||
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
operation_types: Tuple[OperationTypeDefinitionNode, ...]
|
||||
|
||||
|
||||
# Type Extensions
|
||||
|
||||
|
||||
class TypeExtensionNode(TypeSystemDefinitionNode):
|
||||
__slots__ = "name", "directives"
|
||||
|
||||
name: NameNode
|
||||
directives: Tuple[ConstDirectiveNode, ...]
|
||||
|
||||
|
||||
TypeSystemExtensionNode = Union[SchemaExtensionNode, TypeExtensionNode]
|
||||
|
||||
|
||||
class ScalarTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class ObjectTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = "interfaces", "fields"
|
||||
|
||||
interfaces: Tuple[NamedTypeNode, ...]
|
||||
fields: Tuple[FieldDefinitionNode, ...]
|
||||
|
||||
|
||||
class InterfaceTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = "interfaces", "fields"
|
||||
|
||||
interfaces: Tuple[NamedTypeNode, ...]
|
||||
fields: Tuple[FieldDefinitionNode, ...]
|
||||
|
||||
|
||||
class UnionTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = ("types",)
|
||||
|
||||
types: Tuple[NamedTypeNode, ...]
|
||||
|
||||
|
||||
class EnumTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = ("values",)
|
||||
|
||||
values: Tuple[EnumValueDefinitionNode, ...]
|
||||
|
||||
|
||||
class InputObjectTypeExtensionNode(TypeExtensionNode):
|
||||
__slots__ = ("fields",)
|
||||
|
||||
fields: Tuple[InputValueDefinitionNode, ...]
|
||||
@@ -0,0 +1,155 @@
|
||||
from typing import Collection, List
|
||||
from sys import maxsize
|
||||
|
||||
__all__ = [
|
||||
"dedent_block_string_lines",
|
||||
"is_printable_as_block_string",
|
||||
"print_block_string",
|
||||
]
|
||||
|
||||
|
||||
def dedent_block_string_lines(lines: Collection[str]) -> List[str]:
|
||||
"""Produce the value of a block string from its parsed raw value.
|
||||
|
||||
This function works similar to CoffeeScript's block string,
|
||||
Python's docstring trim or Ruby's strip_heredoc.
|
||||
|
||||
It implements the GraphQL spec's BlockStringValue() static algorithm.
|
||||
|
||||
Note that this is very similar to Python's inspect.cleandoc() function.
|
||||
The difference is that the latter also expands tabs to spaces and
|
||||
removes whitespace at the beginning of the first line. Python also has
|
||||
textwrap.dedent() which uses a completely different algorithm.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
common_indent = maxsize
|
||||
first_non_empty_line = None
|
||||
last_non_empty_line = -1
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
indent = leading_white_space(line)
|
||||
|
||||
if indent == len(line):
|
||||
continue # skip empty lines
|
||||
|
||||
if first_non_empty_line is None:
|
||||
first_non_empty_line = i
|
||||
last_non_empty_line = i
|
||||
|
||||
if i and indent < common_indent:
|
||||
common_indent = indent
|
||||
|
||||
if first_non_empty_line is None:
|
||||
first_non_empty_line = 0
|
||||
|
||||
return [ # Remove common indentation from all lines but first.
|
||||
line[common_indent:] if i else line for i, line in enumerate(lines)
|
||||
][ # Remove leading and trailing blank lines.
|
||||
first_non_empty_line : last_non_empty_line + 1
|
||||
]
|
||||
|
||||
|
||||
def leading_white_space(s: str) -> int:
|
||||
i = 0
|
||||
for c in s:
|
||||
if c not in " \t":
|
||||
return i
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def is_printable_as_block_string(value: str) -> bool:
|
||||
"""Check whether the given string is printable as a block string.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
value = str(value) # resolve lazy string proxy object
|
||||
|
||||
if not value:
|
||||
return True # emtpy string is printable
|
||||
|
||||
is_empty_line = True
|
||||
has_indent = False
|
||||
has_common_indent = True
|
||||
seen_non_empty_line = False
|
||||
|
||||
for c in value:
|
||||
if c == "\n":
|
||||
if is_empty_line and not seen_non_empty_line:
|
||||
return False # has leading new line
|
||||
seen_non_empty_line = True
|
||||
is_empty_line = True
|
||||
has_indent = False
|
||||
elif c in " \t":
|
||||
has_indent = has_indent or is_empty_line
|
||||
elif c <= "\x0f":
|
||||
return False
|
||||
else:
|
||||
has_common_indent = has_common_indent and has_indent
|
||||
is_empty_line = False
|
||||
|
||||
if is_empty_line:
|
||||
return False # has trailing empty lines
|
||||
|
||||
if has_common_indent and seen_non_empty_line:
|
||||
return False # has internal indent
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def print_block_string(value: str, minimize: bool = False) -> str:
|
||||
"""Print a block string in the indented block form.
|
||||
|
||||
Prints a block string in the indented block form by adding a leading and
|
||||
trailing blank line. However, if a block string starts with whitespace and
|
||||
is a single-line, adding a leading blank line would strip that whitespace.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
value = str(value) # resolve lazy string proxy object
|
||||
|
||||
escaped_value = value.replace('"""', '\\"""')
|
||||
|
||||
# Expand a block string's raw value into independent lines.
|
||||
lines = escaped_value.splitlines() or [""]
|
||||
num_lines = len(lines)
|
||||
is_single_line = num_lines == 1
|
||||
|
||||
# If common indentation is found,
|
||||
# we can fix some of those cases by adding a leading new line.
|
||||
force_leading_new_line = num_lines > 1 and all(
|
||||
not line or line[0] in " \t" for line in lines[1:]
|
||||
)
|
||||
|
||||
# Trailing triple quotes just looks confusing but doesn't force trailing new line.
|
||||
has_trailing_triple_quotes = escaped_value.endswith('\\"""')
|
||||
|
||||
# Trailing quote (single or double) or slash forces trailing new line
|
||||
has_trailing_quote = value.endswith('"') and not has_trailing_triple_quotes
|
||||
has_trailing_slash = value.endswith("\\")
|
||||
force_trailing_new_line = has_trailing_quote or has_trailing_slash
|
||||
|
||||
print_as_multiple_lines = not minimize and (
|
||||
# add leading and trailing new lines only if it improves readability
|
||||
not is_single_line
|
||||
or len(value) > 70
|
||||
or force_trailing_new_line
|
||||
or force_leading_new_line
|
||||
or has_trailing_triple_quotes
|
||||
)
|
||||
|
||||
# Format a multi-line block quote to account for leading space.
|
||||
skip_leading_new_line = is_single_line and value and value[0] in " \t"
|
||||
before = (
|
||||
"\n"
|
||||
if print_as_multiple_lines
|
||||
and not skip_leading_new_line
|
||||
or force_leading_new_line
|
||||
else ""
|
||||
)
|
||||
after = "\n" if print_as_multiple_lines or force_trailing_new_line else ""
|
||||
|
||||
return f'"""{before}{escaped_value}{after}"""'
|
||||
@@ -0,0 +1,68 @@
|
||||
__all__ = ["is_digit", "is_letter", "is_name_start", "is_name_continue"]
|
||||
|
||||
try:
|
||||
"string".isascii()
|
||||
except AttributeError: # Python < 3.7
|
||||
|
||||
def is_digit(char: str) -> bool:
|
||||
"""Check whether char is a digit
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return "0" <= char <= "9"
|
||||
|
||||
def is_letter(char: str) -> bool:
|
||||
"""Check whether char is a plain ASCII letter
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return "a" <= char <= "z" or "A" <= char <= "Z"
|
||||
|
||||
def is_name_start(char: str) -> bool:
|
||||
"""Check whether char is allowed at the beginning of a GraphQL name
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return "a" <= char <= "z" or "A" <= char <= "Z" or char == "_"
|
||||
|
||||
def is_name_continue(char: str) -> bool:
|
||||
"""Check whether char is allowed in the continuation of a GraphQL name
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return (
|
||||
"a" <= char <= "z"
|
||||
or "A" <= char <= "Z"
|
||||
or "0" <= char <= "9"
|
||||
or char == "_"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def is_digit(char: str) -> bool:
|
||||
"""Check whether char is a digit
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return char.isascii() and char.isdigit()
|
||||
|
||||
def is_letter(char: str) -> bool:
|
||||
"""Check whether char is a plain ASCII letter
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return char.isascii() and char.isalpha()
|
||||
|
||||
def is_name_start(char: str) -> bool:
|
||||
"""Check whether char is allowed at the beginning of a GraphQL name
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return char.isascii() and (char.isalpha() or char == "_")
|
||||
|
||||
def is_name_continue(char: str) -> bool:
|
||||
"""Check whether char is allowed in the continuation of a GraphQL name
|
||||
|
||||
For internal use by the lexer only.
|
||||
"""
|
||||
return char.isascii() and (char.isalnum() or char == "_")
|
||||
@@ -0,0 +1,30 @@
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["DirectiveLocation"]
|
||||
|
||||
|
||||
class DirectiveLocation(Enum):
|
||||
"""The enum type representing the directive location values."""
|
||||
|
||||
# Request Definitions
|
||||
QUERY = "query"
|
||||
MUTATION = "mutation"
|
||||
SUBSCRIPTION = "subscription"
|
||||
FIELD = "field"
|
||||
FRAGMENT_DEFINITION = "fragment definition"
|
||||
FRAGMENT_SPREAD = "fragment spread"
|
||||
VARIABLE_DEFINITION = "variable definition"
|
||||
INLINE_FRAGMENT = "inline fragment"
|
||||
|
||||
# Type System Definitions
|
||||
SCHEMA = "schema"
|
||||
SCALAR = "scalar"
|
||||
OBJECT = "object"
|
||||
FIELD_DEFINITION = "field definition"
|
||||
ARGUMENT_DEFINITION = "argument definition"
|
||||
INTERFACE = "interface"
|
||||
UNION = "union"
|
||||
ENUM = "enum"
|
||||
ENUM_VALUE = "enum value"
|
||||
INPUT_OBJECT = "input object"
|
||||
INPUT_FIELD_DEFINITION = "input field definition"
|
||||
@@ -0,0 +1,574 @@
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
from ..error import GraphQLSyntaxError
|
||||
from .ast import Token
|
||||
from .block_string import dedent_block_string_lines
|
||||
from .character_classes import is_digit, is_name_start, is_name_continue
|
||||
from .source import Source
|
||||
from .token_kind import TokenKind
|
||||
|
||||
__all__ = ["Lexer", "is_punctuator_token_kind"]
|
||||
|
||||
|
||||
class EscapeSequence(NamedTuple):
|
||||
"""The string value and lexed size of an escape sequence."""
|
||||
|
||||
value: str
|
||||
size: int
|
||||
|
||||
|
||||
class Lexer:
|
||||
"""GraphQL Lexer
|
||||
|
||||
A Lexer is a stateful stream generator in that every time it is advanced, it returns
|
||||
the next token in the Source. Assuming the source lexes, the final Token emitted by
|
||||
the lexer will be of kind EOF, after which the lexer will repeatedly return the same
|
||||
EOF token whenever called.
|
||||
"""
|
||||
|
||||
def __init__(self, source: Source):
|
||||
"""Given a Source object, initialize a Lexer for that source."""
|
||||
self.source = source
|
||||
self.token = self.last_token = Token(TokenKind.SOF, 0, 0, 0, 0)
|
||||
self.line, self.line_start = 1, 0
|
||||
|
||||
def advance(self) -> Token:
|
||||
"""Advance the token stream to the next non-ignored token."""
|
||||
self.last_token = self.token
|
||||
token = self.token = self.lookahead()
|
||||
return token
|
||||
|
||||
def lookahead(self) -> Token:
|
||||
"""Look ahead and return the next non-ignored token, but do not change state."""
|
||||
token = self.token
|
||||
if token.kind != TokenKind.EOF:
|
||||
while True:
|
||||
if token.next:
|
||||
token = token.next
|
||||
else:
|
||||
# Read the next token and form a link in the token linked-list.
|
||||
next_token = self.read_next_token(token.end)
|
||||
token.next = next_token
|
||||
next_token.prev = token
|
||||
token = next_token
|
||||
if token.kind != TokenKind.COMMENT:
|
||||
break
|
||||
return token
|
||||
|
||||
def print_code_point_at(self, location: int) -> str:
|
||||
"""Print the code point at the given location.
|
||||
|
||||
Prints the code point (or end of file reference) at a given location in a
|
||||
source for use in error messages.
|
||||
|
||||
Printable ASCII is printed quoted, while other points are printed in Unicode
|
||||
code point form (ie. U+1234).
|
||||
"""
|
||||
body = self.source.body
|
||||
if location >= len(body):
|
||||
return TokenKind.EOF.value
|
||||
char = body[location]
|
||||
# Printable ASCII
|
||||
if "\x20" <= char <= "\x7E":
|
||||
return "'\"'" if char == '"' else f"'{char}'"
|
||||
# Unicode code point
|
||||
point = ord(
|
||||
body[location : location + 2]
|
||||
.encode("utf-16", "surrogatepass")
|
||||
.decode("utf-16")
|
||||
if is_supplementary_code_point(body, location)
|
||||
else char
|
||||
)
|
||||
return f"U+{point:04X}"
|
||||
|
||||
def create_token(
|
||||
self, kind: TokenKind, start: int, end: int, value: Optional[str] = None
|
||||
) -> Token:
|
||||
"""Create a token with line and column location information."""
|
||||
line = self.line
|
||||
col = 1 + start - self.line_start
|
||||
return Token(kind, start, end, line, col, value)
|
||||
|
||||
def read_next_token(self, start: int) -> Token:
|
||||
"""Get the next token from the source starting at the given position.
|
||||
|
||||
This skips over whitespace until it finds the next lexable token, then lexes
|
||||
punctuators immediately or calls the appropriate helper function for more
|
||||
complicated tokens.
|
||||
"""
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
position = start
|
||||
|
||||
while position < body_length:
|
||||
char = body[position] # SourceCharacter
|
||||
|
||||
if char in " \t,\ufeff":
|
||||
position += 1
|
||||
continue
|
||||
elif char == "\n":
|
||||
position += 1
|
||||
self.line += 1
|
||||
self.line_start = position
|
||||
continue
|
||||
elif char == "\r":
|
||||
if body[position + 1 : position + 2] == "\n":
|
||||
position += 2
|
||||
else:
|
||||
position += 1
|
||||
self.line += 1
|
||||
self.line_start = position
|
||||
continue
|
||||
|
||||
if char == "#":
|
||||
return self.read_comment(position)
|
||||
|
||||
if char == '"':
|
||||
if body[position + 1 : position + 3] == '""':
|
||||
return self.read_block_string(position)
|
||||
return self.read_string(position)
|
||||
|
||||
kind = _KIND_FOR_PUNCT.get(char)
|
||||
if kind:
|
||||
return self.create_token(kind, position, position + 1)
|
||||
|
||||
if is_digit(char) or char == "-":
|
||||
return self.read_number(position, char)
|
||||
|
||||
if is_name_start(char):
|
||||
return self.read_name(position)
|
||||
|
||||
if char == ".":
|
||||
if body[position + 1 : position + 3] == "..":
|
||||
return self.create_token(TokenKind.SPREAD, position, position + 3)
|
||||
|
||||
message = (
|
||||
"Unexpected single quote character ('),"
|
||||
' did you mean to use a double quote (")?'
|
||||
if char == "'"
|
||||
else (
|
||||
f"Unexpected character: {self.print_code_point_at(position)}."
|
||||
if is_unicode_scalar_value(char)
|
||||
or is_supplementary_code_point(body, position)
|
||||
else f"Invalid character: {self.print_code_point_at(position)}."
|
||||
)
|
||||
)
|
||||
|
||||
raise GraphQLSyntaxError(self.source, position, message)
|
||||
|
||||
return self.create_token(TokenKind.EOF, body_length, body_length)
|
||||
|
||||
def read_comment(self, start: int) -> Token:
|
||||
"""Read a comment token from the source file."""
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
|
||||
position = start + 1
|
||||
while position < body_length:
|
||||
char = body[position]
|
||||
if char in "\r\n":
|
||||
break
|
||||
if is_unicode_scalar_value(char):
|
||||
position += 1
|
||||
elif is_supplementary_code_point(body, position):
|
||||
position += 2
|
||||
else:
|
||||
break # pragma: no cover
|
||||
|
||||
return self.create_token(
|
||||
TokenKind.COMMENT,
|
||||
start,
|
||||
position,
|
||||
body[start + 1 : position],
|
||||
)
|
||||
|
||||
def read_number(self, start: int, first_char: str) -> Token:
|
||||
"""Reads a number token from the source file.
|
||||
|
||||
This can be either a FloatValue or an IntValue,
|
||||
depending on whether a FractionalPart or ExponentPart is encountered.
|
||||
"""
|
||||
body = self.source.body
|
||||
position = start
|
||||
char = first_char
|
||||
is_float = False
|
||||
|
||||
if char == "-":
|
||||
position += 1
|
||||
char = body[position : position + 1]
|
||||
if char == "0":
|
||||
position += 1
|
||||
char = body[position : position + 1]
|
||||
if is_digit(char):
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
"Invalid number, unexpected digit after 0:"
|
||||
f" {self.print_code_point_at(position)}.",
|
||||
)
|
||||
else:
|
||||
position = self.read_digits(position, char)
|
||||
char = body[position : position + 1]
|
||||
if char == ".":
|
||||
is_float = True
|
||||
position += 1
|
||||
char = body[position : position + 1]
|
||||
position = self.read_digits(position, char)
|
||||
char = body[position : position + 1]
|
||||
if char and char in "Ee":
|
||||
is_float = True
|
||||
position += 1
|
||||
char = body[position : position + 1]
|
||||
if char and char in "+-":
|
||||
position += 1
|
||||
char = body[position : position + 1]
|
||||
position = self.read_digits(position, char)
|
||||
char = body[position : position + 1]
|
||||
|
||||
# Numbers cannot be followed by . or NameStart
|
||||
if char and (char == "." or is_name_start(char)):
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
"Invalid number, expected digit but got:"
|
||||
f" {self.print_code_point_at(position)}.",
|
||||
)
|
||||
|
||||
return self.create_token(
|
||||
TokenKind.FLOAT if is_float else TokenKind.INT,
|
||||
start,
|
||||
position,
|
||||
body[start:position],
|
||||
)
|
||||
|
||||
def read_digits(self, start: int, first_char: str) -> int:
|
||||
"""Return the new position in the source after reading one or more digits."""
|
||||
if not is_digit(first_char):
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
start,
|
||||
"Invalid number, expected digit but got:"
|
||||
f" {self.print_code_point_at(start)}.",
|
||||
)
|
||||
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
position = start + 1
|
||||
while position < body_length and is_digit(body[position]):
|
||||
position += 1
|
||||
return position
|
||||
|
||||
def read_string(self, start: int) -> Token:
|
||||
"""Read a single-quote string token from the source file."""
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
position = start + 1
|
||||
chunk_start = position
|
||||
value: List[str] = []
|
||||
append = value.append
|
||||
|
||||
while position < body_length:
|
||||
char = body[position]
|
||||
|
||||
if char == '"':
|
||||
append(body[chunk_start:position])
|
||||
return self.create_token(
|
||||
TokenKind.STRING,
|
||||
start,
|
||||
position + 1,
|
||||
"".join(value),
|
||||
)
|
||||
|
||||
if char == "\\":
|
||||
append(body[chunk_start:position])
|
||||
escape = (
|
||||
(
|
||||
self.read_escaped_unicode_variable_width(position)
|
||||
if body[position + 2 : position + 3] == "{"
|
||||
else self.read_escaped_unicode_fixed_width(position)
|
||||
)
|
||||
if body[position + 1 : position + 2] == "u"
|
||||
else self.read_escaped_character(position)
|
||||
)
|
||||
append(escape.value)
|
||||
position += escape.size
|
||||
chunk_start = position
|
||||
continue
|
||||
|
||||
if char in "\r\n":
|
||||
break
|
||||
|
||||
if is_unicode_scalar_value(char):
|
||||
position += 1
|
||||
elif is_supplementary_code_point(body, position):
|
||||
position += 2
|
||||
else:
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
"Invalid character within String:"
|
||||
f" {self.print_code_point_at(position)}.",
|
||||
)
|
||||
|
||||
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
|
||||
|
||||
def read_escaped_unicode_variable_width(self, position: int) -> EscapeSequence:
|
||||
body = self.source.body
|
||||
point = 0
|
||||
size = 3
|
||||
max_size = min(12, len(body) - position)
|
||||
# Cannot be larger than 12 chars (\u{00000000}).
|
||||
while size < max_size:
|
||||
char = body[position + size]
|
||||
size += 1
|
||||
if char == "}":
|
||||
# Must be at least 5 chars (\u{0}) and encode a Unicode scalar value.
|
||||
if size < 5 or not (
|
||||
0 <= point <= 0xD7FF or 0xE000 <= point <= 0x10FFFF
|
||||
):
|
||||
break
|
||||
return EscapeSequence(chr(point), size)
|
||||
# Append this hex digit to the code point.
|
||||
point = (point << 4) | read_hex_digit(char)
|
||||
if point < 0:
|
||||
break
|
||||
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
f"Invalid Unicode escape sequence: '{body[position: position + size]}'.",
|
||||
)
|
||||
|
||||
def read_escaped_unicode_fixed_width(self, position: int) -> EscapeSequence:
|
||||
body = self.source.body
|
||||
code = read_16_bit_hex_code(body, position + 2)
|
||||
|
||||
if 0 <= code <= 0xD7FF or 0xE000 <= code <= 0x10FFFF:
|
||||
return EscapeSequence(chr(code), 6)
|
||||
|
||||
# GraphQL allows JSON-style surrogate pair escape sequences, but only when
|
||||
# a valid pair is formed.
|
||||
if 0xD800 <= code <= 0xDBFF:
|
||||
if body[position + 6 : position + 8] == "\\u":
|
||||
trailing_code = read_16_bit_hex_code(body, position + 8)
|
||||
if 0xDC00 <= trailing_code <= 0xDFFF:
|
||||
return EscapeSequence(
|
||||
(chr(code) + chr(trailing_code))
|
||||
.encode("utf-16", "surrogatepass")
|
||||
.decode("utf-16"),
|
||||
12,
|
||||
)
|
||||
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
f"Invalid Unicode escape sequence: '{body[position: position + 6]}'.",
|
||||
)
|
||||
|
||||
def read_escaped_character(self, position: int) -> EscapeSequence:
|
||||
body = self.source.body
|
||||
value = _ESCAPED_CHARS.get(body[position + 1])
|
||||
if value:
|
||||
return EscapeSequence(value, 2)
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
f"Invalid character escape sequence: '{body[position: position + 2]}'.",
|
||||
)
|
||||
|
||||
def read_block_string(self, start: int) -> Token:
|
||||
"""Read a block string token from the source file."""
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
line_start = self.line_start
|
||||
|
||||
position = start + 3
|
||||
chunk_start = position
|
||||
current_line = ""
|
||||
|
||||
block_lines = []
|
||||
while position < body_length:
|
||||
char = body[position]
|
||||
|
||||
if char == '"' and body[position + 1 : position + 3] == '""':
|
||||
current_line += body[chunk_start:position]
|
||||
block_lines.append(current_line)
|
||||
|
||||
token = self.create_token(
|
||||
TokenKind.BLOCK_STRING,
|
||||
start,
|
||||
position + 3,
|
||||
# return a string of the lines joined with new lines
|
||||
"\n".join(dedent_block_string_lines(block_lines)),
|
||||
)
|
||||
|
||||
self.line += len(block_lines) - 1
|
||||
self.line_start = line_start
|
||||
return token
|
||||
|
||||
if char == "\\" and body[position + 1 : position + 4] == '"""':
|
||||
current_line += body[chunk_start:position]
|
||||
chunk_start = position + 1 # skip only slash
|
||||
position += 4
|
||||
continue
|
||||
|
||||
if char in "\r\n":
|
||||
current_line += body[chunk_start:position]
|
||||
block_lines.append(current_line)
|
||||
|
||||
if char == "\r" and body[position + 1 : position + 2] == "\n":
|
||||
position += 2
|
||||
else:
|
||||
position += 1
|
||||
|
||||
current_line = ""
|
||||
chunk_start = line_start = position
|
||||
continue
|
||||
|
||||
if is_unicode_scalar_value(char):
|
||||
position += 1
|
||||
elif is_supplementary_code_point(body, position):
|
||||
position += 2
|
||||
else:
|
||||
raise GraphQLSyntaxError(
|
||||
self.source,
|
||||
position,
|
||||
"Invalid character within String:"
|
||||
f" {self.print_code_point_at(position)}.",
|
||||
)
|
||||
|
||||
raise GraphQLSyntaxError(self.source, position, "Unterminated string.")
|
||||
|
||||
def read_name(self, start: int) -> Token:
|
||||
"""Read an alphanumeric + underscore name from the source."""
|
||||
body = self.source.body
|
||||
body_length = len(body)
|
||||
position = start + 1
|
||||
|
||||
while position < body_length:
|
||||
char = body[position]
|
||||
if not is_name_continue(char):
|
||||
break
|
||||
position += 1
|
||||
|
||||
return self.create_token(TokenKind.NAME, start, position, body[start:position])
|
||||
|
||||
|
||||
_punctuator_token_kinds = frozenset(
|
||||
[
|
||||
TokenKind.BANG,
|
||||
TokenKind.DOLLAR,
|
||||
TokenKind.AMP,
|
||||
TokenKind.PAREN_L,
|
||||
TokenKind.PAREN_R,
|
||||
TokenKind.SPREAD,
|
||||
TokenKind.COLON,
|
||||
TokenKind.EQUALS,
|
||||
TokenKind.AT,
|
||||
TokenKind.BRACKET_L,
|
||||
TokenKind.BRACKET_R,
|
||||
TokenKind.BRACE_L,
|
||||
TokenKind.PIPE,
|
||||
TokenKind.BRACE_R,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_punctuator_token_kind(kind: TokenKind) -> bool:
|
||||
"""Check whether the given token kind corresponds to a punctuator.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
return kind in _punctuator_token_kinds
|
||||
|
||||
|
||||
_KIND_FOR_PUNCT = {
|
||||
"!": TokenKind.BANG,
|
||||
"$": TokenKind.DOLLAR,
|
||||
"&": TokenKind.AMP,
|
||||
"(": TokenKind.PAREN_L,
|
||||
")": TokenKind.PAREN_R,
|
||||
":": TokenKind.COLON,
|
||||
"=": TokenKind.EQUALS,
|
||||
"@": TokenKind.AT,
|
||||
"[": TokenKind.BRACKET_L,
|
||||
"]": TokenKind.BRACKET_R,
|
||||
"{": TokenKind.BRACE_L,
|
||||
"}": TokenKind.BRACE_R,
|
||||
"|": TokenKind.PIPE,
|
||||
}
|
||||
|
||||
|
||||
_ESCAPED_CHARS = {
|
||||
'"': '"',
|
||||
"/": "/",
|
||||
"\\": "\\",
|
||||
"b": "\b",
|
||||
"f": "\f",
|
||||
"n": "\n",
|
||||
"r": "\r",
|
||||
"t": "\t",
|
||||
}
|
||||
|
||||
|
||||
def read_16_bit_hex_code(body: str, position: int) -> int:
|
||||
"""Read a 16bit hexadecimal string and return its positive integer value (0-65535).
|
||||
|
||||
Reads four hexadecimal characters and returns the positive integer that 16bit
|
||||
hexadecimal string represents. For example, "000f" will return 15, and "dead"
|
||||
will return 57005.
|
||||
|
||||
Returns a negative number if any char was not a valid hexadecimal digit.
|
||||
"""
|
||||
# read_hex_digit() returns -1 on error. ORing a negative value with any other
|
||||
# value always produces a negative value.
|
||||
return (
|
||||
read_hex_digit(body[position]) << 12
|
||||
| read_hex_digit(body[position + 1]) << 8
|
||||
| read_hex_digit(body[position + 2]) << 4
|
||||
| read_hex_digit(body[position + 3])
|
||||
)
|
||||
|
||||
|
||||
def read_hex_digit(char: str) -> int:
|
||||
"""Read a hexadecimal character and returns its positive integer value (0-15).
|
||||
|
||||
'0' becomes 0, '9' becomes 9
|
||||
'A' becomes 10, 'F' becomes 15
|
||||
'a' becomes 10, 'f' becomes 15
|
||||
|
||||
Returns -1 if the provided character code was not a valid hexadecimal digit.
|
||||
"""
|
||||
if "0" <= char <= "9":
|
||||
return ord(char) - 48
|
||||
elif "A" <= char <= "F":
|
||||
return ord(char) - 55
|
||||
elif "a" <= char <= "f":
|
||||
return ord(char) - 87
|
||||
return -1
|
||||
|
||||
|
||||
def is_unicode_scalar_value(char: str) -> bool:
|
||||
"""Check whether this is a Unicode scalar value.
|
||||
|
||||
A Unicode scalar value is any Unicode code point except surrogate code
|
||||
points. In other words, the inclusive ranges of values 0x0000 to 0xD7FF and
|
||||
0xE000 to 0x10FFFF.
|
||||
"""
|
||||
return "\x00" <= char <= "\ud7ff" or "\ue000" <= char <= "\U0010ffff"
|
||||
|
||||
|
||||
def is_supplementary_code_point(body: str, location: int) -> bool:
|
||||
"""
|
||||
Check whether the current location is a supplementary code point.
|
||||
|
||||
The GraphQL specification defines source text as a sequence of unicode scalar
|
||||
values (which Unicode defines to exclude surrogate code points).
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
"\ud800" <= body[location] <= "\udbff"
|
||||
and "\udc00" <= body[location + 1] <= "\udfff"
|
||||
)
|
||||
except IndexError:
|
||||
return False
|
||||
@@ -0,0 +1,46 @@
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError: # Python < 3.8
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .source import Source # noqa: F401
|
||||
|
||||
__all__ = ["get_location", "SourceLocation", "FormattedSourceLocation"]
|
||||
|
||||
|
||||
class FormattedSourceLocation(TypedDict):
|
||||
"""Formatted source location"""
|
||||
|
||||
line: int
|
||||
column: int
|
||||
|
||||
|
||||
class SourceLocation(NamedTuple):
|
||||
"""Represents a location in a Source."""
|
||||
|
||||
line: int
|
||||
column: int
|
||||
|
||||
@property
|
||||
def formatted(self) -> FormattedSourceLocation:
|
||||
return dict(line=self.line, column=self.column)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, dict):
|
||||
return self.formatted == other
|
||||
return tuple(self) == other
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def get_location(source: "Source", position: int) -> SourceLocation:
|
||||
"""Get the line and column for a character position in the source.
|
||||
|
||||
Takes a Source and a UTF-8 character offset, and returns the corresponding line and
|
||||
column as a SourceLocation.
|
||||
"""
|
||||
return source.get_location(position)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,86 @@
|
||||
from .ast import (
|
||||
Node,
|
||||
DefinitionNode,
|
||||
ExecutableDefinitionNode,
|
||||
ListValueNode,
|
||||
ObjectValueNode,
|
||||
SchemaExtensionNode,
|
||||
SelectionNode,
|
||||
TypeDefinitionNode,
|
||||
TypeExtensionNode,
|
||||
TypeNode,
|
||||
TypeSystemDefinitionNode,
|
||||
ValueNode,
|
||||
VariableNode,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"is_definition_node",
|
||||
"is_executable_definition_node",
|
||||
"is_selection_node",
|
||||
"is_value_node",
|
||||
"is_const_value_node",
|
||||
"is_type_node",
|
||||
"is_type_system_definition_node",
|
||||
"is_type_definition_node",
|
||||
"is_type_system_extension_node",
|
||||
"is_type_extension_node",
|
||||
]
|
||||
|
||||
|
||||
def is_definition_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a definition."""
|
||||
return isinstance(node, DefinitionNode)
|
||||
|
||||
|
||||
def is_executable_definition_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents an executable definition."""
|
||||
return isinstance(node, ExecutableDefinitionNode)
|
||||
|
||||
|
||||
def is_selection_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a selection."""
|
||||
return isinstance(node, SelectionNode)
|
||||
|
||||
|
||||
def is_value_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a value."""
|
||||
return isinstance(node, ValueNode)
|
||||
|
||||
|
||||
def is_const_value_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a constant value."""
|
||||
return is_value_node(node) and (
|
||||
any(is_const_value_node(value) for value in node.values)
|
||||
if isinstance(node, ListValueNode)
|
||||
else (
|
||||
any(is_const_value_node(field.value) for field in node.fields)
|
||||
if isinstance(node, ObjectValueNode)
|
||||
else not isinstance(node, VariableNode)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_type_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a type."""
|
||||
return isinstance(node, TypeNode)
|
||||
|
||||
|
||||
def is_type_system_definition_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a type system definition."""
|
||||
return isinstance(node, TypeSystemDefinitionNode)
|
||||
|
||||
|
||||
def is_type_definition_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a type definition."""
|
||||
return isinstance(node, TypeDefinitionNode)
|
||||
|
||||
|
||||
def is_type_system_extension_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a type system extension."""
|
||||
return isinstance(node, (SchemaExtensionNode, TypeExtensionNode))
|
||||
|
||||
|
||||
def is_type_extension_node(node: Node) -> bool:
|
||||
"""Check whether the given node represents a type extension."""
|
||||
return isinstance(node, TypeExtensionNode)
|
||||
@@ -0,0 +1,79 @@
|
||||
import re
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
from .ast import Location
|
||||
from .location import SourceLocation, get_location
|
||||
from .source import Source
|
||||
|
||||
|
||||
__all__ = ["print_location", "print_source_location"]
|
||||
|
||||
|
||||
def print_location(location: Location) -> str:
|
||||
"""Render a helpful description of the location in the GraphQL Source document."""
|
||||
return print_source_location(
|
||||
location.source, get_location(location.source, location.start)
|
||||
)
|
||||
|
||||
|
||||
_re_newline = re.compile(r"\r\n|[\n\r]")
|
||||
|
||||
|
||||
def print_source_location(source: Source, source_location: SourceLocation) -> str:
|
||||
"""Render a helpful description of the location in the GraphQL Source document."""
|
||||
first_line_column_offset = source.location_offset.column - 1
|
||||
body = "".rjust(first_line_column_offset) + source.body
|
||||
|
||||
line_index = source_location.line - 1
|
||||
line_offset = source.location_offset.line - 1
|
||||
line_num = source_location.line + line_offset
|
||||
|
||||
column_offset = first_line_column_offset if source_location.line == 1 else 0
|
||||
column_num = source_location.column + column_offset
|
||||
location_str = f"{source.name}:{line_num}:{column_num}\n"
|
||||
|
||||
lines = _re_newline.split(body) # works a bit different from splitlines()
|
||||
location_line = lines[line_index]
|
||||
|
||||
# Special case for minified documents
|
||||
if len(location_line) > 120:
|
||||
sub_line_index, sub_line_column_num = divmod(column_num, 80)
|
||||
sub_lines = [
|
||||
location_line[i : i + 80] for i in range(0, len(location_line), 80)
|
||||
]
|
||||
|
||||
return location_str + print_prefixed_lines(
|
||||
(f"{line_num} |", sub_lines[0]),
|
||||
*[("|", sub_line) for sub_line in sub_lines[1 : sub_line_index + 1]],
|
||||
("|", "^".rjust(sub_line_column_num)),
|
||||
(
|
||||
"|",
|
||||
(
|
||||
sub_lines[sub_line_index + 1]
|
||||
if sub_line_index < len(sub_lines) - 1
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return location_str + print_prefixed_lines(
|
||||
(f"{line_num - 1} |", lines[line_index - 1] if line_index > 0 else None),
|
||||
(f"{line_num} |", location_line),
|
||||
("|", "^".rjust(column_num)),
|
||||
(
|
||||
f"{line_num + 1} |",
|
||||
lines[line_index + 1] if line_index < len(lines) - 1 else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def print_prefixed_lines(*lines: Tuple[str, Optional[str]]) -> str:
|
||||
"""Print lines specified like this: ("prefix", "string")"""
|
||||
existing_lines = [
|
||||
cast(Tuple[str, str], line) for line in lines if line[1] is not None
|
||||
]
|
||||
pad_len = max(len(line[0]) for line in existing_lines)
|
||||
return "\n".join(
|
||||
prefix.rjust(pad_len) + (" " + line if line else "")
|
||||
for prefix, line in existing_lines
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
__all__ = ["print_string"]
|
||||
|
||||
|
||||
def print_string(s: str) -> str:
|
||||
"""Print a string as a GraphQL StringValue literal.
|
||||
|
||||
Replaces control characters and excluded characters (" U+0022 and \\ U+005C)
|
||||
with escape sequences.
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
s = str(s)
|
||||
return f'"{s.translate(escape_sequences)}"'
|
||||
|
||||
|
||||
escape_sequences = {
|
||||
0x00: "\\u0000",
|
||||
0x01: "\\u0001",
|
||||
0x02: "\\u0002",
|
||||
0x03: "\\u0003",
|
||||
0x04: "\\u0004",
|
||||
0x05: "\\u0005",
|
||||
0x06: "\\u0006",
|
||||
0x07: "\\u0007",
|
||||
0x08: "\\b",
|
||||
0x09: "\\t",
|
||||
0x0A: "\\n",
|
||||
0x0B: "\\u000B",
|
||||
0x0C: "\\f",
|
||||
0x0D: "\\r",
|
||||
0x0E: "\\u000E",
|
||||
0x0F: "\\u000F",
|
||||
0x10: "\\u0010",
|
||||
0x11: "\\u0011",
|
||||
0x12: "\\u0012",
|
||||
0x13: "\\u0013",
|
||||
0x14: "\\u0014",
|
||||
0x15: "\\u0015",
|
||||
0x16: "\\u0016",
|
||||
0x17: "\\u0017",
|
||||
0x18: "\\u0018",
|
||||
0x19: "\\u0019",
|
||||
0x1A: "\\u001A",
|
||||
0x1B: "\\u001B",
|
||||
0x1C: "\\u001C",
|
||||
0x1D: "\\u001D",
|
||||
0x1E: "\\u001E",
|
||||
0x1F: "\\u001F",
|
||||
0x22: '\\"',
|
||||
0x5C: "\\\\",
|
||||
0x7F: "\\u007F",
|
||||
0x80: "\\u0080",
|
||||
0x81: "\\u0081",
|
||||
0x82: "\\u0082",
|
||||
0x83: "\\u0083",
|
||||
0x84: "\\u0084",
|
||||
0x85: "\\u0085",
|
||||
0x86: "\\u0086",
|
||||
0x87: "\\u0087",
|
||||
0x88: "\\u0088",
|
||||
0x89: "\\u0089",
|
||||
0x8A: "\\u008A",
|
||||
0x8B: "\\u008B",
|
||||
0x8C: "\\u008C",
|
||||
0x8D: "\\u008D",
|
||||
0x8E: "\\u008E",
|
||||
0x8F: "\\u008F",
|
||||
0x90: "\\u0090",
|
||||
0x91: "\\u0091",
|
||||
0x92: "\\u0092",
|
||||
0x93: "\\u0093",
|
||||
0x94: "\\u0094",
|
||||
0x95: "\\u0095",
|
||||
0x96: "\\u0096",
|
||||
0x97: "\\u0097",
|
||||
0x98: "\\u0098",
|
||||
0x99: "\\u0099",
|
||||
0x9A: "\\u009A",
|
||||
0x9B: "\\u009B",
|
||||
0x9C: "\\u009C",
|
||||
0x9D: "\\u009D",
|
||||
0x9E: "\\u009E",
|
||||
0x9F: "\\u009F",
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
from typing import Any, Collection, Optional
|
||||
|
||||
from ..language.ast import Node, OperationType
|
||||
from .block_string import print_block_string
|
||||
from .print_string import print_string
|
||||
from .visitor import visit, Visitor
|
||||
|
||||
__all__ = ["print_ast"]
|
||||
|
||||
|
||||
MAX_LINE_LENGTH = 80
|
||||
|
||||
Strings = Collection[str]
|
||||
|
||||
|
||||
class PrintedNode:
|
||||
"""A union type for all nodes that have been processed by the printer."""
|
||||
|
||||
alias: str
|
||||
arguments: Strings
|
||||
block: bool
|
||||
default_value: str
|
||||
definitions: Strings
|
||||
description: str
|
||||
directives: str
|
||||
fields: Strings
|
||||
interfaces: Strings
|
||||
locations: Strings
|
||||
name: str
|
||||
operation: OperationType
|
||||
operation_types: Strings
|
||||
repeatable: bool
|
||||
selection_set: str
|
||||
selections: Strings
|
||||
type: str
|
||||
type_condition: str
|
||||
types: Strings
|
||||
value: str
|
||||
values: Strings
|
||||
variable: str
|
||||
variable_definitions: Strings
|
||||
|
||||
|
||||
def print_ast(ast: Node) -> str:
|
||||
"""Convert an AST into a string.
|
||||
|
||||
The conversion is done using a set of reasonable formatting rules.
|
||||
"""
|
||||
return visit(ast, PrintAstVisitor())
|
||||
|
||||
|
||||
class PrintAstVisitor(Visitor):
|
||||
@staticmethod
|
||||
def leave_name(node: PrintedNode, *_args: Any) -> str:
|
||||
return node.value
|
||||
|
||||
@staticmethod
|
||||
def leave_variable(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"${node.name}"
|
||||
|
||||
# Document
|
||||
|
||||
@staticmethod
|
||||
def leave_document(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(node.definitions, "\n\n")
|
||||
|
||||
@staticmethod
|
||||
def leave_operation_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
var_defs = wrap("(", join(node.variable_definitions, ", "), ")")
|
||||
prefix = join(
|
||||
(
|
||||
node.operation.value,
|
||||
join((node.name, var_defs)),
|
||||
join(node.directives, " "),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
# Anonymous queries with no directives or variable definitions can use the
|
||||
# query short form.
|
||||
return ("" if prefix == "query" else prefix + " ") + node.selection_set
|
||||
|
||||
@staticmethod
|
||||
def leave_variable_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return (
|
||||
f"{node.variable}: {node.type}"
|
||||
f"{wrap(' = ', node.default_value)}"
|
||||
f"{wrap(' ', join(node.directives, ' '))}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_selection_set(node: PrintedNode, *_args: Any) -> str:
|
||||
return block(node.selections)
|
||||
|
||||
@staticmethod
|
||||
def leave_field(node: PrintedNode, *_args: Any) -> str:
|
||||
prefix = wrap("", node.alias, ": ") + node.name
|
||||
args_line = prefix + wrap("(", join(node.arguments, ", "), ")")
|
||||
|
||||
if len(args_line) > MAX_LINE_LENGTH:
|
||||
args_line = prefix + wrap("(\n", indent(join(node.arguments, "\n")), "\n)")
|
||||
|
||||
return join((args_line, join(node.directives, " "), node.selection_set), " ")
|
||||
|
||||
@staticmethod
|
||||
def leave_argument(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"{node.name}: {node.value}"
|
||||
|
||||
# Fragments
|
||||
|
||||
@staticmethod
|
||||
def leave_fragment_spread(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"...{node.name}{wrap(' ', join(node.directives, ' '))}"
|
||||
|
||||
@staticmethod
|
||||
def leave_inline_fragment(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
(
|
||||
"...",
|
||||
wrap("on ", node.type_condition),
|
||||
join(node.directives, " "),
|
||||
node.selection_set,
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_fragment_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
# Note: fragment variable definitions are deprecated and will be removed in v3.3
|
||||
return (
|
||||
f"fragment {node.name}"
|
||||
f"{wrap('(', join(node.variable_definitions, ', '), ')')}"
|
||||
f" on {node.type_condition}"
|
||||
f" {wrap('', join(node.directives, ' '), ' ')}"
|
||||
f"{node.selection_set}"
|
||||
)
|
||||
|
||||
# Value
|
||||
|
||||
@staticmethod
|
||||
def leave_int_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return node.value
|
||||
|
||||
@staticmethod
|
||||
def leave_float_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return node.value
|
||||
|
||||
@staticmethod
|
||||
def leave_string_value(node: PrintedNode, *_args: Any) -> str:
|
||||
if node.block:
|
||||
return print_block_string(node.value)
|
||||
return print_string(node.value)
|
||||
|
||||
@staticmethod
|
||||
def leave_boolean_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return "true" if node.value else "false"
|
||||
|
||||
@staticmethod
|
||||
def leave_null_value(_node: PrintedNode, *_args: Any) -> str:
|
||||
return "null"
|
||||
|
||||
@staticmethod
|
||||
def leave_enum_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return node.value
|
||||
|
||||
@staticmethod
|
||||
def leave_list_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"[{join(node.values, ', ')}]"
|
||||
|
||||
@staticmethod
|
||||
def leave_object_value(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"{{{join(node.fields, ', ')}}}"
|
||||
|
||||
@staticmethod
|
||||
def leave_object_field(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"{node.name}: {node.value}"
|
||||
|
||||
# Directive
|
||||
|
||||
@staticmethod
|
||||
def leave_directive(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"@{node.name}{wrap('(', join(node.arguments, ', '), ')')}"
|
||||
|
||||
# Type
|
||||
|
||||
@staticmethod
|
||||
def leave_named_type(node: PrintedNode, *_args: Any) -> str:
|
||||
return node.name
|
||||
|
||||
@staticmethod
|
||||
def leave_list_type(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"[{node.type}]"
|
||||
|
||||
@staticmethod
|
||||
def leave_non_null_type(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"{node.type}!"
|
||||
|
||||
# Type System Definitions
|
||||
|
||||
@staticmethod
|
||||
def leave_schema_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
"schema",
|
||||
join(node.directives, " "),
|
||||
block(node.operation_types),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_operation_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return f"{node.operation.value}: {node.type}"
|
||||
|
||||
@staticmethod
|
||||
def leave_scalar_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
"scalar",
|
||||
node.name,
|
||||
join(node.directives, " "),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_object_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
"type",
|
||||
node.name,
|
||||
wrap("implements ", join(node.interfaces, " & ")),
|
||||
join(node.directives, " "),
|
||||
block(node.fields),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_field_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
args = node.arguments
|
||||
args = (
|
||||
wrap("(\n", indent(join(args, "\n")), "\n)")
|
||||
if has_multiline_items(args)
|
||||
else wrap("(", join(args, ", "), ")")
|
||||
)
|
||||
directives = wrap(" ", join(node.directives, " "))
|
||||
return (
|
||||
wrap("", node.description, "\n")
|
||||
+ f"{node.name}{args}: {node.type}{directives}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_input_value_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
f"{node.name}: {node.type}",
|
||||
wrap("= ", node.default_value),
|
||||
join(node.directives, " "),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_interface_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
"interface",
|
||||
node.name,
|
||||
wrap("implements ", join(node.interfaces, " & ")),
|
||||
join(node.directives, " "),
|
||||
block(node.fields),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_union_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(
|
||||
"union",
|
||||
node.name,
|
||||
join(node.directives, " "),
|
||||
wrap("= ", join(node.types, " | ")),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_enum_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
("enum", node.name, join(node.directives, " "), block(node.values)), " "
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_enum_value_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
(node.name, join(node.directives, " ")), " "
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_input_object_type_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
return wrap("", node.description, "\n") + join(
|
||||
("input", node.name, join(node.directives, " "), block(node.fields)), " "
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_directive_definition(node: PrintedNode, *_args: Any) -> str:
|
||||
args = node.arguments
|
||||
args = (
|
||||
wrap("(\n", indent(join(args, "\n")), "\n)")
|
||||
if has_multiline_items(args)
|
||||
else wrap("(", join(args, ", "), ")")
|
||||
)
|
||||
repeatable = " repeatable" if node.repeatable else ""
|
||||
locations = join(node.locations, " | ")
|
||||
return (
|
||||
wrap("", node.description, "\n")
|
||||
+ f"directive @{node.name}{args}{repeatable} on {locations}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_schema_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
("extend schema", join(node.directives, " "), block(node.operation_types)),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_scalar_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(("extend scalar", node.name, join(node.directives, " ")), " ")
|
||||
|
||||
@staticmethod
|
||||
def leave_object_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
(
|
||||
"extend type",
|
||||
node.name,
|
||||
wrap("implements ", join(node.interfaces, " & ")),
|
||||
join(node.directives, " "),
|
||||
block(node.fields),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_interface_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
(
|
||||
"extend interface",
|
||||
node.name,
|
||||
wrap("implements ", join(node.interfaces, " & ")),
|
||||
join(node.directives, " "),
|
||||
block(node.fields),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_union_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
(
|
||||
"extend union",
|
||||
node.name,
|
||||
join(node.directives, " "),
|
||||
wrap("= ", join(node.types, " | ")),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_enum_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
("extend enum", node.name, join(node.directives, " "), block(node.values)),
|
||||
" ",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def leave_input_object_type_extension(node: PrintedNode, *_args: Any) -> str:
|
||||
return join(
|
||||
("extend input", node.name, join(node.directives, " "), block(node.fields)),
|
||||
" ",
|
||||
)
|
||||
|
||||
|
||||
def join(strings: Optional[Strings], separator: str = "") -> str:
|
||||
"""Join strings in a given collection.
|
||||
|
||||
Return an empty string if it is None or empty, otherwise join all items together
|
||||
separated by separator if provided.
|
||||
"""
|
||||
return separator.join(s for s in strings if s) if strings else ""
|
||||
|
||||
|
||||
def block(strings: Optional[Strings]) -> str:
|
||||
"""Return strings inside a block.
|
||||
|
||||
Given a collection of strings, return a string with each item on its own line,
|
||||
wrapped in an indented "{ }" block.
|
||||
"""
|
||||
return wrap("{\n", indent(join(strings, "\n")), "\n}")
|
||||
|
||||
|
||||
def wrap(start: str, string: Optional[str], end: str = "") -> str:
|
||||
"""Wrap string inside other strings at start and end.
|
||||
|
||||
If the string is not None or empty, then wrap with start and end, otherwise return
|
||||
an empty string.
|
||||
"""
|
||||
return f"{start}{string}{end}" if string else ""
|
||||
|
||||
|
||||
def indent(string: str) -> str:
|
||||
"""Indent string with two spaces.
|
||||
|
||||
If the string is not None or empty, add two spaces at the beginning of every line
|
||||
inside the string.
|
||||
"""
|
||||
return wrap(" ", string.replace("\n", "\n "))
|
||||
|
||||
|
||||
def is_multiline(string: str) -> bool:
|
||||
"""Check whether a string consists of multiple lines."""
|
||||
return "\n" in string
|
||||
|
||||
|
||||
def has_multiline_items(strings: Optional[Strings]) -> bool:
|
||||
"""Check whether one of the items in the list has multiple lines."""
|
||||
return any(is_multiline(item) for item in strings) if strings else False
|
||||
@@ -0,0 +1,70 @@
|
||||
from typing import Any
|
||||
|
||||
from .location import SourceLocation
|
||||
|
||||
__all__ = ["Source", "is_source"]
|
||||
|
||||
|
||||
class Source:
|
||||
"""A representation of source input to GraphQL."""
|
||||
|
||||
# allow custom attributes and weak references (not used internally)
|
||||
__slots__ = "__weakref__", "__dict__", "body", "name", "location_offset"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
body: str,
|
||||
name: str = "GraphQL request",
|
||||
location_offset: SourceLocation = SourceLocation(1, 1),
|
||||
) -> None:
|
||||
"""Initialize source input.
|
||||
|
||||
The ``name`` and ``location_offset`` parameters are optional, but they are
|
||||
useful for clients who store GraphQL documents in source files. For example,
|
||||
if the GraphQL input starts at line 40 in a file named ``Foo.graphql``, it might
|
||||
be useful for ``name`` to be ``"Foo.graphql"`` and location to be ``(40, 0)``.
|
||||
|
||||
The ``line`` and ``column`` attributes in ``location_offset`` are 1-indexed.
|
||||
"""
|
||||
self.body = body
|
||||
self.name = name
|
||||
if not isinstance(location_offset, SourceLocation):
|
||||
location_offset = SourceLocation._make(location_offset)
|
||||
if location_offset.line <= 0:
|
||||
raise ValueError(
|
||||
"line in location_offset is 1-indexed and must be positive."
|
||||
)
|
||||
if location_offset.column <= 0:
|
||||
raise ValueError(
|
||||
"column in location_offset is 1-indexed and must be positive."
|
||||
)
|
||||
self.location_offset = location_offset
|
||||
|
||||
def get_location(self, position: int) -> SourceLocation:
|
||||
lines = self.body[:position].splitlines()
|
||||
if lines:
|
||||
line = len(lines)
|
||||
column = len(lines[-1]) + 1
|
||||
else:
|
||||
line = 1
|
||||
column = 1
|
||||
return SourceLocation(line, column)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} name={self.name!r}>"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (isinstance(other, Source) and other.body == self.body) or (
|
||||
isinstance(other, str) and other == self.body
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def is_source(source: Any) -> bool:
|
||||
"""Test if the given value is a Source object.
|
||||
|
||||
For internal use only.
|
||||
"""
|
||||
return isinstance(source, Source)
|
||||
@@ -0,0 +1,30 @@
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["TokenKind"]
|
||||
|
||||
|
||||
class TokenKind(Enum):
|
||||
"""The different kinds of tokens that the lexer emits"""
|
||||
|
||||
SOF = "<SOF>"
|
||||
EOF = "<EOF>"
|
||||
BANG = "!"
|
||||
DOLLAR = "$"
|
||||
AMP = "&"
|
||||
PAREN_L = "("
|
||||
PAREN_R = ")"
|
||||
SPREAD = "..."
|
||||
COLON = ":"
|
||||
EQUALS = "="
|
||||
AT = "@"
|
||||
BRACKET_L = "["
|
||||
BRACKET_R = "]"
|
||||
BRACE_L = "{"
|
||||
PIPE = "|"
|
||||
BRACE_R = "}"
|
||||
NAME = "Name"
|
||||
INT = "Int"
|
||||
FLOAT = "Float"
|
||||
STRING = "String"
|
||||
BLOCK_STRING = "BlockString"
|
||||
COMMENT = "Comment"
|
||||
@@ -0,0 +1,375 @@
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..pyutils import inspect, snake_to_camel
|
||||
from . import ast
|
||||
from .ast import QUERY_DOCUMENT_KEYS, Node
|
||||
|
||||
__all__ = [
|
||||
"Visitor",
|
||||
"ParallelVisitor",
|
||||
"VisitorAction",
|
||||
"visit",
|
||||
"BREAK",
|
||||
"SKIP",
|
||||
"REMOVE",
|
||||
"IDLE",
|
||||
]
|
||||
|
||||
|
||||
class VisitorActionEnum(Enum):
|
||||
"""Special return values for the visitor methods.
|
||||
|
||||
You can also use the values of this enum directly.
|
||||
"""
|
||||
|
||||
BREAK = True
|
||||
SKIP = False
|
||||
REMOVE = Ellipsis
|
||||
|
||||
|
||||
VisitorAction = Optional[VisitorActionEnum]
|
||||
|
||||
# Note that in GraphQL.js these are defined differently:
|
||||
# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
|
||||
|
||||
BREAK = VisitorActionEnum.BREAK
|
||||
SKIP = VisitorActionEnum.SKIP
|
||||
REMOVE = VisitorActionEnum.REMOVE
|
||||
IDLE = None
|
||||
|
||||
VisitorKeyMap = Dict[str, Tuple[str, ...]]
|
||||
|
||||
|
||||
class EnterLeaveVisitor(NamedTuple):
|
||||
"""Visitor with functions for entering and leaving."""
|
||||
|
||||
enter: Optional[Callable[..., Optional[VisitorAction]]]
|
||||
leave: Optional[Callable[..., Optional[VisitorAction]]]
|
||||
|
||||
|
||||
class Visitor:
|
||||
"""Visitor that walks through an AST.
|
||||
|
||||
Visitors can define two generic methods "enter" and "leave". The former will be
|
||||
called when a node is entered in the traversal, the latter is called after visiting
|
||||
the node and its child nodes. These methods have the following signature::
|
||||
|
||||
def enter(self, node, key, parent, path, ancestors):
|
||||
# The return value has the following meaning:
|
||||
# IDLE (None): no action
|
||||
# SKIP: skip visiting this node
|
||||
# BREAK: stop visiting altogether
|
||||
# REMOVE: delete this node
|
||||
# any other value: replace this node with the returned value
|
||||
return
|
||||
|
||||
def leave(self, node, key, parent, path, ancestors):
|
||||
# The return value has the following meaning:
|
||||
# IDLE (None) or SKIP: no action
|
||||
# BREAK: stop visiting altogether
|
||||
# REMOVE: delete this node
|
||||
# any other value: replace this node with the returned value
|
||||
return
|
||||
|
||||
The parameters have the following meaning:
|
||||
|
||||
:arg node: The current node being visiting.
|
||||
:arg key: The index or key to this node from the parent node or Array.
|
||||
:arg parent: the parent immediately above this node, which may be an Array.
|
||||
:arg path: The key path to get to this node from the root node.
|
||||
:arg ancestors: All nodes and Arrays visited before reaching parent
|
||||
of this node. These correspond to array indices in ``path``.
|
||||
Note: ancestors includes arrays which contain the parent of visited node.
|
||||
|
||||
You can also define node kind specific methods by suffixing them with an underscore
|
||||
followed by the kind of the node to be visited. For instance, to visit ``field``
|
||||
nodes, you would defined the methods ``enter_field()`` and/or ``leave_field()``,
|
||||
with the same signature as above. If no kind specific method has been defined
|
||||
for a given node, the generic method is called.
|
||||
"""
|
||||
|
||||
# Provide special return values as attributes
|
||||
BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE
|
||||
|
||||
enter_leave_map: Dict[str, EnterLeaveVisitor]
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Verify that all defined handlers are valid."""
|
||||
super().__init_subclass__()
|
||||
for attr, val in cls.__dict__.items():
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
attr_kind = attr.split("_", 1)
|
||||
if len(attr_kind) < 2:
|
||||
kind: Optional[str] = None
|
||||
else:
|
||||
attr, kind = attr_kind
|
||||
if attr in ("enter", "leave") and kind:
|
||||
name = snake_to_camel(kind) + "Node"
|
||||
node_cls = getattr(ast, name, None)
|
||||
if (
|
||||
not node_cls
|
||||
or not isinstance(node_cls, type)
|
||||
or not issubclass(node_cls, Node)
|
||||
):
|
||||
raise TypeError(f"Invalid AST node kind: {kind}.")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.enter_leave_map = {}
|
||||
|
||||
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
|
||||
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
|
||||
try:
|
||||
return self.enter_leave_map[kind]
|
||||
except KeyError:
|
||||
enter_fn = getattr(self, f"enter_{kind}", None)
|
||||
if not enter_fn:
|
||||
enter_fn = getattr(self, "enter", None)
|
||||
leave_fn = getattr(self, f"leave_{kind}", None)
|
||||
if not leave_fn:
|
||||
leave_fn = getattr(self, "leave", None)
|
||||
enter_leave = EnterLeaveVisitor(enter_fn, leave_fn)
|
||||
self.enter_leave_map[kind] = enter_leave
|
||||
return enter_leave
|
||||
|
||||
def get_visit_fn(
|
||||
self, kind: str, is_leaving: bool = False
|
||||
) -> Optional[Callable[..., Optional[VisitorAction]]]:
|
||||
"""Get the visit function for the given node kind and direction.
|
||||
|
||||
.. deprecated:: 3.2
|
||||
Please use ``get_enter_leave_for_kind`` instead. Will be removed in v3.3.
|
||||
"""
|
||||
enter_leave = self.get_enter_leave_for_kind(kind)
|
||||
return enter_leave.leave if is_leaving else enter_leave.enter
|
||||
|
||||
|
||||
class Stack(NamedTuple):
|
||||
"""A stack for the visit function."""
|
||||
|
||||
in_array: bool
|
||||
idx: int
|
||||
keys: Tuple[Node, ...]
|
||||
edits: List[Tuple[Union[int, str], Node]]
|
||||
prev: Any # 'Stack' (python/mypy/issues/731)
|
||||
|
||||
|
||||
def visit(
|
||||
root: Node, visitor: Visitor, visitor_keys: Optional[VisitorKeyMap] = None
|
||||
) -> Any:
|
||||
"""Visit each node in an AST.
|
||||
|
||||
:func:`~.visit` will walk through an AST using a depth-first traversal, calling the
|
||||
visitor's enter methods at each node in the traversal, and calling the leave methods
|
||||
after visiting that node and all of its child nodes.
|
||||
|
||||
By returning different values from the enter and leave methods, the behavior of the
|
||||
visitor can be altered, including skipping over a sub-tree of the AST (by returning
|
||||
False), editing the AST by returning a value or None to remove the value, or to stop
|
||||
the whole traversal by returning :data:`~.BREAK`.
|
||||
|
||||
When using :func:`~.visit` to edit an AST, the original AST will not be modified,
|
||||
and a new version of the AST with the changes applied will be returned from the
|
||||
visit function.
|
||||
|
||||
To customize the node attributes to be used for traversal, you can provide a
|
||||
dictionary visitor_keys mapping node kinds to node attributes.
|
||||
"""
|
||||
if not isinstance(root, Node):
|
||||
raise TypeError(f"Not an AST Node: {inspect(root)}.")
|
||||
if not isinstance(visitor, Visitor):
|
||||
raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.")
|
||||
if visitor_keys is None:
|
||||
visitor_keys = QUERY_DOCUMENT_KEYS
|
||||
|
||||
stack: Any = None
|
||||
in_array = False
|
||||
keys: Tuple[Node, ...] = (root,)
|
||||
idx = -1
|
||||
edits: List[Any] = []
|
||||
node: Any = root
|
||||
key: Any = None
|
||||
parent: Any = None
|
||||
path: List[Any] = []
|
||||
path_append = path.append
|
||||
path_pop = path.pop
|
||||
ancestors: List[Any] = []
|
||||
ancestors_append = ancestors.append
|
||||
ancestors_pop = ancestors.pop
|
||||
|
||||
while True:
|
||||
idx += 1
|
||||
is_leaving = idx == len(keys)
|
||||
is_edited = is_leaving and edits
|
||||
if is_leaving:
|
||||
key = path[-1] if ancestors else None
|
||||
node = parent
|
||||
parent = ancestors_pop() if ancestors else None
|
||||
if is_edited:
|
||||
if in_array:
|
||||
node = list(node)
|
||||
edit_offset = 0
|
||||
for edit_key, edit_value in edits:
|
||||
array_key = edit_key - edit_offset
|
||||
if edit_value is REMOVE or edit_value is Ellipsis:
|
||||
node.pop(array_key)
|
||||
edit_offset += 1
|
||||
else:
|
||||
node[array_key] = edit_value
|
||||
node = tuple(node)
|
||||
else:
|
||||
node = copy(node)
|
||||
for edit_key, edit_value in edits:
|
||||
setattr(node, edit_key, edit_value)
|
||||
idx = stack.idx
|
||||
keys = stack.keys
|
||||
edits = stack.edits
|
||||
in_array = stack.in_array
|
||||
stack = stack.prev
|
||||
elif parent:
|
||||
if in_array:
|
||||
key = idx
|
||||
node = parent[key]
|
||||
else:
|
||||
key = keys[idx]
|
||||
node = getattr(parent, key, None)
|
||||
if node is None:
|
||||
continue
|
||||
path_append(key)
|
||||
|
||||
if isinstance(node, tuple):
|
||||
result = None
|
||||
else:
|
||||
if not isinstance(node, Node):
|
||||
raise TypeError(f"Invalid AST Node: {inspect(node)}.")
|
||||
enter_leave = visitor.get_enter_leave_for_kind(node.kind)
|
||||
visit_fn = enter_leave.leave if is_leaving else enter_leave.enter
|
||||
if visit_fn:
|
||||
result = visit_fn(node, key, parent, path, ancestors)
|
||||
|
||||
if result is BREAK or result is True:
|
||||
break
|
||||
|
||||
if result is SKIP or result is False:
|
||||
if not is_leaving:
|
||||
path_pop()
|
||||
continue
|
||||
|
||||
elif result is not None:
|
||||
edits.append((key, result))
|
||||
if not is_leaving:
|
||||
if isinstance(result, Node):
|
||||
node = result
|
||||
else:
|
||||
path_pop()
|
||||
continue
|
||||
else:
|
||||
result = None
|
||||
|
||||
if result is None and is_edited:
|
||||
edits.append((key, node))
|
||||
|
||||
if is_leaving:
|
||||
if path:
|
||||
path_pop()
|
||||
else:
|
||||
stack = Stack(in_array, idx, keys, edits, stack)
|
||||
in_array = isinstance(node, tuple)
|
||||
keys = node if in_array else visitor_keys.get(node.kind, ()) # type: ignore
|
||||
idx = -1
|
||||
edits = []
|
||||
if parent:
|
||||
ancestors_append(parent)
|
||||
parent = node
|
||||
|
||||
if not stack:
|
||||
break
|
||||
|
||||
if edits:
|
||||
return edits[-1][1]
|
||||
|
||||
return root
|
||||
|
||||
|
||||
class ParallelVisitor(Visitor):
|
||||
"""A Visitor which delegates to many visitors to run in parallel.
|
||||
|
||||
Each visitor will be visited for each node before moving on.
|
||||
|
||||
If a prior visitor edits a node, no following visitors will see that node.
|
||||
"""
|
||||
|
||||
def __init__(self, visitors: Collection[Visitor]):
|
||||
"""Create a new visitor from the given list of parallel visitors."""
|
||||
super().__init__()
|
||||
self.visitors = visitors
|
||||
self.skipping: List[Any] = [None] * len(visitors)
|
||||
|
||||
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
|
||||
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
|
||||
try:
|
||||
return self.enter_leave_map[kind]
|
||||
except KeyError:
|
||||
has_visitor = False
|
||||
enter_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
|
||||
leave_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
|
||||
for visitor in self.visitors:
|
||||
enter, leave = visitor.get_enter_leave_for_kind(kind)
|
||||
if not has_visitor and (enter or leave):
|
||||
has_visitor = True
|
||||
enter_list.append(enter)
|
||||
leave_list.append(leave)
|
||||
|
||||
if has_visitor:
|
||||
|
||||
def enter(node: Node, *args: Any) -> Optional[VisitorAction]:
|
||||
skipping = self.skipping
|
||||
for i, fn in enumerate(enter_list):
|
||||
if not skipping[i]:
|
||||
if fn:
|
||||
result = fn(node, *args)
|
||||
if result is SKIP or result is False:
|
||||
skipping[i] = node
|
||||
elif result is BREAK or result is True:
|
||||
skipping[i] = BREAK
|
||||
elif result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def leave(node: Node, *args: Any) -> Optional[VisitorAction]:
|
||||
skipping = self.skipping
|
||||
for i, fn in enumerate(leave_list):
|
||||
if not skipping[i]:
|
||||
if fn:
|
||||
result = fn(node, *args)
|
||||
if result is BREAK or result is True:
|
||||
skipping[i] = BREAK
|
||||
elif (
|
||||
result is not None
|
||||
and result is not SKIP
|
||||
and result is not False
|
||||
):
|
||||
return result
|
||||
elif skipping[i] is node:
|
||||
skipping[i] = None
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
enter = leave = None
|
||||
|
||||
enter_leave = EnterLeaveVisitor(enter, leave)
|
||||
self.enter_leave_map[kind] = enter_leave
|
||||
return enter_leave
|
||||
Reference in New Issue
Block a user