2025-12-01
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
from functools import partial, reduce
|
||||
from inspect import isfunction
|
||||
|
||||
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
|
||||
|
||||
__all__ = ["MiddlewareManager"]
|
||||
|
||||
GraphQLFieldResolver = Callable[..., Any]
|
||||
|
||||
|
||||
class MiddlewareManager:
|
||||
"""Manager for the middleware chain.
|
||||
|
||||
This class helps to wrap resolver functions with the provided middleware functions
|
||||
and/or objects. The functions take the next middleware function as first argument.
|
||||
If middleware is provided as an object, it must provide a method ``resolve`` that is
|
||||
used as the middleware function.
|
||||
|
||||
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
|
||||
must be aware of this and check whether values are awaitable before awaiting them.
|
||||
"""
|
||||
|
||||
# allow custom attributes (not used internally)
|
||||
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
|
||||
|
||||
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
|
||||
_middleware_resolvers: Optional[List[Callable]]
|
||||
|
||||
def __init__(self, *middlewares: Any):
|
||||
self.middlewares = middlewares
|
||||
self._middleware_resolvers = (
|
||||
list(get_middleware_resolvers(middlewares)) if middlewares else None
|
||||
)
|
||||
self._cached_resolvers = {}
|
||||
|
||||
def get_field_resolver(
|
||||
self, field_resolver: GraphQLFieldResolver
|
||||
) -> GraphQLFieldResolver:
|
||||
"""Wrap the provided resolver with the middleware.
|
||||
|
||||
Returns a function that chains the middleware functions with the provided
|
||||
resolver function.
|
||||
"""
|
||||
if self._middleware_resolvers is None:
|
||||
return field_resolver
|
||||
if field_resolver not in self._cached_resolvers:
|
||||
self._cached_resolvers[field_resolver] = reduce(
|
||||
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
|
||||
self._middleware_resolvers,
|
||||
field_resolver,
|
||||
)
|
||||
return self._cached_resolvers[field_resolver]
|
||||
|
||||
|
||||
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
|
||||
"""Get a list of resolver functions from a list of classes or functions."""
|
||||
for middleware in middlewares:
|
||||
if isfunction(middleware):
|
||||
yield middleware
|
||||
else: # middleware provided as object with 'resolve' method
|
||||
resolver_func = getattr(middleware, "resolve", None)
|
||||
if resolver_func is not None:
|
||||
yield resolver_func
|
||||
Reference in New Issue
Block a user