Skip to content

Commit 4d75039

Browse files
authored
[eudsl-tblgen] fix triton attrs (#84)
1 parent ddfd299 commit 4d75039

File tree

4 files changed

+133
-47
lines changed

4 files changed

+133
-47
lines changed

projects/eudsl-tblgen/src/eudsl_tblgen/cmake/eudsl_tblgen-config.cmake

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# copy-pasta from AddMLIR.cmake/AddLLVM.cmake/TableGen.cmake
77

88
function(eudsl_tblgen target)
9-
cmake_parse_arguments(ARG "" "TD_FILE;OUTPUT_DIRECTORY;KIND" "INCLUDES;DEPENDS" ${ARGN})
9+
cmake_parse_arguments(ARG "" "TD_FILE;OUTPUT_DIRECTORY;KIND" "INCLUDES;DEPENDS;INCLUDE;EXCLUDE" ${ARGN})
1010
if (IS_ABSOLUTE ${ARG_TD_FILE})
1111
set(EUDSL_TBLGEN_TD_FILE_INPUT_ABSOLUTE ${ARG_TD_FILE})
1212
else()
@@ -27,6 +27,8 @@ function(eudsl_tblgen target)
2727
${Python_EXECUTABLE} -Wignore -m eudsl_tblgen.mlir ${EUDSL_TBLGEN_TD_FILE_INPUT_ABSOLUTE}
2828
-k ${ARG_KIND} -I ${eudsl_tblgen_includes}
2929
-o "${ARG_OUTPUT_DIRECTORY}"
30+
--include ${ARG_INCLUDE}
31+
--exclude ${ARG_EXCLUDE}
3032
)
3133

3234
get_filename_component(_prefix ${EUDSL_TBLGEN_TD_FILE_INPUT_ABSOLUTE} NAME_WE)

projects/eudsl-tblgen/src/eudsl_tblgen/mlir/__init__.py

+66-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from dataclasses import dataclass
77
import re
8+
from functools import lru_cache
89
from textwrap import dedent
910

1011
from .. import AttrOrTypeParameter
@@ -32,30 +33,44 @@ def map_cpp_to_c_type(t):
3233
element_ty_reg = re.compile(r"ArrayRef<(\w+)>")
3334

3435

35-
@dataclass
36+
@dataclass(frozen=True)
3637
class Param:
3738
class_name: str
3839
param_name: str
3940
c_type: str
4041
cpp_type: str
4142
param_def: AttrOrTypeParameter
4243

44+
@lru_cache(maxsize=1)
4345
def c_param_str(self):
4446
return f"{self.c_type} {self.param_name}"
4547

4648
@property
49+
@lru_cache(maxsize=1)
4750
def getter_name(self):
4851
return f"mlir{self.class_name}Get{self.param_name}"
4952

5053
# TODO(max): bad heuristic - should look inside param_def
54+
@lru_cache(maxsize=1)
5155
def needs_wrap_unwrap(self):
5256
return self.cpp_type != self.c_type
5357

58+
@property
59+
@lru_cache(maxsize=1)
60+
def is_optional(self):
61+
return self.param_def.is_optional()
62+
63+
@property
64+
@lru_cache(maxsize=1)
65+
def default_value(self):
66+
return self.param_def.get_default_value()
5467

55-
@dataclass
68+
69+
@dataclass(frozen=True)
5670
class ArrayRefParam(Param):
5771
c_element_type: str
5872

73+
@lru_cache(maxsize=1)
5974
def c_param_str(self):
6075
return f"{self.c_element_type} *{self.param_name}, unsigned n{self.param_name}s"
6176

