Skip to content

Commit 53fc9ec

Browse files
committed
feat: Add Dict Comprehension support (#198)
1 parent 0fc8192 commit 53fc9ec

File tree

6 files changed

+126
-0
lines changed

6 files changed

+126
-0
lines changed

libs/astx-transpilers/src/astx_transpilers/python_string.py

+12
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,15 @@ def visit(self, node: astx.GeneratorExpr) -> str:
731731
ret_str += f" if {self.visit(cond)}"
732732

733733
return f"({ret_str})"
734+
735+
@dispatch # type: ignore[no-redef]
736+
def visit(self, node: astx.DictComprehension) -> str:
737+
"""Handle DictComprehension nodes."""
738+
ret_str = (
739+
f"{self.visit(node.key)}: {self.visit(node.value)}"
740+
f" for {self.visit(node.target)} in {self.visit(node.iterable)}"
741+
)
742+
if hasattr(node, "conditions") and node.conditions:
743+
for cond in node.conditions:
744+
ret_str += f" if {self.visit(cond)}"
745+
return f"{{{ret_str}}}"

libs/astx-transpilers/tests/test_python.py

+25
Original file line numberDiff line numberDiff line change
@@ -1589,3 +1589,28 @@ def test_transpiler_generator_expr_no_conditions() -> None:
15891589
assert generated_code == expected_code, (
15901590
f"Expected '{expected_code}', but got '{generated_code}'"
15911591
)
1592+
1593+
1594+
def test_transpiler_dict_comprehension() -> None:
1595+
"""Test astx.DictComprehension."""
1596+
dict_comp = astx.DictComprehension(
1597+
key=astx.Variable("x"),
1598+
value=astx.BinaryOp(
1599+
op_code="*", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1600+
),
1601+
target=astx.Variable("x"),
1602+
iterable=astx.Variable("range(10)"),
1603+
conditions=[
1604+
astx.BinaryOp(
1605+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(5)
1606+
),
1607+
astx.BinaryOp(
1608+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
1609+
),
1610+
],
1611+
)
1612+
generated_code = transpiler.visit(dict_comp)
1613+
expected_code = "{x: (x * x) for x in range(10) if (x > 5) if (x < 7)}"
1614+
assert generated_code == expected_code, (
1615+
f"Expected '{expected_code}', but got '{generated_code}'"
1616+
)

libs/astx/src/astx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
AsyncForRangeLoopExpr,
6565
AsyncForRangeLoopStmt,
6666
CaseStmt,
67+
DictComprehension,
6768
DoWhileExpr,
6869
DoWhileStmt,
6970
ForCountLoopExpr,
@@ -234,6 +235,7 @@ def get_version() -> str:
234235
"Date",
235236
"DateTime",
236237
"DeleteStmt",
238+
"DictComprehension",
237239
"DictType",
238240
"DoWhileExpr",
239241
"DoWhileStmt",

libs/astx/src/astx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class ASTKind(Enum):
138138
DoWhileStmtKind = -515
139139
DoWhileExprKind = -516
140140
GeneratorExprKind = -517
141+
DictComprehensionKind = -518
141142

142143
# data types
143144
NullDTKind = -600

libs/astx/src/astx/flows.py

+66
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
StatementType,
1919
)
2020
from astx.blocks import Block
21+
from astx.callables import (
22+
Comprehension,
23+
)
2124
from astx.tools.typing import typechecked
2225
from astx.variables import InlineVariableDeclaration
2326

@@ -792,3 +795,66 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
792795
else ASTNodes().get_struct(simplified),
793796
}
794797
return self._prepare_struct(key, value, simplified)
798+
799+
800+
@public
801+
@typechecked
802+
class DictComprehension(Comprehension):
803+
"""AST node for dictionary comprehensions."""
804+
805+
key: Expr
806+
value: Expr
807+
808+
def __init__(
809+
self,
810+
key: Expr,
811+
value: Expr,
812+
target: Expr,
813+
iterable: Expr,
814+
conditions: Optional[Iterable[Expr] | ASTNodes[Expr]] = None,
815+
is_async: bool = False,
816+
loc: SourceLocation = NO_SOURCE_LOCATION,
817+
parent: Optional[ASTNodes] = None,
818+
) -> None:
819+
super().__init__(
820+
target=target,
821+
iterable=iterable,
822+
conditions=conditions,
823+
is_async=is_async,
824+
loc=loc,
825+
parent=parent,
826+
)
827+
self.key = key
828+
self.value = value
829+
self.kind = ASTKind.DictComprehensionKind
830+
831+
def __str__(self) -> str:
832+
"""Return a string representation of the object."""
833+
conditions_str = []
834+
if hasattr(self, "conditions") and self.conditions:
835+
conditions_str = [str(cond) for cond in self.conditions]
836+
837+
ret_str = (
838+
f"DictComprehension[key={self.key}, value={self.value},"
839+
f" target={self.target}, iterable={self.iterable}, "
840+
f" conditions={conditions_str},"
841+
f" is_async={self.is_async}]"
842+
)
843+
return ret_str
844+
845+
def get_struct(self, simplified: bool = False) -> ReprStruct:
846+
"""Return the AST structure of the object."""
847+
value: DictDataTypesStruct = {
848+
"key": self.key.get_struct(simplified),
849+
"value": self.value.get_struct(simplified),
850+
"target": self.target.get_struct(simplified),
851+
"iterable": self.iterable.get_struct(simplified),
852+
}
853+
if hasattr(self, "conditions") and self.conditions:
854+
value["conditions"] = self.conditions.get_struct(simplified)
855+
key = (
856+
f"DICT-COMPREHENSION#{id(self)}"
857+
if simplified
858+
else "DICT-COMPREHENSION"
859+
)
860+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AsyncForRangeLoopExpr,
1010
AsyncForRangeLoopStmt,
1111
CaseStmt,
12+
DictComprehension,
1213
DoWhileExpr,
1314
DoWhileStmt,
1415
ForCountLoopExpr,
@@ -437,3 +438,22 @@ def test_generator_expr_2() -> None:
437438
assert gen_expr.get_struct()
438439
assert gen_expr.get_struct(simplified=True)
439440
visualize(gen_expr.get_struct())
441+
442+
443+
def test_dict_comprehension() -> None:
444+
"""Test `DictComprehension` class."""
445+
dict_comprehension = DictComprehension(
446+
key=Variable("x"),
447+
value=BinaryOp(op_code="*", lhs=Variable("x"), rhs=Variable("x")),
448+
target=Variable("x"),
449+
iterable=Variable("range(10)"),
450+
conditions=[
451+
BinaryOp(op_code=">", lhs=Variable("x"), rhs=LiteralInt32(5)),
452+
BinaryOp(op_code="<", lhs=Variable("x"), rhs=LiteralInt32(7)),
453+
],
454+
)
455+
456+
assert str(dict_comprehension)
457+
assert dict_comprehension.get_struct()
458+
assert dict_comprehension.get_struct(simplified=True)
459+
visualize(dict_comprehension.get_struct())

0 commit comments

Comments
 (0)