from collections import defaultdict from typing import ( Any, Callable, Collection, DefaultDict, Dict, List, Mapping, Optional, Union, cast, ) from ..language import ( DirectiveDefinitionNode, DirectiveLocation, DocumentNode, EnumTypeDefinitionNode, EnumTypeExtensionNode, EnumValueDefinitionNode, FieldDefinitionNode, InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode, InputValueDefinitionNode, InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode, ListTypeNode, NamedTypeNode, NonNullTypeNode, ObjectTypeDefinitionNode, ObjectTypeExtensionNode, OperationType, ScalarTypeDefinitionNode, ScalarTypeExtensionNode, SchemaExtensionNode, SchemaDefinitionNode, TypeDefinitionNode, TypeExtensionNode, TypeNode, UnionTypeDefinitionNode, UnionTypeExtensionNode, ) from ..pyutils import inspect, merge_kwargs from ..type import ( GraphQLArgument, GraphQLArgumentMap, GraphQLDeprecatedDirective, GraphQLDirective, GraphQLEnumType, GraphQLEnumValue, GraphQLEnumValueMap, GraphQLField, GraphQLFieldMap, GraphQLInputField, GraphQLInputObjectType, GraphQLInputType, GraphQLInputFieldMap, GraphQLInterfaceType, GraphQLList, GraphQLNamedType, GraphQLNonNull, GraphQLNullableType, GraphQLObjectType, GraphQLOutputType, GraphQLScalarType, GraphQLSchema, GraphQLSchemaKwargs, GraphQLSpecifiedByDirective, GraphQLType, GraphQLUnionType, assert_schema, is_enum_type, is_input_object_type, is_interface_type, is_list_type, is_non_null_type, is_object_type, is_scalar_type, is_union_type, is_introspection_type, is_specified_scalar_type, introspection_types, specified_scalar_types, ) from .value_from_ast import value_from_ast __all__ = [ "extend_schema", "extend_schema_impl", ] def extend_schema( schema: GraphQLSchema, document_ast: DocumentNode, assume_valid: bool = False, assume_valid_sdl: bool = False, ) -> GraphQLSchema: """Extend the schema with extensions from a given document. Produces a new schema given an existing schema and a document which may contain GraphQL type extensions and definitions. The original schema will remain unaltered. Because a schema represents a graph of references, a schema cannot be extended without effectively making an entire copy. We do not know until it's too late if subgraphs remain unchanged. This algorithm copies the provided schema, applying extensions while producing the copy. The original schema remains unaltered. When extending a schema with a known valid extension, it might be safe to assume the schema is valid. Set ``assume_valid`` to ``True`` to assume the produced schema is valid. Set ``assume_valid_sdl`` to ``True`` to assume it is already a valid SDL document. """ assert_schema(schema) if not isinstance(document_ast, DocumentNode): raise TypeError("Must provide valid Document AST.") if not (assume_valid or assume_valid_sdl): from ..validation.validate import assert_valid_sdl_extension assert_valid_sdl_extension(document_ast, schema) schema_kwargs = schema.to_kwargs() extended_kwargs = extend_schema_impl(schema_kwargs, document_ast, assume_valid) return ( schema if schema_kwargs is extended_kwargs else GraphQLSchema(**extended_kwargs) ) def extend_schema_impl( schema_kwargs: GraphQLSchemaKwargs, document_ast: DocumentNode, assume_valid: bool = False, ) -> GraphQLSchemaKwargs: """Extend the given schema arguments with extensions from a given document. For internal use only. """ # Note: schema_kwargs should become a TypedDict once we require Python 3.8 # Collect the type definitions and extensions found in the document. type_defs: List[TypeDefinitionNode] = [] type_extensions_map: DefaultDict[str, Any] = defaultdict(list) # New directives and types are separate because a directives and types can have the # same name. For example, a type named "skip". directive_defs: List[DirectiveDefinitionNode] = [] schema_def: Optional[SchemaDefinitionNode] = None # Schema extensions are collected which may add additional operation types. schema_extensions: List[SchemaExtensionNode] = [] for def_ in document_ast.definitions: if isinstance(def_, SchemaDefinitionNode): schema_def = def_ elif isinstance(def_, SchemaExtensionNode): schema_extensions.append(def_) elif isinstance(def_, TypeDefinitionNode): type_defs.append(def_) elif isinstance(def_, TypeExtensionNode): extended_type_name = def_.name.value type_extensions_map[extended_type_name].append(def_) elif isinstance(def_, DirectiveDefinitionNode): directive_defs.append(def_) # If this document contains no new types, extensions, or directives then return the # same unmodified GraphQLSchema instance. if ( not type_extensions_map and not type_defs and not directive_defs and not schema_extensions and not schema_def ): return schema_kwargs # Below are functions used for producing this schema that have closed over this # scope and have access to the schema, cache, and newly defined types. # noinspection PyTypeChecker,PyUnresolvedReferences def replace_type(type_: GraphQLType) -> GraphQLType: if is_list_type(type_): return GraphQLList(replace_type(type_.of_type)) # type: ignore if is_non_null_type(type_): return GraphQLNonNull(replace_type(type_.of_type)) # type: ignore return replace_named_type(type_) # type: ignore def replace_named_type(type_: GraphQLNamedType) -> GraphQLNamedType: # Note: While this could make early assertions to get the correctly # typed values below, that would throw immediately while type system # validation with validate_schema() will produce more actionable results. return type_map[type_.name] # noinspection PyShadowingNames def replace_directive(directive: GraphQLDirective) -> GraphQLDirective: kwargs = directive.to_kwargs() return GraphQLDirective( **merge_kwargs( kwargs, args={name: extend_arg(arg) for name, arg in kwargs["args"].items()}, ) ) def extend_named_type(type_: GraphQLNamedType) -> GraphQLNamedType: if is_introspection_type(type_) or is_specified_scalar_type(type_): # Builtin types are not extended. return type_ if is_scalar_type(type_): type_ = cast(GraphQLScalarType, type_) return extend_scalar_type(type_) if is_object_type(type_): type_ = cast(GraphQLObjectType, type_) return extend_object_type(type_) if is_interface_type(type_): type_ = cast(GraphQLInterfaceType, type_) return extend_interface_type(type_) if is_union_type(type_): type_ = cast(GraphQLUnionType, type_) return extend_union_type(type_) if is_enum_type(type_): type_ = cast(GraphQLEnumType, type_) return extend_enum_type(type_) if is_input_object_type(type_): type_ = cast(GraphQLInputObjectType, type_) return extend_input_object_type(type_) # Not reachable. All possible types have been considered. raise TypeError(f"Unexpected type: {inspect(type_)}.") # pragma: no cover # noinspection PyShadowingNames def extend_input_object_type( type_: GraphQLInputObjectType, ) -> GraphQLInputObjectType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) return GraphQLInputObjectType( **merge_kwargs( kwargs, fields=lambda: { **{ name: GraphQLInputField( **merge_kwargs( field.to_kwargs(), type_=replace_type(field.type), ) ) for name, field in kwargs["fields"].items() }, **build_input_field_map(extensions), }, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) def extend_enum_type(type_: GraphQLEnumType) -> GraphQLEnumType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) return GraphQLEnumType( **merge_kwargs( kwargs, values={**kwargs["values"], **build_enum_value_map(extensions)}, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) specified_by_url = kwargs["specified_by_url"] for extension_node in extensions: specified_by_url = get_specified_by_url(extension_node) or specified_by_url return GraphQLScalarType( **merge_kwargs( kwargs, specified_by_url=specified_by_url, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) # noinspection PyShadowingNames def extend_object_type(type_: GraphQLObjectType) -> GraphQLObjectType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) return GraphQLObjectType( **merge_kwargs( kwargs, interfaces=lambda: [ cast(GraphQLInterfaceType, replace_named_type(interface)) for interface in kwargs["interfaces"] ] + build_interfaces(extensions), fields=lambda: { **{ name: extend_field(field) for name, field in kwargs["fields"].items() }, **build_field_map(extensions), }, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) # noinspection PyShadowingNames def extend_interface_type(type_: GraphQLInterfaceType) -> GraphQLInterfaceType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) return GraphQLInterfaceType( **merge_kwargs( kwargs, interfaces=lambda: [ cast(GraphQLInterfaceType, replace_named_type(interface)) for interface in kwargs["interfaces"] ] + build_interfaces(extensions), fields=lambda: { **{ name: extend_field(field) for name, field in kwargs["fields"].items() }, **build_field_map(extensions), }, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) def extend_union_type(type_: GraphQLUnionType) -> GraphQLUnionType: kwargs = type_.to_kwargs() extensions = tuple(type_extensions_map[kwargs["name"]]) return GraphQLUnionType( **merge_kwargs( kwargs, types=lambda: [ cast(GraphQLObjectType, replace_named_type(member_type)) for member_type in kwargs["types"] ] + build_union_types(extensions), extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) # noinspection PyShadowingNames def extend_field(field: GraphQLField) -> GraphQLField: return GraphQLField( **merge_kwargs( field.to_kwargs(), type_=replace_type(field.type), args={name: extend_arg(arg) for name, arg in field.args.items()}, ) ) def extend_arg(arg: GraphQLArgument) -> GraphQLArgument: return GraphQLArgument( **merge_kwargs( arg.to_kwargs(), type_=replace_type(arg.type), ) ) # noinspection PyShadowingNames def get_operation_types( nodes: Collection[Union[SchemaDefinitionNode, SchemaExtensionNode]] ) -> Dict[OperationType, GraphQLNamedType]: # Note: While this could make early assertions to get the correctly # typed values below, that would throw immediately while type system # validation with validate_schema() will produce more actionable results. return { operation_type.operation: get_named_type(operation_type.type) for node in nodes for operation_type in node.operation_types or [] } # noinspection PyShadowingNames def get_named_type(node: NamedTypeNode) -> GraphQLNamedType: name = node.name.value type_ = std_type_map.get(name) or type_map.get(name) if not type_: raise TypeError(f"Unknown type: '{name}'.") return type_ def get_wrapped_type(node: TypeNode) -> GraphQLType: if isinstance(node, ListTypeNode): return GraphQLList(get_wrapped_type(node.type)) if isinstance(node, NonNullTypeNode): return GraphQLNonNull( cast(GraphQLNullableType, get_wrapped_type(node.type)) ) return get_named_type(cast(NamedTypeNode, node)) def build_directive(node: DirectiveDefinitionNode) -> GraphQLDirective: locations = [DirectiveLocation[node.value] for node in node.locations] return GraphQLDirective( name=node.name.value, description=node.description.value if node.description else None, locations=locations, is_repeatable=node.repeatable, args=build_argument_map(node.arguments), ast_node=node, ) def build_field_map( nodes: Collection[ Union[ InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode, ObjectTypeDefinitionNode, ObjectTypeExtensionNode, ] ], ) -> GraphQLFieldMap: field_map: GraphQLFieldMap = {} for node in nodes: for field in node.fields or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. field_map[field.name.value] = GraphQLField( type_=cast(GraphQLOutputType, get_wrapped_type(field.type)), description=field.description.value if field.description else None, args=build_argument_map(field.arguments), deprecation_reason=get_deprecation_reason(field), ast_node=field, ) return field_map def build_argument_map( args: Optional[Collection[InputValueDefinitionNode]], ) -> GraphQLArgumentMap: arg_map: GraphQLArgumentMap = {} for arg in args or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. type_ = cast(GraphQLInputType, get_wrapped_type(arg.type)) arg_map[arg.name.value] = GraphQLArgument( type_=type_, description=arg.description.value if arg.description else None, default_value=value_from_ast(arg.default_value, type_), deprecation_reason=get_deprecation_reason(arg), ast_node=arg, ) return arg_map def build_input_field_map( nodes: Collection[ Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] ], ) -> GraphQLInputFieldMap: input_field_map: GraphQLInputFieldMap = {} for node in nodes: for field in node.fields or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. type_ = cast(GraphQLInputType, get_wrapped_type(field.type)) input_field_map[field.name.value] = GraphQLInputField( type_=type_, description=field.description.value if field.description else None, default_value=value_from_ast(field.default_value, type_), deprecation_reason=get_deprecation_reason(field), ast_node=field, ) return input_field_map def build_enum_value_map( nodes: Collection[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] ) -> GraphQLEnumValueMap: enum_value_map: GraphQLEnumValueMap = {} for node in nodes: for value in node.values or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. value_name = value.name.value enum_value_map[value_name] = GraphQLEnumValue( value=value_name, description=value.description.value if value.description else None, deprecation_reason=get_deprecation_reason(value), ast_node=value, ) return enum_value_map def build_interfaces( nodes: Collection[ Union[ InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode, ObjectTypeDefinitionNode, ObjectTypeExtensionNode, ] ], ) -> List[GraphQLInterfaceType]: interfaces: List[GraphQLInterfaceType] = [] for node in nodes: for type_ in node.interfaces or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. interfaces.append(cast(GraphQLInterfaceType, get_named_type(type_))) return interfaces def build_union_types( nodes: Collection[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]], ) -> List[GraphQLObjectType]: types: List[GraphQLObjectType] = [] for node in nodes: for type_ in node.types or []: # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. types.append(cast(GraphQLObjectType, get_named_type(type_))) return types def build_object_type(ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectType: extension_nodes = type_extensions_map[ast_node.name.value] all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [ ast_node, *extension_nodes, ] return GraphQLObjectType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, interfaces=lambda: build_interfaces(all_nodes), fields=lambda: build_field_map(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_interface_type( ast_node: InterfaceTypeDefinitionNode, ) -> GraphQLInterfaceType: extension_nodes = type_extensions_map[ast_node.name.value] all_nodes: List[ Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode] ] = [ast_node, *extension_nodes] return GraphQLInterfaceType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, interfaces=lambda: build_interfaces(all_nodes), fields=lambda: build_field_map(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_enum_type(ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType: extension_nodes = type_extensions_map[ast_node.name.value] all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [ ast_node, *extension_nodes, ] return GraphQLEnumType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, values=build_enum_value_map(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_union_type(ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType: extension_nodes = type_extensions_map[ast_node.name.value] all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [ ast_node, *extension_nodes, ] return GraphQLUnionType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, types=lambda: build_union_types(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType: extension_nodes = type_extensions_map[ast_node.name.value] return GraphQLScalarType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, specified_by_url=get_specified_by_url(ast_node), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_input_object_type( ast_node: InputObjectTypeDefinitionNode, ) -> GraphQLInputObjectType: extension_nodes = type_extensions_map[ast_node.name.value] all_nodes: List[ Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] ] = [ast_node, *extension_nodes] return GraphQLInputObjectType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, fields=lambda: build_input_field_map(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) build_type_for_kind = cast( Dict[str, Callable[[TypeDefinitionNode], GraphQLNamedType]], { "object_type_definition": build_object_type, "interface_type_definition": build_interface_type, "enum_type_definition": build_enum_type, "union_type_definition": build_union_type, "scalar_type_definition": build_scalar_type, "input_object_type_definition": build_input_object_type, }, ) def build_type(ast_node: TypeDefinitionNode) -> GraphQLNamedType: try: # object_type_definition_node is built with _build_object_type etc. build_function = build_type_for_kind[ast_node.kind] except KeyError: # pragma: no cover # Not reachable. All possible type definition nodes have been considered. raise TypeError( # pragma: no cover f"Unexpected type definition node: {inspect(ast_node)}." ) else: return build_function(ast_node) type_map: Dict[str, GraphQLNamedType] = {} for existing_type in schema_kwargs["types"] or (): type_map[existing_type.name] = extend_named_type(existing_type) for type_node in type_defs: name = type_node.name.value type_map[name] = std_type_map.get(name) or build_type(type_node) # Get the extended root operation types. operation_types: Dict[OperationType, GraphQLNamedType] = {} for operation_type in OperationType: original_type = schema_kwargs[operation_type.value] if original_type: operation_types[operation_type] = replace_named_type(original_type) # Then, incorporate schema definition and all schema extensions. if schema_def: operation_types.update(get_operation_types([schema_def])) if schema_extensions: operation_types.update(get_operation_types(schema_extensions)) # Then produce and return the kwargs for a Schema with these types. get_operation = operation_types.get return GraphQLSchemaKwargs( query=get_operation(OperationType.QUERY), # type: ignore mutation=get_operation(OperationType.MUTATION), # type: ignore subscription=get_operation(OperationType.SUBSCRIPTION), # type: ignore types=tuple(type_map.values()), directives=tuple( replace_directive(directive) for directive in schema_kwargs["directives"] ) + tuple(build_directive(directive) for directive in directive_defs), description=( schema_def.description.value if schema_def and schema_def.description else None ), extensions={}, ast_node=schema_def or schema_kwargs["ast_node"], extension_ast_nodes=schema_kwargs["extension_ast_nodes"] + tuple(schema_extensions), assume_valid=assume_valid, ) std_type_map: Mapping[str, Union[GraphQLNamedType, GraphQLObjectType]] = { **specified_scalar_types, **introspection_types, } def get_deprecation_reason( node: Union[EnumValueDefinitionNode, FieldDefinitionNode, InputValueDefinitionNode] ) -> Optional[str]: """Given a field or enum value node, get deprecation reason as string.""" from ..execution import get_directive_values deprecated = get_directive_values(GraphQLDeprecatedDirective, node) return deprecated["reason"] if deprecated else None def get_specified_by_url( node: Union[ScalarTypeDefinitionNode, ScalarTypeExtensionNode] ) -> Optional[str]: """Given a scalar node, return the string value for the specifiedByURL.""" from ..execution import get_directive_values specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node) return specified_by_url["url"] if specified_by_url else None