Skip to content

Commit 1614073

Browse files
committed
feat: Add Dict Comprehension support
1 parent 4ddff4e commit 1614073

File tree

6 files changed

+131
-0
lines changed

6 files changed

+131
-0
lines changed

libs/astx-transpilers/src/astx_transpilers/python_string.py

+12
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,15 @@ def visit(self, node: astx.DoWhileStmt) -> str:
713713
body = self._generate_block(node.body)
714714
condition = self.visit(node.condition)
715715
return f"while True:\n{body}\n if not {condition}:\n break"
716+
717+
@dispatch # type: ignore[no-redef]
718+
def visit(self, node: astx.DictComprehension) -> str:
719+
"""Handle DictComprehension nodes."""
720+
ret_str = (
721+
f"{self.visit(node.key)}: {self.visit(node.value)}"
722+
f" for {self.visit(node.target)} in {self.visit(node.iterable)}"
723+
)
724+
if node.conditions is not None:
725+
for cond in node.conditions:
726+
ret_str += f" if {self.visit(cond)}"
727+
return f"{{{ret_str}}}"

libs/astx-transpilers/tests/test_python.py

+25
Original file line numberDiff line numberDiff line change
@@ -1525,3 +1525,28 @@ def test_transpiler_do_while_expr() -> None:
15251525
assert generated_code == expected_code, (
15261526
f"Expected '{expected_code}', but got '{generated_code}'"
15271527
)
1528+
1529+
1530+
def test_transpiler_dict_comprehension() -> None:
1531+
"""Test astx.DictComprehension."""
1532+
dict_comp = astx.DictComprehension(
1533+
key=astx.Variable("x"),
1534+
value=astx.BinaryOp(
1535+
op_code="*", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1536+
),
1537+
target=astx.Variable("x"),
1538+
iterable=astx.Variable("range(10)"),
1539+
conditions=[
1540+
astx.BinaryOp(
1541+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(5)
1542+
),
1543+
astx.BinaryOp(
1544+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
1545+
),
1546+
],
1547+
)
1548+
generated_code = transpiler.visit(dict_comp)
1549+
expected_code = "{x: (x * x) for x in range(10) if (x > 5) if (x < 7)}"
1550+
assert generated_code == expected_code, (
1551+
f"Expected '{expected_code}', but got '{generated_code}'"
1552+
)

libs/astx/src/astx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
AsyncForRangeLoopExpr,
6464
AsyncForRangeLoopStmt,
6565
CaseStmt,
66+
DictComprehension,
6667
DoWhileExpr,
6768
DoWhileStmt,
6869
ForCountLoopExpr,
@@ -232,6 +233,7 @@ def get_version() -> str:
232233
"Date",
233234
"DateTime",
234235
"DeleteStmt",
236+
"DictComprehension",
235237
"DictType",
236238
"DoWhileExpr",
237239
"DoWhileStmt",

libs/astx/src/astx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class ASTKind(Enum):
136136
AsyncRangeLoopExprKind = -514
137137
DoWhileStmtKind = -515
138138
DoWhileExprKind = -516
139+
DictComprehensionKind = -517
139140

140141
# data types
141142
NullDTKind = -600

libs/astx/src/astx/flows.py

+71
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
StatementType,
1919
)
2020
from astx.blocks import Block
21+
from astx.callables import (
22+
Comprehension,
23+
)
2124
from astx.tools.typing import typechecked
2225
from astx.variables import InlineVariableDeclaration
2326

@@ -737,3 +740,71 @@ def __init__(
737740
def __str__(self) -> str:
738741
"""Return a string representation of the object."""
739742
return f"DoWhileExpr[{self.condition}]"
743+
744+
745+
@public
746+
@typechecked
747+
class DictComprehension(Comprehension):
748+
"""AST node for dictionary comprehensions."""
749+
750+
key: Expr
751+
value: Expr
752+
753+
def __init__(
754+
self,
755+
key: Expr,
756+
value: Expr,
757+
target: Expr,
758+
iterable: Expr,
759+
conditions: Optional[list[Expr]] = None,
760+
is_async: bool = False,
761+
loc: SourceLocation = NO_SOURCE_LOCATION,
762+
parent: Optional[ASTNodes] = None,
763+
) -> None:
764+
super().__init__(
765+
target=target,
766+
iterable=iterable,
767+
conditions=conditions,
768+
is_async=is_async,
769+
loc=loc,
770+
parent=parent,
771+
)
772+
self.key = key
773+
self.value = value
774+
self.kind = ASTKind.DictComprehensionKind
775+
776+
def __str__(self) -> str:
777+
"""Return a string representation of the object."""
778+
ret_str = (
779+
f"DictComprehension[key={self.key}, value={self.value},"
780+
f" target={self.target}, iterable={self.iterable},"
781+
)
782+
if self.conditions is not None:
783+
cons_list = []
784+
for cond in self.conditions:
785+
cons_list.append(str(cond))
786+
ret_str += f" conditions={(cons_list)}]"
787+
return ret_str
788+
789+
def get_struct(self, simplified: bool = False) -> ReprStruct:
790+
"""Return the AST structure of the object."""
791+
value: ReprStruct = {
792+
"key": self.key.get_struct(simplified),
793+
"value": self.value.get_struct(simplified),
794+
"target": self.value.get_struct(simplified),
795+
"iterable": self.iterable.get_struct(simplified),
796+
"conditions": cast(
797+
ReprStruct,
798+
{
799+
str(cond): cond.get_struct(simplified)
800+
for cond in self.conditions
801+
},
802+
),
803+
}
804+
805+
key = (
806+
f"DICT-COMPREHENSION#{id(self)}"
807+
if simplified
808+
else "DICT-COMPREHENSION"
809+
)
810+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AsyncForRangeLoopExpr,
1010
AsyncForRangeLoopStmt,
1111
CaseStmt,
12+
DictComprehension,
1213
DoWhileExpr,
1314
DoWhileStmt,
1415
ForCountLoopExpr,
@@ -386,3 +387,22 @@ def test_do_while_stmt() -> None:
386387
assert do_while_stmt.get_struct()
387388
assert do_while_stmt.get_struct(simplified=True)
388389
visualize(do_while_stmt.get_struct())
390+
391+
392+
def test_dict_comprehension() -> None:
393+
"""Test `DictComprehension` class."""
394+
dict_comprehension = DictComprehension(
395+
key=Variable("x"),
396+
value=BinaryOp(op_code="*", lhs=Variable("x"), rhs=Variable("x")),
397+
target=Variable("x"),
398+
iterable=Variable("range(10)"),
399+
conditions=[
400+
BinaryOp(op_code=">", lhs=Variable("x"), rhs=LiteralInt32(5)),
401+
BinaryOp(op_code="<", lhs=Variable("x"), rhs=LiteralInt32(7)),
402+
],
403+
)
404+
405+
assert str(dict_comprehension)
406+
assert dict_comprehension.get_struct()
407+
assert dict_comprehension.get_struct(simplified=True)
408+
visualize(dict_comprehension.get_struct())

0 commit comments

Comments
 (0)