-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_compiler.cpp
More file actions
64 lines (48 loc) · 1.88 KB
/
graph_compiler.cpp
File metadata and controls
64 lines (48 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "Gateway.h"
namespace gateway {
void registerGraphToKernelPass();
}
int main(int argc, char **argv) {
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
llvm::cl::ParseCommandLineOptions(argc, argv, "Graph Compiler\n");
mlir::MLIRContext context;
context.getOrLoadDialect<gateway::GatewayDialect>();
context.getOrLoadDialect<mlir::func::FuncDialect>();
context.getOrLoadDialect<mlir::arith::ArithDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
mlir::OwningOpRef<mlir::ModuleOp> module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
module = mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 1;
}
mlir::PassManager pm(&context);
gateway::registerGraphToKernelPass();
pm.addPass(mlir::createModuleToFunctionPassAdaptor(gateway::createGraphToKernelPass()));
if (mlir::failed(pm.run(*module))) {
llvm::errs() << "Error running pass manager\n";
return 1;
}
module->print(llvm::outs());
return 0;
}