import asyncio import json import logging from contextlib import suppress from ssl import SSLContext from typing import Any, Dict, List, Optional, Tuple, Union, cast from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike from websockets.typing import Subprotocol from .exceptions import ( TransportProtocolError, TransportQueryError, TransportServerError, ) from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) class WebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. This transport uses asyncio and the websockets library in order to send requests on a websocket connection. """ # This transport supports two subprotocols and will autodetect the # subprotocol supported on the server APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") def __init__( self, url: str, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, init_payload: Dict[str, Any] = {}, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, connect_args: Dict[str, Any] = {}, subprotocols: Optional[List[Subprotocol]] = None, ) -> None: """Initialize the transport with the given parameters. :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param init_payload: Dict of the payload sent in the connection_init message. :param connect_timeout: Timeout in seconds for the establishment of the websocket connection. If None is provided this will wait forever. :param close_timeout: Timeout in seconds for the close. If None is provided this will wait forever. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. :param ping_interval: Delay in seconds between pings sent by the client to the backend for the graphql-ws protocol. None (by default) means that we don't send pings. Note: there are also pings sent by the underlying websockets protocol. See the :ref:`keepalive documentation ` for more information about this. :param pong_timeout: Delay in seconds to receive a pong from the backend after we sent a ping (only for the graphql-ws protocol). By default equal to half of the ping_interval. :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True :param connect_args: Other parameters forwarded to `websockets.connect `_ :param subprotocols: list of subprotocols sent to the backend in the 'subprotocols' http header. By default: both apollo and graphql-ws subprotocols. """ super().__init__( url, headers, ssl, init_payload, connect_timeout, close_timeout, ack_timeout, keep_alive_timeout, connect_args, ) self.ping_interval: Optional[Union[int, float]] = ping_interval self.pong_timeout: Optional[Union[int, float]] self.answer_pings: bool = answer_pings if ping_interval is not None: if pong_timeout is None: self.pong_timeout = ping_interval / 2 else: self.pong_timeout = pong_timeout self.send_ping_task: Optional[asyncio.Future] = None self.ping_received: asyncio.Event = asyncio.Event() """ping_received is an asyncio Event which will fire each time a ping is received with the graphql-ws protocol""" self.pong_received: asyncio.Event = asyncio.Event() """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" if subprotocols is None: self.supported_subprotocols = [ self.APOLLO_SUBPROTOCOL, self.GRAPHQLWS_SUBPROTOCOL, ] else: self.supported_subprotocols = subprotocols async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" while True: init_answer = await self._receive() answer_type, answer_id, execution_result = self._parse_answer(init_answer) if answer_type == "connection_ack": return if answer_type != "ka": raise TransportProtocolError( "Websocket server did not return a connection ack" ) async def _send_init_message_and_wait_ack(self) -> None: """Send init message to the provided websocket and wait for the connection ACK. If the answer is not a connection_ack message, we will return an Exception. """ init_message = json.dumps( {"type": "connection_init", "payload": self.init_payload} ) await self._send(init_message) # Wait for the connection_ack message or raise a TimeoutError await asyncio.wait_for(self._wait_ack(), self.ack_timeout) async def _initialize(self): await self._send_init_message_and_wait_ack() async def send_ping(self, payload: Optional[Any] = None) -> None: """Send a ping message for the graphql-ws protocol""" ping_message = {"type": "ping"} if payload is not None: ping_message["payload"] = payload await self._send(json.dumps(ping_message)) async def send_pong(self, payload: Optional[Any] = None) -> None: """Send a pong message for the graphql-ws protocol""" pong_message = {"type": "pong"} if payload is not None: pong_message["payload"] = payload await self._send(json.dumps(pong_message)) async def _send_stop_message(self, query_id: int) -> None: """Send stop message to the provided websocket connection and query_id. The server should afterwards return a 'complete' message. """ stop_message = json.dumps({"id": str(query_id), "type": "stop"}) await self._send(stop_message) async def _send_complete_message(self, query_id: int) -> None: """Send a complete message for the provided query_id. This is only for the graphql-ws protocol. """ complete_message = json.dumps({"id": str(query_id), "type": "complete"}) await self._send(complete_message) async def _stop_listener(self, query_id: int): """Stop the listener corresponding to the query_id depending on the detected backend protocol. For apollo: send a "stop" message (a "complete" message will be sent from the backend) For graphql-ws: send a "complete" message and simulate the reception of a "complete" message from the backend """ log.debug(f"stop listener {query_id}") if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: await self._send_complete_message(query_id) await self.listeners[query_id].put(("complete", None)) else: await self._send_stop_message(query_id) async def _send_connection_terminate_message(self) -> None: """Send a connection_terminate message to the provided websocket connection. This message indicates that the connection will disconnect. """ connection_terminate_message = json.dumps({"type": "connection_terminate"}) await self._send(connection_terminate_message) async def _send_query( self, document: DocumentNode, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> int: """Send a query to the provided websocket connection. We use an incremented id to reference the query. Returns the used id for this query. """ query_id = self.next_query_id self.next_query_id += 1 payload: Dict[str, Any] = {"query": print_ast(document)} if variable_values: payload["variables"] = variable_values if operation_name: payload["operationName"] = operation_name query_type = "start" if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: query_type = "subscribe" query_str = json.dumps( {"id": str(query_id), "type": query_type, "payload": payload} ) await self._send(query_str) return query_id async def _connection_terminate(self): if self.subprotocol == self.APOLLO_SUBPROTOCOL: await self._send_connection_terminate_message() def _parse_answer_graphqlws( self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. Returns a list consisting of: - the answer_type (between: 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None Differences with the apollo websockets protocol (superclass): - the "data" message is now called "next" - the "stop" message is now called "complete" - there is no connection_terminate or connection_error messages - instead of a unidirectional keep-alive (ka) message from server to client, there is now the possibility to send bidirectional ping/pong messages - connection_ack has an optional payload - the 'error' answer type returns a list of errors instead of a single error """ answer_type: str = "" answer_id: Optional[int] = None execution_result: Optional[ExecutionResult] = None try: answer_type = str(json_answer.get("type")) if answer_type in ["next", "error", "complete"]: answer_id = int(str(json_answer.get("id"))) if answer_type == "next" or answer_type == "error": payload = json_answer.get("payload") if answer_type == "next": if not isinstance(payload, dict): raise ValueError("payload is not a dict") if "errors" not in payload and "data" not in payload: raise ValueError( "payload does not contain 'data' or 'errors' fields" ) execution_result = ExecutionResult( errors=payload.get("errors"), data=payload.get("data"), extensions=payload.get("extensions"), ) # Saving answer_type as 'data' to be understood with superclass answer_type = "data" elif answer_type == "error": if not isinstance(payload, list): raise ValueError("payload is not a list") raise TransportQueryError( str(payload[0]), query_id=answer_id, errors=payload ) elif answer_type in ["ping", "pong", "connection_ack"]: self.payloads[answer_type] = json_answer.get("payload", None) else: raise ValueError if self.check_keep_alive_task is not None: self._next_keep_alive_message.set() except ValueError as e: raise TransportProtocolError( f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer_apollo( self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the apollo websockets protocol. Returns a list consisting of: - the answer_type (between: 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ answer_type: str = "" answer_id: Optional[int] = None execution_result: Optional[ExecutionResult] = None try: answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: answer_id = int(str(json_answer.get("id"))) if answer_type == "data" or answer_type == "error": payload = json_answer.get("payload") if not isinstance(payload, dict): raise ValueError("payload is not a dict") if answer_type == "data": if "errors" not in payload and "data" not in payload: raise ValueError( "payload does not contain 'data' or 'errors' fields" ) execution_result = ExecutionResult( errors=payload.get("errors"), data=payload.get("data"), extensions=payload.get("extensions"), ) elif answer_type == "error": raise TransportQueryError( str(payload), query_id=answer_id, errors=[payload] ) elif answer_type == "ka": # Keep-alive message if self.check_keep_alive_task is not None: self._next_keep_alive_message.set() elif answer_type == "connection_ack": pass elif answer_type == "connection_error": error_payload = json_answer.get("payload") raise TransportServerError(f"Server error: '{repr(error_payload)}'") else: raise ValueError except ValueError as e: raise TransportProtocolError( f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer( self, answer: str ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server depending on the detected subprotocol. """ try: json_answer = json.loads(answer) except ValueError: raise TransportProtocolError( f"Server did not return a GraphQL result: {answer}" ) if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: return self._parse_answer_graphqlws(json_answer) return self._parse_answer_apollo(json_answer) async def _send_ping_coro(self) -> None: """Coroutine to periodically send a ping from the client to the backend. Only used for the graphql-ws protocol. Send a ping every ping_interval seconds. Close the connection if a pong is not received within pong_timeout seconds. """ assert self.ping_interval is not None try: while True: await asyncio.sleep(self.ping_interval) await self.send_ping() await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) # Reset for the next iteration self.pong_received.clear() except asyncio.TimeoutError: # No pong received in the appriopriate time, close with error # If the timeout happens during a close already in progress, do nothing if self.close_task is None: await self._fail( TransportServerError( f"No pong received after {self.pong_timeout!r} seconds" ), clean_close=False, ) async def _handle_answer( self, answer_type: str, answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: # Put the answer in the queue await super()._handle_answer(answer_type, answer_id, execution_result) # Answer pong to ping for graphql-ws protocol if answer_type == "ping": self.ping_received.set() if self.answer_pings: await self.send_pong() elif answer_type == "pong": self.pong_received.set() async def _after_connect(self): # Find the backend subprotocol returned in the response headers response_headers = self.websocket.response_headers try: self.subprotocol = response_headers["Sec-WebSocket-Protocol"] except KeyError: # If the server does not send the subprotocol header, using # the apollo subprotocol by default self.subprotocol = self.APOLLO_SUBPROTOCOL log.debug(f"backend subprotocol returned: {self.subprotocol!r}") async def _after_initialize(self): # If requested, create a task to send periodic pings to the backend if ( self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL and self.ping_interval is not None ): self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) async def _close_hook(self): # Properly shut down the send ping task if enabled if self.send_ping_task is not None: self.send_ping_task.cancel() with suppress(asyncio.CancelledError): await self.send_ping_task self.send_ping_task = None