File tree 1 file changed +13
-0
lines changed
1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -263,6 +263,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
263
263
{
264
264
migraphx::shape::type_t t;
265
265
std::size_t n = 0 ;
266
+ // Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum
267
+ if (info.format == " z" )
268
+ {
269
+ MIGRAPHX_THROW (
270
+ " MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using "
271
+ " migraphx.generate_argument with migraphx.add_literal" );
272
+ }
266
273
visit_types ([&](auto as) {
267
274
if (info.format == py::format_descriptor<decltype (as ())>::format () or
268
275
(info.format == " l" and py::format_descriptor<decltype (as ())>::format () == " q" ) or
@@ -388,6 +395,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
388
395
py::arg (" op" ),
389
396
py::arg (" args" ),
390
397
py::arg (" mod_args" ) = std::vector<migraphx::module*>{})
398
+ .def (
399
+ " add_literal" ,
400
+ [](migraphx::module& mm, migraphx::argument a) {
401
+ return mm.add_literal (a.get_shape (), a.data ());
402
+ },
403
+ py::arg (" data" ))
391
404
.def (
392
405
" add_literal" ,
393
406
[](migraphx::module& mm, py::buffer data) {
You can’t perform that action at this time.
0 commit comments