-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathop_sched.cpp
58 lines (50 loc) · 2.24 KB
/
op_sched.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <tensorflow/lite/simple_memory_arena.h>
#include <chrono>
#include <filesystem>
#include <fstream>
#include <hmcos/sched/life.hpp>
#include <hmcos/sched/pass.hpp>
#include <hmcos/sched/plan.hpp>
#include <hmcos/sched/sched.hpp>
#include <hmcos/util/viz.hpp>
using namespace hmcos;
using namespace std::chrono;
#define TIME_CODE(code) \
{ \
auto _begin = system_clock::now(); \
code; \
auto _dur = \
duration_cast<milliseconds>(system_clock::now() - _begin).count(); \
LOG(INFO) << fmt::format("{} ms", _dur); \
}
static uint64_t computeArenaSize(const LifetimeStat &stat) {
std::vector<tflite::ArenaAllocWithUsageInterval> allocs(stat.values.size());
TfLiteContext ctx;
tflite::SimpleMemoryArena arena(64);
for (auto [i, val] : EnumRange(stat.values))
arena.Allocate(&ctx, 64, val.value->type.Size(), i, val.gen,
val.kill - 1, &allocs[i]);
return arena.RequiredBufferSize();
}
int main(int argc, char const *argv[]) {
// Initialize glog
FLAGS_minloglevel = 0;
google::LogToStderr();
google::InitGoogleLogging(argv[0]);
// Build compitation graph from ONNX model
std::ifstream ifs(argv[1], std::ifstream::binary);
onnx::ModelProto model;
model.ParseFromIstream(&ifs);
ifs.close();
Graph graph(model, std::filesystem::path(argv[1]).stem().string());
model.Clear();
// Schedule hierarchical graph
std::vector<OpRef> sched;
TIME_CODE(sched = HierarchicalSchedule(graph);)
LOG(INFO) << "HMCOS Peak: " << EstimatePeak(sched, graph.inputs) / 1024 << " KB";
LOG(INFO) << "HMCOS Arena Size: " << computeArenaSize(ComputeLifetime(sched, graph)) / 1024 << " KB";
sched = ReversePostOrder(graph);
LOG(INFO) << "RPO Peak: " << EstimatePeak(sched, graph.inputs) / 1024 << " KB";
LOG(INFO) << "RPO Arena Size: " << computeArenaSize(ComputeLifetime(sched, graph)) / 1024 << " KB";
return 0;
}