Skip to content

Commit

Permalink
add mse_loss_backward (llvm#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgens-cerebras authored May 12, 2023
1 parent de02b56 commit 17db2aa
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5795,6 +5795,32 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
}];
}

def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMseLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenMseLossBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
9 changes: 9 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ compute_shape_div(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape>
compute_shape_mse_loss_backward(
const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
int64_t reduction) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape>
compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")

Expand Down

0 comments on commit 17db2aa

Please sign in to comment.