From 1f03f1e2fe62211165991a02bd6d7e989ba9ed46 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 10 Oct 2025 21:01:33 +0200 Subject: [PATCH 1/5] init --- dataframely/schema.py | 17 +++++++++++++++++ tests/schema/test_sample.py | 12 ++++++++++++ 2 files changed, 29 insertions(+) diff --git a/dataframely/schema.py b/dataframely/schema.py index b724bfba..c4234907 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -256,6 +256,23 @@ def sample( # frame. values = pl.DataFrame() + # Check that the initial data frame complies with column-only rules + specified_columns = { + name: col for name, col in cls.columns().items() if name in values.columns + } + # TODO make this accessible instead of copy-pasting + column_rules = { + f"{col_name}|{rule_name}": Rule(expr) + for col_name, column in specified_columns.items() + for rule_name, expr in column.validation_rules(pl.col(col_name)).items() + } + # TODO do we need to cast here? + values_filtered, _ = cls._filter_raw(values, column_rules, cast=False) + if len(values_filtered) != len(values): + raise ValueError( + "The provided overrides do not comply with the column-level rules of the schema." + ) + # Prepare expressions for columns that need to be preprocessed during sampling # iterations. sampling_overrides = cls._sampling_overrides() diff --git a/tests/schema/test_sample.py b/tests/schema/test_sample.py index 7da0e332..a64307fb 100644 --- a/tests/schema/test_sample.py +++ b/tests/schema/test_sample.py @@ -91,6 +91,10 @@ def _sampling_overrides(cls) -> dict[str, pl.Expr]: return {"irrelevant_column": pl.col("irrelevant_column").cast(pl.String())} +class MyAdvancedSchema(dy.Schema): + a = dy.Float64(min=20.0) + + # --------------------------------------- TESTS -------------------------------------- # @@ -206,3 +210,11 @@ def test_sample_raises_superfluous_column_override() -> None: match=r"`_sampling_overrides` for columns that are not in the schema", ): SchemaWithIrrelevantColumnPreProcessing.sample(100) + + +def test_sample_invalid_override_values_raises() -> None: + with pytest.raises( + ValueError, + match=r"The provided overrides do not comply with the column-level rules of the schema.", + ): + MyAdvancedSchema.sample(overrides={"a": [0, 1]}) From 7fa26eb028caf0700d45bb7225b33e8c117c22ff Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 10 Oct 2025 21:34:50 +0200 Subject: [PATCH 2/5] refactor --- dataframely/schema.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/dataframely/schema.py b/dataframely/schema.py index c4234907..35618e21 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -260,18 +260,22 @@ def sample( specified_columns = { name: col for name, col in cls.columns().items() if name in values.columns } - # TODO make this accessible instead of copy-pasting + # TODO make this logic accessible and reuse it instead of copy-pasting column_rules = { f"{col_name}|{rule_name}": Rule(expr) for col_name, column in specified_columns.items() for rule_name, expr in column.validation_rules(pl.col(col_name)).items() } - # TODO do we need to cast here? - values_filtered, _ = cls._filter_raw(values, column_rules, cast=False) - if len(values_filtered) != len(values): - raise ValueError( - "The provided overrides do not comply with the column-level rules of the schema." - ) + if len(column_rules) > 0: + lf_with_eval = values.lazy().pipe(with_evaluation_rules, column_rules) + rule_columns = column_rules.keys() + df_evaluated = lf_with_eval.with_columns( + __final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True)) + ).collect() + if df_evaluated.select(~pl.col("__final_valid__").any()).item(): + raise ValueError( + "The provided overrides do not comply with the column-level rules of the schema." + ) # Prepare expressions for columns that need to be preprocessed during sampling # iterations. From 9e9c43776f91218d7e5b201b27f29fd4c3f82bde Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 10 Oct 2025 21:52:15 +0200 Subject: [PATCH 3/5] refactor --- dataframely/_base_schema.py | 15 +++++++----- dataframely/schema.py | 46 ++++++++++++++++++------------------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 524f60f5..03b5dc92 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -45,12 +45,7 @@ def _build_rules( rules["primary_key"] = Rule(~pl.struct(primary_keys).is_duplicated()) # Add column-specific rules - column_rules = { - f"{col_name}|{rule_name}": Rule(expr) - for col_name, column in columns.items() - for rule_name, expr in column.validation_rules(pl.col(col_name)).items() - } - rules.update(column_rules) + rules.update(_build_column_rules(columns)) # Add casting rules if requested. Here, we can simply check whether the nullability # property of a column changes due to lenient dtype casting. Whenever casting fails, @@ -70,6 +65,14 @@ def _build_rules( return rules +def _build_column_rules(columns: dict[str, Column]) -> dict[str, Rule]: + return { + f"{col_name}|{rule_name}": Rule(expr) + for col_name, column in columns.items() + for rule_name, expr in column.validation_rules(pl.col(col_name)).items() + } + + def _primary_keys(columns: dict[str, Column]) -> list[str]: return list(k for k, col in columns.items() if col.primary_key) diff --git a/dataframely/schema.py b/dataframely/schema.py index 35618e21..ddcc3adc 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -19,7 +19,7 @@ from dataframely._compat import deltalake -from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema +from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema, _build_column_rules from ._compat import pa, sa from ._rule import Rule, rule_from_dict, with_evaluation_rules from ._serialization import ( @@ -260,21 +260,13 @@ def sample( specified_columns = { name: col for name, col in cls.columns().items() if name in values.columns } - # TODO make this logic accessible and reuse it instead of copy-pasting - column_rules = { - f"{col_name}|{rule_name}": Rule(expr) - for col_name, column in specified_columns.items() - for rule_name, expr in column.validation_rules(pl.col(col_name)).items() - } + column_rules = _build_column_rules(specified_columns) if len(column_rules) > 0: - lf_with_eval = values.lazy().pipe(with_evaluation_rules, column_rules) - rule_columns = column_rules.keys() - df_evaluated = lf_with_eval.with_columns( - __final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True)) - ).collect() - if df_evaluated.select(~pl.col("__final_valid__").any()).item(): + evaluated_rules = cls._evaluate_rules(values.lazy(), column_rules) + if evaluated_rules.select(~pl.col("__final_valid__").any()).item(): raise ValueError( - "The provided overrides do not comply with the column-level rules of the schema." + "The provided overrides do not comply with the column-level " + "rules of the schema." ) # Prepare expressions for columns that need to be preprocessed during sampling @@ -555,18 +547,11 @@ def _filter_raw( # Then, we filter the data frame if len(rules) > 0: - lf_with_eval = lf.pipe(with_evaluation_rules, rules) - - # At this point, `lf_with_eval` contains the following: - # - All relevant columns of the original data frame + # Evaluate the rules on the data frame # - If `cast` is set to `True`, the columns with their original names are # already cast and the original values are available in the columns # prefixed with `ORIGINAL_COLUMN_PREFIX`. - # - One boolean column for each rule in `rules` - rule_columns = rules.keys() - df_evaluated = lf_with_eval.with_columns( - __final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True)) - ).collect() + df_evaluated = cls._evaluate_rules(lf, rules) # For the output, partition `lf_evaluated` into the returned data frame `lf` # and the invalid data frame @@ -576,7 +561,7 @@ def _filter_raw( .drop( "__final_valid__", cs.starts_with(ORIGINAL_COLUMN_PREFIX), - *rule_columns, + *rules.keys(), ) ) else: @@ -587,6 +572,19 @@ def _filter_raw( df_evaluated, ) + @staticmethod + def _evaluate_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.DataFrame: + lf_with_eval = lf.pipe(with_evaluation_rules, rules) + + # At this point, `lf_with_eval` contains the following: + # - All relevant columns of the original data frame + # - One boolean column for each rule in `rules` + rule_columns = rules.keys() + df_evaluated = lf_with_eval.with_columns( + __final_valid__=pl.all_horizontal(pl.col(rule_columns).fill_null(True)) + ).collect() + return df_evaluated + # ------------------------------------ CASTING ----------------------------------- # @overload From d745ab758550fc5767822289d091b6065b031bfd Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 10 Oct 2025 22:01:15 +0200 Subject: [PATCH 4/5] docs --- dataframely/schema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dataframely/schema.py b/dataframely/schema.py index ddcc3adc..d04a0498 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -207,6 +207,8 @@ def sample( Raises: ValueError: If ``num_rows`` is not equal to the length of the values in ``overrides``. + ValueError: If the values provided through `overrides` do not comply with + column-level validation rules of the schema. ValueError: If no valid data frame can be found in the configured maximum number of iterations. From 321c64fbbe0748f8dc57578e15bc8770b3532e71 Mon Sep 17 00:00:00 2001 From: Moritz Potthoff Date: Fri, 10 Oct 2025 22:12:46 +0200 Subject: [PATCH 5/5] More tests --- dataframely/schema.py | 5 +++-- tests/schema/test_sample.py | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/dataframely/schema.py b/dataframely/schema.py index d04a0498..98a16e85 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -39,7 +39,7 @@ from .columns import Column, column_from_dict from .config import Config from .exc import RuleValidationError, ValidationError, ValidationRequiredError -from .failure import FailureInfo +from .failure import FailureInfo, _compute_counts from .random import Generator if sys.version_info >= (3, 11): @@ -266,9 +266,10 @@ def sample( if len(column_rules) > 0: evaluated_rules = cls._evaluate_rules(values.lazy(), column_rules) if evaluated_rules.select(~pl.col("__final_valid__").any()).item(): + counts = _compute_counts(evaluated_rules, list(column_rules.keys())) raise ValueError( "The provided overrides do not comply with the column-level " - "rules of the schema." + f"rules of the schema. Rule violation counts: {counts}" ) # Prepare expressions for columns that need to be preprocessed during sampling diff --git a/tests/schema/test_sample.py b/tests/schema/test_sample.py index a64307fb..0bb837d5 100644 --- a/tests/schema/test_sample.py +++ b/tests/schema/test_sample.py @@ -1,5 +1,6 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause +from typing import Any import numpy as np import polars as pl @@ -93,6 +94,7 @@ def _sampling_overrides(cls) -> dict[str, pl.Expr]: class MyAdvancedSchema(dy.Schema): a = dy.Float64(min=20.0) + b = dy.String(regex=r"abc*") # --------------------------------------- TESTS -------------------------------------- # @@ -212,9 +214,21 @@ def test_sample_raises_superfluous_column_override() -> None: SchemaWithIrrelevantColumnPreProcessing.sample(100) -def test_sample_invalid_override_values_raises() -> None: +@pytest.mark.parametrize( + "overrides,expected_violations", + [ + ({"a": [0, 1], "b": ["abcd", "abc"]}, r"\{'a|min': 2\}"), + ({"a": [20], "b": ["invalid"]}, r"\{'b|regex': 1\}"), + ], +) +def test_sample_invalid_override_values_raises( + overrides: dict[str, Any], expected_violations: str +) -> None: with pytest.raises( ValueError, - match=r"The provided overrides do not comply with the column-level rules of the schema.", + match=( + r"The provided overrides do not comply with the column-level " + r"rules of the schema. Rule violation counts: " + expected_violations + ), ): - MyAdvancedSchema.sample(overrides={"a": [0, 1]}) + MyAdvancedSchema.sample(overrides=overrides)