Skip to content

Commit f9964c3

Browse files
committed
feature: template
1 parent 8d53459 commit f9964c3

File tree

8 files changed

+158
-2
lines changed

8 files changed

+158
-2
lines changed

README_CN.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,29 @@ pip install -r requirements.txt
5454

5555
## 开始使用
5656

57-
### 初始化git仓库
57+
### 新建项目文件夹并初始化git仓库
58+
59+
```shell
60+
mkdir my_deeplearning_project
61+
cd my_deeplearning_project
62+
git init
63+
```
5864

5965
### 添加easytorch子仓库
6066

61-
###
67+
```shell
68+
git submodule add https://github.com/cnstark/easytorch.git
69+
git add .
70+
git commit -m "init by easytorch"
71+
```
72+
73+
### 复制EasyTorch模板至工作目录
74+
75+
```shell
76+
cp easytorch/examples/template/* .
77+
```
78+
79+
*接下来就可以使用EasyTorch构建你的深度学习项目。*
6280

6381
## README 徽章
6482

examples/template/configs/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
from easydict import EasyDict
3+
4+
from runners.runner_template import RunnerTemplate
5+
6+
CFG = EasyDict()
7+
8+
CFG.DESC = 'config template'
9+
CFG.RUNNER = RunnerTemplate
10+
CFG.GPU_NUM = 1
11+
12+
CFG.MODEL = EasyDict()
13+
CFG.MODEL.NAME = 'model_template'
14+
15+
CFG.TRAIN = EasyDict()
16+
17+
CFG.TRAIN.NUM_EPOCHS = 100
18+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
19+
'checkpoints',
20+
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
21+
)
22+
CFG.TRAIN.CKPT_SAVE_STRATEGY = None
23+
24+
CFG.TRAIN.OPTIM = EasyDict()
25+
CFG.TRAIN.OPTIM.TYPE = 'SGD'
26+
CFG.TRAIN.OPTIM.PARAM = {
27+
'lr': 0.002,
28+
'momentum': 0.1,
29+
}
30+
31+
CFG.TRAIN.DATA = EasyDict()
32+
CFG.TRAIN.DATA.BATCH_SIZE = 4
33+
CFG.TRAIN.DATA.SHUFFLE = True
34+
35+
CFG.VAL = EasyDict()
36+
37+
CFG.VAL.INTERVAL = 1
38+
39+
CFG.VAL.DATA = EasyDict()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .model_template import ModelTemplate
2+
3+
4+
MODEL_DICT = {
5+
'model_template': ModelTemplate
6+
# other models...
7+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from torch import nn
2+
3+
4+
class ModelTemplate(nn.Module):
5+
def __init__(self):
6+
super().__init__()
7+
self.op = nn.Sequential(
8+
nn.Conv2d(1, 20, 5),
9+
nn.ReLU(inplace=True)
10+
)
11+
12+
def forward(self, x):
13+
return self.op(x)

examples/template/runners/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
from torch import nn
3+
from torch.utils.data import Dataset
4+
5+
from easytorch import Runner
6+
7+
from ..models import MODEL_DICT
8+
9+
10+
class RunnerTemplate(Runner):
11+
def __init__(self, cfg: dict):
12+
super().__init__(cfg)
13+
14+
def init_training(self, cfg: dict):
15+
super().init_training(cfg)
16+
17+
# init loss
18+
# e.g.
19+
# self.loss = nn.MSELoss()
20+
# self.loss = self.to_running_device(self.loss)
21+
22+
# register meters by calling:
23+
# self.register_epoch_meter('train_loss', 'train', '{:.2f}')
24+
25+
def init_validation(self, cfg: dict):
26+
super().init_validation(cfg)
27+
28+
# self.register_epoch_meter('val_acc', 'val', '{:.2f}%')
29+
30+
@staticmethod
31+
def define_model(cfg: dict) -> nn.Module:
32+
return MODEL_DICT[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {}))
33+
34+
@staticmethod
35+
def build_train_dataset(cfg: dict) -> Dataset:
36+
# return your train Dataset
37+
pass
38+
39+
@staticmethod
40+
def build_val_dataset(cfg: dict):
41+
# return your val Dataset
42+
pass
43+
44+
def train_iters(self, epoch: int, iter_index: int, data: torch.Tensor or tuple) -> torch.Tensor:
45+
# forward and compute loss
46+
# update meters if necessary
47+
# return loss (will be auto backward and update params) or don't return any thing
48+
49+
# e.g.
50+
# _input, _target = data
51+
# _input = self.to_running_device(_input)
52+
# _target = self.to_running_device(_target)
53+
#
54+
# output = self.model(_input)
55+
# loss = self.loss(output, _target)
56+
# self.update_epoch_meter('train_loss', loss.item())
57+
# return loss
58+
pass
59+
60+
def val_iters(self, iter_index: int, data: torch.Tensor or tuple):
61+
# forward and update meters
62+
pass

examples/template/train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from argparse import ArgumentParser
2+
3+
from easytorch.easytorch import launch_training
4+
5+
6+
def parse_args():
7+
parser = ArgumentParser(description='Welcome to EasyTorch!')
8+
parser.add_argument('-c', '--cfg', help='training config', required=True)
9+
parser.add_argument('--gpus', help='visible gpus', type=str)
10+
parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true')
11+
return parser.parse_args()
12+
13+
14+
if __name__ == "__main__":
15+
args = parse_args()
16+
17+
launch_training(args.cfg, args.gpus, args.tf32)

0 commit comments

Comments
 (0)