diff --git a/src/include/migraphx/op/resize.hpp b/src/include/migraphx/op/resize.hpp index a2c8aa390a1..c0d6f43fe5f 100644 --- a/src/include/migraphx/op/resize.hpp +++ b/src/include/migraphx/op/resize.hpp @@ -136,6 +136,108 @@ struct resize std::string name() const { return "resize"; } + private: + // Helper struct to hold interpolation parameters for one dimension + struct interp_params + { + std::size_t i0; // lower index + std::size_t i1; // upper index + double weight; // interpolation weight (0.0 to 1.0) + }; + + // Compute interpolation parameters for a single dimension + template + static interp_params compute_interp_params_1d(std::size_t in_len, + std::size_t out_len, + std::size_t out_idx, + float scale, + const IdxOp& idx_op) + { + // Handle degenerate dimension (length 1) to avoid NaNs + if(in_len <= 1) + { + return {0, 0, 0.0}; + } + + // Compute the original floating-point coordinate + double coord = idx_op(in_len, out_len, out_idx, scale); + + // Clamp to valid input range [0, in_len-1] + double max_c = in_len > 0 ? static_cast(in_len - 1) : 0.0; + coord = std::max(0.0, std::min(max_c, coord)); + + std::size_t base = std::floor(coord); + std::size_t next = std::min(base + 1, (in_len == 0 ? 0 : in_len - 1)); + double frac = coord - static_cast(base); + + return {base, next, frac}; + } + + // Compute input indices for nearest neighbor mode + template + static std::vector + compute_nearest_indices(const std::vector& in_lens, + const std::vector& out_lens, + const std::vector& out_idx_v, + const std::vector& vec_scale, + const NearestOp& nearest_op, + const IdxOp& idx_op) + { + std::vector in_idx(out_idx_v.size()); + for(std::size_t i = 0; i < out_idx_v.size(); ++i) + { + auto idx_val = idx_op(in_lens[i], out_lens[i], out_idx_v[i], vec_scale[i]); + in_idx[i] = nearest_op(in_lens[i], idx_val); + } + return in_idx; + } + + // Perform N-D multilinear interpolation for a single output point + template + static double compute_linear_interp_point(const Data& data, + const std::vector& in_lens, + const std::vector& out_lens, + const std::vector& out_idx_v, + const std::vector& vec_scale, + const IdxOp& idx_op) + { + const std::size_t ndim = out_idx_v.size(); + + // Precompute interpolation parameters for each dimension + std::vector params(ndim); + for(std::size_t d = 0; d < ndim; d++) + { + params[d] = compute_interp_params_1d( + in_lens[d], out_lens[d], out_idx_v[d], vec_scale[d], idx_op); + } + + // Accumulate over 2^ndim corners + double acc = 0.0; + const std::size_t corners = (ndim == 0) ? 1 : (1ULL << ndim); + std::vector in_idx(ndim); + + for(std::size_t mask = 0; mask < corners; ++mask) + { + double w = 1.0; + for(std::size_t d = 0; d < ndim; ++d) + { + const bool use_high = ((mask >> d) & 1U) != 0U; + w *= use_high ? params[d].weight : (1.0 - params[d].weight); + in_idx[d] = use_high ? params[d].i1 : params[d].i0; + } + + if(w == 0.0) + continue; + + using in_value_t = typename Data::value_type; + in_value_t v = data(in_idx.begin(), in_idx.end()); + acc += w * static_cast(v); + } + + return acc; + } + + public: template static auto reflect(Self& self, F f) { @@ -150,8 +252,9 @@ struct resize { check_shapes{inputs, *this, true}.has(1, 2); - if(mode != "nearest") - MIGRAPHX_THROW("RESIZE: Only Nearest mode is supported"); + // Allow nearest and linear; still reject others + if(mode != "nearest" and mode != "linear") + MIGRAPHX_THROW("RESIZE: Only 'nearest' and 'linear' modes are supported"); // Inputs are X, sizes or scale, ROI and axes not supported. if(inputs.size() == 1) @@ -203,9 +306,22 @@ struct resize // compute() method. For any other target, there must be a compiler pass that replaces // this operation with a fixed-size output at runtime. std::size_t max_val = std::numeric_limits::max(); - std::vector dyn_dims(inputs.back().lens().at(0), - shape::dynamic_dimension{0, max_val}); - return {inputs.front().type(), dyn_dims}; + auto input = inputs.front().to_dynamic(); + std::vector dyn_dims(input.ndim(), {0, max_val}); + + if(not scales.empty()) + { + for(std::size_t i = 0; i < scales.size(); i++) + { + dyn_dims[i].min = static_cast(input.dyn_dims()[i].min * scales[i]); + if(input.dyn_dims()[i].max != max_val) + { + dyn_dims[i].max = + static_cast(input.dyn_dims()[i].max * scales[i]); + } + } + } + return {input.type(), dyn_dims}; } } @@ -230,7 +346,7 @@ struct resize in_lens.begin(), vec_scale.begin(), [](size_t out_len, size_t in_len) { - return (in_len == 0 ? 1.f + return (in_len == 0 ? 1.0f : static_cast(out_len) / in_len); }); } @@ -268,7 +384,6 @@ struct resize else { // read the scale from args[1] - // std::copy(input.begin(), input.end(), vec_scale.begin()); // compute the output dimensions from the given scales. This computation // always rounds down, unlike the internal computation in Nearest mode @@ -286,22 +401,41 @@ struct resize shape output_shape = {args[0].get_shape().type(), out_lens}; argument result{output_shape}; - auto nearest_op = get_nearest_op(nearest_mode); - auto idx_op = get_original_idx_op(coordinate_transformation_mode); - - // Populate each element in output by selecting "nearest" item in input. - visit_all(result, args[0])([&](auto output, auto data) { - migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; - shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) { - std::vector in_idx(out_idx_v.size()); - for(auto ii = 0; ii < out_idx_v.size(); ++ii) - { - auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]); - in_idx[ii] = nearest_op(in_lens[ii], idx_val); - } - output[out_idx] = data(in_idx.begin(), in_idx.end()); + + auto idx_op = get_original_idx_op(coordinate_transformation_mode); + + if(mode == "nearest") + { + auto nearest_op = get_nearest_op(nearest_mode); + // Populate each element in output by selecting "nearest" item in input. + visit_all(result, args[0])([&](auto output, auto data) { + migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; + shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) { + auto in_idx = compute_nearest_indices( + in_lens, out_lens, out_idx_v, vec_scale, nearest_op, idx_op); + output[out_idx] = data(in_idx.begin(), in_idx.end()); + }); + }); + } + else if(mode == "linear") + { + // N-D multilinear interpolation + visit_all(result, args[0])([&](auto output, auto data) { + migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; + shape_for_each(out_comp_shape, [&](const auto& out_idx_v, std::size_t out_idx) { + double acc = compute_linear_interp_point( + data, in_lens, out_lens, out_idx_v, vec_scale, idx_op); + + using out_value_t = typename decltype(output)::value_type; + output[out_idx] = static_cast(acc); + }); }); - }); + } + else + { + MIGRAPHX_THROW("RESIZE: Unsupported mode in compute()"); + } + return result; } }; diff --git a/src/onnx/parse_resize.cpp b/src/onnx/parse_resize.cpp index d2d70fd024e..239e206da8e 100644 --- a/src/onnx/parse_resize.cpp +++ b/src/onnx/parse_resize.cpp @@ -470,6 +470,20 @@ struct parse_resize : op_parser auto out_lens = resize.out_lens; auto vec_scale = resize.vec_scale; + if(args_0->get_shape().dynamic()) + { + // Resize's compute_shape() will read scales_sizes_arg as "scales" or "sizes" + // depending on its data type + + return info.add_instruction( + make_op("resize", + {{"mode", resize.get_mode()}, + {"scales", vec_scale}, + {"coordinate_transformation_mode", resize.get_coord_trans_mode()}}), + args_0, + resize.get_scales_sizes_arg()); + } + // out_lens and other variables can't be populated if non-constant (runtime) size // inputs. if(not resize.is_constant_scale_input()) diff --git a/test/onnx/verify/resize_downsample_linear_dyn_test.cpp b/test/onnx/verify/resize_downsample_linear_dyn_test.cpp new file mode 100644 index 00000000000..f8411ce7dab --- /dev/null +++ b/test/onnx/verify/resize_downsample_linear_dyn_test.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(resize_downsample_linear_dyn_test) +{ + using migraphx::half; + migraphx::onnx_options options; + options.map_dyn_input_dims = {{"X", {{1, 1}, {1, 1}, {2, 3}, {4, 8}}}}; + migraphx::program p = read_onnx("resize_downsample_linear_half_test.onnx", options); + p.compile(migraphx::make_target("ref")); + + migraphx::shape sx{migraphx::shape::half_type, {1, 1, 2, 4}}; + std::vector dx = {half{1}, half{2}, half{3}, half{4}, half{5}, half{6}, half{7}, half{8}}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // Expected output was calculated without any quantization + std::vector gold = {half{2.8333333}, half{4.833333}}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/resize_upsample_linear_dyn_test.cpp b/test/onnx/verify/resize_upsample_linear_dyn_test.cpp new file mode 100644 index 00000000000..90476115b3e --- /dev/null +++ b/test/onnx/verify/resize_upsample_linear_dyn_test.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(resize_upsample_linear_dyn_test) +{ + migraphx::onnx_options options; + options.map_dyn_input_dims = {{"X", {{1, 1}, {1, 1}, {2, 3}, {2, 3}}}}; + + migraphx::program p = read_onnx("resize_upsample_linear_test.onnx", options); + p.compile(migraphx::make_target("ref")); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + std::vector dx = {1.0f, 2.0f, 3.0f, 4.0f}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/ref/resize.cpp b/test/ref/resize.cpp index 703c2bc040e..2b53a727b08 100644 --- a/test/ref/resize.cpp +++ b/test/ref/resize.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -573,3 +573,163 @@ TEST_CASE(resize_fail_test_8) a1); })); } + +TEST_CASE(resize_linear_1x1_degenerate_test) +{ + // Test degenerate 1x1 input case with linear mode + // This tests the special handling in resize.hpp + // which prevents NaN values when in_lens[d] <= 1 + migraphx::program p; + auto* mm = p.get_main_module(); + + // Single pixel input with value 5.0 + migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1}}; + std::vector data = {5.0f}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + + // Upsample 1x1 -> 4x4 using linear interpolation + mm->add_instruction(migraphx::make_op("resize", + {{"sizes", {1, 1, 4, 4}}, + {"mode", "linear"}, + {"coordinate_transformation_mode", "half_pixel"}}), + a0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector res_data; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + // All output values should be exactly 5.0 (the single input value) + // since there's only one value to interpolate from + // clang-format off + std::vector golden = { + 5.0f, 5.0f, 5.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 5.0f + }; + // clang-format on + + EXPECT(migraphx::verify::verify_rms_range(res_data, golden)); +} + +TEST_CASE(resize_linear_1x4_degenerate_height_test) +{ + // Test degenerate case where only height is 1, width is normal + // This tests partial degenerate dimensions + migraphx::program p; + auto* mm = p.get_main_module(); + + // 1x4 input (height=1, width=4) + migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 4}}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + + // Upsample to 3x8 - height should replicate, width should interpolate + mm->add_instruction(migraphx::make_op("resize", + {{"sizes", {1, 1, 3, 8}}, + {"mode", "linear"}, + {"coordinate_transformation_mode", "half_pixel"}}), + a0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector res_data; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + // Height dimension (size 1) should just replicate + // Width dimension should linearly interpolate 1,2,3,4 -> 8 values + // clang-format off + std::vector golden = { + 1.0f, 1.25f, 1.75f, 2.25f, 2.75f, 3.25f, 3.75f, 4.0f, + 1.0f, 1.25f, 1.75f, 2.25f, 2.75f, 3.25f, 3.75f, 4.0f, + 1.0f, 1.25f, 1.75f, 2.25f, 2.75f, 3.25f, 3.75f, 4.0f + }; + // clang-format on + + EXPECT(migraphx::verify::verify_rms_range(res_data, golden)); +} + +TEST_CASE(resize_linear_1x1_with_scales_test) +{ + // Test 1x1 degenerate case using scales instead of sizes + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1}}; + std::vector data = {7.5f}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + + // Use scales attribute: 1x1 -> 5x5 (5x scale) + mm->add_instruction(migraphx::make_op("resize", + {{"scales", {1.0f, 1.0f, 5.0f, 5.0f}}, + {"mode", "linear"}, + {"coordinate_transformation_mode", "asymmetric"}}), + a0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector res_data; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + // All 25 output values should be 7.5 + std::vector golden(25, 7.5f); + + EXPECT(migraphx::verify::verify_rms_range(res_data, golden)); +} + +TEST_CASE(resize_linear_1x1_align_corners_test) +{ + // Test 1x1 with align_corners coordinate transformation + // This ensures the degenerate handling works with different coord transforms + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {2, 3, 1, 1}}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + + // Batch=2, Channels=3, 1x1 -> 3x3 spatial + mm->add_instruction(migraphx::make_op("resize", + {{"scales", {1.0f, 1.0f, 3.0f, 3.0f}}, + {"mode", "linear"}, + {"coordinate_transformation_mode", "align_corners"}}), + a0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector res_data; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + // Each channel should replicate its single value across all 9 spatial positions + // clang-format off + std::vector golden = { + // Batch 0, Channel 0 (value 1.0) + 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, + // Batch 0, Channel 1 (value 2.0) + 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f, + // Batch 0, Channel 2 (value 3.0) + 3.0f, 3.0f, 3.0f, + 3.0f, 3.0f, 3.0f, + 3.0f, 3.0f, 3.0f, + // Batch 1, Channel 0 (value 4.0) + 4.0f, 4.0f, 4.0f, + 4.0f, 4.0f, 4.0f, + 4.0f, 4.0f, 4.0f, + // Batch 1, Channel 1 (value 5.0) + 5.0f, 5.0f, 5.0f, + 5.0f, 5.0f, 5.0f, + 5.0f, 5.0f, 5.0f, + // Batch 1, Channel 2 (value 6.0) + 6.0f, 6.0f, 6.0f, + 6.0f, 6.0f, 6.0f, + 6.0f, 6.0f, 6.0f + }; + // clang-format on + + EXPECT(migraphx::verify::verify_rms_range(res_data, golden)); +}