Skip to content

Commit ed1c4f3

Browse files
authored
Replace ruff with black API for AST rendering (#80)
* introduce logger * Fixed compatibility with pydiverse common 0.3.12 * prepare release 0.5.6 * replace ruff subprocess with black API * Fix cast constraints. They should not depend on attributes.
1 parent 450254c commit ed1c4f3

File tree

10 files changed

+507
-92
lines changed

10 files changed

+507
-92
lines changed

docs/source/changelog.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Changelog
22

3+
## 0.5.6 (2025-08-22)
4+
- improved regex behavior in .str.contains()
5+
- AST repr is printing user friendly syntax (python-like black formatted)
6+
- ColExpr repr is much more intuitive
7+
- fixed handling of const Enum (mutate(enum_col="value"))
8+
- added columns verb to simplify [c.name for c in tbl]
9+
- compatible with pydiverse.common 0.3.12
10+
311
## 0.5.5 (2025-08-20)
412
- don't suffix all joined columns if only the key columns overlap
513
- add query to __str__/__repr__ operators of Table

pixi.lock

Lines changed: 399 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ python = ">=3.10.18,<3.14"
88
pandas = ">=2.3.0,<3"
99
pyarrow = ">=20.0.0,<22"
1010
polars = ">=1.30.0,<2"
11-
pydiverse-common = ">=0.3.5,<0.4"
11+
pydiverse-common = ">=0.3.12,<0.4"
1212
sqlalchemy = ">=2.0.41,<3"
13+
black = ">=20" # used for printing ASTs
1314

1415
[host-dependencies]
1516
pip = "*"
@@ -24,6 +25,7 @@ typos = "*"
2425
pixi-pycharm = ">=0.0.6"
2526
pytest = ">=7.1.2"
2627
pytest-xdist = ">=2.5.0"
28+
structlog = ">=25.4.0,<26"
2729

2830
[feature.release.dependencies]
2931
hatch = ">=1.12.0"

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "pydiverse-transform"
3-
version = "0.5.5"
3+
version = "0.5.6"
44
description = "Pipe based dataframe manipulation library that can also transform data on SQL databases"
55
authors = [
66
{ name = "QuantCo, Inc." },
@@ -30,8 +30,9 @@ dependencies = [
3030
"pandas >=2.3.0,<3",
3131
"pyarrow >=20.0.0,<22",
3232
"polars >=1.30.0,<2",
33-
"pydiverse-common >=0.3.5,<0.4",
33+
"pydiverse-common >=0.3.12,<0.4",
3434
"sqlalchemy >=2.0.41,<3",
35+
"black >=20",
3536
]
3637

3738
[tool.hatch.build.targets.wheel]

src/pydiverse/transform/_internal/backend/polars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _is_in(x, *values):
608608

609609
@impl(ops.str_contains)
610610
def _str_contains(x, y, allow_regex, true_if_regex_unsupported):
611+
_ = true_if_regex_unsupported
611612
return x.str.contains(y, literal=not pl.select(allow_regex).item())
612613

613614
@impl(ops.str_starts_with)

src/pydiverse/transform/_internal/tree/ast.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import itertools
5-
import subprocess
65
from collections.abc import Callable, Iterable
76
from uuid import UUID
87

8+
from pydiverse.transform._internal.util.warnings import warn
9+
910

1011
class AstNode:
1112
__slots__ = ["name"]
@@ -99,17 +100,13 @@ def next_nd(root: AstNode, cond: Callable[["AstNode"], bool]):
99100
+ ")"
100101
)
101102
try:
102-
proc = subprocess.run(
103-
["ruff", "format", "-"], # '-' tells Ruff to read from stdin
104-
input=unformatted.encode(),
105-
capture_output=True,
106-
check=True,
107-
)
108-
except subprocess.CalledProcessError as e:
109-
print(unformatted)
110-
print(e.stderr.decode())
111-
raise e
112-
return proc.stdout.decode()
103+
import black
104+
105+
formatted = black.format_str(unformatted, mode=black.Mode(line_length=120))
106+
return formatted
107+
except Exception:
108+
warn("Could not format AST representation with `black`.")
109+
return unformatted
113110

114111
def short_name(self) -> str:
115112
raise NotImplementedError()

src/pydiverse/transform/_internal/tree/col_expr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Bool,
2424
Date,
2525
Datetime,
26+
Decimal,
2627
Dtype,
2728
Duration,
2829
Enum,
@@ -2798,6 +2799,10 @@ def is_valid_cast(source, target) -> bool:
27982799
}
27992800

