Skip to content

Commit

Permalink
Fix more
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Feb 10, 2025
1 parent a0ed6ee commit e26304c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,12 +833,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
if (paddingIntValues.size() !=
cast<RankedTensorType>(input.getType()).getRank() - 2) {
// pytorch 2.5 generates one element padding = {0} for
// Conv2dWithValidPaddingModule
return rewriter.notifyMatchFailure(op, "unexpected number of padding");
}
SmallVector<Value> outputPaddingIntValues;
if (!getListConstructElements(op.getOutputPadding(),
outputPaddingIntValues))
Expand Down Expand Up @@ -1013,6 +1007,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
strideInts.clear();
strideInts.append(numSpatialDims, 1);
} else {
if (paddingIntValues.size() + 2 !=
cast<RankedTensorType>(input.getType()).getRank()) {
// pytorch 2.5 generates one element padding = {0} for
// Conv2dWithValidPaddingModule
return rewriter.notifyMatchFailure(op, "unexpected number of padding");
}
// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
Expand Down
2 changes: 1 addition & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@
FX_IMPORTER_XFAIL_SET |= {
"AtenSubFloatModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"EqIntModule_basic",
"GeFloatModule_basic",
Expand Down

0 comments on commit e26304c

Please sign in to comment.