Files
blender-portable-repo/scripts/addons/Rokoko Libraries/python311/graphql/utilities/extend_schema.py
T
2026-03-17 14:58:51 -06:00

701 lines
27 KiB
Python

from collections import defaultdict
from typing import (
Any,
Callable,
Collection,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Union,
cast,
)
from ..language import (
DirectiveDefinitionNode,
DirectiveLocation,
DocumentNode,
EnumTypeDefinitionNode,
EnumTypeExtensionNode,
EnumValueDefinitionNode,
FieldDefinitionNode,
InputObjectTypeDefinitionNode,
InputObjectTypeExtensionNode,
InputValueDefinitionNode,
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ListTypeNode,
NamedTypeNode,
NonNullTypeNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
OperationType,
ScalarTypeDefinitionNode,
ScalarTypeExtensionNode,
SchemaExtensionNode,
SchemaDefinitionNode,
TypeDefinitionNode,
TypeExtensionNode,
TypeNode,
UnionTypeDefinitionNode,
UnionTypeExtensionNode,
)
from ..pyutils import inspect, merge_kwargs
from ..type import (
GraphQLArgument,
GraphQLArgumentMap,
GraphQLDeprecatedDirective,
GraphQLDirective,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLEnumValueMap,
GraphQLField,
GraphQLFieldMap,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInputFieldMap,
GraphQLInterfaceType,
GraphQLList,
GraphQLNamedType,
GraphQLNonNull,
GraphQLNullableType,
GraphQLObjectType,
GraphQLOutputType,
GraphQLScalarType,
GraphQLSchema,
GraphQLSchemaKwargs,
GraphQLSpecifiedByDirective,
GraphQLType,
GraphQLUnionType,
assert_schema,
is_enum_type,
is_input_object_type,
is_interface_type,
is_list_type,
is_non_null_type,
is_object_type,
is_scalar_type,
is_union_type,
is_introspection_type,
is_specified_scalar_type,
introspection_types,
specified_scalar_types,
)
from .value_from_ast import value_from_ast
__all__ = [
"extend_schema",
"extend_schema_impl",
]
def extend_schema(
schema: GraphQLSchema,
document_ast: DocumentNode,
assume_valid: bool = False,
assume_valid_sdl: bool = False,
) -> GraphQLSchema:
"""Extend the schema with extensions from a given document.
Produces a new schema given an existing schema and a document which may contain
GraphQL type extensions and definitions. The original schema will remain unaltered.
Because a schema represents a graph of references, a schema cannot be extended
without effectively making an entire copy. We do not know until it's too late if
subgraphs remain unchanged.
This algorithm copies the provided schema, applying extensions while producing the
copy. The original schema remains unaltered.
When extending a schema with a known valid extension, it might be safe to assume the
schema is valid. Set ``assume_valid`` to ``True`` to assume the produced schema is
valid. Set ``assume_valid_sdl`` to ``True`` to assume it is already a valid SDL
document.
"""
assert_schema(schema)
if not isinstance(document_ast, DocumentNode):
raise TypeError("Must provide valid Document AST.")
if not (assume_valid or assume_valid_sdl):
from ..validation.validate import assert_valid_sdl_extension
assert_valid_sdl_extension(document_ast, schema)
schema_kwargs = schema.to_kwargs()
extended_kwargs = extend_schema_impl(schema_kwargs, document_ast, assume_valid)
return (
schema if schema_kwargs is extended_kwargs else GraphQLSchema(**extended_kwargs)
)
def extend_schema_impl(
schema_kwargs: GraphQLSchemaKwargs,
document_ast: DocumentNode,
assume_valid: bool = False,
) -> GraphQLSchemaKwargs:
"""Extend the given schema arguments with extensions from a given document.
For internal use only.
"""
# Note: schema_kwargs should become a TypedDict once we require Python 3.8
# Collect the type definitions and extensions found in the document.
type_defs: List[TypeDefinitionNode] = []
type_extensions_map: DefaultDict[str, Any] = defaultdict(list)
# New directives and types are separate because a directives and types can have the
# same name. For example, a type named "skip".
directive_defs: List[DirectiveDefinitionNode] = []
schema_def: Optional[SchemaDefinitionNode] = None
# Schema extensions are collected which may add additional operation types.
schema_extensions: List[SchemaExtensionNode] = []
for def_ in document_ast.definitions:
if isinstance(def_, SchemaDefinitionNode):
schema_def = def_
elif isinstance(def_, SchemaExtensionNode):
schema_extensions.append(def_)
elif isinstance(def_, TypeDefinitionNode):
type_defs.append(def_)
elif isinstance(def_, TypeExtensionNode):
extended_type_name = def_.name.value
type_extensions_map[extended_type_name].append(def_)
elif isinstance(def_, DirectiveDefinitionNode):
directive_defs.append(def_)
# If this document contains no new types, extensions, or directives then return the
# same unmodified GraphQLSchema instance.
if (
not type_extensions_map
and not type_defs
and not directive_defs
and not schema_extensions
and not schema_def
):
return schema_kwargs
# Below are functions used for producing this schema that have closed over this
# scope and have access to the schema, cache, and newly defined types.
# noinspection PyTypeChecker,PyUnresolvedReferences
def replace_type(type_: GraphQLType) -> GraphQLType:
if is_list_type(type_):
return GraphQLList(replace_type(type_.of_type)) # type: ignore
if is_non_null_type(type_):
return GraphQLNonNull(replace_type(type_.of_type)) # type: ignore
return replace_named_type(type_) # type: ignore
def replace_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
# Note: While this could make early assertions to get the correctly
# typed values below, that would throw immediately while type system
# validation with validate_schema() will produce more actionable results.
return type_map[type_.name]
# noinspection PyShadowingNames
def replace_directive(directive: GraphQLDirective) -> GraphQLDirective:
kwargs = directive.to_kwargs()
return GraphQLDirective(
**merge_kwargs(
kwargs,
args={name: extend_arg(arg) for name, arg in kwargs["args"].items()},
)
)
def extend_named_type(type_: GraphQLNamedType) -> GraphQLNamedType:
if is_introspection_type(type_) or is_specified_scalar_type(type_):
# Builtin types are not extended.
return type_
if is_scalar_type(type_):
type_ = cast(GraphQLScalarType, type_)
return extend_scalar_type(type_)
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
return extend_object_type(type_)
if is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
return extend_interface_type(type_)
if is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
return extend_union_type(type_)
if is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
return extend_enum_type(type_)
if is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
return extend_input_object_type(type_)
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.") # pragma: no cover
# noinspection PyShadowingNames
def extend_input_object_type(
type_: GraphQLInputObjectType,
) -> GraphQLInputObjectType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLInputObjectType(
**merge_kwargs(
kwargs,
fields=lambda: {
**{
name: GraphQLInputField(
**merge_kwargs(
field.to_kwargs(),
type_=replace_type(field.type),
)
)
for name, field in kwargs["fields"].items()
},
**build_input_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_enum_type(type_: GraphQLEnumType) -> GraphQLEnumType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLEnumType(
**merge_kwargs(
kwargs,
values={**kwargs["values"], **build_enum_value_map(extensions)},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
specified_by_url = kwargs["specified_by_url"]
for extension_node in extensions:
specified_by_url = get_specified_by_url(extension_node) or specified_by_url
return GraphQLScalarType(
**merge_kwargs(
kwargs,
specified_by_url=specified_by_url,
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_object_type(type_: GraphQLObjectType) -> GraphQLObjectType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLObjectType(
**merge_kwargs(
kwargs,
interfaces=lambda: [
cast(GraphQLInterfaceType, replace_named_type(interface))
for interface in kwargs["interfaces"]
]
+ build_interfaces(extensions),
fields=lambda: {
**{
name: extend_field(field)
for name, field in kwargs["fields"].items()
},
**build_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_interface_type(type_: GraphQLInterfaceType) -> GraphQLInterfaceType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLInterfaceType(
**merge_kwargs(
kwargs,
interfaces=lambda: [
cast(GraphQLInterfaceType, replace_named_type(interface))
for interface in kwargs["interfaces"]
]
+ build_interfaces(extensions),
fields=lambda: {
**{
name: extend_field(field)
for name, field in kwargs["fields"].items()
},
**build_field_map(extensions),
},
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
def extend_union_type(type_: GraphQLUnionType) -> GraphQLUnionType:
kwargs = type_.to_kwargs()
extensions = tuple(type_extensions_map[kwargs["name"]])
return GraphQLUnionType(
**merge_kwargs(
kwargs,
types=lambda: [
cast(GraphQLObjectType, replace_named_type(member_type))
for member_type in kwargs["types"]
]
+ build_union_types(extensions),
extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions,
)
)
# noinspection PyShadowingNames
def extend_field(field: GraphQLField) -> GraphQLField:
return GraphQLField(
**merge_kwargs(
field.to_kwargs(),
type_=replace_type(field.type),
args={name: extend_arg(arg) for name, arg in field.args.items()},
)
)
def extend_arg(arg: GraphQLArgument) -> GraphQLArgument:
return GraphQLArgument(
**merge_kwargs(
arg.to_kwargs(),
type_=replace_type(arg.type),
)
)
# noinspection PyShadowingNames
def get_operation_types(
nodes: Collection[Union[SchemaDefinitionNode, SchemaExtensionNode]]
) -> Dict[OperationType, GraphQLNamedType]:
# Note: While this could make early assertions to get the correctly
# typed values below, that would throw immediately while type system
# validation with validate_schema() will produce more actionable results.
return {
operation_type.operation: get_named_type(operation_type.type)
for node in nodes
for operation_type in node.operation_types or []
}
# noinspection PyShadowingNames
def get_named_type(node: NamedTypeNode) -> GraphQLNamedType:
name = node.name.value
type_ = std_type_map.get(name) or type_map.get(name)
if not type_:
raise TypeError(f"Unknown type: '{name}'.")
return type_
def get_wrapped_type(node: TypeNode) -> GraphQLType:
if isinstance(node, ListTypeNode):
return GraphQLList(get_wrapped_type(node.type))
if isinstance(node, NonNullTypeNode):
return GraphQLNonNull(
cast(GraphQLNullableType, get_wrapped_type(node.type))
)
return get_named_type(cast(NamedTypeNode, node))
def build_directive(node: DirectiveDefinitionNode) -> GraphQLDirective:
locations = [DirectiveLocation[node.value] for node in node.locations]
return GraphQLDirective(
name=node.name.value,
description=node.description.value if node.description else None,
locations=locations,
is_repeatable=node.repeatable,
args=build_argument_map(node.arguments),
ast_node=node,
)
def build_field_map(
nodes: Collection[
Union[
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
]
],
) -> GraphQLFieldMap:
field_map: GraphQLFieldMap = {}
for node in nodes:
for field in node.fields or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
field_map[field.name.value] = GraphQLField(
type_=cast(GraphQLOutputType, get_wrapped_type(field.type)),
description=field.description.value if field.description else None,
args=build_argument_map(field.arguments),
deprecation_reason=get_deprecation_reason(field),
ast_node=field,
)
return field_map
def build_argument_map(
args: Optional[Collection[InputValueDefinitionNode]],
) -> GraphQLArgumentMap:
arg_map: GraphQLArgumentMap = {}
for arg in args or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
type_ = cast(GraphQLInputType, get_wrapped_type(arg.type))
arg_map[arg.name.value] = GraphQLArgument(
type_=type_,
description=arg.description.value if arg.description else None,
default_value=value_from_ast(arg.default_value, type_),
deprecation_reason=get_deprecation_reason(arg),
ast_node=arg,
)
return arg_map
def build_input_field_map(
nodes: Collection[
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
],
) -> GraphQLInputFieldMap:
input_field_map: GraphQLInputFieldMap = {}
for node in nodes:
for field in node.fields or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
type_ = cast(GraphQLInputType, get_wrapped_type(field.type))
input_field_map[field.name.value] = GraphQLInputField(
type_=type_,
description=field.description.value if field.description else None,
default_value=value_from_ast(field.default_value, type_),
deprecation_reason=get_deprecation_reason(field),
ast_node=field,
)
return input_field_map
def build_enum_value_map(
nodes: Collection[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]]
) -> GraphQLEnumValueMap:
enum_value_map: GraphQLEnumValueMap = {}
for node in nodes:
for value in node.values or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
value_name = value.name.value
enum_value_map[value_name] = GraphQLEnumValue(
value=value_name,
description=value.description.value if value.description else None,
deprecation_reason=get_deprecation_reason(value),
ast_node=value,
)
return enum_value_map
def build_interfaces(
nodes: Collection[
Union[
InterfaceTypeDefinitionNode,
InterfaceTypeExtensionNode,
ObjectTypeDefinitionNode,
ObjectTypeExtensionNode,
]
],
) -> List[GraphQLInterfaceType]:
interfaces: List[GraphQLInterfaceType] = []
for node in nodes:
for type_ in node.interfaces or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
interfaces.append(cast(GraphQLInterfaceType, get_named_type(type_)))
return interfaces
def build_union_types(
nodes: Collection[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]],
) -> List[GraphQLObjectType]:
types: List[GraphQLObjectType] = []
for node in nodes:
for type_ in node.types or []:
# Note: While this could make assertions to get the correctly typed
# value, that would throw immediately while type system validation
# with validate_schema() will produce more actionable results.
types.append(cast(GraphQLObjectType, get_named_type(type_)))
return types
def build_object_type(ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLObjectType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
interfaces=lambda: build_interfaces(all_nodes),
fields=lambda: build_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_interface_type(
ast_node: InterfaceTypeDefinitionNode,
) -> GraphQLInterfaceType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[
Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode]
] = [ast_node, *extension_nodes]
return GraphQLInterfaceType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
interfaces=lambda: build_interfaces(all_nodes),
fields=lambda: build_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_enum_type(ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLEnumType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
values=build_enum_value_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_union_type(ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [
ast_node,
*extension_nodes,
]
return GraphQLUnionType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
types=lambda: build_union_types(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType:
extension_nodes = type_extensions_map[ast_node.name.value]
return GraphQLScalarType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
specified_by_url=get_specified_by_url(ast_node),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
def build_input_object_type(
ast_node: InputObjectTypeDefinitionNode,
) -> GraphQLInputObjectType:
extension_nodes = type_extensions_map[ast_node.name.value]
all_nodes: List[
Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode]
] = [ast_node, *extension_nodes]
return GraphQLInputObjectType(
name=ast_node.name.value,
description=ast_node.description.value if ast_node.description else None,
fields=lambda: build_input_field_map(all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
)
build_type_for_kind = cast(
Dict[str, Callable[[TypeDefinitionNode], GraphQLNamedType]],
{
"object_type_definition": build_object_type,
"interface_type_definition": build_interface_type,
"enum_type_definition": build_enum_type,
"union_type_definition": build_union_type,
"scalar_type_definition": build_scalar_type,
"input_object_type_definition": build_input_object_type,
},
)
def build_type(ast_node: TypeDefinitionNode) -> GraphQLNamedType:
try:
# object_type_definition_node is built with _build_object_type etc.
build_function = build_type_for_kind[ast_node.kind]
except KeyError: # pragma: no cover
# Not reachable. All possible type definition nodes have been considered.
raise TypeError( # pragma: no cover
f"Unexpected type definition node: {inspect(ast_node)}."
)
else:
return build_function(ast_node)
type_map: Dict[str, GraphQLNamedType] = {}
for existing_type in schema_kwargs["types"] or ():
type_map[existing_type.name] = extend_named_type(existing_type)
for type_node in type_defs:
name = type_node.name.value
type_map[name] = std_type_map.get(name) or build_type(type_node)
# Get the extended root operation types.
operation_types: Dict[OperationType, GraphQLNamedType] = {}
for operation_type in OperationType:
original_type = schema_kwargs[operation_type.value]
if original_type:
operation_types[operation_type] = replace_named_type(original_type)
# Then, incorporate schema definition and all schema extensions.
if schema_def:
operation_types.update(get_operation_types([schema_def]))
if schema_extensions:
operation_types.update(get_operation_types(schema_extensions))
# Then produce and return the kwargs for a Schema with these types.
get_operation = operation_types.get
return GraphQLSchemaKwargs(
query=get_operation(OperationType.QUERY), # type: ignore
mutation=get_operation(OperationType.MUTATION), # type: ignore
subscription=get_operation(OperationType.SUBSCRIPTION), # type: ignore
types=tuple(type_map.values()),
directives=tuple(
replace_directive(directive) for directive in schema_kwargs["directives"]
)
+ tuple(build_directive(directive) for directive in directive_defs),
description=(
schema_def.description.value
if schema_def and schema_def.description
else None
),
extensions={},
ast_node=schema_def or schema_kwargs["ast_node"],
extension_ast_nodes=schema_kwargs["extension_ast_nodes"]
+ tuple(schema_extensions),
assume_valid=assume_valid,
)
std_type_map: Mapping[str, Union[GraphQLNamedType, GraphQLObjectType]] = {
**specified_scalar_types,
**introspection_types,
}
def get_deprecation_reason(
node: Union[EnumValueDefinitionNode, FieldDefinitionNode, InputValueDefinitionNode]
) -> Optional[str]:
"""Given a field or enum value node, get deprecation reason as string."""
from ..execution import get_directive_values
deprecated = get_directive_values(GraphQLDeprecatedDirective, node)
return deprecated["reason"] if deprecated else None
def get_specified_by_url(
node: Union[ScalarTypeDefinitionNode, ScalarTypeExtensionNode]
) -> Optional[str]:
"""Given a scalar node, return the string value for the specifiedByURL."""
from ..execution import get_directive_values
specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
return specified_by_url["url"] if specified_by_url else None