Skip to content

Commit f2be2e6

Browse files
committed
feat: Add generator expr support
1 parent 4ddff4e commit f2be2e6

File tree

6 files changed

+138
-1
lines changed

6 files changed

+138
-1
lines changed

libs/astx-transpilers/src/astx_transpilers/python_string.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,15 @@ def visit(self, node: astx.DoWhileStmt) -> str:
713713
body = self._generate_block(node.body)
714714
condition = self.visit(node.condition)
715715
return f"while True:\n{body}\n if not {condition}:\n break"
716+
717+
@dispatch # type: ignore[no-redef]
718+
def visit(self, node: astx.GeneratorExpr) -> str:
719+
"""Handle GeneratorExpr nodes."""
720+
ret_str = (
721+
f"{self.visit(node.element)} for {self.visit(node.target)}"
722+
f" in {self.visit(node.iterable)}"
723+
)
724+
if node.conditions:
725+
for condition in node.conditions:
726+
ret_str += f" if {self.visit(condition)}"
727+
return f"({ret_str})"

libs/astx-transpilers/tests/test_python.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,3 +1525,44 @@ def test_transpiler_do_while_expr() -> None:
15251525
assert generated_code == expected_code, (
15261526
f"Expected '{expected_code}', but got '{generated_code}'"
15271527
)
1528+
1529+
1530+
def test_transpiler_generator_expr() -> None:
1531+
"""Test astx.GeneratorExpr."""
1532+
gen_expr = astx.GeneratorExpr(
1533+
element=astx.Variable("x"),
1534+
target=astx.BinaryOp(
1535+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1536+
),
1537+
iterable=astx.Identifier("range(10)"),
1538+
conditions=[
1539+
astx.BinaryOp(
1540+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(3)
1541+
),
1542+
astx.BinaryOp(
1543+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
1544+
),
1545+
],
1546+
)
1547+
generated_code = transpiler.visit(gen_expr)
1548+
expected_code = "(x for (x + x) in range(10) if (x > 3) if (x < 7))"
1549+
assert generated_code == expected_code, (
1550+
f"Expected '{expected_code}', but got '{generated_code}'"
1551+
)
1552+
1553+
1554+
def test_transpiler_generator_expr_no_conditions() -> None:
1555+
"""Test astx.GeneratorExpr with no conditions."""
1556+
gen_expr = astx.GeneratorExpr(
1557+
element=astx.Variable("x"),
1558+
target=astx.BinaryOp(
1559+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
1560+
),
1561+
iterable=astx.Identifier("range(10)"),
1562+
)
1563+
1564+
generated_code = transpiler.visit(gen_expr)
1565+
expected_code = "(x for (x + x) in range(10))"
1566+
assert generated_code == expected_code, (
1567+
f"Expected '{expected_code}', but got '{generated_code}'"
1568+
)

libs/astx/src/astx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
ForCountLoopStmt,
7070
ForRangeLoopExpr,
7171
ForRangeLoopStmt,
72+
GeneratorExpr,
7273
GotoStmt,
7374
IfExpr,
7475
IfStmt,
@@ -253,6 +254,7 @@ def get_version() -> str:
253254
"FunctionDef",
254255
"FunctionPrototype",
255256
"FunctionReturn",
257+
"GeneratorExpr",
256258
"GotoStmt",
257259
"Identifier",
258260
"IfExpr",

libs/astx/src/astx/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ class ASTKind(Enum):
136136
AsyncRangeLoopExprKind = -514
137137
DoWhileStmtKind = -515
138138
DoWhileExprKind = -516
139-
139+
GeneratorExprKind = -517
140+
140141
# data types
141142
NullDTKind = -600
142143
BooleanDTKind = -601

libs/astx/src/astx/flows.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,3 +737,61 @@ def __init__(
737737
def __str__(self) -> str:
738738
"""Return a string representation of the object."""
739739
return f"DoWhileExpr[{self.condition}]"
740+
741+
742+
@public
743+
@typechecked
744+
class GeneratorExpr(Expr):
745+
"""AST class for generator expressions."""
746+
747+
element: Expr
748+
target: Expr
749+
iterable: Expr
750+
conditions: list[Expr]
751+
752+
def __init__(
753+
self,
754+
element: Expr,
755+
target: Expr,
756+
iterable: Expr,
757+
conditions: Optional[list[Expr]] = None,
758+
loc: SourceLocation = NO_SOURCE_LOCATION,
759+
parent: Optional[ASTNodes] = None,
760+
) -> None:
761+
"""Initialize the GeneratorExpr instance."""
762+
super().__init__(loc=loc, parent=parent)
763+
self.element = element
764+
self.target = target
765+
self.iterable = iterable
766+
self.conditions = conditions if conditions is not None else []
767+
self.kind = ASTKind.GeneratorExprKind
768+
769+
def __str__(self) -> str:
770+
"""Return a string representation of the object."""
771+
ret_str = (
772+
f"GeneratorExpr[element={self.element}, target={self.target},"
773+
f" iterable={self.iterable},"
774+
)
775+
if self.conditions is not None:
776+
cons_list = []
777+
for cond in self.conditions:
778+
cons_list.append(str(cond))
779+
ret_str += f" conditions={(cons_list)}]"
780+
return ret_str
781+
782+
def get_struct(self, simplified: bool = False) -> ReprStruct:
783+
"""Return the AST structure of the object."""
784+
key = f"GENERATOR-EXPR#{id(self)}" if simplified else "GENERATOR-EXPR"
785+
value: ReprStruct = {
786+
"element": self.element.get_struct(simplified),
787+
"target": self.target.get_struct(simplified),
788+
"iterable": self.iterable.get_struct(simplified),
789+
"conditions": cast(
790+
ReprStruct,
791+
{
792+
str(cond): cond.get_struct(simplified)
793+
for cond in self.conditions
794+
},
795+
),
796+
}
797+
return self._prepare_struct(key, value, simplified)

libs/astx/tests/test_flows.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,26 @@ def test_do_while_stmt() -> None:
386386
assert do_while_stmt.get_struct()
387387
assert do_while_stmt.get_struct(simplified=True)
388388
visualize(do_while_stmt.get_struct())
389+
390+
391+
def test_generator_expr() -> None:
392+
"""Test `GeneratorExpr` class."""
393+
gen_expr = astx.GeneratorExpr(
394+
element=astx.Variable("x"),
395+
target=astx.BinaryOp(
396+
op_code="+", lhs=astx.Variable("x"), rhs=astx.Variable("x")
397+
),
398+
iterable=astx.Identifier("range(10)"),
399+
conditions=[
400+
astx.BinaryOp(
401+
op_code=">", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(3)
402+
),
403+
astx.BinaryOp(
404+
op_code="<", lhs=astx.Variable("x"), rhs=astx.LiteralInt32(7)
405+
),
406+
],
407+
)
408+
assert str(gen_expr)
409+
assert gen_expr.get_struct()
410+
assert gen_expr.get_struct(simplified=True)
411+
visualize(gen_expr.get_struct())

0 commit comments

Comments
 (0)