Skip to content

Commit 4955f3b

Browse files
Deivanayaki-Sdeivanayakisankaralingam
and
deivanayakisankaralingam
authored
[Relax][PyTorch] Add Pixel Shuffle Op Support for Exported Program and FX graph (#17886)
* add pixel shuffle op into torch frontends * fix end of files formatting issue * fix trailing whitespaces issue * fix lint issues * fix long line code formatting * add arg check condition * fix lint issue in base fx graph script * add condition in struct info function * fix lint issue in struct info func --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent cbba0e3 commit 4955f3b

File tree

13 files changed

+308
-0
lines changed

13 files changed

+308
-0
lines changed

include/tvm/relax/attrs/nn.h

+9
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,15 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
603603
}
604604
};
605605

606+
/*! \brief Attributes used for the pixel shuffle operator */
607+
struct PixelShuffleAttrs : public tvm::AttrsNode<PixelShuffleAttrs> {
608+
int upscale_factor;
609+
610+
TVM_DECLARE_ATTRS(PixelShuffleAttrs, "relax.attrs.PixelShuffleAttrs") {
611+
TVM_ATTR_FIELD(upscale_factor).describe("Scale factor for spatial upsampling.");
612+
}
613+
};
614+
606615
} // namespace relax
607616
} // namespace tvm
608617

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

+9
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,15 @@ def _pad(self, node: fx.Node) -> relax.Var:
913913

914914
return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))
915915

916+
def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
917+
data = self.env[node.args[0]]
918+
upscale_factor = node.args[1]
919+
assert isinstance(
920+
upscale_factor, int
921+
), "PixelShuffle only accepts an integer upscale_factor."
922+
923+
return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor))
924+
916925
def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
917926
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
918927
query = transpose_S_H(self.env[node.args[0]])

python/tvm/relax/frontend/torch/exported_program_translator.py

