@@ -36,7 +36,11 @@ def _(article):
3636"""
3737from __future__ import annotations
3838
39+ import asyncio
3940import enum
41+ import functools
42+ import inspect
43+ from contextlib import contextmanager
4044from dataclasses import dataclass , field
4145from itertools import count
4246from typing import Any , Callable , Iterable , TypeVar
@@ -66,6 +70,7 @@ class StepFunctionContext:
6670 parser : StepParser
6771 converters : dict [str , Callable [..., Any ]] = field (default_factory = dict )
6872 target_fixture : str | None = None
73+ is_async : bool = False
6974
7075
7176def get_step_fixture_name (step : Step ) -> str :
@@ -86,6 +91,7 @@ def given(
8691 {<param_name>: <converter function>}.
8792 :param target_fixture: Target fixture name to replace by steps definition function.
8893 :param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
94+ :param is_async: True if the step is asynchronous. (Default: False)
8995
9096 :return: Decorator function for the step.
9197 """
@@ -105,6 +111,7 @@ def when(
105111 {<param_name>: <converter function>}.
106112 :param target_fixture: Target fixture name to replace by steps definition function.
107113 :param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
114+ :param is_async: True if the step is asynchronous. (Default: False)
108115
109116 :return: Decorator function for the step.
110117 """
@@ -124,6 +131,7 @@ def then(
124131 {<param_name>: <converter function>}.
125132 :param target_fixture: Target fixture name to replace by steps definition function.
126133 :param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
134+ :param is_async: True if the step is asynchronous. (Default: False)
127135
128136 :return: Decorator function for the step.
129137 """
@@ -144,6 +152,7 @@ def step(
144152 :param converters: Optional step arguments converters mapping.
145153 :param target_fixture: Optional fixture name to replace by step definition.
146154 :param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
155+ :param is_async: True if the step is asynchronous. (Default: False)
147156
148157 :return: Decorator function for the step.
149158
@@ -159,6 +168,11 @@ def step(
159168 def decorator (func : TCallable ) -> TCallable :
160169 parser = get_parser (name )
161170
171+ if inspect .isasyncgenfunction (func ):
172+ func = wrap_asyncgen (func )
173+ elif inspect .iscoroutinefunction (func ):
174+ func = wrap_coroutine (func )
175+
162176 context = StepFunctionContext (
163177 type = type_ ,
164178 step_func = func ,
@@ -177,11 +191,79 @@ def step_function_marker() -> StepFunctionContext:
177191 f"{ StepNamePrefix .step_def .value } _{ type_ or '*' } _{ parser .name } " , seen = caller_locals .keys ()
178192 )
179193 caller_locals [fixture_step_name ] = pytest .fixture (name = fixture_step_name )(step_function_marker )
194+
180195 return func
181196
182197 return decorator
183198
184199
200+ def _synchronize (func : Callable , async_wrapper : Callable ) -> Callable :
201+ """Provide a synchronous wrapper for an async function or generator.
202+
203+ :param func: The async function / generator to wrap.
204+ :param async_wrapper: A function taking an event loop and either a
205+ coroutine or an async_generator (the result of calling func)
206+ and returning the result.
207+
208+ :returns: The wrapped async function.
209+ """
210+
211+ @functools .wraps (func )
212+ def _wrapper (* args , ** kwargs ):
213+ try :
214+ loop , created = asyncio .get_running_loop (), False
215+ except RuntimeError :
216+ loop , created = asyncio .get_event_loop_policy ().new_event_loop (), True
217+
218+ try :
219+ yield async_wrapper (loop , func (* args , ** kwargs ))
220+ except :
221+ raise
222+ finally :
223+ if created :
224+ loop .close ()
225+
226+ return _wrapper
227+
228+
229+ def wrap_asyncgen (func : Callable ) -> Callable :
230+ """Wrapper for an async_generator function.
231+
232+ :param func: The function to wrap.
233+
234+ :returns: The wrapped function. The wrapped function will raise ValueError
235+ if the generator yields more than once.
236+ """
237+
238+ def _wrapper (loop : asyncio .events .AbstractEventLoop , async_obj ):
239+ result = loop .run_until_complete (async_obj .__anext__ ())
240+ try :
241+ loop .run_until_complete (async_obj .__anext__ ())
242+ except StopAsyncIteration :
243+ pass
244+ else :
245+ msg = "Async genetator should yield only once."
246+ raise ValueError (msg )
247+
248+ return result
249+
250+ return _synchronize (func , _wrapper )
251+
252+
253+ def wrap_coroutine (func : Callable ) -> Callable :
254+ """Wrapper for a coroutine function.
255+
256+ :param func: The function to wrap.
257+
258+ :returns: The wrapped function.
259+ """
260+
261+ def _wrapper (loop : asyncio .events .AbstractEventLoop , async_obj ):
262+ return loop .run_until_complete (async_obj )
263+
264+ return _synchronize (func , _wrapper )
265+
266+
185267def find_unique_name (name : str , seen : Iterable [str ]) -> str :
186268 """Find unique name among a set of strings.
187269
0 commit comments