64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
from functools import partial, reduce
|
|
from inspect import isfunction
|
|
|
|
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
|
|
|
|
__all__ = ["MiddlewareManager"]
|
|
|
|
GraphQLFieldResolver = Callable[..., Any]
|
|
|
|
|
|
class MiddlewareManager:
|
|
"""Manager for the middleware chain.
|
|
|
|
This class helps to wrap resolver functions with the provided middleware functions
|
|
and/or objects. The functions take the next middleware function as first argument.
|
|
If middleware is provided as an object, it must provide a method ``resolve`` that is
|
|
used as the middleware function.
|
|
|
|
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
|
|
must be aware of this and check whether values are awaitable before awaiting them.
|
|
"""
|
|
|
|
# allow custom attributes (not used internally)
|
|
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
|
|
|
|
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
|
|
_middleware_resolvers: Optional[List[Callable]]
|
|
|
|
def __init__(self, *middlewares: Any):
|
|
self.middlewares = middlewares
|
|
self._middleware_resolvers = (
|
|
list(get_middleware_resolvers(middlewares)) if middlewares else None
|
|
)
|
|
self._cached_resolvers = {}
|
|
|
|
def get_field_resolver(
|
|
self, field_resolver: GraphQLFieldResolver
|
|
) -> GraphQLFieldResolver:
|
|
"""Wrap the provided resolver with the middleware.
|
|
|
|
Returns a function that chains the middleware functions with the provided
|
|
resolver function.
|
|
"""
|
|
if self._middleware_resolvers is None:
|
|
return field_resolver
|
|
if field_resolver not in self._cached_resolvers:
|
|
self._cached_resolvers[field_resolver] = reduce(
|
|
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
|
|
self._middleware_resolvers,
|
|
field_resolver,
|
|
)
|
|
return self._cached_resolvers[field_resolver]
|
|
|
|
|
|
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
|
|
"""Get a list of resolver functions from a list of classes or functions."""
|
|
for middleware in middlewares:
|
|
if isfunction(middleware):
|
|
yield middleware
|
|
else: # middleware provided as object with 'resolve' method
|
|
resolver_func = getattr(middleware, "resolve", None)
|
|
if resolver_func is not None:
|
|
yield resolver_func
|