Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ def _build_rules(
rules["primary_key"] = Rule(~pl.struct(primary_keys).is_duplicated())

# Add column-specific rules
column_rules = {
f"{col_name}|{rule_name}": Rule(expr)
for col_name, column in columns.items()
for rule_name, expr in column.validation_rules(pl.col(col_name)).items()
}
rules.update(column_rules)
rules.update(_build_column_rules(columns))

# Add casting rules if requested. Here, we can simply check whether the nullability
# property of a column changes due to lenient dtype casting. Whenever casting fails,
Expand All @@ -70,6 +65,14 @@ def _build_rules(
return rules


def _build_column_rules(columns: dict[str, Column]) -> dict[str, Rule]:
return {
f"{col_name}|{rule_name}": Rule(expr)
for col_name, column in columns.items()
for rule_name, expr in column.validation_rules(pl.col(col_name)).items()
}


def _primary_keys(columns: dict[str, Column]) -> list[str]:
return list(k for k, col in columns.items() if col.primary_key)

Expand Down
46 changes: 34 additions & 12 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from dataframely._compat import deltalake

from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema
from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema, _build_column_rules
from ._compat import pa, sa
from ._rule import Rule, rule_from_dict, with_evaluation_rules
from ._serialization import (
Expand All @@ -39,7 +39,7 @@
from .columns import Column, column_from_dict
from .config import Config
from .exc import RuleValidationError, ValidationError, ValidationRequiredError
from .failure import FailureInfo
from .failure import FailureInfo, _compute_counts
from .random import Generator

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -209,6 +209,8 @@ def sample(
``overrides``.
ValueError: If ``overrides`` are specified as a sequence of mappings and
the mappings do not provide the same keys.
ValueError: If the values provided through `overrides` do not comply with
column-level validation rules of the schema.
ValueError: If no valid data frame can be found in the configured maximum
number of iterations.

Expand Down Expand Up @@ -273,6 +275,20 @@ def sample(
# frame.
values = pl.DataFrame()

# Check that the initial data frame complies with column-only rules
specified_columns = {
name: col for name, col in cls.columns().items() if name in values.columns
}
column_rules = _build_column_rules(specified_columns)
if len(column_rules) > 0:
evaluated_rules = cls._evaluate_rules(values.lazy(), column_rules)
if evaluated_rules.select(~pl.col("__final_valid__").any()).item():
counts = _compute_counts(evaluated_rules, list(column_rules.keys()))
raise ValueError(
"The provided overrides do not comply with the column-level "
f"rules of the schema. Rule violation counts: {counts}"
)

# Prepare expressions for columns that need to be preprocessed during sampling
# iterations.
sampling_overrides = cls._sampling_overrides()
Expand Down Expand Up @@ -551,18 +567,11 @@ def _filter_raw(

# Then, we filter the data frame
if len(rules) > 0:
lf_with_eval = lf.pipe(with_evaluation_rules, rules)

# At this point, `lf_with_eval` contains the following:
# - All relevant columns of the original data frame
# Evaluate the rules on the data frame
# - If `cast` is set to `True`, the columns with their original names are
# already cast and the original values are available in the columns
# prefixed with `ORIGINAL_COLUMN_PREFIX`.
# - One boolean column for each rule in `rules`
rule_columns = rules.keys()
df_evaluated = lf_with_eval.with_columns(
__final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True))
).collect()
df_evaluated = cls._evaluate_rules(lf, rules)

# For the output, partition `lf_evaluated` into the returned data frame `lf`
# and the invalid data frame
Expand All @@ -572,7 +581,7 @@ def _filter_raw(
.drop(
"__final_valid__",
cs.starts_with(ORIGINAL_COLUMN_PREFIX),
*rule_columns,
*rules.keys(),
)
)
else:
Expand All @@ -583,6 +592,19 @@ def _filter_raw(
df_evaluated,
)

@staticmethod
def _evaluate_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.DataFrame:
lf_with_eval = lf.pipe(with_evaluation_rules, rules)

# At this point, `lf_with_eval` contains the following:
# - All relevant columns of the original data frame
# - One boolean column for each rule in `rules`
rule_columns = rules.keys()
df_evaluated = lf_with_eval.with_columns(
__final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True))
).collect()
return df_evaluated

# ------------------------------------ CASTING ----------------------------------- #

@overload
Expand Down
26 changes: 26 additions & 0 deletions tests/schema/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) QuantCo 2025-2025
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any

import numpy as np
import polars as pl
Expand Down Expand Up @@ -91,6 +92,11 @@ def _sampling_overrides(cls) -> dict[str, pl.Expr]:
return {"irrelevant_column": pl.col("irrelevant_column").cast(pl.String())}


class MyAdvancedSchema(dy.Schema):
a = dy.Float64(min=20.0)
b = dy.String(regex=r"abc*")


# --------------------------------------- TESTS -------------------------------------- #


Expand Down Expand Up @@ -223,3 +229,23 @@ def test_sample_with_inconsistent_overrides_keys_raises() -> None:
{"b": 2},
]
)


@pytest.mark.parametrize(
"overrides,expected_violations",
[
({"a": [0, 1], "b": ["abcd", "abc"]}, r"\{'a|min': 2\}"),
({"a": [20], "b": ["invalid"]}, r"\{'b|regex': 1\}"),
],
)
def test_sample_invalid_override_values_raises(
overrides: dict[str, Any], expected_violations: str
) -> None:
with pytest.raises(
ValueError,
match=(
r"The provided overrides do not comply with the column-level "
r"rules of the schema. Rule violation counts: " + expected_violations
),
):
MyAdvancedSchema.sample(overrides=overrides)
Loading