28002801
source = types.without_const(source)
2802+
if isinstance(source, String):
2803+
source = String() # drop max_length
2804+
elif isinstance(source, Decimal):
2805+
source = Decimal() # standardize precision/scale
28012806

28022807
if isinstance(source, String) and isinstance(target, Enum):
28032808
return True

src/pydiverse/transform/_internal/tree/types.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,28 @@ def converts_to(source: Dtype, target: Dtype) -> bool:
103103
source = without_const(source)
104104
if isinstance(source, List):
105105
return isinstance(target, List) and converts_to(source.inner, target.inner)
106-
if isinstance(source, Enum):
107-
return target == source or target == String()
106+
if isinstance(source, Enum | String):
107+
return (
108+
target == source
109+
or target == String()
110+
or (
111+
type(target) is String
112+
and source.max_length is not None
113+
and target.max_length > source.max_length
114+
)
115+
)
116+
if isinstance(source, Decimal):
117+
return (
118+
target == source
119+
or target in FLOAT_SUBTYPES
120+
or target == Float()
121+
or target == Decimal()
122+
or (
123+
isinstance(target, Decimal)
124+
and target.scale >= source.scale
125+
and (target.precision - target.scale >= source.precision - source.scale)
126+
)
127+
)
108128
return target in IMPLICIT_CONVS[source]
109129

110130

@@ -117,7 +137,7 @@ def to_python(dtype: Dtype):
117137
return float
118138
elif isinstance(dtype, List):
119139
return list
120-
elif isinstance(dtype, Enum):
140+
elif isinstance(dtype, Enum | String):
121141
return str
122142