+1
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def create_convert_map(
311311
"log_softmax.int": self._log_softmax,
312312
"neg.default": self._unary_op(relax.op.negative),
313313
"pad.default": self._pad,
314+
"pixel_shuffle.default": self._pixel_shuffle,
314315
"prelu.default": self._prelu,
315316
"reciprocal.default": self._reciprocal,
316317
"relu.default": self._unary_op(relax.op.nn.relu),

python/tvm/relax/frontend/torch/fx_translator.py

+9
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,13 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
460460

461461
return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
462462

463+
def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var:
464+
x = self.env[node.args[0]]
465+
module = self.named_modules[node.target]
466+
upscale_factor = module.upscale_factor
467+
468+
return self.block_builder.emit(relax.op.nn.pixel_shuffle(x, upscale_factor))
469+
463470
########## Linear Interpolation ##########
464471

465472
def _lerp(self, node: fx.Node) -> relax.Var:
@@ -656,6 +663,7 @@ def create_convert_map(
656663
nn.Linear: self._linear_module,
657664
nn.MaxPool2d: self._max_pool2d_module,
658665
nn.modules.sparse.Embedding: self._embedding_module,
666+
nn.PixelShuffle: self._pixel_shuffle_module,
659667
# tensor manipulation
660668
nn.Flatten: self._flatten_module,
661669
## call_function and call_method
@@ -695,6 +703,7 @@ def create_convert_map(
695703
"log_softmax": self._log_softmax,
696704
"neg": self._unary_op(relax.op.negative),
697705
"pad": self._pad,
706+
"pixel_shuffle": self._pixel_shuffle,
698707
"prelu": self._prelu,
699708
"reciprocal": self._reciprocal,
700709
"relu": self._unary_op(relax.op.nn.relu),

python/tvm/relax/op/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
max_pool3d,
4444
nll_loss,
4545
pad,
46+
pixel_shuffle,
4647
prelu,
4748
relu,
4849
rms_norm,

python/tvm/relax/op/nn/nn.py

+33
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,39 @@ def pad(
549549
return _ffi_api.pad(data, pad_width, pad_mode, pad_value)
550550

551551

552+
def pixel_shuffle(data: Expr, upscale_factor: int):
553+
r"""
554+
Pixel Shuffle Operator
555+
556+
This operator performs the pixel shuffle operation on the input tensor,
557+
which is often used for efficient sub-pixel convolution in image
558+
super-resolution tasks. It rearranges elements in a tensor of shape
559+
(N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r`
560+
is the upscale factor.
561+
562+
Parameters
563+
----------
564+
data : relax.Expr
565+
The input tensor to the pixel shuffle operator. It must have 4 dimensions
566+
with the format (N, C * r^2, H, W), where `r` is the upscale factor.
567+
568+
upscale_factor : int
569+
The upscaling factor `r`. It determines how much to increase the spatial
570+
resolution (height and width) of the input tensor.
571+
572+
Returns
573+
-------
574+
result : relax.Expr
575+
The transformed tensor with shape (N, C, H * r, W * r).
576+
577+
Example
578+
-------
579+
If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2,
580+
the resulting tensor will have shape (1, 2, 20, 30).
581+
"""
582+
return _ffi_api.pixel_shuffle(data, upscale_factor)
583+
584+
552585
def max_pool1d(
553586
data: Expr,
554587
pool_size: Union[int, Tuple[int, int]] = (1,),

python/tvm/relax/transform/legalize_ops/nn.py

+6
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
249249
)
250250

251251

252+
@register_legalize("relax.nn.pixel_shuffle")
253+
def _nn_pixel_shuffle(bb: BlockBuilder, call: Call) -> Expr:
254+
upscale_factor = call.attrs.upscale_factor
255+
return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor=upscale_factor)
256+
257+
252258
@register_legalize("relax.nn.max_pool1d")
253259
def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr:
254260
if call.attrs.out_layout != call.attrs.layout:

python/tvm/topi/nn/pixel_shuffle.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""TVM operator pixel shuffle compute."""
18+
from __future__ import absolute_import
19+
20+
import tvm
21+
22+
23+
def pixel_shuffle(data, upscale_factor, name="PixelShuffle"):
24+
"""PixelShuffle operator that rearranges elements in a tensor of shape
25+
[..., C * r * r, H, W] to [..., C, H * r, W * r].
26+
27+
Parameters
28+
----------
29+
data : tvm.te.Tensor
30+
N-D input tensor with at least 3 dimensions. Channel must be at index -3.
31+
32+
upscale_factor : int
33+
The upscale factor (r).
34+
35+
name : str
36+
Name of the output tensor.
37+
38+
Returns
39+
-------
40+
output : tvm.te.Tensor
41+
Pixel shuffled tensor with shape [..., C, H*r, W*r]
42+
"""
43+
assert isinstance(upscale_factor, int) and upscale_factor > 0
44+
ndim = len(data.shape)
45+
assert ndim >= 3, "Input must be at least 3D"
46+
47+
upscale_factor_const = tvm.tir.const(upscale_factor, "int32")
48+
c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1]
49+
50+
c_out = tvm.tir.floordiv(c_in, upscale_factor_const * upscale_factor_const)
51+
h_out = h_in * upscale_factor_const
52+
w_out = w_in * upscale_factor_const
53+
54+
out_shape = list(data.shape[:-3]) + [c_out, h_out, w_out]
55+
56+
def _compute(*indices):
57+
batch_indices = indices[:-3]
58+
c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1]
59+
60+
h_idx = tvm.tir.floordiv(h_out_idx, upscale_factor_const)
61+
h_offset = h_out_idx % upscale_factor_const
62+
63+
w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const)
64+
w_offset = w_out_idx % upscale_factor_const
65+
66+
c_in_idx = (
67+
(c_out_idx * upscale_factor_const * upscale_factor_const)
68+
+ (h_offset * upscale_factor_const)
69+
+ w_offset
70+
)
71+
72+
index_tuple = batch_indices + (c_in_idx, h_idx, w_idx)
73+
return data[index_tuple]
74+
75+
return tvm.te.compute(out_shape, _compute, name=name)

src/relax/op/nn/nn.cc

+70
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,76 @@ TVM_REGISTER_OP("relax.nn.pad")
224224
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
225225
.set_attr<Bool>("FPurity", Bool(true));
226226

227+
/* relax.nn.pixel_shuffle */
228+
TVM_REGISTER_NODE_TYPE(PixelShuffleAttrs);
229+
230+
Expr pixel_shuffle(Expr data, int upscale_factor) {
231+
auto attrs = make_object<PixelShuffleAttrs>();
232+
attrs->upscale_factor = upscale_factor;
233+
static const Op& op = Op::Get("relax.nn.pixel_shuffle");
234+
return Call(op, {data}, Attrs(attrs), {});
235+
}
236+
237+
TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle);
238+
239+
StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) {
240+
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
241+
const auto* attrs = call->attrs.as<PixelShuffleAttrs>();
242+
int r = attrs->upscale_factor;
243+
ICHECK_GT(r, 0) << "Upscale factor must be positive";
244+
245+
const TensorStructInfo& input = input_sinfo[0];
246+
int ndim = input->ndim;
247+
ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor";
248+
249+
if (!input->shape.defined()) {
250+
return TensorStructInfo(input->dtype, ndim);
251+
}
252+
253+
const auto* shape = input->shape.as<ShapeExprNode>();
254+
Array<PrimExpr> in_shape = shape->values;
255+
256+
int channel_idx = ndim - 3;
257+
int h_idx = ndim - 2;
258+
int w_idx = ndim - 1;
259+
260+
PrimExpr c_in = in_shape[channel_idx];
261+
PrimExpr h_in = in_shape[h_idx];
262+
PrimExpr w_in = in_shape[w_idx];
263+
264+
PrimExpr r_expr = IntImm(DataType::Int(32), r);
265+
PrimExpr r_squared = r_expr * r_expr;
266+
267+
const auto* c_in_imm = c_in.as<IntImmNode>();
268+
const auto* r2_imm = r_squared.as<IntImmNode>();
269+
270+
ICHECK_EQ(c_in_imm->value % r2_imm->value, 0)
271+
<< "Number of input channels must be divisible by the square of the upscale factor";
272+
273+
// Output shape:
274+
Array<PrimExpr> out_shape;
275+
for (int i = 0; i < ndim; ++i) {
276+
if (i == channel_idx) {
277+
out_shape.push_back(c_in / r_squared);
278+
} else if (i == h_idx) {
279+
out_shape.push_back(h_in * r_expr);
280+
} else if (i == w_idx) {
281+
out_shape.push_back(w_in * r_expr);
282+
} else {
283+
out_shape.push_back(in_shape[i]);
284+
}
285+
}
286+
287+
return TensorStructInfo(ShapeExpr(out_shape), input->dtype);
288+
}
289+
290+
TVM_REGISTER_OP("relax.nn.pixel_shuffle")
291+
.set_num_inputs(1)
292+
.add_argument("data", "Tensor", "The input tensor.")
293+
.set_attrs_type<PixelShuffleAttrs>()
294+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPixelShuffle)
295+
.set_attr<Bool>("FPurity", Bool(true));
296+
227297
/* relax.nn.batchnorm */
228298
bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
229299
const Array<TensorStructInfo>& input_sinfo, Array<Integer> axes) {

src/relax/op/nn/nn.h

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ Expr softplus(Expr data, double beta, double threshold);
7575
/*! \brief LogSoftmax function. */
7676
Expr log_softmax(Expr data, int axis);
7777

78+
/*! \brief Pixel Shuffle function. */
79+
Expr pixel_shuffle(Expr data, int upscale_factor);
80+
7881
/*! \brief Compute batch normalization. */
7982
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, //
8083
int axis, double epsilon, bool center, bool scale, double momentum, bool training);

tests/python/relax/test_frontend_from_exported_program.py

+36
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,42 @@ def main(
20172017
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular)
20182018

20192019

2020+
def test_pixel_shuffle():
2021+
class PixelShuffle1(torch.nn.Module):
2022+
def __init__(self, upscale_factor=2):
2023+
super().__init__()
2024+
self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor)
2025+
2026+
def forward(self, x):
2027+
return self.pixel_shuffle(x)
2028+
2029+
class PixelShuffle2(torch.nn.Module):
2030+
def __init__(self, upscale_factor=2):
2031+
super().__init__()
2032+
self.upscale_factor = upscale_factor
2033+
2034+
def forward(self, x):
2035+
return torch.nn.functional.pixel_shuffle(x, self.upscale_factor)
2036+
2037+
@tvm.script.ir_module
2038+
class expected:
2039+
@R.function
2040+
def main(
2041+
x: R.Tensor((1, 8, 10, 15), dtype="float32")
2042+
) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")):
2043+
with R.dataflow():
2044+
lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(
2045+
x, upscale_factor=2
2046+
)
2047+
gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,)
2048+
R.output(gv)
2049+
return gv
2050+
2051+
example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
2052+
verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
2053+
verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)
2054+
2055+
20202056
def test_einsum():
20212057
class Einsum1(Module):
20222058
def __init__(self):

tests/python/relax/test_frontend_from_fx.py

+36
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,42 @@ def main(
592592
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular)
593593

594594

595+
def test_pixel_shuffle():
596+
class PixelShuffle1(torch.nn.Module):
597+
def __init__(self, upscale_factor=2):
598+
super().__init__()
599+
self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor)
600+
601+
def forward(self, x):
602+
return self.pixel_shuffle(x)
603+
604+
class PixelShuffle2(torch.nn.Module):
605+
def __init__(self, upscale_factor=2):
606+
super().__init__()
607+
self.upscale_factor = upscale_factor
608+
609+
def forward(self, x):
610+
return torch.nn.functional.pixel_shuffle(x, self.upscale_factor)
611+
612+
@tvm.script.ir_module
613+
class expected:
614+
@R.function
615+
def main(
616+
inp_0: R.Tensor((1, 8, 10, 15), dtype="float32")
617+
) -> R.Tensor((1, 2, 20, 30), dtype="float32"):
618+
with R.dataflow():
619+
lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(
620+
inp_0, upscale_factor=2
621+
)
622+
gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv
623+
R.output(gv)
624+
return gv
625+
626+
input_infos = [([1, 8, 10, 15], "float32")]
627+
verify_model(PixelShuffle1(2), input_infos, {}, expected)
628+
verify_model(PixelShuffle2(2), input_infos, {}, expected)
629+
630+
595631
def test_linear():
596632
# nn.Linear
597633
class Dense1(Module):

0 commit comments

Comments
 (0)