diff --git a/src/astx/__init__.py b/src/astx/__init__.py index 0ffa1e51..d3c80820 100644 --- a/src/astx/__init__.py +++ b/src/astx/__init__.py @@ -58,6 +58,8 @@ ThrowStmt, ) from astx.flows import ( + AsyncForRangeLoopExpr, + AsyncForRangeLoopStmt, CaseStmt, ForCountLoopExpr, ForCountLoopStmt, @@ -199,6 +201,8 @@ def get_version() -> str: "Argument", "Arguments", "AssignmentExpr", + "AsyncForRangeLoopExpr", + "AsyncForRangeLoopStmt", "AwaitExpr", "BinaryOp", "Block", diff --git a/src/astx/base.py b/src/astx/base.py index 84d4b5b9..416a3d31 100644 --- a/src/astx/base.py +++ b/src/astx/base.py @@ -127,6 +127,8 @@ class ASTKind(Enum): SwitchStmtKind = -509 GotoStmtKind = -511 WithStmtKind = -512 + AsyncRangeLoopStmtKind = -513 + AsyncRangeLoopExprKind = -514 # data types NullDTKind = -600 diff --git a/src/astx/flows.py b/src/astx/flows.py index 79f3196c..d9319fce 100644 --- a/src/astx/flows.py +++ b/src/astx/flows.py @@ -343,6 +343,131 @@ def get_struct(self, simplified: bool = False) -> ReprStruct: return self._prepare_struct(key, value, simplified) +@public +@typechecked +class AsyncForRangeLoopStmt(StatementType): + """AST class for asynchronous `For` Range Statement.""" + + variable: InlineVariableDeclaration + start: Optional[Expr] + end: Expr + step: Optional[Expr] + body: Block + + def __init__( + self, + variable: InlineVariableDeclaration, + start: Optional[Expr], + end: Expr, + step: Optional[Expr], + body: Block, + loc: SourceLocation = NO_SOURCE_LOCATION, + parent: Optional[ASTNodes] = None, + ) -> None: + """Initialize the AsyncForRangeLoopStmt instance.""" + super().__init__(loc=loc, parent=parent) + self.variable = variable + self.start = start + self.end = end + self.step = step + self.body = body + self.kind = ASTKind.AsyncRangeLoopStmtKind + + def __str__(self) -> str: + """Return a string that represents the object.""" + start = self.start + end = self.end + step = self.step + var_name = self.variable.name + return f"AsyncForRangeLoopStmt({var_name}=[{start}:{end}:{step}])" + + def get_struct(self, simplified: bool = False) -> ReprStruct: + """Return the AST structure of the object.""" + for_start = { + "start": {} + if self.start is None + else self.start.get_struct(simplified) + } + for_end = {"end": self.end.get_struct(simplified)} + for_step = { + "step": {} + if self.step is None + else self.step.get_struct(simplified) + } + for_body = self.body.get_struct(simplified) + + key = "ASYNC-FOR-RANGE-LOOP-STMT" + value: ReprStruct = { + **cast(DictDataTypesStruct, for_start), + **cast(DictDataTypesStruct, for_end), + **cast(DictDataTypesStruct, for_step), + **cast(DictDataTypesStruct, for_body), + } + return self._prepare_struct(key, value, simplified) + + +@public +@typechecked +class AsyncForRangeLoopExpr(Expr): + """AST class for asynchronous `For` Range Expression.""" + + variable: InlineVariableDeclaration + start: Optional[Expr] + end: Expr + step: Optional[Expr] + body: Block + + def __init__( + self, + variable: InlineVariableDeclaration, + start: Optional[Expr], + end: Expr, + step: Optional[Expr], + body: Block, + loc: SourceLocation = NO_SOURCE_LOCATION, + parent: Optional[ASTNodes] = None, + ) -> None: + """Initialize the AsyncForRangeLoopExpr instance.""" + super().__init__(loc=loc, parent=parent) + self.variable = variable + self.start = start + self.end = end + self.step = step + self.body = body + self.kind = ASTKind.AsyncRangeLoopExprKind + + def __str__(self) -> str: + """Return a string that represents the object.""" + var_name = self.variable.name + return f"AsyncForRangeLoopExpr[{var_name}]" + + def get_struct(self, simplified: bool = False) -> ReprStruct: + """Return the AST structure of the object.""" + for_var = {"var": self.variable.get_struct(simplified)} + for_start = { + "start": {} + if self.start is None + else self.start.get_struct(simplified) + } + for_end = {"end": self.end.get_struct(simplified)} + for_step = { + "step": {} + if self.step is None + else self.step.get_struct(simplified) + } + for_body = self.body.get_struct(simplified) + + key = "ASYNC-FOR-RANGE-LOOP-EXPR" + value: ReprStruct = { + **cast(DictDataTypesStruct, for_var), + **cast(DictDataTypesStruct, for_start), + **cast(DictDataTypesStruct, for_end), + **cast(DictDataTypesStruct, for_step), + **cast(DictDataTypesStruct, for_body), + } + return self._prepare_struct(key, value, simplified) + + @public @typechecked class WhileStmt(StatementType): diff --git a/src/astx/tools/transpilers/python.py b/src/astx/tools/transpilers/python.py index 6383bff1..b2d4a729 100644 --- a/src/astx/tools/transpilers/python.py +++ b/src/astx/tools/transpilers/python.py @@ -66,6 +66,31 @@ def visit(self, node: astx.AssignmentExpr) -> str: target_str = " = ".join(self.visit(target) for target in node.targets) return f"{target_str} = {self.visit(node.value)}" + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.AsyncForRangeLoopExpr) -> str: + """Handle AsyncForRangeLoopExpr nodes.""" + if len(node.body) > 1: + raise ValueError( + "AsyncForRangeLoopExpr in Python just accept 1 node in the " + "body attribute." + ) + start = ( + self.visit(node.start) + if getattr(node, "start", None) is not None + else "0" + ) + end = self.visit(node.end) + step = ( + self.visit(node.step) + if getattr(node, "step", None) is not None + else "1" + ) + + return ( + f"result = [{self.visit(node.body).strip()} async for " + f"{node.variable.name} in range({start}, {end}, {step})]" + ) + @dispatch # type: ignore[no-redef] def visit(self, node: astx.AwaitExpr) -> str: """Handle AwaitExpr nodes.""" diff --git a/tests/test_flows.py b/tests/test_flows.py index 69068d57..83837cd6 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -6,6 +6,8 @@ from astx.base import SourceLocation from astx.blocks import Block from astx.flows import ( + AsyncForRangeLoopExpr, + AsyncForRangeLoopStmt, CaseStmt, ForCountLoopExpr, ForCountLoopStmt, @@ -155,6 +157,46 @@ def test_for_count_loop_expr() -> None: visualize(for_expr.get_struct()) +def test_async_for_range_loop_expr() -> None: + """Test `Async For Range Loop` expression`.""" + decl_a = InlineVariableDeclaration( + "a", type_=Int32(), value=LiteralInt32(-1) + ) + start = LiteralInt32(1) + end = LiteralInt32(10) + step = LiteralInt32(1) + body = Block() + body.append(LiteralInt32(2)) + for_expr = AsyncForRangeLoopExpr( + variable=decl_a, start=start, end=end, step=step, body=body + ) + + assert str(for_expr) + assert for_expr.get_struct() + assert for_expr.get_struct(simplified=True) + visualize(for_expr.get_struct()) + + +def test_async_for_range_loop_stmt() -> None: + """Test `Async For Range Loop` statement.""" + decl_a = InlineVariableDeclaration( + "a", type_=Int32(), value=LiteralInt32(-1) + ) + start = LiteralInt32(1) + end = LiteralInt32(10) + step = LiteralInt32(1) + body = Block() + body.append(LiteralInt32(2)) + for_stmt = AsyncForRangeLoopStmt( + variable=decl_a, start=start, end=end, step=step, body=body + ) + + assert str(for_stmt) + assert for_stmt.get_struct() + assert for_stmt.get_struct(simplified=True) + visualize(for_stmt.get_struct()) + + def test_while_expr() -> None: """Test `WhileExpr` class.""" # Define a condition: x < 5 diff --git a/tests/tools/transpilers/test_python.py b/tests/tools/transpilers/test_python.py index 4ce136d6..faa18bf6 100644 --- a/tests/tools/transpilers/test_python.py +++ b/tests/tools/transpilers/test_python.py @@ -434,6 +434,29 @@ def test_transpiler_for_range_loop_expr() -> None: ) +def test_transpiler_async_for_range_loop_expr() -> None: + """Test `Async For Range Loop` expression`.""" + decl_a = astx.InlineVariableDeclaration( + "a", type_=astx.Int32(), value=astx.LiteralInt32(-1) + ) + start = astx.LiteralInt32(0) + end = astx.LiteralInt32(10) + step = astx.LiteralInt32(1) + body = astx.Block() + body.append(astx.LiteralInt32(2)) + + for_expr = astx.AsyncForRangeLoopExpr( + variable=decl_a, start=start, end=end, step=step, body=body + ) + + generated_code = translate(for_expr) + expected_code = "result = [2 async for a in range(0, 10, 1)]" + + assert generated_code == expected_code, ( + f"Expected '{expected_code}', but got '{generated_code}'" + ) + + def test_transpiler_binary_op() -> None: """Test astx.BinaryOp for addition operation.""" # Create a BinaryOp node for the expression "x + y"