Skip to content

Commit 6f6705c

Browse files
catch python buffer unsupported types (#3701)
1 parent 0860461 commit 6f6705c

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/py/migraphx_py.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
263263
{
264264
migraphx::shape::type_t t;
265265
std::size_t n = 0;
266+
// Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum
267+
if(info.format == "z")
268+
{
269+
MIGRAPHX_THROW(
270+
"MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using "
271+
"migraphx.generate_argument with migraphx.add_literal");
272+
}
266273
visit_types([&](auto as) {
267274
if(info.format == py::format_descriptor<decltype(as())>::format() or
268275
(info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
@@ -388,6 +395,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
388395
py::arg("op"),
389396
py::arg("args"),
390397
py::arg("mod_args") = std::vector<migraphx::module*>{})
398+
.def(
399+
"add_literal",
400+
[](migraphx::module& mm, migraphx::argument a) {
401+
return mm.add_literal(a.get_shape(), a.data());
402+
},
403+
py::arg("data"))
391404
.def(
392405
"add_literal",
393406
[](migraphx::module& mm, py::buffer data) {

0 commit comments

Comments
 (0)