Skip to content

Commit 268e1f7

Browse files
authored
Merge pull request #30 from DrWacker/subscription_headers
Subscription headers
2 parents 72d491d + 990b9b4 commit 268e1f7

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ repos:
33
rev: stable
44
hooks:
55
- id: black
6-
language_version: python3.7
6+
language_version: python3.8
77
- repo: https://gitlab.com/pycqa/flake8
88
rev: ""
99
hooks:

python_graphql_client/graphql_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def execute(
4747
)
4848

4949
result = requests.post(
50-
self.endpoint, json=request_body, headers=self.__request_headers(headers),
50+
self.endpoint,
51+
json=request_body,
52+
headers=self.__request_headers(headers),
5153
)
5254

5355
result.raise_for_status()
@@ -92,13 +94,17 @@ async def subscribe(
9294
)
9395

9496
async with websockets.connect(
95-
self.endpoint, subprotocols=["graphql-ws"]
97+
self.endpoint,
98+
subprotocols=["graphql-ws"],
99+
extra_headers=self.__request_headers(headers),
96100
) as websocket:
97101
await websocket.send(connection_init_message)
98102
await websocket.send(request_message)
99103
async for response_message in websocket:
100104
response_body = json.loads(response_message)
101105
if response_body["type"] == "connection_ack":
102106
logging.info("the server accepted the connection")
107+
elif response_body["type"] == "ka":
108+
logging.info("the server sent a keep alive message")
103109
else:
104110
handle(response_body["payload"])

tests/test_graphql_client.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,62 @@ async def test_subscribe(self, mock_connect):
248248
call({"data": {"messageAdded": "two"}}),
249249
]
250250
)
251+
252+
@patch("logging.info")
253+
@patch("websockets.connect")
254+
async def test_does_not_crash_with_keep_alive(self, mock_connect, mock_info):
255+
"""Subsribe a GraphQL subscription."""
256+
mock_websocket = mock_connect.return_value.__aenter__.return_value
257+
mock_websocket.send = AsyncMock()
258+
mock_websocket.__aiter__.return_value = [
259+
'{"type": "ka"}',
260+
]
261+
262+
client = GraphqlClient(endpoint="ws://www.test-api.com/graphql")
263+
query = """
264+
subscription onMessageAdded {
265+
messageAdded
266+
}
267+
"""
268+
269+
await client.subscribe(query=query, handle=MagicMock())
270+
271+
mock_info.assert_has_calls([call("the server sent a keep alive message")])
272+
273+
@patch("websockets.connect")
274+
async def test_headers_passed_to_websocket_connect(self, mock_connect):
275+
"""Subsribe a GraphQL subscription."""
276+
mock_websocket = mock_connect.return_value.__aenter__.return_value
277+
mock_websocket.send = AsyncMock()
278+
mock_websocket.__aiter__.return_value = [
279+
'{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}',
280+
]
281+
282+
expected_endpoint = "ws://www.test-api.com/graphql"
283+
client = GraphqlClient(endpoint=expected_endpoint)
284+
285+
query = """
286+
subscription onMessageAdded {
287+
messageAdded
288+
}
289+
"""
290+
291+
mock_handle = MagicMock()
292+
293+
expected_headers = {"some": "header"}
294+
295+
await client.subscribe(
296+
query=query, handle=mock_handle, headers=expected_headers
297+
)
298+
299+
mock_connect.assert_called_with(
300+
expected_endpoint,
301+
subprotocols=["graphql-ws"],
302+
extra_headers=expected_headers,
303+
)
304+
305+
mock_handle.assert_has_calls(
306+
[
307+
call({"data": {"messageAdded": "one"}}),
308+
]
309+
)

0 commit comments

Comments
 (0)