Skip to content

Commit 808526d

Browse files
committed
Add Dict Comprehension support
1 parent 77e5610 commit 808526d

File tree

7 files changed

+307
-0
lines changed

7 files changed

+307
-0
lines changed

docs/tutorials/dict_comprehension.ipynb

+183
Large diffs are not rendered by default.

libs/astx-transpilers/src/astx_transpilers/python_string.py

+12
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,15 @@ def visit(self, node: astx.LiteralDict) -> str:
699699
for key, value in node.elements.items()
700700
)
701701
return f"{{{items_code}}}"
702+
703+
@dispatch # type: ignore[no-redef]
704+
def visit(self, node: astx.DictComprehension) -> str:
705+
"""Handle DictComprehension nodes."""
706+
ret_str = (
707+
f"{self.visit(node.key)}: {self.visit(node.value)}"
708+
f" for {self.visit(node.target)} in {self.visit(node.iterable)}"
709+
)
710+
if node.conditions is not None:
711+
for cond in node.conditions:
712+
ret_str += f" if {self.visit(cond)}"
713+
return f"{{{ret_str}}}"

libs/astx-transpilers/tests/test_python.py

+25
Original file line numberDiff line numberDiff line change
@@ -1424,3 +1424,28 @@ def test_transpiler_literal_dict() -> None:
14241424
1: 10,
14251425
2: 20,
14261426
}, f"Expected '{expected_code}', but got '{generated_code}'"
1427+
1428+
1429+
def test_transpiler_dict_comprehension() -> None:
1430+
"""Test astx.DictComprehension."""
1431+
dict_comp = astx.DictComprehension(
1432+
key=astx.Variable("x"),
1433+
value=astx.BinaryOp(
1434+
op_code="*", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1435+
),
1436+
target=astx.Variable("x"),
1437+
iterable=astx.Variable("range(10)"),
1438+
conditions=[
1439+
astx.BinaryOp(
1440+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(5)
1441+
),
1442+
astx.BinaryOp(
1443+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
1444+
),
1445+
],
1446+
)
1447+
generated_code = transpiler.visit(dict_comp)
1448+
expected_code = "{x: (x * x) for x in range(10) if (x > 5) if (x < 7)}"
1449+
assert generated_code == expected_code, (
1450+
f"Expected '{expected_code}', but got '{generated_code}'"
1451+
)

libs/astx/src/astx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
AsyncForRangeLoopExpr,
6464
AsyncForRangeLoopStmt,
6565
CaseStmt,
66+
DictComprehension,
6667
ForCountLoopExpr,
6768
ForCountLoopStmt,
6869
ForRangeLoopExpr,
@@ -230,6 +231,7 @@ def get_version() -> str:
230231
"Date",
231232
"DateTime",
232233
"DeleteStmt",
234+
"DictComprehension",
233235
"DictType",
234236
"EnumDeclStmt",
235237
"ExceptionHandlerStmt",

libs/astx/src/astx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ASTKind(Enum):
134134
WithStmtKind = -512
135135
AsyncRangeLoopStmtKind = -513
136136
AsyncRangeLoopExprKind = -514
137+
DictComprehensionKind = -515
137138

138139
# data types
139140
NullDTKind = -600

libs/astx/src/astx/flows.py

+64
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
StatementType,
1919
)
2020
from astx.blocks import Block
21+
from astx.callables import (
22+
Comprehension,
23+
)
2124
from astx.tools.typing import typechecked
2225
from astx.variables import InlineVariableDeclaration
2326

@@ -691,3 +694,64 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
691694
key = f"GOTO-STMT[{self.label.value}]"
692695
value: DictDataTypesStruct = {}
693696
return self._prepare_struct(key, value, simplified)
697+
698+
699+
@public
700+
@typechecked
701+
class DictComprehension(Comprehension):
702+
"""AST node for dictionary comprehensions."""
703+
704+
key: Expr
705+
value: Expr
706+
707+
def __init__(
708+
self,
709+
key: Expr,
710+
value: Expr,
711+
target: Expr,
712+
iterable: Expr,
713+
conditions: Optional[list[Expr]] = None,
714+
is_async: bool = False,
715+
loc: SourceLocation = NO_SOURCE_LOCATION,
716+
parent: Optional[ASTNodes] = None,
717+
) -> None:
718+
super().__init__(
719+
target=target,
720+
iterable=iterable,
721+
conditions=conditions,
722+
is_async=is_async,
723+
loc=loc,
724+
parent=parent,
725+
)
726+
self.key = key
727+
self.value = value
728+
self.kind = ASTKind.DictComprehensionKind
729+
730+
def __str__(self) -> str:
731+
"""Return a string representation of the object."""
732+
ret_str = (
733+
f"{self.key}: {self.value} for {self.target} in {self.iterable}"
734+
)
735+
if self.conditions is not None:
736+
for cond in self.conditions:
737+
ret_str += f" if {cond}"
738+
return f"{{{ret_str}}}"
739+
740+
def get_struct(self, simplified: bool = False) -> ReprStruct:
741+
"""Return the AST structure of the object."""
742+
value: ReprStruct = {
743+
"key": self.key.get_struct(simplified),
744+
"value": self.value.get_struct(simplified),
745+
"target": self.value.get_struct(simplified),
746+
"iterable": self.iterable.get_struct(simplified),
747+
"conditions": cast(
748+
ReprStruct,
749+
{
750+
str(cond): cond.get_struct(simplified)
751+
for cond in self.conditions
752+
},
753+
),
754+
}
755+
756+
key = "DICT-COMPREHENSION"
757+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AsyncForRangeLoopExpr,
1010
AsyncForRangeLoopStmt,
1111
CaseStmt,
12+
DictComprehension,
1213
ForCountLoopExpr,
1314
ForCountLoopStmt,
1415
ForRangeLoopExpr,
@@ -342,3 +343,22 @@ def test_comprehension() -> None:
342343
assert str(comp) == expected_str
343344
assert comp.get_struct(simplified=True)
344345
assert comp.get_struct(simplified=False)
346+
347+
348+
def test_dict_comprehension() -> None:
349+
"""Test `DictComprehension` class."""
350+
dict_comprehension = DictComprehension(
351+
key=Variable("x"),
352+
value=BinaryOp(op_code="*", lhs=Variable("x"), rhs=Variable("x")),
353+
target=Variable("x"),
354+
iterable=Variable("range(10)"),
355+
conditions=[
356+
BinaryOp(op_code=">", lhs=Variable("x"), rhs=LiteralInt32(5)),
357+
BinaryOp(op_code="<", lhs=Variable("x"), rhs=LiteralInt32(7)),
358+
],
359+
)
360+
361+
assert str(dict_comprehension)
362+
assert dict_comprehension.get_struct()
363+
assert dict_comprehension.get_struct(simplified=True)
364+
visualize(dict_comprehension.get_struct())

0 commit comments

Comments
 (0)