diff --git a/README_CN.md b/README_CN.md index 1abf244..cd69fbc 100644 --- a/README_CN.md +++ b/README_CN.md @@ -8,7 +8,7 @@ --- -Easytorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。 +EasyTorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。 ## 功能亮点 @@ -35,9 +35,7 @@ python >= 3.6 (推荐 >= 3.7) ### PyTorch及CUDA -[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7) - -[CUDA](https://developer.nvidia.com/zh-cn/cuda-toolkit) >= 9.2 (推荐 >= 11.0) +[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7),如需使用CUDA,安装PyTorch时选择对应CUDA版本编译的包。 注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7且CUDA版本 >= 11.0。 @@ -51,6 +49,23 @@ pip install -r requirements.txt * [线性回归](examples/linear_regression) * [MNIST手写数字识别](examples/mnist) +* [ImageNet图像分类](examples/imagenet) + +## 开始使用 + +### 安装EasyTorch + +```shell +pip install easy-torch +``` + +### 复制EasyTorch模板至工作目录 + +```shell +cp -r easytorch/examples/template/* . +``` + +*接下来就可以使用EasyTorch构建你的深度学习项目。* ## README 徽章 diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000..e69de29 diff --git a/examples/template/configs/__init__.py b/examples/template/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/template/configs/config_template.py b/examples/template/configs/config_template.py new file mode 100644 index 0000000..9aec2e7 --- /dev/null +++ b/examples/template/configs/config_template.py @@ -0,0 +1,39 @@ +import os +from easydict import EasyDict + +from runners.runner_template import RunnerTemplate + +CFG = EasyDict() + +CFG.DESC = 'config template' +CFG.RUNNER = RunnerTemplate +CFG.GPU_NUM = 1 + +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = 'model_template' + +CFG.TRAIN = EasyDict() + +CFG.TRAIN.NUM_EPOCHS = 100 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +CFG.TRAIN.CKPT_SAVE_STRATEGY = None + +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = 'SGD' +CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.002, + 'momentum': 0.1, +} + +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 4 +CFG.TRAIN.DATA.SHUFFLE = True + +CFG.VAL = EasyDict() + +CFG.VAL.INTERVAL = 1 + +CFG.VAL.DATA = EasyDict() diff --git a/examples/template/models/__init__.py b/examples/template/models/__init__.py new file mode 100644 index 0000000..8c21dab --- /dev/null +++ b/examples/template/models/__init__.py @@ -0,0 +1,7 @@ +from .model_template import ModelTemplate + + +MODEL_DICT = { + 'model_template': ModelTemplate + # other models... +} diff --git a/examples/template/models/model_template.py b/examples/template/models/model_template.py new file mode 100644 index 0000000..869beae --- /dev/null +++ b/examples/template/models/model_template.py @@ -0,0 +1,13 @@ +from torch import nn + + +class ModelTemplate(nn.Module): + def __init__(self): + super().__init__() + self.op = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.op(x) diff --git a/examples/template/runners/__init__.py b/examples/template/runners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/template/runners/runner_template.py b/examples/template/runners/runner_template.py new file mode 100644 index 0000000..05a7a34 --- /dev/null +++ b/examples/template/runners/runner_template.py @@ -0,0 +1,64 @@ +from typing import Dict + +import torch +from torch import nn +from torch.utils.data import Dataset + +from easytorch import Runner + +from ..models import MODEL_DICT + + +class RunnerTemplate(Runner): + def __init__(self, cfg: Dict): + super().__init__(cfg) + + def init_training(self, cfg: Dict): + super().init_training(cfg) + + # init loss + # e.g. + # self.loss = nn.MSELoss() + # self.loss = self.to_running_device(self.loss) + + # register meters by calling: + # self.register_epoch_meter('train_loss', 'train', '{:.2f}') + + def init_validation(self, cfg: Dict): + super().init_validation(cfg) + + # self.register_epoch_meter('val_acc', 'val', '{:.2f}%') + + @staticmethod + def define_model(cfg: Dict) -> nn.Module: + return MODEL_DICT[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {})) + + @staticmethod + def build_train_dataset(cfg: Dict) -> Dataset: + # return your train Dataset + pass + + @staticmethod + def build_val_dataset(cfg: Dict): + # return your val Dataset + pass + + def train_iters(self, epoch: int, iter_index: int, data: torch.Tensor or tuple) -> torch.Tensor: + # forward and compute loss + # update meters if necessary + # return loss (will be auto backward and update params) or don't return any thing + + # e.g. + # _input, _target = data + # _input = self.to_running_device(_input) + # _target = self.to_running_device(_target) + # + # output = self.model(_input) + # loss = self.loss(output, _target) + # self.update_epoch_meter('train_loss', loss.item()) + # return loss + pass + + def val_iters(self, iter_index: int, data: torch.Tensor or tuple): + # forward and update meters + pass diff --git a/requirements.txt b/requirements.txt index 42747b4..9650486 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch>=1.7 +torchvision easydict tensorboard tqdm