Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.swp
figure/

led / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -151,6 +152,7 @@ events*
# logger
logger
searched_result
*searched_result

# csv file
*.csv
Expand Down
16 changes: 12 additions & 4 deletions architecture_main_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
help='gpu number to use')
parser.add_argument('--dataset', type=str, default='cifar10', \
help='using dataset')
parser.add_argument('--supernet_type', type=str, default='mobilenetv2', \
help='supernet type')

# SGD optimizer - weight
parser.add_argument('--lr', type=float, default=0.1, \
Expand Down Expand Up @@ -126,11 +128,14 @@ def main():
# flops & param
fnp = args.fnp
if args.dataset == 'cifar10':
model = fbnet_builder.get_model(arch, cnt_classes=10).cuda()
if args.supernet_type == 'resnet_torchvision':
model = fbnet_builder.resnet18(pretrained=False, progress=False).cuda()
else:
model = fbnet_builder.get_model(arch, cnt_classes=10, supernet_type=args.supernet_type).cuda()
elif args.dataset == 'cifar100':
model = fbnet_builder.get_model(arch, cnt_classes=100).cuda()
model = fbnet_builder.get_model(arch, cnt_classes=100, supernet_type=args.supernet_type).cuda()
elif args.dataset == 'tiny_imagenet':
model = fbnet_builder.get_model(arch, cnt_classes=200).cuda()
model = fbnet_builder.get_model(arch, cnt_classes=200, supernet_type=args.supernet_type).cuda()

model = model.apply(weights_init)

Expand Down Expand Up @@ -166,7 +171,10 @@ def main():
compression_scheduler, optimizer = convert_model_to_quant(model.module.stages, yaml_path)
else:
compression_scheduler = None
print(model)

#print(model)
#print(summary(model, input_size=(3, 32, 32)))

#### Scheduler
if args.scheduler == 'MultiStepLR':
milestones = args.milestones.split(' ')
Expand Down
27 changes: 14 additions & 13 deletions architecture_ploting(B&G).ipynb

Large diffs are not rendered by default.

Loading