diff --git a/py/src/braintrust/functions/invoke.py b/py/src/braintrust/functions/invoke.py index 9bd0a41a8..e473870ee 100644 --- a/py/src/braintrust/functions/invoke.py +++ b/py/src/braintrust/functions/invoke.py @@ -211,13 +211,20 @@ def init_function(project_name: str, slug: str, version: Optional[str] = None): :return: A function that can be used as a task or scorer. """ - def f(*args: Any, **kwargs: Any) -> Any: - if len(args) > 0: - # Task. - return invoke(project_name=project_name, slug=slug, version=version, input=args[0]) + def f(input, hooks=None): + # When used as a task, hooks may be provided with metadata + # When used as a scorer, all args come via kwargs in the input dict + if hooks is not None and not isinstance(hooks, str): + # Task mode with hooks object + metadata = hooks.metadata if hasattr(hooks, 'metadata') else None + return invoke(project_name=project_name, slug=slug, version=version, input=input, metadata=metadata) + elif isinstance(input, dict) and ('output' in input or 'expected' in input or 'metadata' in input): + # Scorer mode - input is a dict with scorer args + metadata = input.get('metadata') + return invoke(project_name=project_name, slug=slug, version=version, input=input, metadata=metadata) else: - # Scorer. - return invoke(project_name=project_name, slug=slug, version=version, input=kwargs) + # Task mode without hooks (backward compatibility) or hooks is not actually hooks + return invoke(project_name=project_name, slug=slug, version=version, input=input) f.__name__ = f"init_function-{project_name}-{slug}-{version or 'latest'}" return f