Skip to content

feat: add f-string literal #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions libs/astx-transpilers/src/astx_transpilers/python_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,51 @@ def visit(self, node: astx.FinallyHandlerStmt) -> str:
body_str = self._generate_block(node.body)
return f"finally:\n{body_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.FormattedValue) -> str:
"""Handle FormattedValue nodes (f-string parts like {x!r:.2f})."""
value_str = self.visit(node.value)
if not isinstance(
node.value, (astx.Variable, astx.Identifier, astx.Literal)
):
if not (value_str.startswith("(") and value_str.endswith(")")):
value_str = f"({value_str})"

conv_char = f"!{chr(node.conversion)}" if node.conversion else ""
format_spec_str = ""
if node.format_spec:
if isinstance(node.format_spec, astx.LiteralString):
format_spec_inner = node.format_spec.value
format_spec_str = f":{format_spec_inner}"
else:
format_spec_inner_expr = self.visit(node.format_spec)
format_spec_str = f":{{{format_spec_inner_expr}}}"

return f"{{{value_str}{conv_char}{format_spec_str}}}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.LiteralString) -> str:
"""Handle LiteralString nodes. Escapes braces for f-strings."""
escaped_value = node.value.replace("{", "{{").replace("}", "}}")
return repr(escaped_value)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.JoinedStr) -> str:
"""Handle JoinedStr nodes (f-string)."""
parts = []
for value_node in node.values:
if isinstance(value_node, astx.LiteralString):
parts.append(
value_node.value.replace("{", "{{").replace("}", "}}")
)
elif isinstance(value_node, astx.FormattedValue):
parts.append(self.visit(value_node))
else:
raise TypeError(
f"Unexpected node type in JoinedStr: {type(value_node)}"
)
return f"f'{''.join(parts)}'"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ForRangeLoopExpr) -> str:
"""Handle ForRangeLoopExpr nodes."""
Expand Down
110 changes: 110 additions & 0 deletions libs/astx-transpilers/tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,116 @@ def test_transpiler_literal_utf8_string() -> None:
)


def test_transpiler_formatted_value_simple() -> None:
"""Test transpiling astx.FormattedValue simple case."""
var_x = astx.Variable("x")
fmt_val = astx.FormattedValue(value=var_x)
generated_code = transpiler.visit(fmt_val)
expected_code = "{x}"
assert generated_code == expected_code
check_transpilation(f"f'{generated_code}'")


def test_transpiler_formatted_value_full() -> None:
"""Test transpiling astx.FormattedValue with conversion and format spec."""
var_y = astx.Variable("y")
fmt_spec = astx.LiteralString(".2f")
fmt_val = astx.FormattedValue(
value=var_y,
conversion=ord("r"),
format_spec=fmt_spec,
)
generated_code = transpiler.visit(fmt_val)
expected_code = "{y!r:.2f}"
assert generated_code == expected_code
check_transpilation(f"f'{generated_code}'")


def test_transpiler_formatted_value_complex_expr() -> None:
"""Test transpiling astx.FormattedValue with a complex expression."""
expr = astx.BinaryOp(
op_code="+", lhs=astx.Variable("a"), rhs=astx.LiteralInt32(1)
)
fmt_val = astx.FormattedValue(value=expr)
generated_code = transpiler.visit(fmt_val)
expected_code = "{(a + 1)}"
assert generated_code == expected_code
check_transpilation(f"f'{generated_code}'")


def test_transpiler_formatted_value_complex_format_spec() -> None:
"""Test FormattedValue with expression as format spec."""
var_val = astx.Variable("value")
var_width = astx.Variable("width")
fmt_val = astx.FormattedValue(value=var_val, format_spec=var_width)
generated_code = transpiler.visit(fmt_val)
expected_code = "{value:{width}}"
assert generated_code == expected_code
check_transpilation(f"f'{generated_code}'")


def test_transpiler_joined_str_simple() -> None:
"""Test transpiling a simple astx.JoinedStr."""
lit1 = astx.LiteralString("Hello ")
var_name = astx.Variable("name")
fmt_val = astx.FormattedValue(value=var_name)
lit2 = astx.LiteralString("!")
joined = astx.JoinedStr(values=[lit1, fmt_val, lit2])
generated_code = translate(joined)
expected_code = "f'Hello {name}!'"
assert generated_code == expected_code


def test_transpiler_joined_str_complex() -> None:
"""Test transpiling a more complex astx.JoinedStr."""
lit1 = astx.LiteralString("Value: ")
var_y = astx.Variable("y")
fmt_spec = astx.LiteralString(".2f")
fmt_val = astx.FormattedValue(
value=var_y,
conversion=ord("r"),
format_spec=fmt_spec,
)
lit2 = astx.LiteralString(" (result)")
joined = astx.JoinedStr(values=[lit1, fmt_val, lit2])
generated_code = translate(joined)
expected_code = "f'Value: {y!r:.2f} (result)'"
assert generated_code == expected_code


def test_transpiler_joined_str_with_escapes() -> None:
"""Test transpiling JoinedStr with literal braces."""
lit1 = astx.LiteralString("Show {curly} braces and ")
var_x = astx.Variable("x")
fmt_val = astx.FormattedValue(value=var_x)
joined = astx.JoinedStr(values=[lit1, fmt_val])
generated_code = translate(joined)
expected_code = "f'Show {{curly}} braces and {x}'"
assert generated_code == expected_code


def test_transpiler_joined_str_nested_format_spec() -> None:
"""Test transpiling JoinedStr with nested format specifier expression."""
lit1 = astx.LiteralString("Result: ")
var_val = astx.Variable("value")
var_width = astx.Variable("width")
fmt_val = astx.FormattedValue(value=var_val, format_spec=var_width)
joined = astx.JoinedStr(values=[lit1, fmt_val])
generated_code = translate(joined)
expected_code = "f'Result: {value:{width}}'"
assert generated_code == expected_code


