diff --git a/python_graphql_client/graphql_client.py b/python_graphql_client/graphql_client.py index 1d3d104..4ecce25 100644 --- a/python_graphql_client/graphql_client.py +++ b/python_graphql_client/graphql_client.py @@ -82,6 +82,7 @@ async def subscribe( operation_name: str = None, headers: dict = {}, init_payload: dict = {}, + ws_subprotocol: str = "graphql-ws", ): """Make asynchronous request for GraphQL subscription.""" connection_init_message = json.dumps( @@ -91,13 +92,19 @@ async def subscribe( request_body = self.__request_body( query=query, variables=variables, operation_name=operation_name ) - request_message = json.dumps( - {"type": "start", "id": "1", "payload": request_body} - ) + + if ws_subprotocol == "graphql-ws": + protocol_specific_request = {"type": "start", "id": "1", "payload": request_body} + elif ws_subprotocol == "graphql-transport-ws": + protocol_specific_request = {"type": "subscribe", "id": "1", "payload": request_body} + else: + raise ValueError(f"Unknown subprotocol {ws_subprotocol}") + + request_message = json.dumps(protocol_specific_request) async with websockets.connect( self.endpoint, - subprotocols=["graphql-ws"], + subprotocols=[ws_subprotocol], extra_headers={**self.headers, **headers}, ) as websocket: await websocket.send(connection_init_message)