@@ -224,6 +224,76 @@ TVM_REGISTER_OP("relax.nn.pad")
224
224
.set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoPad)
225
225
.set_attr<Bool>(" FPurity" , Bool(true ));
226
226
227
+ /* relax.nn.pixel_shuffle */
228
+ TVM_REGISTER_NODE_TYPE (PixelShuffleAttrs);
229
+
230
+ Expr pixel_shuffle (Expr data, int upscale_factor) {
231
+ auto attrs = make_object<PixelShuffleAttrs>();
232
+ attrs->upscale_factor = upscale_factor;
233
+ static const Op& op = Op::Get (" relax.nn.pixel_shuffle" );
234
+ return Call (op, {data}, Attrs (attrs), {});
235
+ }
236
+
237
+ TVM_REGISTER_GLOBAL (" relax.op.nn.pixel_shuffle" ).set_body_typed(pixel_shuffle);
238
+
239
+ StructInfo InferStructInfoPixelShuffle (const Call& call, const BlockBuilder& ctx) {
240
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo (call, ctx);
241
+ const auto * attrs = call->attrs .as <PixelShuffleAttrs>();
242
+ int r = attrs->upscale_factor ;
243
+ ICHECK_GT (r, 0 ) << " Upscale factor must be positive" ;
244
+
245
+ const TensorStructInfo& input = input_sinfo[0 ];
246
+ int ndim = input->ndim ;
247
+ ICHECK_GE (ndim, 3 ) << " PixelShuffle requires at least 3D input tensor" ;
248
+
249
+ if (!input->shape .defined ()) {
250
+ return TensorStructInfo (input->dtype , ndim);
251
+ }
252
+
253
+ const auto * shape = input->shape .as <ShapeExprNode>();
254
+ Array<PrimExpr> in_shape = shape->values ;
255
+
256
+ int channel_idx = ndim - 3 ;
257
+ int h_idx = ndim - 2 ;
258
+ int w_idx = ndim - 1 ;
259
+
260
+ PrimExpr c_in = in_shape[channel_idx];
261
+ PrimExpr h_in = in_shape[h_idx];
262
+ PrimExpr w_in = in_shape[w_idx];
263
+
264
+ PrimExpr r_expr = IntImm (DataType::Int (32 ), r);
265
+ PrimExpr r_squared = r_expr * r_expr;
266
+
267
+ const auto * c_in_imm = c_in.as <IntImmNode>();
268
+ const auto * r2_imm = r_squared.as <IntImmNode>();
269
+
270
+ ICHECK_EQ (c_in_imm->value % r2_imm->value , 0 )
271
+ << " Number of input channels must be divisible by the square of the upscale factor" ;
272
+
273
+ // Output shape:
274
+ Array<PrimExpr> out_shape;
275
+ for (int i = 0 ; i < ndim; ++i) {
276
+ if (i == channel_idx) {
277
+ out_shape.push_back (c_in / r_squared);
278
+ } else if (i == h_idx) {
279
+ out_shape.push_back (h_in * r_expr);
280
+ } else if (i == w_idx) {
281
+ out_shape.push_back (w_in * r_expr);
282
+ } else {
283
+ out_shape.push_back (in_shape[i]);
284
+ }
285
+ }
286
+
287
+ return TensorStructInfo (ShapeExpr (out_shape), input->dtype );
288
+ }
289
+
290
+ TVM_REGISTER_OP (" relax.nn.pixel_shuffle" )
291
+ .set_num_inputs(1 )
292
+ .add_argument(" data" , " Tensor" , " The input tensor." )
293
+ .set_attrs_type<PixelShuffleAttrs>()
294
+ .set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoPixelShuffle)
295
+ .set_attr<Bool>(" FPurity" , Bool(true ));
296
+
227
297
/* relax.nn.batchnorm */
228
298
bool NormCheckDtypeAndShape (const Call& call, const BlockBuilder& ctx,
229
299
const Array<TensorStructInfo>& input_sinfo, Array<Integer> axes) {
0 commit comments