5
5
import warnings
6
6
from dataclasses import dataclass
7
7
import re
8
+ from functools import lru_cache
8
9
from textwrap import dedent
9
10
10
11
from .. import AttrOrTypeParameter
@@ -32,30 +33,44 @@ def map_cpp_to_c_type(t):
32
33
element_ty_reg = re .compile (r"ArrayRef<(\w+)>" )
33
34
34
35
35
- @dataclass
36
+ @dataclass ( frozen = True )
36
37
class Param :
37
38
class_name : str
38
39
param_name : str
39
40
c_type : str
40
41
cpp_type : str
41
42
param_def : AttrOrTypeParameter
42
43
44
+ @lru_cache (maxsize = 1 )
43
45
def c_param_str (self ):
44
46
return f"{ self .c_type } { self .param_name } "
45
47
46
48
@property
49
+ @lru_cache (maxsize = 1 )
47
50
def getter_name (self ):
48
51
return f"mlir{ self .class_name } Get{ self .param_name } "
49
52
50
53
# TODO(max): bad heuristic - should look inside param_def
54
+ @lru_cache (maxsize = 1 )
51
55
def needs_wrap_unwrap (self ):
52
56
return self .cpp_type != self .c_type
53
57
58
+ @property
59
+ @lru_cache (maxsize = 1 )
60
+ def is_optional (self ):
61
+ return self .param_def .is_optional ()
62
+
63
+ @property
64
+ @lru_cache (maxsize = 1 )
65
+ def default_value (self ):
66
+ return self .param_def .get_default_value ()
54
67
55
- @dataclass
68
+
69
+ @dataclass (frozen = True )
56
70
class ArrayRefParam (Param ):
57
71
c_element_type : str
58
72
73
+ @lru_cache (maxsize = 1 )
59
74
def c_param_str (self ):
60
75
return f"{ self .c_element_type } *{ self .param_name } , unsigned n{ self .param_name } s"
61
76
@@ -100,7 +115,7 @@ def emit_c_attr_or_type_builder(
100
115
cclass_kind : CClassKind , class_name , params : list [AttrOrTypeParameter ]
101
116
):
102
117
mapped_params = map_params (class_name , params )
103
- sig = f"""{ cclass_kind } mlir{ class_name } { 'Attr' if cclass_kind == CClassKind . ATTRIBUTE else 'Type' } Get({ ', ' .join ([p .c_param_str () for p in mapped_params ])} , MlirContext mlirContext)"""
118
+ sig = f"""{ cclass_kind } mlir{ class_name } { cclass_kind . replace ( 'Mlir' , '' ) } Get({ ', ' .join ([p .c_param_str () for p in mapped_params ])} , MlirContext mlirContext)"""
104
119
decl = f"""MLIR_CAPI_EXPORTED { sig } ;"""
105
120
defn = dedent (
106
121
f"""
@@ -162,10 +177,6 @@ def emit_attr_or_type_nanobind_class(
162
177
):
163
178
mapped_params = map_params (class_name , params )
164
179
165
- mlir_attr_or_mlir_type = (
166
- "MlirAttribute" if cclass_kind == CClassKind .ATTRIBUTE else "MlirType"
167
- )
168
-
169
180
helper_decls = []
170
181
helper_defns = []
171
182
helper_decls .append (
@@ -181,12 +192,12 @@ def emit_attr_or_type_nanobind_class(
181
192
)
182
193
)
183
194
helper_decls .append (
184
- f"MLIR_CAPI_EXPORTED bool isaMlir{ class_name } ({ mlir_attr_or_mlir_type } thing);"
195
+ f"MLIR_CAPI_EXPORTED bool isaMlir{ class_name } ({ cclass_kind } thing);"
185
196
)
186
197
helper_defns .append (
187
198
dedent (
188
199
f"""\
189
- bool isaMlir{ class_name } ({ mlir_attr_or_mlir_type } thing) {{
200
+ bool isaMlir{ class_name } ({ cclass_kind } thing) {{
190
201
return isa<{ class_name } >(unwrap(thing));
191
202
}}
192
203
"""
@@ -196,13 +207,16 @@ def emit_attr_or_type_nanobind_class(
196
207
params_str = []
197
208
for mp in mapped_params :
198
209
if isinstance (mp , ArrayRefParam ):
199
- params_str . append ( f"std::vector<{ mp .c_element_type } > & { mp . param_name } " )
210
+ typ = f"std::vector<{ mp .c_element_type } >&"
200
211
else :
201
- params_str .append (f"{ mp .c_type } { mp .param_name } " )
212
+ typ = f"{ mp .c_type } "
213
+ if mp .is_optional :
214
+ typ = f"std::optional<{ typ } >"
215
+ params_str .append (f"{ typ } { mp .param_name } " )
202
216
params_str = ", " .join (params_str )
203
217
s = dedent (
204
218
f"""
205
- auto nb{ class_name } = { 'mlir_attribute_subclass' if cclass_kind == CClassKind . ATTRIBUTE else 'mlir_type_subclass' } (m, "{ class_name } ", isaMlir{ class_name } , mlir{ class_name } GetTypeID);
219
+ auto nb{ class_name } = { underscore ( cclass_kind ) } _subclass (m, "{ class_name } ", isaMlir{ class_name } , mlir{ class_name } GetTypeID);
206
220
nb{ class_name } .def_staticmethod("get", []({ params_str } , MlirContext context) {{
207
221
"""
208
222
)
@@ -211,10 +225,29 @@ def emit_attr_or_type_nanobind_class(
211
225
help_str = []
212
226
for mp in mapped_params :
213
227
if isinstance (mp , ArrayRefParam ):
214
- arg_str .append (f"{ mp .param_name } .data(), { mp .param_name } .size()" )
228
+ if mp .is_optional :
229
+ arg_str .append (
230
+ f"{ mp .param_name } .has_value() ? { mp .param_name } ->data() : nullptr, { mp .param_name } .has_value() ? { mp .param_name } ->size() : 0"
231
+ )
232
+ else :
233
+ arg_str .append (f"{ mp .param_name } .data(), { mp .param_name } .size()" )
215
234
else :
216
- arg_str .append (f"{ mp .param_name } " )
217
- help_str .append (f'"{ underscore (mp .param_name )} "_a' )
235
+ if (default_val := mp .default_value ) and mp .needs_wrap_unwrap ():
236
+ default_val = f"wrap({ default_val } )"
237
+ arg_str .append (
238
+ f"{ mp .param_name } .has_value() ? *{ mp .param_name } : { default_val } "
239
+ )
240
+ elif mp .default_value and not mp .needs_wrap_unwrap ():
241
+ arg_str .append (f"*{ mp .param_name } " )
242
+ else :
243
+ arg_str .append (f"{ mp .param_name } " )
244
+
245
+ if (default_val := mp .default_value ) and not mp .needs_wrap_unwrap ():
246
+ help_str .append (f'"{ underscore (mp .param_name )} "_a = { default_val } ' )
247
+ elif mp .is_optional :
248
+ help_str .append (f'"{ underscore (mp .param_name )} "_a = nb::none()' )
249
+ else :
250
+ help_str .append (f'"{ underscore (mp .param_name )} "_a' )
218
251
arg_str .append ("context" )
219
252
arg_str = ", " .join (arg_str )
220
253
@@ -223,7 +256,7 @@ def emit_attr_or_type_nanobind_class(
223
256
224
257
s += dedent (
225
258
f"""\
226
- return mlir{ class_name } { 'Attr' if cclass_kind == CClassKind . ATTRIBUTE else 'Type' } Get({ arg_str } );
259
+ return mlir{ class_name } { cclass_kind . replace ( 'Mlir' , '' ) } Get({ arg_str } );
227
260
}}, { help_str } );
228
261
"""
229
262
)
@@ -232,7 +265,7 @@ def emit_attr_or_type_nanobind_class(
232
265
if isinstance (mp , ArrayRefParam ):
233
266
s += dedent (
234
267
f"""
235
- nb{ class_name } .def_property_readonly("{ underscore (mp .param_name )} ", []({ mlir_attr_or_mlir_type } self) {{
268
+ nb{ class_name } .def_property_readonly("{ underscore (mp .param_name )} ", []({ cclass_kind } self) {{
236
269
unsigned n{ mp .param_name } s;
237
270
{ mp .c_element_type } * { mp .param_name } Ptr;
238
271
{ mp .getter_name } (self, &{ mp .param_name } Ptr, &n{ mp .param_name } s);
@@ -243,7 +276,7 @@ def emit_attr_or_type_nanobind_class(
243
276
else :
244
277
s += dedent (
245
278
f"""
246
- nb{ class_name } .def_property_readonly("{ underscore (mp .param_name )} ", []({ 'MlirAttribute' if cclass_kind == CClassKind . ATTRIBUTE else 'MlirType' } self) {{
279
+ nb{ class_name } .def_property_readonly("{ underscore (mp .param_name )} ", []({ cclass_kind } self) {{
247
280
return { mp .getter_name } (self);
248
281
}});
249
282
"""
@@ -252,16 +285,27 @@ def emit_attr_or_type_nanobind_class(
252
285
return helper_decls , helper_defns , s
253
286
254
287
255
- def emit_decls_defns_nbclasses (cclass_kind : CClassKind , defs ):
288
+ def emit_decls_defns_nbclasses (
289
+ cclass_kind : CClassKind , defs , include = None , exclude = None
290
+ ):
291
+ if include or exclude :
292
+ assert not (include and exclude ), f"only include or exclude allowed"
293
+ if exclude is None :
294
+ exclude = set ()
256
295
decls = []
257
296
defns = []
258
297
nbclasses = []
259
298
for d in defs :
299
+ name = d .get_name ()
300
+ if include is not None and name not in include :
301
+ continue
302
+ if d .get_name () in exclude :
303
+ continue
304
+ base_class_name = d .get_cpp_base_class_name ()
305
+ assert base_class_name in {"::mlir::Attribute" , "::mlir::Type" }
306
+ class_name = d .get_cpp_class_name ()
260
307
params = list (d .get_parameters ())
261
308
if params :
262
- base_class_name = d .get_cpp_base_class_name ()
263
- assert base_class_name in {"::mlir::Attribute" , "::mlir::Type" }
264
- class_name = d .get_cpp_class_name ()
265
309
decl , defn = emit_c_attr_or_type_builder (cclass_kind , class_name , params )
266
310
decls .append (decl )
267
311
defns .append (defn )
0 commit comments