当前项目已提供基于 torch-pruning 的结构化通道剪枝 + 微调框架,支持从基座模型 checkpoint 出发,完成:
- 按模型名读取基座模型根目录下的
best_model.pth符号链接 - 恢复默认模型结构并严格加载权重
- 执行多轮 iterative structured pruning
- 每轮剪枝后进行微调恢复(可选)
- 仅在最终轮保存包含权重、拓扑与元数据的剪枝 checkpoint
当前阶段暂不包含 ONNX/ATC 导出,也不包含 QAT。
剪枝阶段当前复用基座模型的 Warmup + Cosine 调度器实现,但默认超参已针对“剪枝后微调恢复”做了下调。
base checkpoint
-> 恢复基座模型
-> iterative structured pruning
-> 每轮提取 channel_cfg / architecture_signature
-> 每轮微调恢复(可选)
-> 仅最终轮保存 pruning checkpoint
torch-pruning结构化通道剪枝- iterative pruning 多轮剪枝
- 基座模型按
output/base_model/<model>/best_model.pth自动解析 - 剪枝后拓扑通过
channel_cfg显式保存 - 剪枝 checkpoint 作为后续 QAT 的上游输入
项目根目录执行:
uv run src/pruning_main.py --help| 参数 | 默认值 | 说明 |
|---|---|---|
--model |
必填 | 基座模型名,将自动解析 output/base_model/<model>/best_model.pth |
--model_path |
best_pruned_model.pth |
剪枝后最佳模型文件名 |
--data_dir |
Data |
数据集路径 |
--data_dtype |
fp16 |
数据集输出 tensor 精度,可选 fp16 / fp32 |
--full_load |
False |
是否全量加载数据集 |
--num_workers |
None |
DataLoader 工作线程数 |
--prefetch_factor |
2 |
DataLoader 预取因子 |
--persistent_workers |
True |
是否保持 DataLoader 工作线程 |
--pin_memory |
True |
是否启用 pin_memory |
--pruning_ratio |
0.3 |
最终总剪枝率,会按十进制四舍五入规范到 2 位小数 |
--pruning_steps |
5 |
iterative pruning 的剪枝轮数 |
--global_pruning |
True |
是否启用全局剪枝 |
--ignore_fc |
True |
是否默认忽略分类头 |
--finetune_epochs |
10 |
每轮剪枝后的微调轮数 |
--batch_size |
64 |
批次大小 |
--lr |
1e-4 |
微调学习率 |
--weight_decay |
1e-4 |
权重衰减 |
--warmup_ratio |
0.05 |
Warmup 占总步数比例 |
--warmup_steps |
0 |
Warmup 步数,0 表示使用 warmup_ratio |
--min_lr |
1e-7 |
最小学习率 |
--cudnn_benchmark |
True |
是否启用 cuDNN benchmark |
--cudnn_deterministic |
False |
是否启用 cuDNN 确定性算法 |
--evaluate_test |
True |
微调结束后是否评估测试集 |
uv run src/pruning_main.py --model resnet6_2duv run src/pruning_main.py \
--model resnet18_2d \
--pruning_ratio 0.25 \
--pruning_steps 5 \
--global_pruning False \
--finetune_epochs 12uv run src/pruning_main.py \
--model resnet14_2d \
--finetune_epochs 0 \
--evaluate_test False剪枝入口不会手动接收外部 checkpoint 路径,而是固定读取:
output/base_model/<model>/best_model.pth
要求:
- 该路径存在
- 若为符号链接,则链接必须可正常解析
- 加载后的 checkpoint 中
model_structure.model_name必须与--model一致
若上述任一条件不满足,剪枝入口会直接报错退出。
output/pruning/<model>/ratio<ratio>_steps<steps>_<global|local>_ft<epochs>_bs<batch_size>/
典型产物包括:
best_pruned_model.pthbest_pruned_info.txt(每轮一行最佳摘要)pruning_summary.jsonConfusion_matrix.png(仅最终测试阶段生成)runs/
model_state_dictmodel_structuremodel_namemodel_kwargschannel_cfgarchitecture_signature
pruning_metapruning_stepstarget_total_ratiostep_ratioglobal_pruningignored_layersexample_input_shapetorch_pruning_versionparams_before / params_aftermacs_before / macs_after
train_contextcheckpoint_link_pathresolved_checkpoint_path
best_accbest_val_loss
- 剪枝阶段当前读取的是基座模型 checkpoint,因此模型恢复入口优先使用默认
load_model_map()。 --pruning_ratio的有效精度固定为 2 位小数;summary、checkpoint 和输出目录都会使用同一个规范值。- 剪枝后的完整拓扑通过实际模型提取,不依赖默认模板反推。
- 多轮剪枝过程中,中间轮的最佳权重只保留在内存中作为下一轮输入,不落盘。
- 若开启
--evaluate_test,仅最终模型测试阶段会生成一次Confusion_matrix.png。 - 后续 QAT 可直接以剪枝 checkpoint 中保存的
channel_cfg和权重为输入继续恢复。