def test_transpiler_joined_str_only_literals() -> None:
"""Test transpiling JoinedStr with only literal parts."""
lit1 = astx.LiteralString("Just a literal ")
lit2 = astx.LiteralString("string with {escaped} braces.")
joined = astx.JoinedStr(values=[lit1, lit2])
generated_code = translate(joined)
expected_code = "f'Just a literal string with {{escaped}} braces.'"
assert generated_code == expected_code


def test_transpiler_for_range_loop_expr() -> None:
"""Test `For Range Loop` expression`."""
decl_a = astx.InlineVariableDeclaration(
Expand Down
4 changes: 4 additions & 0 deletions libs/astx/src/astx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
WhileStmt,
)
from astx.literals import (
FormattedValue,
JoinedStr,
Literal,
LiteralBoolean,
LiteralComplex,
Expand Down Expand Up @@ -265,6 +267,7 @@ def get_version() -> str:
"ForCountLoopStmt",
"ForRangeLoopExpr",
"ForRangeLoopStmt",
"FormattedValue",
"FunctionAsyncDef",
"FunctionCall",
"FunctionDef",
Expand All @@ -285,6 +288,7 @@ def get_version() -> str:
"Int32",
"Int64",
"Integer",
"JoinedStr",
"LambdaExpr",
"ListComprehension",
"ListType",
Expand Down
2 changes: 2 additions & 0 deletions libs/astx/src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ class ASTKind(Enum):
TimeDTKind = -624
DateDTKind = -625
DateTimeDTKind = -626
FormattedValueKind = -650
JoinedStrKind = -651

# imports(packages)
ImportStmtKind = -700
Expand Down
5 changes: 5 additions & 0 deletions libs/astx/src/astx/literals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
LiteralUInt128,
)
from astx.literals.string import (
FormattedValue,
JoinedStr,
LiteralString,
LiteralUTF8Char,
LiteralUTF8String,
Expand All @@ -44,6 +46,8 @@
)

__all__ = [
"FormattedValue",
"JoinedStr",
"Literal",
"LiteralBoolean",
"LiteralComplex",
Expand All @@ -55,6 +59,7 @@
"LiteralFloat16",
"LiteralFloat32",
"LiteralFloat64",
"LiteralFormattedString",
"LiteralInt8",
"LiteralInt16",
"LiteralInt32",
Expand Down
107 changes: 106 additions & 1 deletion libs/astx/src/astx/literals/string.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
"""ASTx Data Types module."""
"""ASTx Data Types module for strings."""

from __future__ import annotations

from typing import Any, Dict, List, Optional

from public import public

from astx.base import (
NO_SOURCE_LOCATION,
ASTKind,
ASTNodes,
Expr,
Identifier,
ReprStruct,
SourceLocation,
)
from astx.literals.base import Literal
from astx.tools.typing import typechecked
from astx.types.string import String, UTF8Char, UTF8String
from astx.variables import Variable


@public
Expand Down Expand Up @@ -87,3 +94,101 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
key = f"LiteralUTF8Char: {self.value}"
value = self.value
return self._prepare_struct(key, value, simplified)


@public
@typechecked
class FormattedValue(Expr):
"""Represents formatted value parts within a JoinedStr (e.g., {x:.2f})."""

value: Expr
conversion: Optional[int]
format_spec: Optional[Expr]

kind: ASTKind = ASTKind.FormattedValueKind

def __init__(
self,
value: Expr,
conversion: Optional[int] = None,
format_spec: Optional[Expr] = None,
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.value = value
self.conversion = conversion
self.format_spec = format_spec

def __str__(self) -> str:
"""Return a string representation of the formatted value part."""
if isinstance(self.value, Variable):
value_str = self.value.name
elif isinstance(self.value, Identifier):
value_str = self.value.value
elif isinstance(self.value, LiteralString):
value_str = repr(self.value.value)
elif isinstance(self.value, Literal):
value_str = str(self.value.value)
else:
value_str = str(self.value)

conv_char = f"!{chr(self.conversion)}" if self.conversion else ""

fmt_spec_inner_str = ""
if isinstance(self.format_spec, LiteralString):
fmt_spec_inner_str = self.format_spec.value
elif self.format_spec is not None:
fmt_spec_inner_str = str(self.format_spec)

fmt_spec_str = f":{fmt_spec_inner_str}" if self.format_spec else ""

return f"FormattedValue({value_str}{conv_char}{fmt_spec_str})"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the object."""
content: Dict[str, Any] = {"value": self.value.get_struct(simplified)}
if self.conversion is not None:
content["conversion"] = chr(self.conversion)
if self.format_spec is not None:
content["format_spec"] = self.format_spec.get_struct(simplified)

key = "FormattedValue"
return self._prepare_struct(key, content, simplified)


@public
@typechecked
class JoinedStr(Expr):
"""Represents an f-string literal (e.g., f'hello {name}')."""

values: List[Expr]

kind: ASTKind = ASTKind.JoinedStrKind

def __init__(
self,
values: List[Expr],
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
for val in values:
if not isinstance(val, (LiteralString, FormattedValue)):
raise TypeError(
"JoinedStr values must be LiteralString or FormattedValue"
)
self.values = values

def __str__(self) -> str:
"""Return a string representation of the joined string structure."""
value_strs = [str(v) for v in self.values]
return f"JoinedStr([{', '.join(value_strs)}])"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the object."""
content: Dict[str, Any] = {
"values": [v.get_struct(simplified) for v in self.values]
}
key = "JoinedStr"
return self._prepare_struct(key, content, simplified)
Loading
Loading