Skip to content

Commit 9f13aba

Browse files
JoshuaGlaZZehen-249
authored andcommitted
feat: Add support for AsyncFor (arxlang#211)
1 parent 4ff752d commit 9f13aba

File tree

6 files changed

+221
-0
lines changed

6 files changed

+221
-0
lines changed

src/astx/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
ThrowStmt,
6060
)
6161
from astx.flows import (
62+
AsyncForRangeLoopExpr,
63+
AsyncForRangeLoopStmt,
6264
CaseStmt,
6365
ForCountLoopExpr,
6466
ForCountLoopStmt,
@@ -200,6 +202,8 @@ def get_version() -> str:
200202
"Argument",
201203
"Arguments",
202204
"AssignmentExpr",
205+
"AsyncForRangeLoopExpr",
206+
"AsyncForRangeLoopStmt",
203207
"AwaitExpr",
204208
"BinaryOp",
205209
"Block",

src/astx/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class ASTKind(Enum):
129129
GotoStmtKind = -511
130130
WithStmtKind = -512
131131
GeneratorExprKind = -516
132+
AsyncRangeLoopStmtKind = -513
133+
AsyncRangeLoopExprKind = -514
132134

133135
# data types
134136
NullDTKind = -600

src/astx/flows.py

+125
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,131 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
344344
return self._prepare_struct(key, value, simplified)
345345

346346

347+
@public
348+
@typechecked
349+
class AsyncForRangeLoopStmt(StatementType):
350+
"""AST class for asynchronous `For` Range Statement."""
351+
352+
variable: InlineVariableDeclaration
353+
start: Optional[Expr]
354+
end: Expr
355+
step: Optional[Expr]
356+
body: Block
357+
358+
def __init__(
359+
self,
360+
variable: InlineVariableDeclaration,
361+
start: Optional[Expr],
362+
end: Expr,
363+
step: Optional[Expr],
364+
body: Block,
365+
loc: SourceLocation = NO_SOURCE_LOCATION,
366+
parent: Optional[ASTNodes] = None,
367+
) -> None:
368+
"""Initialize the AsyncForRangeLoopStmt instance."""
369+
super().__init__(loc=loc, parent=parent)
370+
self.variable = variable
371+
self.start = start
372+
self.end = end
373+
self.step = step
374+
self.body = body
375+
self.kind = ASTKind.AsyncRangeLoopStmtKind
376+
377+
def __str__(self) -> str:
378+
"""Return a string that represents the object."""
379+
start = self.start
380+
end = self.end
381+
step = self.step
382+
var_name = self.variable.name
383+
return f"AsyncForRangeLoopStmt({var_name}=[{start}:{end}:{step}])"
384+
385+
def get_struct(self, simplified: bool = False) -> ReprStruct:
386+
"""Return the AST structure of the object."""
387+
for_start = {
388+
"start": {}
389+
if self.start is None
390+
else self.start.get_struct(simplified)
391+
}
392+
for_end = {"end": self.end.get_struct(simplified)}
393+
for_step = {
394+
"step": {}
395+
if self.step is None
396+
else self.step.get_struct(simplified)
397+
}
398+
for_body = self.body.get_struct(simplified)
399+
400+
key = "ASYNC-FOR-RANGE-LOOP-STMT"
401+
value: ReprStruct = {
402+
**cast(DictDataTypesStruct, for_start),
403+
**cast(DictDataTypesStruct, for_end),
404+
**cast(DictDataTypesStruct, for_step),
405+
**cast(DictDataTypesStruct, for_body),
406+
}
407+
return self._prepare_struct(key, value, simplified)
408+
409+
410+
@public
411+
@typechecked
412+
class AsyncForRangeLoopExpr(Expr):
413+
"""AST class for asynchronous `For` Range Expression."""
414+
415+
variable: InlineVariableDeclaration
416+
start: Optional[Expr]
417+
end: Expr
418+
step: Optional[Expr]
419+
body: Block
420+
421+
def __init__(
422+
self,
423+
variable: InlineVariableDeclaration,
424+
start: Optional[Expr],
425+
end: Expr,
426+
step: Optional[Expr],
427+
body: Block,
428+
loc: SourceLocation = NO_SOURCE_LOCATION,
429+
parent: Optional[ASTNodes] = None,
430+
) -> None:
431+
"""Initialize the AsyncForRangeLoopExpr instance."""
432+
super().__init__(loc=loc, parent=parent)
433+
self.variable = variable
434+
self.start = start
435+
self.end = end
436+
self.step = step
437+
self.body = body
438+
self.kind = ASTKind.AsyncRangeLoopExprKind
439+
440+
def __str__(self) -> str:
441+
"""Return a string that represents the object."""
442+
var_name = self.variable.name
443+
return f"AsyncForRangeLoopExpr[{var_name}]"
444+
445+
def get_struct(self, simplified: bool = False) -> ReprStruct:
446+
"""Return the AST structure of the object."""
447+
for_var = {"var": self.variable.get_struct(simplified)}
448+
for_start = {
449+
"start": {}
450+
if self.start is None
451+
else self.start.get_struct(simplified)
452+
}
453+
for_end = {"end": self.end.get_struct(simplified)}
454+
for_step = {
455+
"step": {}
456+
if self.step is None
457+
else self.step.get_struct(simplified)
458+
}
459+
for_body = self.body.get_struct(simplified)
460+
461+
key = "ASYNC-FOR-RANGE-LOOP-EXPR"
462+
value: ReprStruct = {
463+
**cast(DictDataTypesStruct, for_var),
464+
**cast(DictDataTypesStruct, for_start),
465+
**cast(DictDataTypesStruct, for_end),
466+
**cast(DictDataTypesStruct, for_step),
467+
**cast(DictDataTypesStruct, for_body),
468+
}
469+
return self._prepare_struct(key, value, simplified)
470+
471+
347472
@public
348473
@typechecked
349474
class WhileStmt(StatementType):

src/astx/tools/transpilers/python.py

+25
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,31 @@ def visit(self, node: astx.AssignmentExpr) -> str:
6666
target_str = " = ".join(self.visit(target) for target in node.targets)
6767
return f"{target_str} = {self.visit(node.value)}"
6868

69+
@dispatch # type: ignore[no-redef]
70+
def visit(self, node: astx.AsyncForRangeLoopExpr) -> str:
71+
"""Handle AsyncForRangeLoopExpr nodes."""
72+
if len(node.body) > 1:
73+
raise ValueError(
74+
"AsyncForRangeLoopExpr in Python just accept 1 node in the "
75+
"body attribute."
76+
)
77+
start = (
78+
self.visit(node.start)
79+
if getattr(node, "start", None) is not None
80+
else "0"
81+
)
82+
end = self.visit(node.end)
83+
step = (
84+
self.visit(node.step)
85+
if getattr(node, "step", None) is not None
86+
else "1"
87+
)
88+
89+
return (
90+
f"result = [{self.visit(node.body).strip()} async for "
91+
f"{node.variable.name} in range({start}, {end}, {step})]"
92+
)
93+
6994
@dispatch # type: ignore[no-redef]
7095
def visit(self, node: astx.AwaitExpr) -> str:
7196
"""Handle AwaitExpr nodes."""

tests/test_flows.py

+42
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from astx.base import SourceLocation
77
from astx.blocks import Block
88
from astx.flows import (
9+
AsyncForRangeLoopExpr,
10+
AsyncForRangeLoopStmt,
911
CaseStmt,
1012
ForCountLoopExpr,
1113
ForCountLoopStmt,
@@ -155,6 +157,46 @@ def test_for_count_loop_expr() -> None:
155157
visualize(for_expr.get_struct())
156158

157159

160+
def test_async_for_range_loop_expr() -> None:
161+
"""Test `Async For Range Loop` expression`."""
162+
decl_a = InlineVariableDeclaration(
163+
"a", type_=Int32(), value=LiteralInt32(-1)
164+
)
165+
start = LiteralInt32(1)
166+
end = LiteralInt32(10)
167+
step = LiteralInt32(1)
168+
body = Block()
169+
body.append(LiteralInt32(2))
170+
for_expr = AsyncForRangeLoopExpr(
171+
variable=decl_a, start=start, end=end, step=step, body=body
172+
)
173+
174+
assert str(for_expr)
175+
assert for_expr.get_struct()
176+
assert for_expr.get_struct(simplified=True)
177+
visualize(for_expr.get_struct())
178+
179+
180+
def test_async_for_range_loop_stmt() -> None:
181+
"""Test `Async For Range Loop` statement."""
182+
decl_a = InlineVariableDeclaration(
183+
"a", type_=Int32(), value=LiteralInt32(-1)
184+
)
185+
start = LiteralInt32(1)
186+
end = LiteralInt32(10)
187+
step = LiteralInt32(1)
188+
body = Block()
189+
body.append(LiteralInt32(2))
190+
for_stmt = AsyncForRangeLoopStmt(
191+
variable=decl_a, start=start, end=end, step=step, body=body
192+
)
193+
194+
assert str(for_stmt)
195+
assert for_stmt.get_struct()
196+
assert for_stmt.get_struct(simplified=True)
197+
visualize(for_stmt.get_struct())
198+
199+
158200
def test_while_expr() -> None:
159201
"""Test `WhileExpr` class."""
160202
# Define a condition: x < 5

tests/tools/transpilers/test_python.py

+23
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,29 @@ def test_transpiler_for_range_loop_expr() -> None:
434434
)
435435

436436

437+
def test_transpiler_async_for_range_loop_expr() -> None:
438+
"""Test `Async For Range Loop` expression`."""
439+
decl_a = astx.InlineVariableDeclaration(
440+
"a", type_=astx.Int32(), value=astx.LiteralInt32(-1)
441+
)
442+
start = astx.LiteralInt32(0)
443+
end = astx.LiteralInt32(10)
444+
step = astx.LiteralInt32(1)
445+
body = astx.Block()
446+
body.append(astx.LiteralInt32(2))
447+
448+
for_expr = astx.AsyncForRangeLoopExpr(
449+
variable=decl_a, start=start, end=end, step=step, body=body
450+
)
451+
452+
generated_code = translate(for_expr)
453+
expected_code = "result = [2 async for a in range(0, 10, 1)]"
454+
455+
assert generated_code == expected_code, (
456+
f"Expected '{expected_code}', but got '{generated_code}'"
457+
)
458+
459+
437460
def test_transpiler_binary_op() -> None:
438461
"""Test astx.BinaryOp for addition operation."""
439462
# Create a BinaryOp node for the expression "x + y"

0 commit comments

Comments
 (0)