Skip to content

Commit

Permalink
Additional information in error message (llvm#2783)
Browse files Browse the repository at this point in the history
See change in test for what the new message looks like.
  • Loading branch information
newling authored Jan 30, 2024
1 parent e18fceb commit 1e882f5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
22 changes: 19 additions & 3 deletions projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "class_annotator.h"

#include <sstream>
#include <stdexcept>

using namespace torch_mlir;
Expand Down Expand Up @@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) {
}

static void fillArgAnnotations(MethodAnnotation &methodAnnotation,
std::vector<ArgAnnotation> argAnnotations,
const std::vector<ArgAnnotation> &argAnnotations,
torch::jit::Function *function) {
if (argAnnotations.size() != function->num_inputs()) {
throw std::invalid_argument("Arg annotations should have one entry per "
"function parameter (including self).");

std::ostringstream oss;
oss << "There must be one argument annotation per function parameter. "
<< "Including 'self' the number of argument annotations is: "
<< argAnnotations.size()
<< ". The number of function parameters is: " << function->num_inputs()
<< ". ";
const auto &args = function->getSchema().arguments();
if (args.size() > 0) {
oss << "The function signature is (";
oss << args[0];
for (auto iter = args.begin() + 1; iter != args.end(); iter++) {
oss << ", " << *iter;
}
oss << ')' << '.';
}
throw std::invalid_argument(oss.str());
}
if (!methodAnnotation.argAnnotations.has_value()) {
methodAnnotation.argAnnotations.emplace(function->num_inputs(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def forward(self, tensor):
try:
annotator.annotateArgs(class_type, ['forward'], [None])
except Exception as e:
# CHECK: Arg annotations should have one entry per function parameter (including self).
# CHECK: There must be one argument annotation per function parameter.
# CHECK-SAME: Including 'self' the number of argument annotations is: 1.
# CHECK-SAME: The number of function parameters is: 2.
# CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor)
print(e)

try:
Expand Down

0 comments on commit 1e882f5

Please sign in to comment.