@@ -100,7 +115,7 @@ def emit_c_attr_or_type_builder(
100115
cclass_kind: CClassKind, class_name, params: list[AttrOrTypeParameter]
101116
):
102117
mapped_params = map_params(class_name, params)
103-
sig = f"""{cclass_kind} mlir{class_name}{'Attr' if cclass_kind == CClassKind.ATTRIBUTE else 'Type'}Get({', '.join([p.c_param_str() for p in mapped_params])}, MlirContext mlirContext)"""
118+
sig = f"""{cclass_kind} mlir{class_name}{cclass_kind.replace('Mlir', '')}Get({', '.join([p.c_param_str() for p in mapped_params])}, MlirContext mlirContext)"""
104119
decl = f"""MLIR_CAPI_EXPORTED {sig};"""
105120
defn = dedent(
106121
f"""
@@ -162,10 +177,6 @@ def emit_attr_or_type_nanobind_class(
162177
):
163178
mapped_params = map_params(class_name, params)
164179

165-
mlir_attr_or_mlir_type = (
166-
"MlirAttribute" if cclass_kind == CClassKind.ATTRIBUTE else "MlirType"
167-
)
168-
169180
helper_decls = []
170181
helper_defns = []
171182
helper_decls.append(
@@ -181,12 +192,12 @@ def emit_attr_or_type_nanobind_class(
181192
)
182193
)
183194
helper_decls.append(
184-
f"MLIR_CAPI_EXPORTED bool isaMlir{class_name}({mlir_attr_or_mlir_type} thing);"
195+
f"MLIR_CAPI_EXPORTED bool isaMlir{class_name}({cclass_kind} thing);"
185196
)
186197
helper_defns.append(
187198
dedent(
188199
f"""\
189-
bool isaMlir{class_name}({mlir_attr_or_mlir_type} thing) {{
200+
bool isaMlir{class_name}({cclass_kind} thing) {{
190201
return isa<{class_name}>(unwrap(thing));
191202
}}
192203
"""
@@ -196,13 +207,16 @@ def emit_attr_or_type_nanobind_class(
196207
params_str = []
197208
for mp in mapped_params:
198209
if isinstance(mp, ArrayRefParam):
199-
params_str.append(f"std::vector<{mp.c_element_type}> &{mp.param_name}")
210+
typ = f"std::vector<{mp.c_element_type}>&"
200211
else:
201-
params_str.append(f"{mp.c_type} {mp.param_name}")
212+
typ = f"{mp.c_type}"
213+
if mp.is_optional:
214+
typ = f"std::optional<{typ}>"
215+
params_str.append(f"{typ} {mp.param_name}")
202216
params_str = ", ".join(params_str)
203217
s = dedent(
204218
f"""
205-
auto nb{class_name} = {'mlir_attribute_subclass' if cclass_kind == CClassKind.ATTRIBUTE else 'mlir_type_subclass'}(m, "{class_name}", isaMlir{class_name}, mlir{class_name}GetTypeID);
219+
auto nb{class_name} = {underscore(cclass_kind)}_subclass(m, "{class_name}", isaMlir{class_name}, mlir{class_name}GetTypeID);
206220
nb{class_name}.def_staticmethod("get", []({params_str}, MlirContext context) {{
207221
"""
208222
)
@@ -211,10 +225,29 @@ def emit_attr_or_type_nanobind_class(
211225
help_str = []
212226
for mp in mapped_params:
213227
if isinstance(mp, ArrayRefParam):
214-
arg_str.append(f"{mp.param_name}.data(), {mp.param_name}.size()")
228+
if mp.is_optional:
229+
arg_str.append(
230+
f"{mp.param_name}.has_value() ? {mp.param_name}->data() : nullptr, {mp.param_name}.has_value() ? {mp.param_name}->size() : 0"
231+
)
232+
else:
233+
arg_str.append(f"{mp.param_name}.data(), {mp.param_name}.size()")
215234
else:
216-
arg_str.append(f"{mp.param_name}")
217-
help_str.append(f'"{underscore(mp.param_name)}"_a')
235+
if (default_val := mp.default_value) and mp.needs_wrap_unwrap():
236+
default_val = f"wrap({default_val})"
237+
arg_str.append(
238+
f"{mp.param_name}.has_value() ? *{mp.param_name} : {default_val}"
239+
)
240+
elif mp.default_value and not mp.needs_wrap_unwrap():
241+
arg_str.append(f"*{mp.param_name}")
242+
else:
243+
arg_str.append(f"{mp.param_name}")
244+
245+
if (default_val := mp.default_value) and not mp.needs_wrap_unwrap():
246+
help_str.append(f'"{underscore(mp.param_name)}"_a = {default_val}')
247+
elif mp.is_optional:
248+
help_str.append(f'"{underscore(mp.param_name)}"_a = nb::none()')
249+
else:
250+
help_str.append(f'"{underscore(mp.param_name)}"_a')
218251
arg_str.append("context")
219252
arg_str = ", ".join(arg_str)
220253

@@ -223,7 +256,7 @@ def emit_attr_or_type_nanobind_class(
223256

224257
s += dedent(
225258
f"""\
226-
return mlir{class_name}{'Attr' if cclass_kind == CClassKind.ATTRIBUTE else 'Type'}Get({arg_str});
259+
return mlir{class_name}{cclass_kind.replace('Mlir', '')}Get({arg_str});
227260
}}, {help_str});
228261
"""
229262
)
@@ -232,7 +265,7 @@ def emit_attr_or_type_nanobind_class(
232265
if isinstance(mp, ArrayRefParam):
233266
s += dedent(
234267
f"""
235-
nb{class_name}.def_property_readonly("{underscore(mp.param_name)}", []({mlir_attr_or_mlir_type} self) {{
268+
nb{class_name}.def_property_readonly("{underscore(mp.param_name)}", []({cclass_kind} self) {{
236269
unsigned n{mp.param_name}s;
237270
{mp.c_element_type}* {mp.param_name}Ptr;
238271
{mp.getter_name}(self, &{mp.param_name}Ptr, &n{mp.param_name}s);
@@ -243,7 +276,7 @@ def emit_attr_or_type_nanobind_class(
243276
else:
244277
s += dedent(
245278
f"""
246-
nb{class_name}.def_property_readonly("{underscore(mp.param_name)}", []({'MlirAttribute' if cclass_kind == CClassKind.ATTRIBUTE else 'MlirType'} self) {{
279+
nb{class_name}.def_property_readonly("{underscore(mp.param_name)}", []({cclass_kind} self) {{
247280
return {mp.getter_name}(self);
248281
}});
249282
"""
@@ -252,16 +285,27 @@ def emit_attr_or_type_nanobind_class(
252285
return helper_decls, helper_defns, s
253286

254287

255-
def emit_decls_defns_nbclasses(cclass_kind: CClassKind, defs):
288+
def emit_decls_defns_nbclasses(
289+
cclass_kind: CClassKind, defs, include=None, exclude=None
290+
):
291+
if include or exclude:
292+
assert not (include and exclude), f"only include or exclude allowed"
293+
if exclude is None:
294+
exclude = set()
256295
decls = []
257296
defns = []
258297
nbclasses = []
259298
for d in defs:
299+
name = d.get_name()
300+
if include is not None and name not in include:
301+
continue
302+
if d.get_name() in exclude:
303+
continue
304+
base_class_name = d.get_cpp_base_class_name()
305+
assert base_class_name in {"::mlir::Attribute", "::mlir::Type"}
306+
class_name = d.get_cpp_class_name()
260307
params = list(d.get_parameters())
261308
if params:
262-
base_class_name = d.get_cpp_base_class_name()
263-
assert base_class_name in {"::mlir::Attribute", "::mlir::Type"}
264-
class_name = d.get_cpp_class_name()
265309
decl, defn = emit_c_attr_or_type_builder(cclass_kind, class_name, params)
266310
decls.append(decl)
267311
defns.append(defn)

projects/eudsl-tblgen/src/eudsl_tblgen/mlir/__main__.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ def __call__(self, parser, namespace, values, option_string=None):
4343
setattr(namespace, self.dest, value)
4444

4545

46-
def emit_attrs_or_types(kind, rk, output_dir, output_prefix):
46+
def emit_attrs_or_types(
47+
kind, rk, output_dir, output_prefix, include=None, exclude=None
48+
):
4749
all_defs = collect_all_attr_or_type_defs(collect_all_defs(rk))
48-
decls, defns, nbclasses = emit_decls_defns_nbclasses(kind, all_defs)
50+
decls, defns, nbclasses = emit_decls_defns_nbclasses(
51+
kind, all_defs, include, exclude
52+
)
4953

5054
attr_decls = open(output_dir / f"{output_prefix}_{kind}_decls.h.inc", "w")
5155
attr_defns = open(output_dir / f"{output_prefix}_{kind}_defns.cpp.inc", "w")
@@ -74,7 +78,14 @@ def main(args):
7478
str(args.td_file),
7579
[str(ip) for ip in args.include_paths],
7680
)
77-
emit_attrs_or_types(args.kind, defs_rk, args.output_dir, args.output_prefix)
81+
emit_attrs_or_types(
82+
args.kind,
83+
defs_rk,
84+
args.output_dir,
85+
args.output_prefix,
86+
include=args.include,
87+
exclude=args.exclude,
88+
)
7889

7990

8091
if __name__ == "__main__":
@@ -83,8 +94,18 @@ def main(args):
8394
args.add_argument("-o", "--output-dir", type=Path, required=True)
8495
args.add_argument("-k", "--kind", type=CClassKind, action=EnumAction, required=True)
8596
args.add_argument("-I", "--include-paths", nargs="+", type=Path, required=True)
97+
args.add_argument("--exclude", nargs="*")
98+
args.add_argument("--include", nargs="*")
8699

87100
args = args.parse_args()
101+
if args.include:
102+
args.include = set(args.include)
103+
else:
104+
args.include = None
105+
if args.exclude:
106+
args.exclude = set(args.exclude)
107+
else:
108+
args.exclude = None
88109
args.output_prefix = Path(args.td_file).stem
89110

90111
main(args)

projects/eudsl-tblgen/tests/test_c_api_emission.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -40,35 +40,54 @@ def test_attrs(record_keeper_triton_gpu_attrs):
4040
all_defs = collect_all_attr_or_type_defs(
4141
collect_all_defs(record_keeper_triton_gpu_attrs)
4242
)
43-
decls, defns, nbclasses = emit_decls_defns_nbclasses(CClassKind.ATTRIBUTE, all_defs)
43+
decls, defns, nbclasses = emit_decls_defns_nbclasses(
44+
CClassKind.ATTRIBUTE, all_defs, exclude={"BlockedEncodingAttr", "SliceEncodingAttr"}
45+
)
46+
47+
dump_dir = Path(__file__).parent
4448

45-
print()
46-
for d in decls:
47-
print(d)
48-
for d in defns:
49-
print(d)
49+
with open(f"{dump_dir}/TritonGPUAttrDefs_MlirAttribute_decls.h.inc", "w") as f:
50+
for d in decls:
51+
print(d, file=f)
52+
with open(f"{dump_dir}/TritonGPUAttrDefs_MlirAttribute_defns.cpp.inc", "w") as f:
53+
for d in defns:
54+
print(d, file=f)
5055
for hdecl, hdefn, n in nbclasses:
51-
for h in hdecl:
52-
print(h)
53-
for h in hdefn:
54-
print(h)
55-
print(n)
56+
with open(f"{dump_dir}/TritonGPUAttrDefs_MlirAttribute_decls.h.inc", "a") as f:
57+
for h in hdecl:
58+
print(h, file=f)
59+
with open(
60+
f"{dump_dir}/TritonGPUAttrDefs_MlirAttribute_defns.cpp.inc", "a"
61+
) as f:
62+
for h in hdefn:
63+
print(h, file=f)
64+
with open(
65+
f"{dump_dir}/TritonGPUAttrDefs_MlirAttribute_nbclasses.cpp.inc", "w"
66+
) as f:
67+
for *_, n in nbclasses:
68+
print(n, file=f)
5669

5770

5871
def test_types(record_keeper_triton_gpu_types):
5972
all_defs = collect_all_attr_or_type_defs(
6073
collect_all_defs(record_keeper_triton_gpu_types)
6174
)
6275
decls, defns, nbclasses = emit_decls_defns_nbclasses(CClassKind.TYPE, all_defs)
76+
dump_dir = Path(__file__).parent
6377

64-
print()
65-
for d in decls:
66-
print(d)
67-
for d in defns:
68-
print(d)
78+
with open(f"{dump_dir}/TritonGPUTypes_MlirType_decls.h.inc", "w") as f:
79+
for d in decls:
80+
print(d, file=f)
81+
with open(f"{dump_dir}/TritonGPUTypes_MlirType_defns.cpp.inc", "w") as f:
82+
for d in defns:
83+
print(d, file=f)
6984
for hdecl, hdefn, n in nbclasses:
70-
for h in hdecl:
71-
print(h)
72-
for h in hdefn:
73-
print(h)
74-
print(n)
85+
with open(f"{dump_dir}/TritonGPUTypes_MlirType_decls.h.inc", "a") as f:
86+
for h in hdecl:
87+
print(h, file=f)
88+
with open(f"{dump_dir}/TritonGPUTypes_MlirType_defns.cpp.inc", "a") as f:
89+
for h in hdefn:
90+
print(h, file=f)
91+
with open(f"{dump_dir}/TritonGPUTypes_MlirType_nbclasses.cpp.inc", "w") as f:
92+
for *_, n in nbclasses:
93+
print(n, file=f)

0 commit comments

Comments
 (0)