Skip to content

Commit aea14f3

Browse files
authored
feat(printer): Print Contextual Stmts in ToPython (#44)
1 parent 0b0904f commit aea14f3

File tree

4 files changed

+57
-15
lines changed

4 files changed

+57
-15
lines changed

include/mlc/printer/ast.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,32 @@ namespace printer {
1010
using mlc::core::ObjectPath;
1111

1212
struct PrinterConfigObj : public Object {
13+
bool def_free_var = true;
1314
int32_t indent_spaces = 2;
1415
int8_t print_line_numbers = 0;
1516
int32_t num_context_lines = -1;
1617
mlc::List<ObjectPath> path_to_underline;
1718

1819
PrinterConfigObj() = default;
19-
explicit PrinterConfigObj(int32_t indent_spaces, int8_t print_line_numbers, int32_t num_context_lines,
20-
mlc::List<ObjectPath> path_to_underline)
21-
: indent_spaces(indent_spaces), print_line_numbers(print_line_numbers), num_context_lines(num_context_lines),
22-
path_to_underline(path_to_underline) {}
20+
explicit PrinterConfigObj(bool def_free_var, int32_t indent_spaces, int8_t print_line_numbers,
21+
int32_t num_context_lines, mlc::List<ObjectPath> path_to_underline)
22+
: def_free_var(def_free_var), indent_spaces(indent_spaces), print_line_numbers(print_line_numbers),
23+
num_context_lines(num_context_lines), path_to_underline(path_to_underline) {}
2324
MLC_DEF_DYN_TYPE(MLC_EXPORTS, PrinterConfigObj, Object, "mlc.printer.PrinterConfig");
2425
};
2526

2627
struct PrinterConfig : public ObjectRef {
2728
MLC_DEF_OBJ_REF(MLC_EXPORTS, PrinterConfig, PrinterConfigObj, ObjectRef)
29+
.Field("def_free_var", &PrinterConfigObj::def_free_var)
2830
.Field("indent_spaces", &PrinterConfigObj::indent_spaces)
2931
.Field("print_line_numbers", &PrinterConfigObj::print_line_numbers)
3032
.Field("num_context_lines", &PrinterConfigObj::num_context_lines)
3133
.Field("path_to_underline", &PrinterConfigObj::path_to_underline)
32-
.StaticFn("__init__", InitOf<PrinterConfigObj, int32_t, int8_t, int32_t, mlc::List<ObjectPath>>);
33-
explicit PrinterConfig(int32_t indent_spaces = 2, int8_t print_line_numbers = 0, int32_t num_context_lines = -1,
34-
mlc::List<ObjectPath> path_to_underline = {})
35-
: PrinterConfig(PrinterConfig::New(indent_spaces, print_line_numbers, num_context_lines, path_to_underline)) {}
34+
.StaticFn("__init__", InitOf<PrinterConfigObj, bool, int32_t, int8_t, int32_t, mlc::List<ObjectPath>>);
35+
explicit PrinterConfig(bool def_free_var = true, int32_t indent_spaces = 2, int8_t print_line_numbers = 0,
36+
int32_t num_context_lines = -1, mlc::List<ObjectPath> path_to_underline = {})
37+
: PrinterConfig(PrinterConfig::New(def_free_var, indent_spaces, print_line_numbers, num_context_lines,
38+
path_to_underline)) {}
3639
};
3740

3841
} // namespace printer
@@ -82,6 +85,17 @@ struct ExprObj : public ::mlc::Object {
8285
}; // struct ExprObj
8386

8487
struct Expr : public ::mlc::printer::Node {
88+
Expr Attr(mlc::Str name) const { return this->get()->Attr(name); }
89+
Expr Index(mlc::List<::mlc::printer::Expr> idx) const { return this->get()->Index(idx); }
90+
Expr Call(mlc::List<::mlc::printer::Expr> args) const { return this->get()->Call(args); }
91+
Expr CallKw(mlc::List<::mlc::printer::Expr> args, mlc::List<::mlc::Str> kwargs_keys,
92+
mlc::List<::mlc::printer::Expr> kwargs_values) const {
93+
return this->get()->CallKw(args, kwargs_keys, kwargs_values);
94+
}
95+
Expr AddPath(mlc::core::ObjectPath p) {
96+
this->get()->source_paths->push_back(p);
97+
return *this;
98+
}
8599
MLC_DEF_OBJ_REF(MLC_EXPORTS, Expr, ExprObj, ::mlc::printer::Node)
86100
.Field("source_paths", &ExprObj::source_paths)
87101
.StaticFn("__init__", ::mlc::InitOf<ExprObj, ::mlc::List<::mlc::core::ObjectPath>>)

include/mlc/printer/ir_printer.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ struct IRPrinterObj : public Object {
139139
return ret;
140140
}
141141

142+
template <typename T> mlc::List<T> ApplyToList(const UList &list, const ObjectPath &p) const {
143+
// TODO: expose a Python interface
144+
int64_t n = list->size();
145+
mlc::List<T> args;
146+
args.reserve(n);
147+
for (int64_t i = 0; i < n; ++i) {
148+
args->push_back(this->operator()(list[i], p->WithListIndex(i)));
149+
}
150+
return args;
151+
}
152+
142153
void FramePush(const ObjectRef &frame) {
143154
frames.push_back(frame);
144155
frame_vars[frame] = mlc::UList();
@@ -178,14 +189,28 @@ struct IRPrinter : public ObjectRef {
178189
explicit IRPrinter(PrinterConfig cfg, mlc::Dict<Any, VarInfo> obj2info, mlc::Dict<Str, int64_t> defined_names,
179190
mlc::UList frames, mlc::UDict frame_vars)
180191
: IRPrinter(IRPrinter::New(cfg, obj2info, defined_names, frames, frame_vars)) {}
181-
182192
}; // struct IRPrinter
183193

184194
inline Str ToPython(const ObjectRef &obj, const PrinterConfig &cfg) {
185195
IRPrinter printer(cfg);
186-
printer->FramePush(DefaultFrame());
196+
DefaultFrame frame;
197+
printer->FramePush(frame);
187198
Node ret = ::mlc::Lib::IRPrint(obj, printer, ObjectPath::Root());
188199
printer->FramePop();
200+
if (frame->stmts->empty()) {
201+
return ret->ToPython(cfg);
202+
}
203+
if (const auto *block = ret.as<StmtBlockObj>()) {
204+
// TODO: support List::insert by iterator
205+
frame->stmts->insert(frame->stmts.size(), block->stmts->begin(), block->stmts->end());
206+
} else if (const auto *expr = ret.as<ExprObj>()) {
207+
frame->stmts->push_back(ExprStmt(mlc::List<ObjectPath>{}, Optional<Str>{}, Expr(expr)));
208+
} else if (const auto *stmt = ret.as<StmtObj>()) {
209+
frame->stmts->push_back(Stmt(stmt));
210+
} else {
211+
MLC_THROW(ValueError) << "Unsupported type: " << ret;
212+
}
213+
ret = StmtBlock(mlc::List<ObjectPath>{}, Optional<Str>{}, frame->stmts);
189214
return ret->ToPython(cfg);
190215
}
191216

python/mlc/printer/ast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
@mlcd.c_class("mlc.printer.PrinterConfig")
1111
class PrinterConfig(Object):
12+
def_free_var: bool = True
1213
indent_spaces: int = 2
1314
print_line_numbers: int = 0
1415
num_context_lines: int = -1

tests/python/test_printer_ast.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from __future__ import annotations
2+
13
import itertools
2-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING
35

46
import mlc.printer as mlcp
57
import pytest
68

79
if TYPE_CHECKING:
8-
import _pytest
10+
from _pytest.mark import ParameterSet
911

1012

1113
@pytest.mark.parametrize(
@@ -645,7 +647,7 @@ def test_print_expr_stmt_doc() -> None:
645647
],
646648
ids=itertools.count(),
647649
)
648-
def test_print_assert_doc(msg: Optional[mlcp.ast.Expr], expected: str) -> None:
650+
def test_print_assert_doc(msg: mlcp.ast.Expr | None, expected: str) -> None:
649651
test = mlcp.ast.Literal(True)
650652
doc = mlcp.ast.Assert(test, msg)
651653
assert doc.to_python().strip() == expected.strip()
@@ -774,7 +776,7 @@ def test_print_function_doc(
774776
args: list[mlcp.ast.Assign],
775777
decorators: list[mlcp.ast.Id],
776778
body: list[mlcp.ast.Stmt],
777-
return_type: Optional[mlcp.ast.Expr],
779+
return_type: mlcp.ast.Expr | None,
778780
expected: str,
779781
) -> None:
780782
doc = mlcp.ast.Function(mlcp.ast.Id("func"), args, decorators, return_type, body)
@@ -1103,7 +1105,7 @@ def test_print_invalid_multiline_doc_comment(doc: mlcp.ast.Stmt) -> None:
11031105
assert "cannot have newline" in str(e.value)
11041106

11051107

1106-
def generate_expr_precedence_test_cases() -> list["_pytest.mark.ParameterSet"]:
1108+
def generate_expr_precedence_test_cases() -> list[ParameterSet]:
11071109
x = mlcp.ast.Id("x")
11081110
y = mlcp.ast.Id("y")
11091111
z = mlcp.ast.Id("z")

0 commit comments

Comments
 (0)