Skip to content

Commit

Permalink
Feat: add support for na_values and keep_default_na in csv_settings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Feb 21, 2025
1 parent d566301 commit bb43e5f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/reference/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,8 @@ Options specified within the `kind` property's `csv_settings` property (override
| `skipinitialspace` | Skip spaces after delimiter. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |
| `lineterminator` | Character used to denote a line break. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
| `encoding` | Encoding to use for UTF when reading/writing (ex. 'utf-8'). More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
| `na_values` | An array of values that should be recognized as NA/NaN. In order to specify such an array per column, a mapping in the form of `(col1 = (v1, v2, ...), col2 = ...)` can be passed instead. These values can be integers, strings, booleans or NULL, and they are converted to their corresponding Python values. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | array[value] \| array[array[key = value]] | N |
| `keep_default_na` | Whether or not to include the default NaN values when parsing the data. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |


Python model kind `name` enum value: [ModelKindName.SEED](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName)
3 changes: 2 additions & 1 deletion sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,10 @@ def to_expression(

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
csv_setting_values = (self.csv_settings or CsvSettings()).dict().values()
return [
*super().data_hash_values,
*(self.csv_settings or CsvSettings()).dict().values(),
*(v if isinstance(v, (str, type(None))) else str(v) for v in csv_setting_values),
]

@property
Expand Down
41 changes: 40 additions & 1 deletion sqlmesh/core/model/seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import typing as t
import zlib
from io import StringIO
Expand All @@ -8,12 +9,18 @@
import pandas as pd
from sqlglot import exp
from sqlglot.dialects.dialect import UNESCAPED_SEQUENCES
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlmesh.core.model.common import parse_bool
from sqlmesh.utils.pandas import columns_to_types_from_df
from sqlmesh.utils.pydantic import PydanticModel, field_validator

logger = logging.getLogger(__name__)

NaHashables = t.List[t.Union[int, str, bool, t.Literal[None]]]
NaValues = t.Union[NaHashables, t.Dict[str, NaHashables]]


class CsvSettings(PydanticModel):
"""Settings for CSV seeds."""
Expand All @@ -25,8 +32,10 @@ class CsvSettings(PydanticModel):
skipinitialspace: t.Optional[bool] = None
lineterminator: t.Optional[str] = None
encoding: t.Optional[str] = None
na_values: t.Optional[NaValues] = None
keep_default_na: t.Optional[bool] = None

@field_validator("doublequote", "skipinitialspace", mode="before")
@field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before")
@classmethod
def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
if v is None:
Expand All @@ -46,6 +55,36 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]:
v = v.this
return UNESCAPED_SEQUENCES.get(v, v)

@field_validator("na_values", mode="before")
@classmethod
def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
if v is None or not isinstance(v, exp.Expression):
return v

try:
if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)):
v = exp.Tuple(expressions=[v.unnest()])

expressions = v.expressions
if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)):
return {
e.left.name: [
rhs_val.to_py()
for rhs_val in (
[e.right.unnest()]
if isinstance(e.right, exp.Paren)
else e.right.expressions
)
]
for e in expressions
}

return [e.to_py() for e in expressions]
except ValueError as e:
logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}")

return None


class CsvSeedReader:
def __init__(self, content: str, dialect: str, settings: CsvSettings):
Expand Down
66 changes: 65 additions & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,8 @@ def test_seed_csv_settings():
csv_settings (
quotechar = '''',
escapechar = '\\',
keep_default_na = false,
na_values = (id = [1, '2', false, null], alias = ('foo'))
),
),
columns (
Expand All @@ -910,7 +912,39 @@ def test_seed_csv_settings():
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.kind.csv_settings == CsvSettings(quotechar="'", escapechar="\\")
assert model.kind.csv_settings == CsvSettings(
quotechar="'",
escapechar="\\",
na_values={"id": [1, "2", False, None], "alias": ["foo"]},
keep_default_na=False,
)
assert model.kind.data_hash_values == [
"SEED",
"'",
"\\",
"{'id': [1, '2', False, None], 'alias': ['foo']}",
"False",
]

expressions = d.parse(
"""
MODEL (
name db.seed,
kind SEED (
path '../seeds/waiter_names.csv',
csv_settings (
na_values = ('#N/A', 'other')
),
),
);
"""
)

model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.kind.csv_settings == CsvSettings(na_values=["#N/A", "other"])
assert model.kind.data_hash_values == ["SEED", "['#N/A', 'other']"]


def test_seed_marker_substitution():
Expand Down Expand Up @@ -7755,3 +7789,33 @@ def get_current_date(evaluator):
FROM "discount_promotion_dates" AS "discount_promotion_dates"
""",
)


def test_seed_dont_coerce_na_into_null(tmp_path):
model_csv_path = (tmp_path / "model.csv").absolute()

with open(model_csv_path, "w", encoding="utf-8") as fd:
fd.write("code\nNA")

expressions = d.parse(
f"""
MODEL (
name db.seed,
kind SEED (
path '{str(model_csv_path)}',
csv_settings (
-- override NaN handling, such that no value can be coerced into NaN
keep_default_na = false,
na_values = (),
),
),
);
"""
)

model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.seed is not None
assert len(model.seed.content) > 0
assert next(model.render(context=None)).to_dict() == {"code": {0: "NA"}}

0 comments on commit bb43e5f

Please sign in to comment.