Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/ragas/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ["Experiment", "experiment", "version_experiment"]

import asyncio
import inspect
import typing as t
from pathlib import Path

Expand Down Expand Up @@ -98,10 +99,33 @@ def __init__(
self.experiment_model = experiment_model
self.default_backend = default_backend
self.name_prefix = name_prefix
# Store function signature for validation
self.signature = inspect.signature(func)
# Preserve function metadata
self.__name__ = getattr(func, "__name__", "experiment_function")
self.__doc__ = getattr(func, "__doc__", None)

def _validate_function_parameters(self, *args, **kwargs) -> None:
"""Validate that the function can be called with the provided arguments."""
try:
# Try to bind the arguments to the function signature
self.signature.bind(*args, **kwargs)
except TypeError as e:
func_name = getattr(self.func, "__name__", "experiment_function")

param_info = []
for name, param in self.signature.parameters.items():
if param.default == inspect.Parameter.empty:
param_info.append(f"{name} (required)")

expected_params = ", ".join(param_info)

raise ValueError(
f"Parameter validation failed for experiment function '{func_name}()'. "
f"Expected parameters: [{expected_params}]. "
f"Original error: {str(e)}"
) from e

async def __call__(self, *args, **kwargs) -> t.Any:
"""Call the original function."""
if asyncio.iscoroutinefunction(self.func):
Expand All @@ -118,6 +142,12 @@ async def arun(
**kwargs,
) -> "Experiment":
"""Run the experiment against a dataset."""
# Validate function parameters before any setup
# Use the first dataset item as a representative sample for validation
if len(dataset) > 0:
sample_item = next(iter(dataset))
self._validate_function_parameters(sample_item, *args, **kwargs)

# Generate name if not provided
if name is None:
name = memorable_names.generate_unique_name()
Expand Down
175 changes: 175 additions & 0 deletions tests/unit/test_experiment_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Tests for experiment parameter validation functionality."""

import pytest

from ragas.experimental import Dataset, experiment


class TestExperimentValidation:
"""Test cases for experiment parameter validation."""

def setup_method(self):
"""Setup test fixtures."""
self.dataset = Dataset(name="test_dataset", backend="inmemory")
self.dataset.append({"question": "What is 2+2?", "context": "Math"})
self.dataset.append({"question": "What is 3+3?", "context": "More math"})

@pytest.mark.asyncio
async def test_valid_single_parameter_experiment(self):
"""Test experiment with correct single parameter."""

@experiment()
async def single_param_experiment(row):
return {"result": f"Answer to: {row['question']}", "score": 1.0}

# Should work without errors
result = await single_param_experiment.arun(self.dataset)
assert len(result) == 2

@pytest.mark.asyncio
async def test_valid_multi_parameter_experiment(self):
"""Test experiment with multiple parameters provided correctly."""

@experiment()
async def multi_param_experiment(row, evaluator_llm, flag=True):
return {
"result": f"Answer using {evaluator_llm}",
"flag": flag,
"score": 1.0,
}

# Should work when all required parameters are provided
result = await multi_param_experiment.arun(
self.dataset, evaluator_llm="gpt-4", flag=False
)
assert len(result) == 2

@pytest.mark.asyncio
async def test_missing_required_parameter(self):
"""Test that missing required parameters raise ValueError."""

@experiment()
async def multi_param_experiment(row, evaluator_llm, flag=True):
return {"result": "test", "score": 1.0}

# Should raise ValueError when required parameter is missing
with pytest.raises(ValueError) as exc_info:
await multi_param_experiment.arun(self.dataset, abc=123)

error_msg = str(exc_info.value)
assert "Parameter validation failed" in error_msg
assert "multi_param_experiment()" in error_msg
assert "evaluator_llm (required)" in error_msg
assert "missing a required argument: 'evaluator_llm'" in error_msg

@pytest.mark.asyncio
async def test_validation_catches_parameter_binding_errors(self):
"""Test that validation catches various parameter binding issues."""

@experiment()
async def strict_param_experiment(row, required_str, required_int=10):
return {"result": "test", "score": 1.0}

# Test 1: Wrong keyword argument name should fail validation
with pytest.raises(ValueError) as exc_info:
await strict_param_experiment.arun(
self.dataset,
wrong_param_name="value",
name="test", # This should fail
)

error_msg = str(exc_info.value)
assert "Parameter validation failed" in error_msg
assert "strict_param_experiment()" in error_msg

# Test 2: Valid call should work
result = await strict_param_experiment.arun(
self.dataset, required_str="valid_value", required_int=20, name="test"
)
assert len(result) == 2

@pytest.mark.asyncio
async def test_unexpected_keyword_arguments(self):
"""Test that unexpected keyword arguments raise ValueError."""

@experiment()
async def single_param_experiment(row):
return {"result": "test", "score": 1.0}

# Should raise ValueError when unexpected keyword argument
with pytest.raises(ValueError) as exc_info:
await single_param_experiment.arun(self.dataset, unexpected_kwarg="value")

error_msg = str(exc_info.value)
assert "Parameter validation failed" in error_msg
assert "single_param_experiment()" in error_msg

@pytest.mark.asyncio
async def test_validation_happens_before_backend_resolution(self):
"""Test that validation occurs before any backend setup."""

@experiment()
async def invalid_experiment(row, required_param):
return {"result": "test", "score": 1.0}

# Should fail immediately without trying to resolve invalid backend
with pytest.raises(ValueError) as exc_info:
await invalid_experiment.arun(
self.dataset,
backend="nonexistent_backend", # This would normally fail later
)

# Should get validation error, not backend resolution error
error_msg = str(exc_info.value)
assert "Parameter validation failed" in error_msg
assert "required_param (required)" in error_msg

@pytest.mark.asyncio
async def test_empty_dataset_skips_validation(self):
"""Test that empty datasets skip validation."""

@experiment()
async def invalid_experiment(row, required_param):
return {"result": "test", "score": 1.0}

empty_dataset = Dataset(name="empty", backend="inmemory")

# Should not raise validation error for empty dataset
result = await invalid_experiment.arun(empty_dataset)
assert len(result) == 0

@pytest.mark.asyncio
async def test_function_with_kwargs(self):
"""Test experiment function that accepts **kwargs."""

@experiment()
async def kwargs_experiment(row, **kwargs):
return {"result": f"Row: {row['question']}", "kwargs": kwargs, "score": 1.0}

# Should work with additional keyword arguments
result = await kwargs_experiment.arun(
self.dataset, extra_param="value", another_param=42
)
assert len(result) == 2

@pytest.mark.asyncio
async def test_function_with_args_and_kwargs(self):
"""Test experiment function with *args and **kwargs."""

@experiment()
async def flexible_experiment(row, *args, **kwargs):
return {
"result": f"Row: {row['question']}",
"args": args,
"kwargs": kwargs,
"score": 1.0,
}

# Should work with additional keyword arguments that get passed to **kwargs
result = await flexible_experiment.arun(
self.dataset,
name="test",
extra_kwarg="value", # This will be passed to **kwargs
another_kwarg=42, # This will also be passed to **kwargs
)
assert len(result) == 2
Loading