Skip to content

Add dict comprehension Support #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions docs/tutorials/dict-comprehension.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/tutorials/for-loop.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "45319c53",
"metadata": {},
"outputs": [],
Expand Down
2 changes: 2 additions & 0 deletions src/astx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
AsyncForRangeLoopExpr,
AsyncForRangeLoopStmt,
CaseStmt,
DictComprehension,
ForCountLoopExpr,
ForCountLoopStmt,
ForRangeLoopExpr,
Expand Down Expand Up @@ -224,6 +225,7 @@ def get_version() -> str:
"DataTypeOps",
"Date",
"DateTime",
"DictComprehension",
"DictType",
"EnumDeclStmt",
"ExceptionHandlerStmt",
Expand Down
1 change: 1 addition & 0 deletions src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class ASTKind(Enum):
WithStmtKind = -512
AsyncRangeLoopStmtKind = -513
AsyncRangeLoopExprKind = -514
DictComprehensionKind = -515

# data types
NullDTKind = -600
Expand Down
85 changes: 84 additions & 1 deletion src/astx/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Optional, cast
from typing import Optional, Union, cast

from public import public

Expand All @@ -18,6 +18,7 @@
StatementType,
)
from astx.blocks import Block
from astx.literals import LiteralList, LiteralSet, LiteralString, LiteralTuple
from astx.tools.typing import typechecked
from astx.variables import InlineVariableDeclaration

Expand Down Expand Up @@ -687,3 +688,85 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
key = f"GOTO-STMT[{self.label.value}]"
value: DictDataTypesStruct = {}
return self._prepare_struct(key, value, simplified)


@public
@typechecked
class DictComprehension(Expr):
"""AST class for function `DictComprehension` statement."""

