-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
23 lines (20 loc) · 774 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from model import Train
from model import TorchModel
from src import Evaluation
from data import data_generator
if __name__ == '__main__':
# Generate data samples for training and test
x_train, x_test, y_train, y_test = data_generator(num_samples=100, visualize_plot=False)
# Initialize training class
torch_model = Train(x_train, y_train)
# Train torch model
torch_model.train_torch_model()
# Save model
torch_model.save_model()
# Once the model is trained and saved as torch_model.onnx,
# it will be loaded and evaluated with onnxruntime, caff2 and tensorflow
evaluation = Evaluation(x_test, y_test, 'onnx/torch_model.onnx')
# Trigger the evaluators
evaluation.onnxruntime_evaluation()
evaluation.caffe2_evaluation()
evaluation.tensorflow_evaluation()