2025-12-01
This commit is contained in:
@@ -0,0 +1,514 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from ssl import SSLContext
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from graphql import DocumentNode, ExecutionResult, print_ast
|
||||
from websockets.datastructures import HeadersLike
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from .exceptions import (
|
||||
TransportProtocolError,
|
||||
TransportQueryError,
|
||||
TransportServerError,
|
||||
)
|
||||
from .websockets_base import WebsocketsTransportBase
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketsTransport(WebsocketsTransportBase):
|
||||
""":ref:`Async Transport <async_transports>` used to execute GraphQL queries on
|
||||
remote servers with websocket connection.
|
||||
|
||||
This transport uses asyncio and the websockets library in order to send requests
|
||||
on a websocket connection.
|
||||
"""
|
||||
|
||||
# This transport supports two subprotocols and will autodetect the
|
||||
# subprotocol supported on the server
|
||||
APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws")
|
||||
GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[HeadersLike] = None,
|
||||
ssl: Union[SSLContext, bool] = False,
|
||||
init_payload: Dict[str, Any] = {},
|
||||
connect_timeout: Optional[Union[int, float]] = 10,
|
||||
close_timeout: Optional[Union[int, float]] = 10,
|
||||
ack_timeout: Optional[Union[int, float]] = 10,
|
||||
keep_alive_timeout: Optional[Union[int, float]] = None,
|
||||
ping_interval: Optional[Union[int, float]] = None,
|
||||
pong_timeout: Optional[Union[int, float]] = None,
|
||||
answer_pings: bool = True,
|
||||
connect_args: Dict[str, Any] = {},
|
||||
subprotocols: Optional[List[Subprotocol]] = None,
|
||||
) -> None:
|
||||
"""Initialize the transport with the given parameters.
|
||||
|
||||
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
|
||||
:param headers: Dict of HTTP Headers.
|
||||
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
|
||||
:param init_payload: Dict of the payload sent in the connection_init message.
|
||||
:param connect_timeout: Timeout in seconds for the establishment
|
||||
of the websocket connection. If None is provided this will wait forever.
|
||||
:param close_timeout: Timeout in seconds for the close. If None is provided
|
||||
this will wait forever.
|
||||
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
|
||||
from the server. If None is provided this will wait forever.
|
||||
:param keep_alive_timeout: Optional Timeout in seconds to receive
|
||||
a sign of liveness from the server.
|
||||
:param ping_interval: Delay in seconds between pings sent by the client to
|
||||
the backend for the graphql-ws protocol. None (by default) means that
|
||||
we don't send pings. Note: there are also pings sent by the underlying
|
||||
websockets protocol. See the
|
||||
:ref:`keepalive documentation <websockets_transport_keepalives>`
|
||||
for more information about this.
|
||||
:param pong_timeout: Delay in seconds to receive a pong from the backend
|
||||
after we sent a ping (only for the graphql-ws protocol).
|
||||
By default equal to half of the ping_interval.
|
||||
:param answer_pings: Whether the client answers the pings from the backend
|
||||
(for the graphql-ws protocol).
|
||||
By default: True
|
||||
:param connect_args: Other parameters forwarded to
|
||||
`websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
|
||||
client.html#opening-a-connection>`_
|
||||
:param subprotocols: list of subprotocols sent to the
|
||||
backend in the 'subprotocols' http header.
|
||||
By default: both apollo and graphql-ws subprotocols.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
url,
|
||||
headers,
|
||||
ssl,
|
||||
init_payload,
|
||||
connect_timeout,
|
||||
close_timeout,
|
||||
ack_timeout,
|
||||
keep_alive_timeout,
|
||||
connect_args,
|
||||
)
|
||||
|
||||
self.ping_interval: Optional[Union[int, float]] = ping_interval
|
||||
self.pong_timeout: Optional[Union[int, float]]
|
||||
self.answer_pings: bool = answer_pings
|
||||
|
||||
if ping_interval is not None:
|
||||
if pong_timeout is None:
|
||||
self.pong_timeout = ping_interval / 2
|
||||
else:
|
||||
self.pong_timeout = pong_timeout
|
||||
|
||||
self.send_ping_task: Optional[asyncio.Future] = None
|
||||
|
||||
self.ping_received: asyncio.Event = asyncio.Event()
|
||||
"""ping_received is an asyncio Event which will fire each time
|
||||
a ping is received with the graphql-ws protocol"""
|
||||
|
||||
self.pong_received: asyncio.Event = asyncio.Event()
|
||||
"""pong_received is an asyncio Event which will fire each time
|
||||
a pong is received with the graphql-ws protocol"""
|
||||
|
||||
if subprotocols is None:
|
||||
self.supported_subprotocols = [
|
||||
self.APOLLO_SUBPROTOCOL,
|
||||
self.GRAPHQLWS_SUBPROTOCOL,
|
||||
]
|
||||
else:
|
||||
self.supported_subprotocols = subprotocols
|
||||
|
||||
async def _wait_ack(self) -> None:
|
||||
"""Wait for the connection_ack message. Keep alive messages are ignored"""
|
||||
|
||||
while True:
|
||||
init_answer = await self._receive()
|
||||
|
||||
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
|
||||
|
||||
if answer_type == "connection_ack":
|
||||
return
|
||||
|
||||
if answer_type != "ka":
|
||||
raise TransportProtocolError(
|
||||
"Websocket server did not return a connection ack"
|
||||
)
|
||||
|
||||
async def _send_init_message_and_wait_ack(self) -> None:
|
||||
"""Send init message to the provided websocket and wait for the connection ACK.
|
||||
|
||||
If the answer is not a connection_ack message, we will return an Exception.
|
||||
"""
|
||||
|
||||
init_message = json.dumps(
|
||||
{"type": "connection_init", "payload": self.init_payload}
|
||||
)
|
||||
|
||||
await self._send(init_message)
|
||||
|
||||
# Wait for the connection_ack message or raise a TimeoutError
|
||||
await asyncio.wait_for(self._wait_ack(), self.ack_timeout)
|
||||
|
||||
async def _initialize(self):
|
||||
await self._send_init_message_and_wait_ack()
|
||||
|
||||
async def send_ping(self, payload: Optional[Any] = None) -> None:
|
||||
"""Send a ping message for the graphql-ws protocol"""
|
||||
|
||||
ping_message = {"type": "ping"}
|
||||
|
||||
if payload is not None:
|
||||
ping_message["payload"] = payload
|
||||
|
||||
await self._send(json.dumps(ping_message))
|
||||
|
||||
async def send_pong(self, payload: Optional[Any] = None) -> None:
|
||||
"""Send a pong message for the graphql-ws protocol"""
|
||||
|
||||
pong_message = {"type": "pong"}
|
||||
|
||||
if payload is not None:
|
||||
pong_message["payload"] = payload
|
||||
|
||||
await self._send(json.dumps(pong_message))
|
||||
|
||||
async def _send_stop_message(self, query_id: int) -> None:
|
||||
"""Send stop message to the provided websocket connection and query_id.
|
||||
|
||||
The server should afterwards return a 'complete' message.
|
||||
"""
|
||||
|
||||
stop_message = json.dumps({"id": str(query_id), "type": "stop"})
|
||||
|
||||
await self._send(stop_message)
|
||||
|
||||
async def _send_complete_message(self, query_id: int) -> None:
|
||||
"""Send a complete message for the provided query_id.
|
||||
|
||||
This is only for the graphql-ws protocol.
|
||||
"""
|
||||
|
||||
complete_message = json.dumps({"id": str(query_id), "type": "complete"})
|
||||
|
||||
await self._send(complete_message)
|
||||
|
||||
async def _stop_listener(self, query_id: int):
|
||||
"""Stop the listener corresponding to the query_id depending on the
|
||||
detected backend protocol.
|
||||
|
||||
For apollo: send a "stop" message
|
||||
(a "complete" message will be sent from the backend)
|
||||
|
||||
For graphql-ws: send a "complete" message and simulate the reception
|
||||
of a "complete" message from the backend
|
||||
"""
|
||||
log.debug(f"stop listener {query_id}")
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
await self._send_complete_message(query_id)
|
||||
await self.listeners[query_id].put(("complete", None))
|
||||
else:
|
||||
await self._send_stop_message(query_id)
|
||||
|
||||
async def _send_connection_terminate_message(self) -> None:
|
||||
"""Send a connection_terminate message to the provided websocket connection.
|
||||
|
||||
This message indicates that the connection will disconnect.
|
||||
"""
|
||||
|
||||
connection_terminate_message = json.dumps({"type": "connection_terminate"})
|
||||
|
||||
await self._send(connection_terminate_message)
|
||||
|
||||
async def _send_query(
|
||||
self,
|
||||
document: DocumentNode,
|
||||
variable_values: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Send a query to the provided websocket connection.
|
||||
|
||||
We use an incremented id to reference the query.
|
||||
|
||||
Returns the used id for this query.
|
||||
"""
|
||||
|
||||
query_id = self.next_query_id
|
||||
self.next_query_id += 1
|
||||
|
||||
payload: Dict[str, Any] = {"query": print_ast(document)}
|
||||
if variable_values:
|
||||
payload["variables"] = variable_values
|
||||
if operation_name:
|
||||
payload["operationName"] = operation_name
|
||||
|
||||
query_type = "start"
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
query_type = "subscribe"
|
||||
|
||||
query_str = json.dumps(
|
||||
{"id": str(query_id), "type": query_type, "payload": payload}
|
||||
)
|
||||
|
||||
await self._send(query_str)
|
||||
|
||||
return query_id
|
||||
|
||||
async def _connection_terminate(self):
|
||||
if self.subprotocol == self.APOLLO_SUBPROTOCOL:
|
||||
await self._send_connection_terminate_message()
|
||||
|
||||
def _parse_answer_graphqlws(
|
||||
self, json_answer: Dict[str, Any]
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server if the server supports the
|
||||
graphql-ws protocol.
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type (between:
|
||||
'connection_ack', 'ping', 'pong', 'data', 'error', 'complete')
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
|
||||
Differences with the apollo websockets protocol (superclass):
|
||||
- the "data" message is now called "next"
|
||||
- the "stop" message is now called "complete"
|
||||
- there is no connection_terminate or connection_error messages
|
||||
- instead of a unidirectional keep-alive (ka) message from server to client,
|
||||
there is now the possibility to send bidirectional ping/pong messages
|
||||
- connection_ack has an optional payload
|
||||
- the 'error' answer type returns a list of errors instead of a single error
|
||||
"""
|
||||
|
||||
answer_type: str = ""
|
||||
answer_id: Optional[int] = None
|
||||
execution_result: Optional[ExecutionResult] = None
|
||||
|
||||
try:
|
||||
answer_type = str(json_answer.get("type"))
|
||||
|
||||
if answer_type in ["next", "error", "complete"]:
|
||||
answer_id = int(str(json_answer.get("id")))
|
||||
|
||||
if answer_type == "next" or answer_type == "error":
|
||||
|
||||
payload = json_answer.get("payload")
|
||||
|
||||
if answer_type == "next":
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("payload is not a dict")
|
||||
|
||||
if "errors" not in payload and "data" not in payload:
|
||||
raise ValueError(
|
||||
"payload does not contain 'data' or 'errors' fields"
|
||||
)
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
errors=payload.get("errors"),
|
||||
data=payload.get("data"),
|
||||
extensions=payload.get("extensions"),
|
||||
)
|
||||
|
||||
# Saving answer_type as 'data' to be understood with superclass
|
||||
answer_type = "data"
|
||||
|
||||
elif answer_type == "error":
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise ValueError("payload is not a list")
|
||||
|
||||
raise TransportQueryError(
|
||||
str(payload[0]), query_id=answer_id, errors=payload
|
||||
)
|
||||
|
||||
elif answer_type in ["ping", "pong", "connection_ack"]:
|
||||
self.payloads[answer_type] = json_answer.get("payload", None)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if self.check_keep_alive_task is not None:
|
||||
self._next_keep_alive_message.set()
|
||||
|
||||
except ValueError as e:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {json_answer}"
|
||||
) from e
|
||||
|
||||
return answer_type, answer_id, execution_result
|
||||
|
||||
def _parse_answer_apollo(
|
||||
self, json_answer: Dict[str, Any]
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server if the server supports the
|
||||
apollo websockets protocol.
|
||||
|
||||
Returns a list consisting of:
|
||||
- the answer_type (between:
|
||||
'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete')
|
||||
- the answer id (Integer) if received or None
|
||||
- an execution Result if the answer_type is 'data' or None
|
||||
"""
|
||||
|
||||
answer_type: str = ""
|
||||
answer_id: Optional[int] = None
|
||||
execution_result: Optional[ExecutionResult] = None
|
||||
|
||||
try:
|
||||
answer_type = str(json_answer.get("type"))
|
||||
|
||||
if answer_type in ["data", "error", "complete"]:
|
||||
answer_id = int(str(json_answer.get("id")))
|
||||
|
||||
if answer_type == "data" or answer_type == "error":
|
||||
|
||||
payload = json_answer.get("payload")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("payload is not a dict")
|
||||
|
||||
if answer_type == "data":
|
||||
|
||||
if "errors" not in payload and "data" not in payload:
|
||||
raise ValueError(
|
||||
"payload does not contain 'data' or 'errors' fields"
|
||||
)
|
||||
|
||||
execution_result = ExecutionResult(
|
||||
errors=payload.get("errors"),
|
||||
data=payload.get("data"),
|
||||
extensions=payload.get("extensions"),
|
||||
)
|
||||
|
||||
elif answer_type == "error":
|
||||
|
||||
raise TransportQueryError(
|
||||
str(payload), query_id=answer_id, errors=[payload]
|
||||
)
|
||||
|
||||
elif answer_type == "ka":
|
||||
# Keep-alive message
|
||||
if self.check_keep_alive_task is not None:
|
||||
self._next_keep_alive_message.set()
|
||||
elif answer_type == "connection_ack":
|
||||
pass
|
||||
elif answer_type == "connection_error":
|
||||
error_payload = json_answer.get("payload")
|
||||
raise TransportServerError(f"Server error: '{repr(error_payload)}'")
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
except ValueError as e:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {json_answer}"
|
||||
) from e
|
||||
|
||||
return answer_type, answer_id, execution_result
|
||||
|
||||
def _parse_answer(
|
||||
self, answer: str
|
||||
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
|
||||
"""Parse the answer received from the server depending on
|
||||
the detected subprotocol.
|
||||
"""
|
||||
try:
|
||||
json_answer = json.loads(answer)
|
||||
except ValueError:
|
||||
raise TransportProtocolError(
|
||||
f"Server did not return a GraphQL result: {answer}"
|
||||
)
|
||||
|
||||
if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL:
|
||||
return self._parse_answer_graphqlws(json_answer)
|
||||
|
||||
return self._parse_answer_apollo(json_answer)
|
||||
|
||||
async def _send_ping_coro(self) -> None:
|
||||
"""Coroutine to periodically send a ping from the client to the backend.
|
||||
|
||||
Only used for the graphql-ws protocol.
|
||||
|
||||
Send a ping every ping_interval seconds.
|
||||
Close the connection if a pong is not received within pong_timeout seconds.
|
||||
"""
|
||||
|
||||
assert self.ping_interval is not None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
await self.send_ping()
|
||||
|
||||
await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout)
|
||||
|
||||
# Reset for the next iteration
|
||||
self.pong_received.clear()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No pong received in the appriopriate time, close with error
|
||||
# If the timeout happens during a close already in progress, do nothing
|
||||
if self.close_task is None:
|
||||
await self._fail(
|
||||
TransportServerError(
|
||||
f"No pong received after {self.pong_timeout!r} seconds"
|
||||
),
|
||||
clean_close=False,
|
||||
)
|
||||
|
||||
async def _handle_answer(
|
||||
self,
|
||||
answer_type: str,
|
||||
answer_id: Optional[int],
|
||||
execution_result: Optional[ExecutionResult],
|
||||
) -> None:
|
||||
|
||||
# Put the answer in the queue
|
||||
await super()._handle_answer(answer_type, answer_id, execution_result)
|
||||
|
||||
# Answer pong to ping for graphql-ws protocol
|
||||
if answer_type == "ping":
|
||||
self.ping_received.set()
|
||||
if self.answer_pings:
|
||||
await self.send_pong()
|
||||
|
||||
elif answer_type == "pong":
|
||||
self.pong_received.set()
|
||||
|
||||
async def _after_connect(self):
|
||||
|
||||
# Find the backend subprotocol returned in the response headers
|
||||
response_headers = self.websocket.response_headers
|
||||
try:
|
||||
self.subprotocol = response_headers["Sec-WebSocket-Protocol"]
|
||||
except KeyError:
|
||||
# If the server does not send the subprotocol header, using
|
||||
# the apollo subprotocol by default
|
||||
self.subprotocol = self.APOLLO_SUBPROTOCOL
|
||||
|
||||
log.debug(f"backend subprotocol returned: {self.subprotocol!r}")
|
||||
|
||||
async def _after_initialize(self):
|
||||
|
||||
# If requested, create a task to send periodic pings to the backend
|
||||
if (
|
||||
self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL
|
||||
and self.ping_interval is not None
|
||||
):
|
||||
|
||||
self.send_ping_task = asyncio.ensure_future(self._send_ping_coro())
|
||||
|
||||
async def _close_hook(self):
|
||||
|
||||
# Properly shut down the send ping task if enabled
|
||||
if self.send_ping_task is not None:
|
||||
self.send_ping_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self.send_ping_task
|
||||
self.send_ping_task = None
|
||||
Reference in New Issue
Block a user