|
5 | 5 | import inspect
|
6 | 6 | import logging
|
7 | 7 | import contextvars
|
8 |
| -from typing import Any, Dict, List, Tuple, Optional, Awaitable, Generator |
| 8 | +from typing import Any, Dict, List, Tuple, Optional, Awaitable, Generator, AsyncIterator |
9 | 9 | from functools import wraps
|
10 | 10 | from contextlib import contextmanager
|
11 | 11 |
|
@@ -287,6 +287,15 @@ def decorator(func):
|
287 | 287 | async def wrapper(*func_args, **func_kwargs):
|
288 | 288 | if step_kwargs.get("name") is None:
|
289 | 289 | step_kwargs["name"] = func.__name__
|
| 290 | + |
| 291 | + # Check if function is an async generator |
| 292 | + if inspect.isasyncgenfunction(func): |
| 293 | + return handle_async_generator( |
| 294 | + func, func_args, func_kwargs, step_args, step_kwargs, |
| 295 | + inference_pipeline_id, context_kwarg, func_signature |
| 296 | + ) |
| 297 | + |
| 298 | + # Handle regular async functions |
290 | 299 | with create_step(
|
291 | 300 | *step_args, inference_pipeline_id=inference_pipeline_id, **step_kwargs
|
292 | 301 | ) as step:
|
@@ -327,6 +336,69 @@ async def wrapper(*func_args, **func_kwargs):
|
327 | 336 | raise exception
|
328 | 337 | return output
|
329 | 338 |
|
| 339 | + async def handle_async_generator( |
| 340 | + func, func_args, func_kwargs, step_args, step_kwargs, |
| 341 | + inference_pipeline_id, context_kwarg, func_signature |
| 342 | + ) -> AsyncIterator[Any]: |
| 343 | + """Handle async generator functions properly.""" |
| 344 | + with create_step( |
| 345 | + *step_args, inference_pipeline_id=inference_pipeline_id, **step_kwargs |
| 346 | + ) as step: |
| 347 | + collected_output = [] |
| 348 | + exception = None |
| 349 | + |
| 350 | + # Prepare inputs |
| 351 | + bound = func_signature.bind(*func_args, **func_kwargs) |
| 352 | + bound.apply_defaults() |
| 353 | + inputs = dict(bound.arguments) |
| 354 | + inputs.pop("self", None) |
| 355 | + inputs.pop("cls", None) |
| 356 | + |
| 357 | + if context_kwarg: |
| 358 | + if context_kwarg in inputs: |
| 359 | + log_context(inputs.get(context_kwarg)) |
| 360 | + else: |
| 361 | + logger.warning( |
| 362 | + "Context kwarg `%s` not found in inputs of the " |
| 363 | + "current function.", |
| 364 | + context_kwarg, |
| 365 | + ) |
| 366 | + |
| 367 | + try: |
| 368 | + # Get the async generator |
| 369 | + async_gen = func(*func_args, **func_kwargs) |
| 370 | + |
| 371 | + # Consume and collect all values while yielding them |
| 372 | + async for value in async_gen: |
| 373 | + collected_output.append(value) |
| 374 | + yield value # Maintain streaming behavior |
| 375 | + |
| 376 | + except Exception as exc: |
| 377 | + step.log(metadata={"Exceptions": str(exc)}) |
| 378 | + exception = exc |
| 379 | + raise |
| 380 | + finally: |
| 381 | + # Log complete output after streaming finishes |
| 382 | + end_time = time.time() |
| 383 | + latency = (end_time - step.start_time) * 1000 # in ms |
| 384 | + |
| 385 | + # Convert collected output to string representation |
| 386 | + if collected_output: |
| 387 | + # Handle different types of output |
| 388 | + if all(isinstance(item, str) for item in collected_output): |
| 389 | + complete_output = "".join(collected_output) |
| 390 | + else: |
| 391 | + complete_output = "".join(str(item) for item in collected_output) |
| 392 | + else: |
| 393 | + complete_output = "" |
| 394 | + |
| 395 | + step.log( |
| 396 | + inputs=inputs, |
| 397 | + output=complete_output, # Actual content, not generator object |
| 398 | + end_time=end_time, |
| 399 | + latency=latency, # Correct timing for full streaming |
| 400 | + ) |
| 401 | + |
330 | 402 | return wrapper
|
331 | 403 |
|
332 | 404 | return decorator
|
|
0 commit comments