Skip to content

Commit 47c8eb6

Browse files
committed
Add Dict Comprehension support
1 parent 47e06d5 commit 47c8eb6

File tree

7 files changed

+239
-0
lines changed

7 files changed

+239
-0
lines changed

docs/tutorials/dict_comprehension.ipynb

+126
Large diffs are not rendered by default.

libs/astx-transpilers/src/astx_transpilers/python_string.py

+13
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,16 @@ def visit(self, node: astx.LiteralDict) -> str:
699699
for key, value in node.elements.items()
700700
)
701701
return f"{{{items_code}}}"
702+
703+
@dispatch # type: ignore[no-redef]
704+
def visit(self, node: astx.DictComprehension) -> str:
705+
"""Handle DictComprehension nodes."""
706+
ret_str = (
707+
f"{{{self.visit(node.key)}: {self.visit(node.value)} "
708+
f" for {self.visit(node.iterable)}"
709+
)
710+
if node.conditions is not None:
711+
for cond in node.conditions:
712+
ret_str += f" if {self.visit(cond)}"
713+
ret_str += "}"
714+
return ret_str

libs/astx-transpilers/tests/test_python.py

+18
Original file line numberDiff line numberDiff line change
@@ -1424,3 +1424,21 @@ def test_transpiler_literal_dict() -> None:
14241424
1: 10,
14251425
2: 20,
14261426
}, f"Expected '{expected_code}', but got '{generated_code}'"
1427+
1428+
1429+
def test_transpiler_dict_comprehension() -> None:
1430+
"""Test astx.DictComprehension."""
1431+
dict_comp = astx.DictComprehension(
1432+
key=astx.LiteralString("x"),
1433+
value=astx.LiteralString("x*x"),
1434+
iterable=astx.LiteralString("range(10)"),
1435+
conditions=[
1436+
astx.LiteralString("x > 5"),
1437+
astx.LiteralString("x < 7"),
1438+
],
1439+
)
1440+
generated_code = transpiler.visit(dict_comp)
1441+
expected_code = "{x: x*x for x in range(10) if x > 5 and x < 7}"
1442+
assert generated_code == expected_code, (
1443+
f"Expected '{expected_code}', but got '{generated_code}'"
1444+
)

libs/astx/src/astx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
AsyncForRangeLoopExpr,
6464
AsyncForRangeLoopStmt,
6565
CaseStmt,
66+
DictComprehension,
6667
ForCountLoopExpr,
6768
ForCountLoopStmt,
6869
ForRangeLoopExpr,
@@ -230,6 +231,7 @@ def get_version() -> str:
230231
"Date",
231232
"DateTime",
232233
"DeleteStmt",
234+
"DictComprehension",
233235
"DictType",
234236
"EnumDeclStmt",
235237
"ExceptionHandlerStmt",

libs/astx/src/astx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ASTKind(Enum):
134134
WithStmtKind = -512
135135
AsyncRangeLoopStmtKind = -513
136136
AsyncRangeLoopExprKind = -514
137+
DictComprehensionKind = -515
137138

138139
# data types
139140
NullDTKind = -600

libs/astx/src/astx/flows.py

+60
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
StatementType,
1919
)
2020
from astx.blocks import Block
21+
from astx.callables import (
22+
Comprehension,
23+
)
2124
from astx.tools.typing import typechecked
2225
from astx.variables import InlineVariableDeclaration
2326

@@ -691,3 +694,60 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
691694
key = f"GOTO-STMT[{self.label.value}]"
692695
value: DictDataTypesStruct = {}
693696
return self._prepare_struct(key, value, simplified)
697+
698+
699+
@public
700+
@typechecked
701+
class DictComprehension(Comprehension):
702+
"""AST node for dictionary comprehensions."""
703+
704+
key: Expr
705+
value: Expr
706+
707+
def __init__(
708+
self,
709+
key: Expr,
710+
value: Expr,
711+
iterable: Expr,
712+
conditions: Optional[list[Expr]] = None,
713+
is_async: bool = False,
714+
loc: SourceLocation = NO_SOURCE_LOCATION,
715+
parent: Optional[ASTNodes] = None,
716+
) -> None:
717+
super().__init__(
718+
target=key,
719+
iterable=iterable,
720+
conditions=conditions,
721+
is_async=is_async,
722+
loc=loc,
723+
parent=parent,
724+
)
725+
self.key = key
726+
self.value = value
727+
self.kind = ASTKind.DictComprehensionKind
728+
729+
def __str__(self) -> str:
730+
"""Return a string representation of the object."""
731+
ret_str = (
732+
f"{{{self.key}: {self.value} for {self.key} in {self.iterable}"
733+
)
734+
if self.conditions is not None:
735+
for cond in self.conditions:
736+
ret_str += f" if {cond}"
737+
ret_str += "}"
738+
return ret_str
739+
740+
def get_struct(self, simplified: bool = False) -> ReprStruct:
741+
"""Return the AST structure of the object."""
742+
value: ReprStruct = {
743+
"key": self.key.get_struct(simplified),
744+
"value": self.value.get_struct(simplified),
745+
"iterable": self.iterable.get_struct(simplified),
746+
"conditions": cast(
747+
ReprStruct,
748+
[cond.get_struct(simplified) for cond in self.conditions],
749+
),
750+
}
751+
752+
key = "DICT-COMPREHENSION"
753+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

+19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AsyncForRangeLoopExpr,
1010
AsyncForRangeLoopStmt,
1111
CaseStmt,
12+
DictComprehension,
1213
ForCountLoopExpr,
1314
ForCountLoopStmt,
1415
ForRangeLoopExpr,
@@ -342,3 +343,21 @@ def test_comprehension() -> None:
342343
assert str(comp) == expected_str
343344
assert comp.get_struct(simplified=True)
344345
assert comp.get_struct(simplified=False)
346+
347+
348+
def test_dict_comprehension() -> None:
349+
"""Test `DictComprehension` class."""
350+
dict_comprehension = DictComprehension(
351+
key=LiteralString("x"),
352+
value=LiteralString("x*x"),
353+
iterable=LiteralString("range(10)"),
354+
conditions=[
355+
LiteralString("x > 5"),
356+
LiteralString("x < 7"),
357+
],
358+
)
359+
360+
assert str(dict_comprehension)
361+
assert dict_comprehension.get_struct()
362+
assert dict_comprehension.get_struct(simplified=True)
363+
visualize(dict_comprehension.get_struct())

0 commit comments

Comments
 (0)