Skip to content

Commit 7c3f8b8

Browse files
committed
Add generator expr support
1 parent 3499c7b commit 7c3f8b8

File tree

6 files changed

+137
-0
lines changed

6 files changed

+137
-0
lines changed

libs/astx-transpilers/src/astx_transpilers/python_string.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,15 @@ def visit(self, node: astx.LiteralDict) -> str:
699699
for key, value in node.elements.items()
700700
)
701701
return f"{{{items_code}}}"
702+
703+
@dispatch # type: ignore[no-redef]
704+
def visit(self, node: astx.GeneratorExpr) -> str:
705+
"""Handle GeneratorExpr nodes."""
706+
ret_str = (
707+
f"{self.visit(node.element)} for {self.visit(node.target)}"
708+
f" in {self.visit(node.iterable)}"
709+
)
710+
if node.conditions:
711+
for condition in node.conditions:
712+
ret_str += f" if {self.visit(condition)}"
713+
return f"({ret_str})"

libs/astx-transpilers/tests/test_python.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,3 +1424,44 @@ def test_transpiler_literal_dict() -> None:
14241424
1: 10,
14251425
2: 20,
14261426
}, f"Expected '{expected_code}', but got '{generated_code}'"
1427+
1428+
1429+
def test_transpiler_generator_expr() -> None:
1430+
"""Test astx.GeneratorExpr."""
1431+
gen_expr = astx.GeneratorExpr(
1432+
element=astx.Variable("x"),
1433+
target=astx.BinaryOp(
1434+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1435+
),
1436+
iterable=astx.Identifier("range(10)"),
1437+
conditions=[
1438+
astx.BinaryOp(
1439+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(3)
1440+
),
1441+
astx.BinaryOp(
1442+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
1443+
),
1444+
],
1445+
)
1446+
generated_code = transpiler.visit(gen_expr)
1447+
expected_code = "(x for (x + x) in range(10) if (x > 3) if (x < 7))"
1448+
assert generated_code == expected_code, (
1449+
f"Expected '{expected_code}', but got '{generated_code}'"
1450+
)
1451+
1452+
1453+
def test_transpiler_generator_expr_no_conditions() -> None:
1454+
"""Test astx.GeneratorExpr with no conditions."""
1455+
gen_expr = astx.GeneratorExpr(
1456+
element=astx.Variable("x"),
1457+
target=astx.BinaryOp(
1458+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1459+
),
1460+
iterable=astx.Identifier("range(10)"),
1461+
)
1462+
1463+
generated_code = transpiler.visit(gen_expr)
1464+
expected_code = "(x for (x + x) in range(10))"
1465+
assert generated_code == expected_code, (
1466+
f"Expected '{expected_code}', but got '{generated_code}'"
1467+
)

libs/astx/src/astx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
ForCountLoopStmt,
6868
ForRangeLoopExpr,
6969
ForRangeLoopStmt,
70+
GeneratorExpr,
7071
GotoStmt,
7172
IfExpr,
7273
IfStmt,
@@ -249,6 +250,7 @@ def get_version() -> str:
249250
"FunctionDef",
250251
"FunctionPrototype",
251252
"FunctionReturn",
253+
"GeneratorExpr",
252254
"GotoStmt",
253255
"Identifier",
254256
"IfExpr",

libs/astx/src/astx/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ASTKind(Enum):
134134
WithStmtKind = -512
135135
AsyncRangeLoopStmtKind = -513
136136
AsyncRangeLoopExprKind = -514
137+
GeneratorExprKind = -516
137138

138139
# data types
139140
NullDTKind = -600

libs/astx/src/astx/flows.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,61 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
691691
key = f"GOTO-STMT[{self.label.value}]"
692692
value: DictDataTypesStruct = {}
693693
return self._prepare_struct(key, value, simplified)
694+
695+
696+
@public
697+
@typechecked
698+
class GeneratorExpr(Expr):
699+
"""AST class for generator expressions."""
700+
701+
element: Expr
702+
target: Expr
703+
iterable: Expr
704+
conditions: list[Expr]
705+
706+
def __init__(
707+
self,
708+
element: Expr,
709+
target: Expr,
710+
iterable: Expr,
711+
conditions: Optional[list[Expr]] = None,
712+
loc: SourceLocation = NO_SOURCE_LOCATION,
713+
parent: Optional[ASTNodes] = None,
714+
) -> None:
715+
"""Initialize the GeneratorExpr instance."""
716+
super().__init__(loc=loc, parent=parent)
717+
self.element = element
718+
self.target = target
719+
self.iterable = iterable
720+
self.conditions = conditions if conditions is not None else []
721+
self.kind = ASTKind.GeneratorExprKind
722+
723+
def __str__(self) -> str:
724+
"""Return a string representation of the object."""
725+
ret_str = (
726+
f"GeneratorExpr[element={self.element}, target={self.target},"
727+
f" iterable={self.iterable},"
728+
)
729+
if self.conditions is not None:
730+
cons_list = []
731+
for cond in self.conditions:
732+
cons_list.append(str(cond))
733+
ret_str += f" conditions={(cons_list)}]"
734+
return ret_str
735+
736+
def get_struct(self, simplified: bool = False) -> ReprStruct:
737+
"""Return the AST structure of the object."""
738+
key = f"GENERATOR-EXPR#{id(self)}" if simplified else "GENERATOR-EXPR"
739+
value: ReprStruct = {
740+
"element": self.element.get_struct(simplified),
741+
"target": self.target.get_struct(simplified),
742+
"iterable": self.iterable.get_struct(simplified),
743+
"conditions": cast(
744+
ReprStruct,
745+
{
746+
str(cond): cond.get_struct(simplified)
747+
for cond in self.conditions
748+
},
749+
),
750+
}
751+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,26 @@ def test_comprehension() -> None:
342342
assert str(comp) == expected_str
343343
assert comp.get_struct(simplified=True)
344344
assert comp.get_struct(simplified=False)
345+
346+
347+
def test_generator_expr() -> None:
348+
"""Test `GeneratorExpr` class."""
349+
gen_expr = astx.GeneratorExpr(
350+
element=astx.Variable("x"),
351+
target=astx.BinaryOp(
352+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
353+
),
354+
iterable=astx.Identifier("range(10)"),
355+
conditions=[
356+
astx.BinaryOp(
357+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(3)
358+
),
359+
astx.BinaryOp(
360+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
361+
),
362+
],
363+
)
364+
assert str(gen_expr)
365+
assert gen_expr.get_struct()
366+
assert gen_expr.get_struct(simplified=True)
367+
visualize(gen_expr.get_struct())

0 commit comments

Comments
 (0)