123143
return {
@@ -187,13 +207,23 @@ def lca_type(dtypes: list[Dtype]) -> Dtype:
187207

188208
return List(lca_type([dtype.inner for dtype in dtypes]))
189209

190-
if any(isinstance(dtype, Enum) for dtype in dtypes):
210+
if any(isinstance(dtype, Enum | String) for dtype in dtypes):
191211
if all(dtype == dtypes[0] for dtype in dtypes):
192212
return copy.copy(dtypes[0])
193213
if all(isinstance(dtype, Enum | String) for dtype in dtypes):
194214
return String()
195215
raise DataTypeError(f"incompatible types `{', '.join(str(d) for d in dtypes)}`")
196216

217+
if any(isinstance(dtype, Decimal) for dtype in dtypes):
218+
if all(dtype == dtypes[0] for dtype in dtypes):
219+
return copy.copy(dtypes[0])
220+
if all(isinstance(dtype, Decimal) for dtype in dtypes):
221+
precision_diff = max(dtype.precision - dtype.scale for dtype in dtypes)
222+
scale = max(dtype.scale for dtype in dtypes)
223+
precision = precision_diff + scale
224+
return Decimal(precision, scale)
225+
raise DataTypeError(f"incompatible types `{', '.join(str(d) for d in dtypes)}`")
226+
197227
if not (
198228
common_ancestors := functools.reduce(
199229
operator.and_,
@@ -253,8 +283,12 @@ def is_subtype(dtype: Dtype) -> bool:
253283
def implicit_conversions(dtype: Dtype) -> list[Dtype]:
254284
if isinstance(dtype, List):
255285
return [List(inner) for inner in implicit_conversions(dtype.inner)]
256-
if isinstance(dtype, Enum):
257-
return [dtype, String()]
286+
if isinstance(dtype, Enum | String):
287+
return [String()] + ([dtype] if dtype.max_length is not None else [])
288+
if isinstance(dtype, Decimal):
289+
return (
290+
list(FLOAT_SUBTYPES) + [Float()] + ([dtype] if dtype != Decimal() else [])
291+
)
258292
return list(IMPLICIT_CONVS[dtype].keys())
259293

260294

@@ -303,8 +337,14 @@ def conversion_cost(dtype: Dtype, target: Dtype) -> tuple[int, int]:
303337
dtype = without_const(dtype)
304338
if isinstance(dtype, List):
305339
return conversion_cost(dtype.inner, target.inner)
306-
if isinstance(dtype, Enum):
307-
return (0, 0) if dtype == target else (0, 1)
340+
if isinstance(dtype, Enum | String | Decimal):
341+
return (
342+
(0, 0)
343+
if dtype == target
344+
else (0, 1)
345+
if type(dtype) is type(target)
346+
else (0, 2)
347+
)
308348
return IMPLICIT_CONVS[dtype][target]
309349

310350

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Copyright (c) QuantCo and pydiverse contributors 2025-2025
22
# SPDX-License-Identifier: BSD-3-Clause
3+
import logging
34

45
import pytest
56

7+
from pydiverse.common.util.structlog import setup_logging
8+
69
# Setup
710

811

@@ -37,3 +40,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items):
3740
for kw in item.keywords
3841
):
3942
item.add_marker(skip)
43+
44+
45+
setup_logging(log_level=logging.INFO)

tests/test_polars_table.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55

66
import polars as pl
77
import pytest
8+
import structlog
89

910
import pydiverse.transform as pdt
1011
from pydiverse.transform._internal.errors import ColumnNotFoundError, DataTypeError
1112
from pydiverse.transform.extended import *
1213
from tests.util import assert_equal
1314

15+
logger = structlog.get_logger(__name__)
16+
1417
df1 = pl.DataFrame(
1518
{
1619
"col1": [1, 2, 3, 4],
@@ -845,6 +848,7 @@ def test_col_html_repr(self, tbl1):
845848

846849
def test_expr_str(self, tbl1):
847850
expr_str = str(tbl1.col1 * 2)
851+
logger.info(expr_str)
848852
assert "failed" not in expr_str
849853

850854
def test_expr_html_repr(self, tbl1):
@@ -907,7 +911,7 @@ def test_verb_ast_repr(self, tbl3, tbl4):
907911
>> alias("tbl 42")
908912
)
909913

910-
(
914+
ast_str = (
911915
intermed
912916
>> left_join(
913917
tbl4
@@ -924,27 +928,23 @@ def test_verb_ast_repr(self, tbl3, tbl4):
924928

925929
series = pl.Series([2**i for i in range(12)])
926930

927-
assert (
928-
"tbl__1729.j"
929-
in (
930-
tbl3
931-
>> mutate(
932-
u=C.col1 + 5,
933-
v=(tbl3.col2.exp() + tbl3.col4) * tbl3.col1,
934-
w=pdt.max(tbl3.col2, tbl3.col1, tbl3.col5.str.len()),
935-
x=pdt.count(),
936-
y=eval_aligned(
937-
(
938-
tbl3
939-
>> mutate(j=42)
940-
>> alias("tbl. 1729")
941-
>> filter(C.col2 > 0)
942-
).j
943-
+ tbl3.col1
944-
),
945-
z=eval_aligned(series),
946-
)
947-
>> select(tbl3.col1, tbl3.col4)
948-
>> left_join(tbl4, on="col1")
949-
)._ast.ast_repr()
950-
)
931+
ast_str = (
932+
tbl3
933+
>> mutate(
934+
u=C.col1 + 5,
935+
v=(tbl3.col2.exp() + tbl3.col4) * tbl3.col1,
936+
w=pdt.max(tbl3.col2, tbl3.col1, tbl3.col5.str.len()),
937+
x=pdt.count(),
938+
y=eval_aligned(
939+
(tbl3 >> mutate(j=42) >> alias("tbl. 1729") >> filter(C.col2 > 0)).j
940+
+ tbl3.col1
941+
),
942+
z=eval_aligned(series),
943+
)
944+
>> select(tbl3.col1, tbl3.col4)
945+
>> left_join(tbl4, on="col1")
946+
)._ast.ast_repr()
947+
948+
logger.info(f"AST:\n{ast_str}")
949+
950+
assert "tbl__1729.j" in ast_str

0 commit comments

Comments
 (0)