def __init__(
self,
key: Identifier,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key and value probably should be Expr instead of Indentifier ... but maybe we need to discuss a bit other part of this class.

let's consider this example:

>>> ast.dump(ast.parse("{'y': x*x for x in my_list}"))
"Module(body=[Expr(value=DictComp(key=Constant(value='y'), value=BinOp(left=Name(id='x', ctx=Load()), op=Mult(), right=Name(id='x', ctx=Load())), generators=[comprehension(target=Name(id='x', ctx=Store()), iter=Name(id='my_list', ctx=Load()), ifs=[], is_async=0)]))], type_ignores=[])"

se it seems we will also need to add ifs and the comprehension class as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will work on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually did considered Expr for these variables but during pre-commit checks I was failing some checks bcoz, Expr did not have value attribute. But my tests was doing fine coz I had initialized key value as identifiers in the object creation.

value: Identifier,
iterable: Identifier,
iterator: Union[
LiteralList, LiteralTuple, LiteralSet, LiteralString, Identifier
],
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
"""Initialize the Return instance."""
super().__init__(loc, parent)
self.key = key
self.value = value
self.iterable = iterable
self.iterator = iterator
self.kind = ASTKind.DictComprehensionKind

def __str__(self) -> str:
"""Return a string representation of the object."""
if isinstance(self.iterator, LiteralList):
elements = ", ".join(str(e.value) for e in self.iterator.elements)
return (
f"{{{self.key.value}: {self.value.value} for "
f"{self.iterable.value} in [{elements}]}}"
)
elif isinstance(self.iterator, LiteralTuple):
elements = ", ".join(str(e.value) for e in self.iterator.elements)
return (
f"{{{self.key.value}: {self.value.value} for "
f"{self.iterable.value} in ({elements})}}"
)
elif isinstance(self.iterator, LiteralSet):
elements = ", ".join(str(e.value) for e in self.iterator.elements)
return (
f"{{{self.key.value}: {self.value.value} for "
f"{self.iterable.value} in {{{elements}}}}}"
)
elif isinstance(self.iterator, Identifier):
return (
f"{{{self.key.value}: {self.value.value} for "
f"{self.iterable.value} in {self.iterator.value}}}"
)
else:
return (
f"{{{self.key.value}: {self.value.value} for "
f"{self.iterable.value} in {self.iterator}}}"
)

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the object."""
key = "DICT-COMPREHENSION"
value: DictDataTypesStruct = {
"key": self.key.get_struct(simplified),
"value": self.value.get_struct(simplified),
"iterable": self.iterable.get_struct(simplified),
}

if isinstance(self.iterator, (LiteralList, LiteralSet, LiteralTuple)):
elements_struct = [
element.get_struct(simplified)
for element in self.iterator.elements
]
iterator_struct = {
f"{self.iterator.__class__.__name__}": {
f"ELEMENT[{i}]": element
for i, element in enumerate(elements_struct)
}
}
value["iterator"] = cast(DictDataTypesStruct, iterator_struct)
else:
value["iterator"] = self.iterator.get_struct(simplified)

return self._prepare_struct(key, value, simplified)
10 changes: 10 additions & 0 deletions src/astx/tools/transpilers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,13 @@ def visit(self, node: astx.LiteralDict) -> str:
for key, value in node.elements.items()
)
return f"{{{items_code}}}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.DictComprehension) -> str:
"""Handle DictComprehension nodes."""
key = self.visit(node.key)
value = self.visit(node.value)
iterable = self.visit(node.iterable)
iterator = self.visit(node.iterator)

return f"{{{key}: {value} for {iterable} in {iterator}}}"
19 changes: 19 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Tests for control flow statements."""

# for running tests from local astx module
# import os,sys
# # print(os.path.abspath(os.path.join(os.getcwd(), ".", "src")))
# sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), ".", "src")))

import astx
import pytest

Expand Down Expand Up @@ -323,3 +328,17 @@ def test_goto_stmt() -> None:
assert goto_stmt.get_struct()
assert goto_stmt.get_struct(simplified=True)
visualize(goto_stmt.get_struct())


def test_dict_comprehension() -> None:
"""Test `DictComprehension` class."""
dict_comp = astx.DictComprehension(
key=astx.Identifier("x"),
value=astx.Identifier("x*x"),
iterable=astx.Identifier("x"),
iterator=astx.Identifier("my_list"),
)
assert str(dict_comp) == "{x: x*x for x in my_list}"
assert dict_comp.get_struct()
assert dict_comp.get_struct(simplified=True)
visualize(dict_comp.get_struct())
42 changes: 42 additions & 0 deletions tests/tools/transpilers/test_python.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Test Python Transpiler."""

# for running tests from local astx module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove these comments.

btw, if you install with poetry install, you shouldn't need any of these hacks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay sure I will do it. I actually did used poetry.

# import os,sys
# # print(os.path.abspath(os.path.join(os.getcwd(), ".", "src")))
# sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), ".", "src")))

import ast
import sys

Expand Down Expand Up @@ -1397,3 +1402,40 @@ def test_transpiler_literal_dict() -> None:
1: 10,
2: 20,
}, f"Expected '{expected_code}', but got '{generated_code}'"


def test_dict_comprehension() -> None:
"""Test astx.DictComprehension."""
dict_comp_1 = astx.DictComprehension(
key=astx.Identifier("x"),
value=astx.Identifier("x*x"),
iterable=astx.Identifier("x"),
iterator=astx.Identifier("my_list"),
)

generated_code_1 = transpiler.visit(dict_comp_1)
expected_code_1 = "{x: x*x for x in my_list}"

dict_comp_2 = astx.DictComprehension(
key=astx.Identifier("x"),
value=astx.Identifier("x*x"),
iterable=astx.Identifier("x"),
iterator=astx.LiteralList(
elements=[
astx.LiteralInt32(10),
astx.LiteralInt32(20),
astx.LiteralInt32(30),
]
),
)

generated_code_2 = transpiler.visit(dict_comp_2)
expected_code_2 = "{x: x*x for x in [10, 20, 30]}"

assert generated_code_1 == expected_code_1, (
f"Expected code: {expected_code_1} ;"
f" Generated code: {generated_code_1}"
)
assert generated_code_2 == expected_code_2, (
f"Expected code: {expected_code_2} ;Generated code: {generated_code_2}"
)
Loading