diff --git a/.gitignore b/.gitignore index e3db027..a1a2b25 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.swp +figure/ led / optimized / DLL files __pycache__/ @@ -151,6 +152,7 @@ events* # logger logger searched_result +*searched_result # csv file *.csv diff --git a/architecture_main_file.py b/architecture_main_file.py index 4e42eef..459eb37 100644 --- a/architecture_main_file.py +++ b/architecture_main_file.py @@ -33,6 +33,8 @@ help='gpu number to use') parser.add_argument('--dataset', type=str, default='cifar10', \ help='using dataset') +parser.add_argument('--supernet_type', type=str, default='mobilenetv2', \ + help='supernet type') # SGD optimizer - weight parser.add_argument('--lr', type=float, default=0.1, \ @@ -126,11 +128,14 @@ def main(): # flops & param fnp = args.fnp if args.dataset == 'cifar10': - model = fbnet_builder.get_model(arch, cnt_classes=10).cuda() + if args.supernet_type == 'resnet_torchvision': + model = fbnet_builder.resnet18(pretrained=False, progress=False).cuda() + else: + model = fbnet_builder.get_model(arch, cnt_classes=10, supernet_type=args.supernet_type).cuda() elif args.dataset == 'cifar100': - model = fbnet_builder.get_model(arch, cnt_classes=100).cuda() + model = fbnet_builder.get_model(arch, cnt_classes=100, supernet_type=args.supernet_type).cuda() elif args.dataset == 'tiny_imagenet': - model = fbnet_builder.get_model(arch, cnt_classes=200).cuda() + model = fbnet_builder.get_model(arch, cnt_classes=200, supernet_type=args.supernet_type).cuda() model = model.apply(weights_init) @@ -166,7 +171,10 @@ def main(): compression_scheduler, optimizer = convert_model_to_quant(model.module.stages, yaml_path) else: compression_scheduler = None - print(model) + + #print(model) + #print(summary(model, input_size=(3, 32, 32))) + #### Scheduler if args.scheduler == 'MultiStepLR': milestones = args.milestones.split(' ') diff --git a/architecture_ploting(B&G).ipynb b/architecture_ploting(B&G).ipynb index 2fdc1bc..118e546 100644 --- a/architecture_ploting(B&G).ipynb +++ b/architecture_ploting(B&G).ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 185, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ "%matplotlib inline\n", "#The line above is necesary to show Matplotlib's plots inside a Jupyter Notebook\n", "\n", - "import cv2\n", + "import cv2 as cv\n", "from matplotlib import pyplot as plt\n", "\n", "from supernet_functions.lookup_table_builder import CANDIDATE_BLOCKS\n", @@ -20,14 +20,14 @@ }, { "cell_type": "code", - "execution_count": 271, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['fbnet_a', 'fbnet_b', 'fbnet_c', 'fbnet_96_035_1', 'fbnet_samsung_s8', 'fbnet_iphonex', 'fbnet_cpu_sample1', 'fbnet_cpu_sample2', 'mb2_example', 'mb2', 'test1', 'test2', 'test3', 'test4', 'test5', '0526_mb2', 'm2_orig', 'FBNet_DoReFa_w2a2', 'test_200616', 'testo', 'cifar10_ngumbel_600_schedule5_1_flop', 'cifar10_ngumbel_600_schedule5_1', 'cifar10_ngumbel_180_schedule5_1', 'cifar100_ngumbel_180_schedule5_1_flop', 'cifar100_ngumbel_180_schedule5_1', 'test', 'fluctuation_exp_4', 'ngumbel_test', 'test_eval_min01', 'test_eval_min1', 'test_eval2', 'test_eval_300_600', 'test_eval_min01_300_600', 'test_eval_min01_10_600', 'test_config', 'test_exp_min01_180', 'test_exp_min1_180', 'test_exp_min01_600', '0815_test_eval_min01_step_15', '0815_test_eval_min01_step_40', '0815_test_eval_min01_step_70', '0815_test_eval_min01_step_95', '0815_test_eval_min01_step_125', '0815_test_eval_min01_step_150', '0815_test_eval_min01_step_180', '0815_test_eval2_step_15', '0815_test_eval2_step_40', '0815_test_eval2_step_70', '0815_test_eval2_step_95', '0815_test_eval2_step_125', '0815_test_eval2_step_150', '0815_test_eval2_step_180', '0817_test_eval_min01_cifar100', '0817_test_eval_min1_cifar100', '0817_test_eval2_cifar100', '0817_test_eval2_cifar100_15', '0817_test_eval2_cifar100_40', '0817_test_eval2_cifar100_70', '0817_test_eval2_cifar100_95', '0817_test_eval2_cifar100_125', '0817_test_eval2_cifar100_150', '0817_test_eval2_cifar100_180', '0817_test_eval_min1_cifar100_15', '0817_test_eval_min1_cifar100_40', '0817_test_eval_min1_cifar100_70', '0817_test_eval_min1_cifar100_95', '0817_test_eval_min1_cifar100_125', '0817_test_eval_min1_cifar100_150', '0817_test_eval_min1_cifar100_180', '0817_test_eval_min01_cifar100_15', '0817_test_eval_min01_cifar100_40', '0817_test_eval_min01_cifar100_70', '0817_test_eval_min01_cifar100_95', '0817_test_eval_min01_cifar100_125', '0817_test_eval_min01_cifar100_150', '0817_test_eval_min01_cifar100_180', '0819_orig_gumbel_N_1000_reg_1e_5_sampling', '0819_orig_gumbel_N_500_reg_1e_4_sampling', '0819_orig_gumbel_N_500_reg_1e_5_sampling', '0819_orig_gumbel_N_1000_reg_1e_4_sampling', '0819_orig_gumbel_N_500_reg_1e_4', '0819_orig_gumbel_N_1000_reg_1e_5', '0819_orig_gumbel_N_1000_reg_1e_4', '0819_orig_gumbel_N_500_reg_1e_5', 'm2_orig_TI', '0824_img200', '0824_img200_step_15', '0824_img200_step_70', '0824_img200_step_125', '0824_img200_step_180', '0902_img200_rs1', '0902_img200_rs2', '0902_img200_rs3', '0903_cifar10_rs1', '0903_cifar10_rs2', '0903_cifar10_rs3', '0903_cifar10_rs4', '0903_cifar10_rs5', '0903_cifar10_rs6', '0903_cifar10_rs7', '0903_cifar10_rs8', '0903_cifar100_rs1', '0903_cifar100_rs2', '0903_cifar100_rs3', '0903_cifar100_rs4', '0903_cifar100_rs5', '0903_cifar100_rs6', '0903_cifar100_rs7', '0903_cifar100_rs8', '0824_img200_step_40', '0824_img200_step_95', '0824_img200_step_150', '0902_img200_rs4', '0902_img200_rs5', '0902_img200_rs6', '0902_img200_rs7', '0902_img200_rs8', '0902_img200_rs9', '0902_img200_sampling']\n" + "['fbnet_a', 'fbnet_b', 'fbnet_c', 'fbnet_96_035_1', 'fbnet_samsung_s8', 'fbnet_iphonex', 'fbnet_cpu_sample1', 'fbnet_cpu_sample2', 'mb2_example', 'mb2', 'test1', 'test2', 'test3', 'test4', 'test5', '0526_mb2', 'm2_orig', 'FBNet_DoReFa_w2a2', 'test_200616', 'testo', 'cifar10_ngumbel_600_schedule5_1_flop', 'cifar10_ngumbel_600_schedule5_1', 'cifar10_ngumbel_180_schedule5_1', 'cifar100_ngumbel_180_schedule5_1_flop', 'cifar100_ngumbel_180_schedule5_1', 'test', 'fluctuation_exp_4', 'ngumbel_test', 'test_eval_min01', 'test_eval_min1', 'test_eval2', 'test_eval_300_600', 'test_eval_min01_300_600', 'test_eval_min01_10_600', 'test_config', 'test_exp_min01_180', 'test_exp_min1_180', 'test_exp_min01_600', '0815_test_eval_min01_step_15', '0815_test_eval_min01_step_40', '0815_test_eval_min01_step_70', '0815_test_eval_min01_step_95', '0815_test_eval_min01_step_125', '0815_test_eval_min01_step_150', '0815_test_eval_min01_step_180', '0815_test_eval2_step_15', '0815_test_eval2_step_40', '0815_test_eval2_step_70', '0815_test_eval2_step_95', '0815_test_eval2_step_125', '0815_test_eval2_step_150', '0815_test_eval2_step_180', '0817_test_eval_min01_cifar100', '0817_test_eval_min1_cifar100', '0817_test_eval2_cifar100', '0817_test_eval2_cifar100_15', '0817_test_eval2_cifar100_40', '0817_test_eval2_cifar100_70', '0817_test_eval2_cifar100_95', '0817_test_eval2_cifar100_125', '0817_test_eval2_cifar100_150', '0817_test_eval2_cifar100_180', '0817_test_eval_min1_cifar100_15', '0817_test_eval_min1_cifar100_40', '0817_test_eval_min1_cifar100_70', '0817_test_eval_min1_cifar100_95', '0817_test_eval_min1_cifar100_125', '0817_test_eval_min1_cifar100_150', '0817_test_eval_min1_cifar100_180', '0817_test_eval_min01_cifar100_15', '0817_test_eval_min01_cifar100_40', '0817_test_eval_min01_cifar100_70', '0817_test_eval_min01_cifar100_95', '0817_test_eval_min01_cifar100_125', '0817_test_eval_min01_cifar100_150', '0817_test_eval_min01_cifar100_180', 'mb_cos_flop_oh', 'mb_exp_flop_oh', 'mb_exp_flop_ws', 'mb_cos_flop_ws', 'mb_exp_noflop_ws', 'mb_cos_noflop_ws', 'mb_exp_noflop_oh', 'mb_cos_noflop_oh', '0819_orig_gumbel_N_1000_reg_1e_5_sampling', '0819_orig_gumbel_N_500_reg_1e_4_sampling', '0819_orig_gumbel_N_500_reg_1e_5_sampling', '0819_orig_gumbel_N_1000_reg_1e_4_sampling', '0819_orig_gumbel_N_500_reg_1e_4', '0819_orig_gumbel_N_1000_reg_1e_5', '0819_orig_gumbel_N_1000_reg_1e_4', '0819_orig_gumbel_N_500_reg_1e_5', 'm2_orig_TI']\n" ] }, { @@ -44,7 +44,7 @@ " 'skip']" ] }, - "execution_count": 271, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -58,12 +58,13 @@ "# check sampled architecture's structure\n", "architecture = model_dictionary[model_name]['block_op_type']\n", "\n", + "CANDIDATE_BLOCKS = CANDIDATE_BLOCKS['mobilenetv2']\n", "CANDIDATE_BLOCKS" ] }, { "cell_type": "code", - "execution_count": 272, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -112,7 +113,7 @@ "True" ] }, - "execution_count": 272, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -121,12 +122,12 @@ "candidate = len(CANDIDATE_BLOCKS)\n", "\n", "# width, height \n", - "one_center_y = int(height / 2)\n", "\n", "width = 1700\n", "height = 200\n", "step = 17\n", "\n", + "one_center_y = int(height / 2)\n", "\n", "# box color : blue , green, white\n", "# for matlplot lib\n", @@ -183,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 247, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -195,7 +196,7 @@ }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAABKCAYAAABAUxQ5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAATGElEQVR4nO2de5AV1Z3HP787L97MAMICAwEViWwCUWGRrMaIYMA1oaywtY5uBJctqsi6te5udH0kVlmprWwSa2tlS+NjN5pEVwXXTZBVMcFXKrggiKKMMCCPYZDnwAyPAeZ19o9z7nBnbt+Z7ju3+/Zcf5+pqel7Tt/Tv/7dPr/77V+fOUeMMSiKoiiFRSLfBiiKoii5R4O7oihKAaLBXVEUpQDR4K4oilKAaHBXFEUpQDS4K4qiFCChBHcRmSci20Vkp4jcE8YxFEVRlMxIrse5i0gRUAPMBeqA94AqY0x1Tg+kKIqiZCQM5f4nwE5jzC5jTDPwPLAghOMoiqIoGSgOoc2xwL6U13XAzO7eMGLECDNhwoScHLy5GYqLIRHDpwnt7dDaCqWl+bYkfIyxn0VZWb4t6XvodVKYhHGumzZtOmqMucCrLozg7gsRWQosBRg/fjwbN27MSbuNjTBgAJSU5KS5nNLSAk1NMHSofX2GMzTSGKiNwQxmIANDsK4z7bRTTz1ttPl+TzHFDGc4gtDWBidOQHk5iEAzzRzjWCAbBjCAIQwJanpgDIZjHKOFFt/vSZBgOMMpoiitrpVW6qnH4D/lWUYZ5ZQjCM3NcPYsDAn51Hu6/koppYIKBOmxLYPhBCc4w5mM+3hdu62tcOqUvU6iwGCop55WWgEY5H6ioL0dGhqgosL2iVwgInsz1YUR3PcD41JeV7qyThhjngCeAJg+fXrOEv9NTfabMY7Bva2tc3B/lVdZxrJAQeABHuAO7gjJwvOc5CTzmU8ttb7fM5nJrGENAxiAMZ077SY28W2+3dGp/LCEJfyIHwU1PTAttFBFFR/wge/3jGQkb/ImF5Aumnazm+u5ntOc9t3ePObxNE8jCK2tcOZM+MH9dV5nKUszXn/XcA3P87znF5gX3+f7vMALGesf5EGWsaxTWXs7nD4dXXA/xzkWspBq7CPAu7mb7/G9SI6d7BMVFZEcLpTg/h4wSUQmYoP6zcAtIRzHkwEDoMjftRg5RUXWviRnOcthDgdqo4mmHFvlTVLhHOGI7/dcwAUdgUIEBqUIomaaOczhQHcCJznpe9/eYDAc53igc02QoJ12z7o22jjCkUDBPVVBFxdD//6+35o1PV1/DTQEau8EJ7r1ode1m0jAwPBvRDtI3qUl7QzyGfWWrn0ibHKemTbGtAJ3AGuAT4AVxpituT5OJpqarEKOI0nl/nkgqVKU4CSV++eBpHL/PBB1nwgl526MeQV4Jas3nzoV3ANDh3ZInU7K3Rg4ftw+xfCLCIwY4S3/29rg6FHbrl9KSzuSbF2Ve04xxj5wOHs22PuGDw8lhxW6Sjl5MnhUqKjoE0/u0pR7U5N9gBGEIUNCvNhyR5pyP3fO9tkgDBwIgwfn1K4wiFq55+2Bakaeegp++EP/+4vAY4/BTTcBXXLubW2wZAn84Q/+2ysvhzfegMrK9LrPPoPrrrNPRfxy1VWwciUUFaXl3HPOfffBiy/637+sDFavhmnTcm5K15x7zlm+HB5+2P/+iQQ88wzMmROSQbkjLee+ciXcdVewRh56CG67Lee25Zq0nPs778B3vmMr/HLnnfbajzmFkHPvHU1NcMR/7hPopFbTcu4NDcHaa23NnNdpb7fKPYiyaGjoUPqhKnew6i7IuZaV2fMNgdBVyunTwc5VxKrCPkCacj9zJnif6CN5HU/lfvhwsLvjPpLX6fM593yjOfd4oDn37NGce2ESdZ8ouODel0bLFDJRq5RCIqrRMnEg6tEy+USVey9R5R4PVLlnjyr3wqQgRsvkE1Xu8UCVe/aocu87NDY2Ul1djTEGEWHy5MkMGzbMc9+e+oQxhtraWvbvt//zWVpaytSpUynNch4KVe4Rospd8YMq977D2rVrWbVqFTU1Nbz22mu8/PLLGff10yceeeQRNm/eTE1NDY8++ih792acXaBHVLlHiCp3xQ+q3PsOra2tzJ07l9mzZ7N+/XqqqzPPbO6nT5SUlFBVVcWwYcNoaGigPciQ0C6oco8QVe6KH1S5FyY6WqaXqHKPB6rcs0eVe2Gio2V6iSr3eKDKPXtUuRcmsVPuIvJzETksIh+nlA0Tkd+KyA73t8KVi4gsd2unbhGRy8M03gtV7vFAlXv2qHIvTOKo3J8G5nUpuwdYa4yZBKx1rwHmA5Pc71LgZ7kx0z+q3OOBKvfsUeVemMRunLsx5h0RmdCleAHwdbf9C+At4J9c+S+NXXX7/0SkXERGG2MOdHeMU6dg3Tq7PXYvfMG//Rhgxw44uu582e7d9q+0wZQTEGSertZW+PB9OJe2vAiUHYRprcGGGDU2QvW7YFLuJpL27RiB/RoMsCrL3j2w7jPvuklH8Vg6IjPt7fDRFjjtMeXKqSI4Nw3o57+9piZYvwX6pTzg3+cWXNw6BMwfE+hcDx6Adbu968bX2VVg/GKAbZ/AcY9Jm1oETn0ZgizI09IMGz+ECo/Fm/b2h7ap4HONCwCOH4N3t3V2z6ef2r+jdsGF+HedAXZ9CofWdb9fzXDgkswNNzTAu9X+c7dHLgZGZq7fsyfztZsc8VexDb6Y2SRP6uqgtodzTXIuAU1fhuSCUPtqYV1dgIN1oaYGDh2Cfv1g61Z3jj3YUtfN8erqYMMGO2nc7t2weTPU12dnW7ZDIUelBOyDwCi37bV+6lggLbinLrM3fPh4ampseenRYMEd4MAB2F2TXi7tMKEpWHBva4Ndu+C0x+pjA+vhS23BnNbUBDU7wHj0kAMt2OAegKNH6fBVV0adCBbcjYHavVDv0ZOaSqFlCoGCe3Mz7NwJpR5zke0bBUwhUK9taMh8roOOBQvuAPv3Q51He60JODuJQMG9rc0G3yEeMywfGArmSwQK7qdO2XP1co8cssE9CIcOZ/ZdkgMTsME9A8lrN+FzDq8TI+k2uB850rNNlfttcA/CsWM9t5ukpQjOXUJHcK+v9/9eLw4csLNPDx5shczBg71r79gxe10NHGj9tWdP8Fm8k/R6nLsxxohI4GXyui6zt3ixqzgErPTfjgDXfA2uqbKvO62h2gr8Ctjmv72yMvjzhXh/w+wF/hmCLIY0ejQsug0oTl9DtRR43H9TAFwxHRZP96gwwO+Ad/23VVQE3/wmcEV6XQPwU+BoANvKy+HWv7T9pusaqm+79oLwxUth8aUZKmuA//XfluBm+/2z9Lpz2M9hTwDb+vWHv7j5vKpJZRv2MgkyB+W48bBosVXJaWuongV+6b8tAb46C766uPv9+tN93nTMGFi0yN93lAHeBLqbXHvGDFg8o3NZ2hqqq4EAMzkDTJ0KUxf72/cM8O9Y1Qnwlctg8WXBjpfKgAF2+YfZs2H9eqiuho5Y1oWe1lA1xmYhqqpg2DC77/z5MHly5uPffnvmumxHyxwSkdEA7m9yrS5f66eGiebc44Hm3LNHc+6FSexGy2RgFbDIbS8CfpNSfpsbNXMl0NhTvj3X6GiZeKCjZbJHR8sUJrEbLSMiz2Fv9ieLSJ2ILAH+BZgrIjuAOe412KX1dgE7gSeB74ZidTeoco8HqtyzR5V7YRLH0TJVGaqu89jXAH/TW6N6gyr3eKDKPXtUuRcmsVPufQ1V7vFAlXv2qHIvTGKn3PsaqtzjgSr37FHl3ncQEbZu3Uq/fv3YunUrRd0EHz99orm5mQ0bNjBkyBBqa2t7ZZsq9whR5a74QZV732HWrFkMHjyYmpoaSkpKuPbaazPu66dPLFy4kIMHD1JTU8OMGTOorAz63xznUeUeIarcFT+ocu87VFZWsjjTwPYu9NQnRISZM2cyc+bMnNimyj1CVLkrflDlXpj0lXHusUWVezxQ5Z49qtwLEx0t00tUuccDVe7Zo8q9MNHRMkVFbmKYACTOf0elKffi4mDtlZR4T/wAtrykJFh7xeddHLpyz+W59pLQVUrQ6ySR6HSdxJk05Z5I9KpPxJk05Z5IQGmpjfp+ieuteheiVu7xC+633AJXXx3sPRdf3LHZ1GQn/yopwX7oy5fbGaz8UlwMo7ymg8KWv/KKlVZ+GTKk4+JLKvehQaapDMIPfgDLlvnfX6T7WYl6QVKldEwIlWuWLoUbbwz2npDONdcklXvHxGELFsC0acEamTgx53aFQVK5d1wns2bB228Ha2TMmJzbFQbJPlHhMe10GMQvuI8Z06sPq5NyF4FLM00rmAVlZXCFxxSKPglVuYvARRfZ3xgQukoZN87+FiBpyn3UqMyCo4+TptwrKiBHo0Xihubce4nm3OOB5tyzR3PuhYnm3HtJXxotkyBBCcFyqUVBVoDoJSXuxy/FFCNuuYmuKiVBglJKacV/SivKcy2mONC5llDSca5dESQr33VsRzRaJvmZGLyXYygOGB568mHCQ0vmY7RM6mcT5TWmOfde0innHjO65tznMIff8/tAbYwjmlTEIAaxghWcC7DkRH/6U0YZkJ5zn8Y03uKtjIHEi1GeS2HknhJKeJInOY1/CVlCCRV4J0/HM57XeZ12/D8UrKCi48siLeceErOZzTu8k7F+KEM9A3Im7uM+ltrF1TwZz/i0srSce8iUUcbTPE2TW3FnLGOjOTDR59zFTuSYX6ZPn242btyYbzMURVH6FCKyyRjjtTZbPIK7iJwEtufbjm4YQbAV56JEbcueONsXZ9sg3vZ9nmz7gjHGc+nkuKRltmf69okDIrIxrvapbdkTZ/vibBvE2z61zVJwo2UURVEUDe6KoigFSVyC+xP5NqAH4myf2pY9cbYvzrZBvO1T24jJA1VFURQlt8RFuSuKoig5JO/BXUTmich2EdkpIvfk4fjjRORNEakWka0i8neufJiI/FZEdri/Fa5cRGS5s3eLiFwegY1FIrJZRFa71xNFZL2z4QURKXXlZe71Tlc/IQLbykXkRRHZJiKfiMisuPhORP7efaYfi8hzItIvn74TkZ+LyGER+TilLLCvRGSR23+HiCwK0bafus91i4j8j4iUp9Td62zbLiLfSCnPeX/2si2l7h9FxIjICPc6Ur91Z5+I/K3z31YR+UlKeTS+M8bk7RcoAj4FLgRKgQ+BKRHbMBq43G0PBmqAKcBPgHtc+T3Aj932DcCrgABXAusjsPEfgP8CVrvXK4Cb3fZjwDK3/V3gMbd9M/BCBLb9Avhrt10KlMfBd8BYYDfQP8Vni/PpO+BrwOXAxyllgXwFDAN2ub8VbrsiJNuuB4rd9o9TbJvi+moZMNH14aKw+rOXba58HLAG2AuMyIffuvHdtcDvgDL3emTUvgu14/twyixgTcrre4F782zTb4C52H+qGu3KRmPH4gM8DlSl7N+xX0j2VAJrgdnAanfRHk3pdB0+dBf6LLdd7PaTEG0big2g0qU8777DBvd9rjMXO999I9++AyZ0CQKBfAVUAY+nlHfaL5e2dam7CXjWbXfqp0nfhdmfvWwDXgSmAXs4H9wj91uGz3UFMMdjv8h8l++0TLIDJqlzZXnB3YpfBqwHRhljDriqg9Ax0UnUNv8bcDd0TFQyHGgwxiRn4Eo9fodtrr7R7R8WE4EjwFMubfQfIjKQGPjOGLMfeAioBQ5gfbGJ+PguSVBf5avP/BVWEcfCNhFZAOw3xnzYpSrvtjkuAa52Kb63RWRG1PblO7jHBhEZBPw3cKcxptPqHsZ+lUY+rEhEbgQOG2M2RX1snxRjb0d/Zoy5DDiNTS10kEffVQALsF9AY4CBwLyo7QhCvnzVEyJyP9AKPJtvWwBEZABwH/BAvm3phmLsXeOVwF3ACpGQlj3LQL6D+37oNM1hpSuLFBEpwQb2Z40xL7niQyIy2tWPBg678iht/lPgWyKyB3gem5p5GCgXkeTUEanH77DN1Q8F6kOyDay6qDPGrHevX8QG+zj4bg6w2xhzxBjTAryE9WdcfJckqK8i7TMishi4EbjVffnEwbaLsF/aH7q+UQm8LyJ/FAPbktQBLxnLBuyd94go7ct3cH8PmORGMJRiH2StitIA9236n8Anxph/TalaBSSfqC/C5uKT5be5p/JXAo0pt9U5xRhzrzGm0hgzAeubN4wxtwJvAgsz2Ja0eaHbPzQlaIw5COwTkeT6ddcB1cTAd9h0zJUiMsB9xknbYuG7FIL6ag1wvYhUuLuT611ZzhGRediU4LeMManLzKwCbhY7wmgiMAnYQET92RjzkTFmpDFmgusbddhBEQeJgd8cv8Y+VEVELsE+JD1KlL7L1QOFXjyIuAE7QuVT4P48HP8q7K3wFuAD93sDNt+6FtiBfeo9zO0vwCPO3o+A6RHZ+XXOj5a50F0QO4GVnH8i38+93unqL4zArq8AG53/fo0diRAL3wEPAtuAj4FfYUco5M13wHPY/H8LNiAtycZX2Pz3Tvd7e4i27cTmgZP94rGU/e93tm0H5qeU57w/e9nWpX4P5x+oRuq3bnxXCjzjrr33gdlR+07/Q1VRFKUAyXdaRlEURQkBDe6KoigFiAZ3RVGUAkSDu6IoSgGiwV1RFKUA0eCuKIpSgGhwVxRFKUA0uCuKohQg/w9LyWlw/EIBbwAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAABKCAYAAABAUxQ5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAATMklEQVR4nO2de5AV1Z3HP787987ADAwzA8qiAwENoqyBoLhIVpOIYMA1oaywtaIbwWXLKrNu6e7Gd8WqrdRuTGJtrWxpfOxGk+iagGsisiom+EoFFwRRZEYYkOcgr5lhEBiY59k/Tt/Lnbl9h+47t/v2XH+fqanpPqcfv/7dPr/77V+fOUeMMSiKoijFRazQBiiKoij5R4O7oihKEaLBXVEUpQjR4K4oilKEaHBXFEUpQjS4K4qiFCGBBHcRmSsiW0Vku4jcG8Q5FEVRlOxIvvu5i0gJ0ADMARqB94CFxpj6vJ5IURRFyUoQyv3PgO3GmB3GmA7gV8D8AM6jKIqiZCEewDHPBfamrTcCM/rbYdSoUWb8+PF5OXlHB8TjEIvg24SeHujqgtLSQlsSPMbYz6KsrNCWDD70PilOgrjWDRs2NBljznKrCyK4e0JEbgVuBRg3bhzr16/Py3GPHoXyckgk8nK4vNLZCW1tMGKEXT/JSY5y1NcxhjOcCioCsK43PfTQTDPddHveJ06ckYxEELq74bPPoKoKRKCDDlpo8WVDOeVUUunXdN8YDC200Emn531ixBjJSEooyajrootmmjF4T3mWUUYVVQhCRwecOgWVAVy6wdBKK+20e95HEEYykrhLuOimm2aa6aHH8/ESJKihBkHo6oLjx+19EgYGQzPNdNEFwDDnJwx6eqC1FaqrbZvIByKyO1tdEMF9HzA2bb3WKeuFMeZJ4EmA6dOn5y3x39ZmvxmjGNy7u3sH91d5ldu4zVcQeJAHuZ3bA7LwNMc4xjzmsYc9nveZxCRWsYpyyjGmd6PdwAa+zbdTjcoLS1jCD/mhX9N900knC1nIB3zgeZ+zOZs3eZOzyBRNO9nJNVzDCU54Pt5c5vIMz6QC3smTwQX3O7iD13jN8z4VVPA6rzORiRl1LbQwi1kc4pDn401jGi/zMqWU0tMDJ06EF9zbaWcBC6jHvgK8m7v5Ht8L5dzJNlFdHcrpAgnu7wETRWQCNqjfANwYwHlcKS+HkkwxFQlKSqx9SU5xylejAGijLc9WuZNUOIc57Hmfszgr9UUlAsPSBFEHHRzikK8ngWMc87ztQDAYjnDE17XGiGVVq910c5jDvoJ7+hNcPA5Dh3re1TettPq61jbasn5uPfTQRJOv4x3hSGo5FoOK4B9EUySf0pL2+vmMBkrfNhE0ec9MG2O6gNuBVcDHwDJjTF2+z5ONtjarkKNIUrl/HkiqFMU/SeX+eSCp3D8PhN0mAsm5G2NeAV7Jaefjx/17YMSIlNTppdyNgSNH7FsMr4jAqFHu8r+7G5qa7HG9UlqaSrL1Ve55xRj7wuHUKX/7jRwZSA4rcJVy7Jj/qFBdPSje3GUo97Y2+wLDD5WVAd5s+SNDube32zbrh4oKGD48r3YFQdjKvWAvVLPy9NPwgx94314EHn8crr8e6JNz7+6GJUvgj3/0fryqKnjjDaitzaz79FO4+mr7VsQrV1wBy5dDSUlGzj3v3H8/vPCC9+3LymDlSpg6Ne+m9M25552lS+GRR7xvH4vBs8/C7NkBGZQ/MnLuy5fDXXf5O8jDD8PNN+fdtnyTkXN/5x34zndshVfuvNPe+xGnGHLuA6OtDQ57z98BvdRqRs69tdXf8bq6sud1enqscvejLFpbU0o/UOUOVt35udayMnu9ARC4Sjlxwt+1ilhVOAjIUO4nT/pvE4Mkr+Oq3A8d8vd0PEjyOoM+515oNOceDTTnnjuacy9Owm4TRRfcB1NvmWImbJVSTATdWyZKhN1bppCoch8gqtyjgSr33FHlXpwURW+ZQqLKPRqocs8dVe6Dh6NHj1JfX48xBhFh0qRJ1NTUuG57pjZhjGHPnj3s22f/57O0tJQpU6ZQmuM4FKrcQ0SVu+IFVe6Dh9WrV7NixQoaGhp47bXXePnll7Nu66VNPProo2zcuJGGhgYee+wxdu/OOrrAGVHlHiKq3BUvqHIfPHR1dTFnzhxmzZrF2rVrqa/PPrK5lzaRSCRYuHAhNTU1tLa20uOnS2gfVLmHiCp3xQuq3IsT7S0zQFS5RwNV7rmjyr040d4yA0SVezRQ5Z47qtyLk8gpdxH5mYgcEpHNaWU1IvI7Ednm/K12ykVEljpzp24SkUuCNN4NVe7RQJV77qhyL06iqNyfAeb2KbsXWG2MmQisdtYB5gETnd9bgZ/mx0zvqHKPBqrcc0eVe3ESuX7uxph3RGR8n+L5wNed5Z8DbwH3OOW/MHbW7f8TkSoRGWOM2d/fOY4fhzVr7PK5u+EL3u3HANu2QdOa02U7d9q/0g2TPwM/43R1dcGH70N7xvQiUHYApnb562J09CjUvwsm7Wkiad+2UdivQR+zsuzeBWs+da+b2ITL1BHZ6emBjzbBCZchV46XQPtUYIj347W1wdpNMCTtBf9eZ8LFukowf4qvaz2wH9bsdK8b12hngfGKAbZ8DEdcBm3qFDj+JfAzIU9nB6z/EKpdJm/aPRS6p4DLJE1ZOdIC727p7Z5PPrF/R++A8/DuOgPs+AQOrnGvO3Ih4N4V25Webti4CVpcvnBaEtA5FfDRFfv4MXh3MyTSho9J9vir3gIX4us2obER9rhcqxvtMWj7EiQnM9u7B9Y0+jhZHxoa4OBBGDIE6upg167Tsaw/e/urW7fODhq3cyds3AjNzbnZlmtXyNFpAfsAMNpZdps/9VwgI7inT7M3cuQ4GhpseWmTv+AOsH8/7GzILJceGN/mL7h3d8OOHXDCZfa7ima4uNuf09raoGEbGJdnpP2d4DK5Tb80NZHyVV9Gf+YvuBsDe3ZDs0tLaiuFzsn4Cu4dHbB9O5S6jEW2dzQwGV+ttrU1+7UOa/EX3AH27YNGl+N1xeDURHwF9+5uG3wrXUZY3j8CzMX4Cu7Hj9trdXOPHLTB3Q8HD7n7zggcr8VfcO+xQavdpU18NgS6L/Zn26l2K8jiLr38avfZ4O6Hlpbs90lfOkug/QJSwb252fu+buzfb0efHj7cCpkDBwZ2vJYWe19VVNix4nbt8j+Kd5IB93M3xhgR8T1NXt9p9hYvdioOAsu9H0eAr30VvrbQrveaQ7UL+CWwxfvxysrgLxfg/g2zG/gX8DMZ0pgxsOhmIJ45h2op8IT3QwFw6XRYPN2lwgC/B971fqySEvjmN4FLM+tagZ8ATT5sq6qCm/7atpu+c6i+7RzPDxdeBIsvylLZAPyv92MJzmi/f5FZ1479HHb5sG3IUPirG06rmnS2YG8TP2NQjh0HixbbPGnGHKqngF94P5YAX5kJX1mcWdcD/AZ8TCgI8YQdUdst6B4E/hV8zQQ8apQdjbgUMudQXQn4GMkZYMoUmLLY27Yngf/Aqk6AL0+DxdP8nS+d8nJ7PbNmwdq1UF8PqVjWhzPNoWqM/dJbuBBqauy28+bBpEnZz3/LLdnrcu0tc1BExgA4f5NzxXmaPzVINOceDTTnnjuacy9OItdbJgsrgEXO8iLgpbTym51eM5cDR8+Ub8832lsmGmhvmdzR3jLFSeR6y4jI89iH/Uki0igiS4CHgDkisg2Y7ayDnVpvB7AdeAr4biBW94Mq92igyj13VLkXJ1HsLbMwS9XVLtsa4O8GatRAUOUeDVS5544q9+Ikcsp9sKHKPRqocs8dVe7FSeSU+2BDlXs0UOWeO6rcBw8iQl1dHUOGDKGuro6SfoKPlzbR0dHBunXrqKysZM+ePQOyTZV7iKhyV7ygyn3wMHPmTIYPH05DQwOJRIKrrroq67Ze2sSCBQs4cOAADQ0NXHbZZdTW+v1vjtOocg8RVe6KF1S5Dx5qa2tZnK1jex/O1CZEhBkzZjBjxoy82KbKPURUuSteUOVenAyWfu6RRZV7NFDlnjuq3IsT7S0zQFS5RwNV7rmjyr040d4yJSXOwDA+iJ3+jspQ7vG4v+MlEu4DP4AtTyT8HS9+2sWBK/d8XusACVyl+L1PYrFe90mUyVDusdiA2kSUyVDusRiUltqo75WoPqr3IWzlHr3gfuONcOWV/vb54hdTi21tdvCvRAL7oS9dakew8ko8DqPdhoPClr/yipVWXqmsTN18SeU+ws8wlX74/vfhttu8by/S/6hEAyCpUlIDQuWbW2+F667zt09A15pvkso9NXDY/Pkwdaq/g0yYkHe7giCp3FP3ycyZ8Pbb/g5yzjl5tysIkm2i2mXY6SCIXnA/55wBfVi9lLsIXJRtWMEcKCuDS12GUPRIoMpdBM4/3/5GgMBVytix9rcIyVDuo0dnFxyDnAzlXl0NeeotEjU05z5ANOceDTTnnjuacy9ONOc+QAZTb5kYMRL4y6WW+JkBYoAknB+vxIkjznQTfVVKjBillNKF95RWmNcaJ+7rWhMkUtfaF0Fy8l1qOeDeMvm81mR9rtdaiN4y6faGeY9pzn2A9Mq5R4y+OffZzOYP/MHXMcYSTipiGMNYxjLafUw5MZShlFEGZObcpzKVt3gLg/d5XUa7ToWRfxIkeIqnOIF3CZkgQTXuydNxjON1XqcH7y8Fq6lOBdCMnHseEYSHeIh7uMfzPjFijGOca10NNbzES3TiMt9gFiqoSAX4jJx7wJRRxjM8Q5sz4865nBvOiQk/5y52IMfCMn36dLN+/fpCm6EoijKoEJENxhi3udmiEdxF5BiwtdB29MMo/M04FyZqW+5E2b4o2wbRtu/zZNsXjDGuUydHJS2zNdu3TxQQkfVRtU9ty50o2xdl2yDa9qltlqLrLaMoiqJocFcURSlKohLcnyy0AWcgyvapbbkTZfuibBtE2z61jYi8UFUURVHyS1SUu6IoipJHCh7cRWSuiGwVke0icm8Bzj9WRN4UkXoRqRORO5zyGhH5nYhsc/5WO+UiIksdezeJyCUh2FgiIhtFZKWzPkFE1jo2/FpESp3yMmd9u1M/PgTbqkTkBRHZIiIfi8jMqPhORP7B+Uw3i8jzIjKkkL4TkZ+JyCER2ZxW5ttXIrLI2X6biCwK0LafOJ/rJhH5jYhUpdXd59i2VUS+kVae9/bsZlta3T+JiBGRUc56qH7rzz4R+XvHf3Ui8uO08nB8Z4wp2C9QAnwCnAeUAh8Ck0O2YQxwibM8HGgAJgM/Bu51yu8FfuQsXwu8CghwObA2BBv/EfhvYKWzvgy4wVl+HLjNWf4u8LizfAPw6xBs+znwt85yKVAVBd8B5wI7gaFpPltcSN8BXwUuATanlfnyFVAD7HD+VjvL1QHZdg0Qd5Z/lGbbZKetlgETnDZcElR7drPNKR8LrAJ2A6MK4bd+fHcV8HugzFk/O2zfBdrwPThlJrAqbf0+4L4C2/QSMAf7T1VjnLIx2L74AE8AC9O2T20XkD21wGpgFrDSuWmb0hpdyofOjT7TWY4720mAto3ABlDpU15w32GD+16nMccd332j0L4DxvcJAr58BSwEnkgr77VdPm3rU3c98Jyz3KudJn0XZHt2sw14AZgK7OJ0cA/db1k+12XAbJftQvNdodMyyQaYpNEpKwjOo/g0YC0w2hiz36k6AKmBTsK2+d+BuyE1UMlIoNUYkxyBK/38Kduc+qPO9kExATgMPO2kjf5TRCqIgO+MMfuAh4E9wH6sLzYQHd8l8eurQrWZv8Eq4kjYJiLzgX3GmA/7VBXcNocLgCudFN/bInJZ2PYVOrhHBhEZBvwPcKcxptfsHsZ+lYberUhErgMOGWM2hH1uj8Sxj6M/NcZMA05gUwspCui7amA+9gvoHKACmBu2HX4olK/OhIg8AHQBzxXaFgARKQfuBx4stC39EMc+NV4O3AUsEwlo2rMsFDq474NewxzWOmWhIiIJbGB/zhjzolN8UETGOPVjgENOeZg2/znwLRHZBfwKm5p5BKgSkeTQEennT9nm1I8AmgOyDay6aDTGrHXWX8AG+yj4bjaw0xhz2BjTCbyI9WdUfJfEr69CbTMishi4DrjJ+fKJgm3nY7+0P3TaRi3wvoj8SQRsS9IIvGgs67BP3qPCtK/Qwf09YKLTg6EU+yJrRZgGON+m/wV8bIz5t7SqFUDyjfoibC4+WX6z81b+cuBo2mN1XjHG3GeMqTXGjMf65g1jzE3Am8CCLLYlbV7gbB+YEjTGHAD2ikhy/rqrgXoi4DtsOuZyESl3PuOkbZHwXRp+fbUKuEZEqp2nk2ucsrwjInOxKcFvGWPSp5lZAdwgtofRBGAisI6Q2rMx5iNjzNnGmPFO22jEdoo4QAT85vBb7EtVROQC7EvSJsL0Xb5eKAzgRcS12B4qnwAPFOD8V2AfhTcBHzi/12LzrauBbdi33jXO9gI86tj7ETA9JDu/zuneMuc5N8R2YDmn38gPcda3O/XnhWDXl4H1jv9+i+2JEAnfAf8MbAE2A7/E9lAomO+A57H5/05sQFqSi6+w+e/tzu8tAdq2HZsHTraLx9O2f8CxbSswL6087+3ZzbY+9bs4/UI1VL/147tS4Fnn3nsfmBW27/Q/VBVFUYqQQqdlFEVRlADQ4K4oilKEaHBXFEUpQjS4K4qiFCEa3BVFUYoQDe6KoihFiAZ3RVGUIkSDu6IoShHy/5Zve9I1a0hyAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -221,9 +222,9 @@ ], "metadata": { "kernelspec": { - "display_name": "py37_tf1", + "display_name": "dist_test2", "language": "python", - "name": "py37_tf1" + "name": "dist_test2" }, "language_info": { "codemirror_mode": { @@ -235,7 +236,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/fbnet_building_blocks/fbnet_builder.py b/fbnet_building_blocks/fbnet_builder.py index 2288a98..23665e0 100644 --- a/fbnet_building_blocks/fbnet_builder.py +++ b/fbnet_building_blocks/fbnet_builder.py @@ -192,6 +192,26 @@ def _get_divisible_by(num, divisible_by, min_val): "ir_k7_sep_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs ), + # for SimpleNet + "s_k3": lambda C_in, C_out, expansion, stride, **kwargs: Simple( + C_in, C_out, stride, kernel=3 + ), + "s_k5": lambda C_in, C_out, expansion, stride, **kwargs: Simple( + C_in, C_out, stride, kernel=5 + ), + "s_k7": lambda C_in, C_out, expansion, stride, **kwargs: Simple( + C_in, C_out, stride, kernel=7 + ), + # for ResNet + "r_k3": lambda C_in, C_out, expansion, stride, **kwargs: ResidualPreAct( + C_in, C_out, stride, kernel=3 + ), + "r_k5": lambda C_in, C_out, expansion, stride, **kwargs: ResidualPreAct( + C_in, C_out, stride, kernel=5 + ), + "r_k7": lambda C_in, C_out, expansion, stride, **kwargs: ResidualPreAct( + C_in, C_out, stride, kernel=7 + ), } @@ -204,6 +224,120 @@ def forward(self, x): return x.view(-1, shape) +class DownSample(nn.Module): + def __init__(self, C_in, C_out, stride): + super(DownSample, self).__init__() + self.conv = Conv2d(C_in, C_out, kernel_size=1, stride=stride, bias=False) + #self.conv = nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, bias=False) + self.bn = BatchNorm2d(C_out) + + def forward(self, x): + out = self.conv(x) + out = self.bn(out) + + return out + + +class ResidualPreAct(nn.Module): + def __init__(self, C_in, C_out, stride, kernel): + super(ResidualPreAct, self).__init__() + self.output_depth = C_out # ANNA's code here + self.stride = stride + + if self.stride == 2: + input_channel = int(C_in/2) + else: + input_channel = C_in + + self.conv1 = ( + BNReluConv( + input_channel, + C_in, + kernel=kernel, + stride=stride, + pad=(kernel//2), + no_bias=1, + use_relu="relu", + bn_type="bn", + ) + ) + self.conv2 = ( + BNReluConv( + C_in, + C_in, + kernel=kernel, + stride=1, + pad=(kernel//2), + no_bias=1, + use_relu="relu", + bn_type="bn", + ) + ) + + if self.stride == 2: + self.downsample = DownSample( + input_channel, + C_in, + stride=2 + ) + + + def forward(self, x): + if self.stride == 2: + shortcut = self.downsample(x) + else: + shortcut = x + + out = self.conv1(x) + out = self.conv2(out) + out = out + shortcut + + return out + + def get_flops(self, x): + self.flops = count_conv_flop(self.conv, x) + out = self.conv(x) + return self.flops, out + + + +class Simple(nn.Module): + def __init__(self, C_in, C_out, stride, kernel): + super(Simple, self).__init__() + self.output_depth = C_out # ANNA's code here + self.conv = ( + ConvBNRelu( + C_in, + C_out, + kernel=kernel, + stride=stride, + pad=(kernel // 2), + no_bias=1, + use_relu="relu", + bn_type="bn", + ) + if C_in != C_out or stride != 1 + else None + ) + + def forward(self, x): + if self.conv: + out = self.conv(x) + else: + out = x + return out + + def get_flops(self, x): + if self.conv: + # self.flops = 0 + self.flops = count_conv_flop(self.conv, x) + out = self.conv(x) + else: + self.flops = 0 + out = x + return self.flops, out + + class Identity(nn.Module): def __init__(self, C_in, C_out, stride): super(Identity, self).__init__() @@ -363,6 +497,89 @@ def forward(self, x): ) + +class BNReluConv(nn.Sequential): + def __init__( + self, + input_depth, + output_depth, + kernel, + stride, + pad, + no_bias, + use_relu, + bn_type, + group=1, + *args, + **kwargs + ): + super(BNReluConv, self).__init__() + + assert use_relu in ["relu", None] + if isinstance(bn_type, (list, tuple)): + assert len(bn_type) == 2 + assert bn_type[0] == "gn" + gn_group = bn_type[1] + bn_type = bn_type[0] + assert bn_type in ["bn", "af", "gn", None] + assert stride in [1, 2, 4] + # for flops calculation + + self.stride = (stride, stride) + self.in_channels = input_depth + self.out_channels = output_depth + self.kernel_size = (kernel, kernel) + self.groups = group + + self.bn = None + if bn_type == "bn": + self.bn = BatchNorm2d(input_depth) + elif bn_type == "gn": + self.bn = nn.GroupNorm(num_groups=gn_group, num_channels=input_depth) + elif bn_type == "af": + self.bn = FrozenBatchNorm2d(input_depth) + # if bn_type is not None: + # self.add_module("bn", self.bn_op) + + self.relu = None + if use_relu == "relu": + self.relu = nn.ReLU(inplace=True) + # self.add_module("relu", self.activation) + + self.conv = Conv2d( + input_depth, + output_depth, + kernel_size=kernel, + stride=stride, + padding=pad, + bias=not no_bias, + groups=group, + *args, + **kwargs + ) + nn.init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="relu") + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0.0) + # self.add_module("conv", self.op) + + + def get_flops(self, x, only_flops=False): + flops1 = count_conv_flop(self.conv, x) + + if only_flops: + return flops1 + + if self.bn != None: + y = self.bn(x) + + if self.relu != None: + y = self.relu(y) + + y = self.conv(y) + + return flops1, y + + class ConvBNRelu(nn.Sequential): def __init__( self, @@ -446,9 +663,6 @@ def get_flops(self, x, only_flops=False): - - - class SEModule(nn.Module): reduction = 4 @@ -742,7 +956,7 @@ def unify_arch_def(arch_def): assert "block_op_type" in arch_def _add_to_arch(ret["stages"], arch_def["block_op_type"], "block_op_type") del ret["block_op_type"] - + return ret @@ -806,6 +1020,34 @@ def add_first(self, stage_info, dim_in=3, pad=True): ) self.last_depth = out_depth return out + + def add_first_resnet(self, stage_info, dim_in=3, pad=True): + # stage_info: [c, s, kernel] + assert len(stage_info) >= 2 + channel = stage_info[0] + stride = stage_info[1] + out_depth = self._get_divisible_width(int(channel * self.width_ratio)) + kernel = 7 + if len(stage_info) > 2: + kernel = stage_info[2] + + out = ConvBNRelu( + dim_in, + out_depth, + kernel=kernel, + stride=stride, + pad=3, + no_bias=1, + use_relu="relu", + bn_type=self.bn_type, + ) + + self.last_depth = out_depth + return out + + def add_maxpool_resnet(self): + out = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + return out def add_blocks(self, blocks): """ blocks: [{}, {}, ...] @@ -882,6 +1124,17 @@ def add_last_states(self, cnt_classes, dropout_ratio=0.2): ])) self.last_depth = cnt_classes return op + + def add_last_states_resnet(self, cnt_classes, dropout_ratio=0.2): + assert cnt_classes >= 1 + op = nn.Sequential(OrderedDict([ + ("avg_pool_k7", nn.AdaptiveAvgPool2d((1,1))), + ("flatten", Flatten()), + ("fc", nn.Linear(in_features=512, out_features=cnt_classes)), + ])) + self.last_depth = cnt_classes + return op + def _get_trunk_cfg(arch_def): num_stages = get_num_stages(arch_def) @@ -891,17 +1144,31 @@ def _get_trunk_cfg(arch_def): class FBNet(nn.Module): def __init__( - self, builder, arch_def, dim_in, cnt_classes=1000 + self, builder, arch_def, dim_in, cnt_classes=1000, supernet_type='mobilenetv2' ): super(FBNet, self).__init__() - self.first = builder.add_first(arch_def["first"], dim_in=dim_in) + + self.supernet_type = supernet_type + + if self.supernet_type == 'resnet': + self.first = builder.add_first_resnet(arch_def["first"], dim_in=dim_in) + self.maxpool = builder.add_maxpool_resnet() + else: + self.first = builder.add_first(arch_def["first"], dim_in=dim_in) trunk_cfg = _get_trunk_cfg(arch_def) self.stages = builder.add_blocks(trunk_cfg["stages"]) - self.last_stages = builder.add_last_states(cnt_classes) + if self.supernet_type == 'resnet': + self.last_stages = builder.add_last_states_resnet(cnt_classes) + else: + self.last_stages = builder.add_last_states(cnt_classes) self.cnt_classes = cnt_classes def forward(self, x): y = self.first(x) + + if self.supernet_type == 'resnet': + y = self.maxpool(y) + y = self.stages(y) y = self.last_stages(y) return y @@ -940,12 +1207,185 @@ def get_flops(self, x): print('last stage : ', flops3) return accumlated_flops -def get_model(arch, cnt_classes): +def get_model(arch, cnt_classes, supernet_type): # for reload updated arch importlib.reload(fbnet_modeldef) assert arch in fbnet_modeldef.MODEL_ARCH arch_def = fbnet_modeldef.MODEL_ARCH[arch] arch_def = unify_arch_def(arch_def) builder = FBNetBuilder(width_ratio=1.0, bn_type="bn", width_divisor=8, dw_skip_bn=True, dw_skip_relu=True) - model = FBNet(builder, arch_def, dim_in=3, cnt_classes=cnt_classes) + model = FBNet(builder, arch_def, dim_in=3, cnt_classes=cnt_classes, supernet_type=supernet_type) + return model + + +################################################## +# torchvision ResNet18 +################################################## + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + #if pretrained: + # state_dict = load_state_dict_from_url(model_urls[arch], + # progress=progress) + # model.load_state_dict(state_dict) return model + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + diff --git a/fbnet_building_blocks/fbnet_modeldef.py b/fbnet_building_blocks/fbnet_modeldef.py index f585d3e..38c00eb 100644 --- a/fbnet_building_blocks/fbnet_modeldef.py +++ b/fbnet_building_blocks/fbnet_modeldef.py @@ -1566,6 +1566,198 @@ "backbone": [num for num in range(18)], }, }, + "mb_cos_flop_oh": { + "block_op_type": [ + ["skip"], + ["ir_k3_e6"], ["skip"], + ["ir_k5_e6"], ["skip"], ["skip"], + ["ir_k5_e6"], ["ir_k5_e3"], ["ir_k5_e1"], ["ir_k5_e3"], + ["ir_k5_e6"], ["ir_k5_e1"], ["ir_k5_e1"], + ["ir_k5_e6"], ["ir_k5_e1"], ["ir_k5_s2"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[1, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[1, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[1, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[3, 64, 1, 1]], [[1, 64, 1, 1]], [[3, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[1, 96, 1, 1]], [[1, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[1, 160, 1, 1]], [[1, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_exp_flop_oh": { + "block_op_type": [ + ["skip"], + ["skip"], ["ir_k3_s2"], + ["ir_k5_e6"], ["skip"], ["ir_k5_s2"], + ["ir_k5_e6"], ["skip"], ["skip"], ["skip"], + ["ir_k5_e6"], ["skip"], ["skip"], + ["ir_k5_e6"], ["ir_k5_s2"], ["ir_k5_s2"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[1, 16, 1, 1]], # stage 1 + [[1, 24, 1, 2]], [[1, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[1, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[1, 64, 1, 1]], [[1, 64, 1, 1]], [[1, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[1, 96, 1, 1]], [[1, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[1, 160, 1, 1]], [[1, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_exp_flop_ws": { + "block_op_type": [ + ["skip"], + ["skip"], ["ir_k3_s2"], + ["ir_k5_e6"], ["skip"], ["ir_k5_s2"], + ["ir_k5_e6"], ["skip"], ["skip"], ["skip"], + ["ir_k5_e6"], ["skip"], ["skip"], + ["ir_k5_e6"], ["ir_k5_s2"], ["ir_k5_s2"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[1, 16, 1, 1]], # stage 1 + [[1, 24, 1, 2]], [[1, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[1, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[1, 64, 1, 1]], [[1, 64, 1, 1]], [[1, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[1, 96, 1, 1]], [[1, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[1, 160, 1, 1]], [[1, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_cos_flop_ws": { + "block_op_type": [ + ["skip"], + ["ir_k3_e6"], ["skip"], + ["ir_k5_e6"], ["skip"], ["skip"], + ["ir_k5_e6"], ["ir_k5_e3"], ["ir_k5_e1"], ["ir_k5_e3"], + ["ir_k5_e6"], ["ir_k5_e1"], ["ir_k5_e1"], + ["ir_k5_e6"], ["ir_k5_e1"], ["ir_k5_s2"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[1, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[1, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[1, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[3, 64, 1, 1]], [[1, 64, 1, 1]], [[3, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[1, 96, 1, 1]], [[1, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[1, 160, 1, 1]], [[1, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_exp_noflop_ws": { + "block_op_type": [ + ["ir_k3_e6"], + ["ir_k3_e6"], ["ir_k3_e6"], + ["ir_k5_e6"], ["skip"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e3"], ["ir_k5_e6"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[6, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[6, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[6, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[6, 96, 1, 1]], [[6, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[3, 160, 1, 1]], [[6, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_cos_noflop_ws": { + "block_op_type": [ + ["ir_k3_e6"], + ["ir_k3_e6"], ["ir_k3_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[6, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[6, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[6, 32, 1, 1]], [[6, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[6, 96, 1, 1]], [[6, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[6, 160, 1, 1]], [[6, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_exp_noflop_oh": { + "block_op_type": [ + ["ir_k3_e6"], + ["ir_k3_e6"], ["ir_k3_e6"], + ["ir_k5_e6"], ["skip"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e3"], ["ir_k5_e6"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[6, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[6, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[1, 32, 1, 1]], [[6, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[6, 96, 1, 1]], [[6, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[3, 160, 1, 1]], [[6, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, + "mb_cos_noflop_oh": { + "block_op_type": [ + ["ir_k3_e6"], + ["ir_k3_e6"], ["ir_k3_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["ir_k5_e6"], ["ir_k5_e6"], ["ir_k5_e6"], + ["skip"], + ], + "block_cfg": { + "first": [32, 1], + "stages": [ + [[6, 16, 1, 1]], # stage 1 + [[6, 24, 1, 2]], [[6, 24, 1, 1]], # stage 2 + [[6, 32, 1, 2]], [[6, 32, 1, 1]], [[6, 32, 1, 1]], # stage 3 + [[6, 64, 1, 2]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], [[6, 64, 1, 1]], # stage 4 + [[6, 96, 1, 1]], [[6, 96, 1, 1]], [[6, 96, 1, 1]], # stage 5 + [[6, 160, 1, 2]], [[6, 160, 1, 1]], [[6, 160, 1, 1]], # stage 6 + [[1, 320, 1, 1]], # stage 7 + ], + "backbone": [num for num in range(18)], + }, + }, "0819_orig_gumbel_N_1000_reg_1e_5_sampling": { "block_op_type": [ ["ir_k5_e6"], diff --git a/general_functions/utils.py b/general_functions/utils.py index 8cfa28f..4d6b05c 100644 --- a/general_functions/utils.py +++ b/general_functions/utils.py @@ -127,7 +127,9 @@ def check_tensor_in_list(atensor, alist): # "ir_k3_e6", "ir_k5_e6", "ir_k5_e6", "ir_k5_e1", "skip", "ir_k5_e3", "ir_k5_e6", "ir_k3_e1", # "ir_k5_e1", "ir_k5_e3", "ir_k5_e6", "ir_k5_e1", "ir_k5_e6", "ir_k5_e6", "ir_k3_e6"] # my_unique_name_for_ARCH = "my_unique_name_for_ARCH" -def writh_new_ARCH_to_fbnet_modeldef(ops_names, my_unique_name_for_ARCH): +def writh_new_ARCH_to_fbnet_modeldef(ops_names, my_unique_name_for_ARCH, supernet_type="mobilenetv2"): + print('-- ops_names:', ops_names) + print('-- my_unique_name_for_ARCH:', my_unique_name_for_ARCH) # assert len(ops_names) == 22 if my_unique_name_for_ARCH in MODEL_ARCH: print("The specification with the name", my_unique_name_for_ARCH, "already written \ @@ -136,39 +138,97 @@ def writh_new_ARCH_to_fbnet_modeldef(ops_names, my_unique_name_for_ARCH): assert my_unique_name_for_ARCH not in MODEL_ARCH ### create text to insert + if supernet_type=='simple': + + text_to_write = " \"" + my_unique_name_for_ARCH + "\": {\n\ + \"block_op_type\": [\n" + + ops = ["[\"" + str(op) + "\"], " for op in ops_names] + ops_lines = [ops[0], ops[1]] + ops_lines = [''.join(line) for line in ops_lines] + text_to_write += ' ' + '\n '.join(ops_lines) + + e = [(op_name[-1] if op_name[-2] == 'e' else '1') for op_name in ops_names] + + text_to_write += "\n\ + ],\n\ + \"block_cfg\": {\n\ + \"first\": [32, 1],\n\ + \"stages\": [\n\ + [[" + e[0] + ", 160, 1, 1]], # stage 1\n\ + [[" + e[1] + ", 320, 1, 1]], # stage 2\n\ + ],\n\ + \"backbone\": [num for num in range(3)],\n\ + },\n\ + },\n\ + }\ + " + + elif supernet_type=='resnet': + text_to_write = " \"" + my_unique_name_for_ARCH + "\": {\n\ + \"block_op_type\": [\n" + + ops = ["[\"" + str(op) + "\"], " for op in ops_names] + ops_lines = [ops[0:2], ops[2:4], ops[4:6], ops[6:8]] + ops_lines = [''.join(line) for line in ops_lines] + text_to_write += ' ' + '\n '.join(ops_lines) + + e = [(op_name[-1] if op_name[-2] == 'e' else '1') for op_name in ops_names] + + text_to_write += "\n\ + ],\n\ + \"block_cfg\": {\n\ + \"first\": [64, 2],\n\ + \"stages\": [\n\ + [[0, 64, 1, 1]], # stage 1\n\ + [[0, 128, 1, 1]], # stage 2\n\ + [[0, 128, 1, 2]], # stage 3\n\ + [[0, 256, 1, 1]], # stage 4\n\ + [[0, 256, 1, 2]], # stage 5\n\ + [[0, 512, 1, 1]], # stage 6\n\ + [[0, 512, 1, 2]], # stage 7\n\ + [[0, 512, 1, 1]], # stage 8\n\ + ],\n\ + \"backbone\": [num for num in range(9)],\n\ + },\n\ + },\n\ + }\ + " - text_to_write = " \"" + my_unique_name_for_ARCH + "\": {\n\ - \"block_op_type\": [\n" - - ops = ["[\"" + str(op) + "\"], " for op in ops_names] - ops_lines = [ops[0], ops[1:3], ops[3:6], ops[6:10], ops[10:13], ops[13:16], ops[16]] - ops_lines = [''.join(line) for line in ops_lines] - text_to_write += ' ' + '\n '.join(ops_lines) - - e = [(op_name[-1] if op_name[-2] == 'e' else '1') for op_name in ops_names] - - text_to_write += "\n\ - ],\n\ - \"block_cfg\": {\n\ - \"first\": [32, 1],\n\ - \"stages\": [\n\ - [[" + e[0] + ", 16, 1, 1]], # stage 1\n\ - [[" + e[1] + ", 24, 1, 1]], [[" + e[2] + ", 24, 1, 1]], # stage 2\n\ - [[" + e[3] + ", 32, 1, 2]], [[" + e[4] + ", 32, 1, 1]], \ - [[" + e[5] + ", 32, 1, 1]], # stage 3\n\ - [[" + e[6] + ", 64, 1, 2]], [[" + e[7] + ", 64, 1, 1]], \ - [[" + e[8] + ", 64, 1, 1]], [[" + e[9] + ", 64, 1, 1]], # stage 4\n\ - [[" + e[10] + ", 96, 1, 1]], [[" + e[11] + ", 96, 1, 1]], \ - [[" + e[12] + ", 96, 1, 1]], # stage 5\n\ - [[" + e[13] + ", 160, 1, 2]], [[" + e[14] + ", 160, 1, 1]], \ - [[" + e[15] + ", 160, 1, 1]], # stage 6\n\ - [[" + e[16] + ", 320, 1, 1]], # stage 7\n\ + else: + text_to_write = " \"" + my_unique_name_for_ARCH + "\": {\n\ + \"block_op_type\": [\n" + + ops = ["[\"" + str(op) + "\"], " for op in ops_names] + ops_lines = [ops[0], ops[1:3], ops[3:6], ops[6:10], ops[10:13], ops[13:16], ops[16]] + ops_lines = [''.join(line) for line in ops_lines] + text_to_write += ' ' + '\n '.join(ops_lines) + + e = [(op_name[-1] if op_name[-2] == 'e' else '1') for op_name in ops_names] + + text_to_write += "\n\ ],\n\ - \"backbone\": [num for num in range(18)],\n\ + \"block_cfg\": {\n\ + \"first\": [32, 1],\n\ + \"stages\": [\n\ + [[" + e[0] + ", 16, 1, 1]], # stage 1\n\ + [[" + e[1] + ", 24, 1, 1]], [[" + e[2] + ", 24, 1, 1]], # stage 2\n\ + [[" + e[3] + ", 32, 1, 2]], [[" + e[4] + ", 32, 1, 1]], \ + [[" + e[5] + ", 32, 1, 1]], # stage 3\n\ + [[" + e[6] + ", 64, 1, 2]], [[" + e[7] + ", 64, 1, 1]], \ + [[" + e[8] + ", 64, 1, 1]], [[" + e[9] + ", 64, 1, 1]], # stage 4\n\ + [[" + e[10] + ", 96, 1, 1]], [[" + e[11] + ", 96, 1, 1]], \ + [[" + e[12] + ", 96, 1, 1]], # stage 5\n\ + [[" + e[13] + ", 160, 1, 2]], [[" + e[14] + ", 160, 1, 1]], \ + [[" + e[15] + ", 160, 1, 1]], # stage 6\n\ + [[" + e[16] + ", 320, 1, 1]], # stage 7\n\ + ],\n\ + \"backbone\": [num for num in range(18)],\n\ + },\n\ },\n\ - },\n\ - }\ - " + }\ + " + ### open file and find place to insert with open('./fbnet_building_blocks/fbnet_modeldef.py') as f1: lines = f1.readlines() diff --git a/supernet_functions/config_for_supernet.py b/supernet_functions/config_for_supernet.py index d47f5d3..47fc473 100644 --- a/supernet_functions/config_for_supernet.py +++ b/supernet_functions/config_for_supernet.py @@ -6,7 +6,9 @@ }, 'lookup_table' : { 'create_from_scratch' : False, - 'path_to_lookup_table' : './supernet_functions/lookup_table.txt', + 'path_to_lookup_table' : './supernet_functions/lookup_table_resnet.txt', + #'path_to_lookup_table' : './supernet_functions/lookup_table_simple.txt', + #'path_to_lookup_table' : './supernet_functions/lookup_table.txt', 'number_of_runs' : 15 # each operation run number_of_runs times and then we will take average }, 'logging' : { @@ -14,7 +16,8 @@ 'path_to_tensorboard_logs' : './supernet_functions/logs/tb' }, 'dataloading' : { - 'batch_size' : 128, # 200, + 'batch_size' : 8, # 200, + #'batch_size' : 128, # 200, 'w_share_in_train' : 0.8, 'path_to_save_data' : './cifar10_data' }, @@ -30,11 +33,12 @@ 'loss' : { 'alpha' : 0.2, 'beta' : 0.3, - 'reg_lambda' : 1e-1, + 'reg_lambda' : 0, 'reg_loss_type' : 'add#linear' # 'add#linear', 'mul#log' }, 'train_settings' : { - 'cnt_epochs' : 180, # 90 + 'cnt_epochs' : 40, # 90 + #'cnt_epochs' : 180, # 90 'train_thetas_from_the_epoch' : 0, 'print_freq' : 50, 'path_to_save_model' : './supernet_functions/logs/best_model.pth', diff --git a/supernet_functions/lookup_table_builder.py b/supernet_functions/lookup_table_builder.py index cf4cebc..ff34f5c 100644 --- a/supernet_functions/lookup_table_builder.py +++ b/supernet_functions/lookup_table_builder.py @@ -8,56 +8,99 @@ # the settings from the page 4 of https://arxiv.org/pdf/1812.03443.pdf #### table 2 -CANDIDATE_BLOCKS = ["ir_k3_e1", "ir_k3_s2", "ir_k3_e3", + +CANDIDATE_BLOCKS = { + "simple": ["s_k3", "s_k5", "s_k7"], + "resnet": ["r_k3", "r_k5", "r_k7"], + "mobilenetv2": ["ir_k3_e1", "ir_k3_s2", "ir_k3_e3", "ir_k3_e6", "ir_k5_e1", "ir_k5_s2", "ir_k5_e3", "ir_k5_e6", "skip"] -SEARCH_SPACE = OrderedDict([ - #### table 1. input shapes of 22 searched layers (considering with strides) - # Note: the second and third dimentions are recommended (will not be used in training) and written just for debagging - # Imagenet - original - # ("input_shape", [(32, 112, 112), - # (16, 112, 112), (24, 56, 56), - # (24, 56, 56), (32, 28, 28), (32, 28, 28), - # (32, 28, 28), (64, 14, 14), (64, 14, 14), (64, 14, 14), - # (64, 14, 14), (96, 14, 14), (96, 14, 14), - # (96, 14, 14), (160, 7, 7), (160, 7, 7), - # (160, 7, 7)]), - - # cifar-10 - ("input_shape", [(32, 32, 32), - (16, 32, 32), (24, 32, 32), - (24, 32, 32), (32, 16, 16), (32, 28, 16), - (32, 16, 16), (64, 8, 8), (64, 8, 8), (64, 8, 8), - (64, 8, 8), (96, 8, 8), (96, 8, 8), - (96, 8, 8), (160, 4, 4), (160, 4, 4), - (160, 4, 4)]), - # table 1. filter numbers over the 22 layers - ("channel_size", [16, - 24, 24, - 32, 32, 32, - 64, 64, 64, 64, - 96, 96, 96, - 160, 160, 160, - 320]), - # table 1. strides over the 22 layers - # mobiletnet v2 - cifar 10 - ("strides", [1, - 1, 1, - 2, 1, 1, - 2, 1, 1, 1, - 1, 1, 1, - 2, 1, 1, - 1]) - - # # mobilenet v2 -imagenet - orig - # ("strides", [1, - # 2, 1, - # 2, 1, 1, - # 2, 1, 1, 1, - # 1, 1, 1, - # 2, 1, 1, - # 1]) -]) + } + +SEARCH_SPACE = { + "simple": OrderedDict([ + ("input_shape", [(32, 32, 32), + (160, 32, 32)]), + ("channel_size", [160, + 320]), + ("strides", [1, + 1]) + ]), + "resnet": OrderedDict([ + ("input_shape", [(64, 16, 16), + (64, 16, 16), + (128, 8, 8), + (128, 8, 8), + (256, 4, 4), + (256, 4, 4), + (512, 2, 2), + (512, 2, 2)]), + ("channel_size", [64, + 128, + 128, + 256, + 256, + 512, + 512, + 512]), + ("strides", [1, + 1, + 2, + 1, + 2, + 1, + 2, + 1]) + ]), + "mobilenetv2": OrderedDict([ + #### table 1. input shapes of 22 searched layers (considering with strides) + # Note: the second and third dimentions are recommended (will not be used in training) and written just for debagging + # Imagenet - original + # ("input_shape", [(32, 112, 112), + # (16, 112, 112), (24, 56, 56), + # (24, 56, 56), (32, 28, 28), (32, 28, 28), + # (32, 28, 28), (64, 14, 14), (64, 14, 14), (64, 14, 14), + # (64, 14, 14), (96, 14, 14), (96, 14, 14), + # (96, 14, 14), (160, 7, 7), (160, 7, 7), + # (160, 7, 7)]), + + # cifar-10 + ("input_shape", [(32, 32, 32), + (16, 32, 32), (24, 32, 32), + (24, 32, 32), (32, 16, 16), (32, 28, 16), + (32, 16, 16), (64, 8, 8), (64, 8, 8), (64, 8, 8), + (64, 8, 8), (96, 8, 8), (96, 8, 8), + (96, 8, 8), (160, 4, 4), (160, 4, 4), + (160, 4, 4)]), + # table 1. filter numbers over the 22 layers + ("channel_size", [16, + 24, 24, + 32, 32, 32, + 64, 64, 64, 64, + 96, 96, 96, + 160, 160, 160, + 320]), + # table 1. strides over the 22 layers + # mobiletnet v2 - cifar 10 + ("strides", [1, + 1, 1, + 2, 1, 1, + 2, 1, 1, 1, + 1, 1, 1, + 2, 1, 1, + 1]) + + # # mobilenet v2 -imagenet - orig + # ("strides", [1, + # 2, 1, + # 2, 1, 1, + # 2, 1, 1, 1, + # 1, 1, 1, + # 2, 1, 1, + # 1]) + ]) +} + # **** to recalculate latency use command: # l_table = LookUpTable(calulate_latency=True, path_to_file='lookup_table.txt', cnt_of_runs=50) @@ -67,14 +110,22 @@ # TODO - flops change class LookUpTable: - def __init__(self, candidate_blocks=CANDIDATE_BLOCKS, search_space=SEARCH_SPACE, - calulate_latency=False, path='./supernet_functions/lookup_table.txt'): - self.cnt_layers = len(search_space["input_shape"]) + def __init__(self, candidate_blocks_dict=CANDIDATE_BLOCKS, search_space_dict=SEARCH_SPACE, + calulate_latency=False, path='./supernet_functions/lookup_table.txt', supernet_type='mobilenetv2'): + + self.candidate_blocks = candidate_blocks_dict[supernet_type] + self.search_space = search_space_dict[supernet_type] + + #print('candidate_blocks:', self.candidate_blocks) + #print('search_space:', self.search_space) + + self.cnt_layers = len(self.search_space["input_shape"]) + # constructors for each operation - self.lookup_table_operations = {op_name : PRIMITIVES[op_name] for op_name in candidate_blocks} + self.lookup_table_operations = {op_name : PRIMITIVES[op_name] for op_name in self.candidate_blocks} # arguments for the ops constructors. one set of arguments for all 9 constructors at each layer # input_shapes just for convinience - self.layers_parameters, self.layers_input_shapes = self._generate_layers_parameters(search_space) + self.layers_parameters, self.layers_input_shapes = self._generate_layers_parameters(self.search_space) # lookup_table self.lookup_table_flops = None @@ -86,6 +137,12 @@ def __init__(self, candidate_blocks=CANDIDATE_BLOCKS, search_space=SEARCH_SPACE, # self._create_from_file(path_to_file=path) + def get_candidate_blocks(self): + return self.candidate_blocks + + def get_search_space(self): + return self.search_space + def _generate_layers_parameters(self, search_space): # layers_parameters are : C_in, C_out, expansion, stride layers_parameters = [(search_space["input_shape"][layer_id][0], @@ -158,7 +215,7 @@ def _read_lookup_table_from_file(self, path_to_file): flops = [line.strip('\n') for line in open(path_to_file)] ops_names = flops[0].split(" ") flops = [list(map(float, layer.split(" "))) for layer in flops[1:]] - + lookup_table_flops= [{op_name : flops[i][op_id] for op_id, op_name in enumerate(ops_names) } for i in range(self.cnt_layers)] diff --git a/supernet_functions/lookup_table_resnet.txt b/supernet_functions/lookup_table_resnet.txt new file mode 100644 index 0000000..ee338db --- /dev/null +++ b/supernet_functions/lookup_table_resnet.txt @@ -0,0 +1,18 @@ +r_k3 r_k5 r_k7 r_k3 r_k5 r_k7 ir_k5_e3 ir_k5_e6 skip +1867776.0 1081344.0 5603328.0 11206656.0 2392064.0 1605632.0 7176192.0 14352384.0 524288.0 +802816.0 475136.0 2408448.0 4816896.0 1064960.0 737280.0 3194880.0 6389760.0 393216.0 +1400832.0 811008.0 4202496.0 8404992.0 1794048.0 1204224.0 5382144.0 10764288.0 0 +841728.0 448512.0 2525184.0 5050368.0 940032.0 546816.0 2820096.0 5640192.0 196608.0 +598016.0 335872.0 1794048.0 3588096.0 729088.0 466944.0 2187264.0 4374528.0 0 +598016.0 335872.0 1794048.0 3588096.0 729088.0 466944.0 2187264.0 4374528.0 0 +411648.0 215040.0 1234944.0 2469888.0 444416.0 247808.0 1333248.0 2666496.0 131072.0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +692224.0 364544.0 2076672.0 4153344.0 757760.0 430080.0 2273280.0 4546560.0 393216.0 +1234944.0 645120.0 3704832.0 7409664.0 1333248.0 743424.0 3999744.0 7999488.0 0 +1234944.0 645120.0 3704832.0 7409664.0 1333248.0 743424.0 3999744.0 7999488.0 0 +849408.0 431616.0 2548224.0 5096448.0 873984.0 456192.0 2621952.0 5243904.0 245760.0 +842240.0 432640.0 2526720.0 5053440.0 883200.0 473600.0 2649600.0 5299200.0 0 +842240.0 432640.0 2526720.0 5053440.0 883200.0 473600.0 2649600.0 5299200.0 0 +1251840.0 637440.0 3755520.0 7511040.0 1292800.0 678400.0 3878400.0 7756800.0 819200.0 diff --git a/supernet_functions/lookup_table_simple.txt b/supernet_functions/lookup_table_simple.txt new file mode 100644 index 0000000..6743534 --- /dev/null +++ b/supernet_functions/lookup_table_simple.txt @@ -0,0 +1,18 @@ +s_k3 s_k5 s_k7 r_k3 r_k5 r_k7 ir_k5_e3 ir_k5_e6 skip +1867776.0 1081344.0 5603328.0 11206656.0 2392064.0 1605632.0 7176192.0 14352384.0 524288.0 +802816.0 475136.0 2408448.0 4816896.0 1064960.0 737280.0 3194880.0 6389760.0 393216.0 +1400832.0 811008.0 4202496.0 8404992.0 1794048.0 1204224.0 5382144.0 10764288.0 0 +841728.0 448512.0 2525184.0 5050368.0 940032.0 546816.0 2820096.0 5640192.0 196608.0 +598016.0 335872.0 1794048.0 3588096.0 729088.0 466944.0 2187264.0 4374528.0 0 +598016.0 335872.0 1794048.0 3588096.0 729088.0 466944.0 2187264.0 4374528.0 0 +411648.0 215040.0 1234944.0 2469888.0 444416.0 247808.0 1333248.0 2666496.0 131072.0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +561152.0 299008.0 1683456.0 3366912.0 626688.0 364544.0 1880064.0 3760128.0 0 +692224.0 364544.0 2076672.0 4153344.0 757760.0 430080.0 2273280.0 4546560.0 393216.0 +1234944.0 645120.0 3704832.0 7409664.0 1333248.0 743424.0 3999744.0 7999488.0 0 +1234944.0 645120.0 3704832.0 7409664.0 1333248.0 743424.0 3999744.0 7999488.0 0 +849408.0 431616.0 2548224.0 5096448.0 873984.0 456192.0 2621952.0 5243904.0 245760.0 +842240.0 432640.0 2526720.0 5053440.0 883200.0 473600.0 2649600.0 5299200.0 0 +842240.0 432640.0 2526720.0 5053440.0 883200.0 473600.0 2649600.0 5299200.0 0 +1251840.0 637440.0 3755520.0 7511040.0 1292800.0 678400.0 3878400.0 7756800.0 819200.0 diff --git a/supernet_functions/model_supernet.py b/supernet_functions/model_supernet.py index cebcf6c..718b137 100644 --- a/supernet_functions/model_supernet.py +++ b/supernet_functions/model_supernet.py @@ -191,6 +191,89 @@ def get_flops(self, x, temperature): y = self.last_stages(y) return y, flops_list, params_list +class FBNet_Stochastic_SuperNet_ResNet(nn.Module): + def __init__(self, lookup_table, params_lookup_table, cnt_classes=1000): + super(FBNet_Stochastic_SuperNet_ResNet, self).__init__() + + data_shape = [1, 3, 32, 32] + # data_shape = [1, 3, 224, 224] + x = torch.torch.zeros(data_shape).cuda() + + self.first = ConvBNRelu(input_depth=3, output_depth=64, kernel=7, stride=2, + pad=3, no_bias=1, use_relu="relu", bn_type="bn") + self.first_flops = self.first.get_flops(x, only_flops=True) + self.first_params = sum(p.numel() for p in self.first.parameters() if p.requires_grad) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.stages_to_search = nn.ModuleList([MixedOperation( + lookup_table.layers_parameters[layer_id], + lookup_table.lookup_table_operations, + lookup_table.lookup_table_flops[layer_id], + params_lookup_table.lookup_table_flops[layer_id]) + for layer_id in range(lookup_table.cnt_layers)]) + self.last_stages = nn.Sequential(OrderedDict([ + ("avg_pool_k7", nn.AdaptiveAvgPool2d((1, 1))), + ("flatten", Flatten()), + ("fc", nn.Linear(in_features=512, out_features=cnt_classes)), + ])) + + # conv and fc flops + # fc flops == weights (1280 x 10) + + last_conv_temp = ConvBNRelu(input_depth=lookup_table.layers_parameters[-1][1], output_depth=1280, kernel=1, stride=1, + pad=0, no_bias=1, use_relu="relu", bn_type="bn") + + # if stride change, need change!! + data_shape = [1, lookup_table.layers_parameters[-1][1], 4, 4] + x = torch.torch.zeros(data_shape).cuda() + + self.last_stages_flops = last_conv_temp.get_flops(x, only_flops=True) + nn.Linear(in_features=1280, out_features=cnt_classes).weight.numel() + + + self.last_stages_params = sum(p.numel() for p in self.last_stages.parameters() if p.requires_grad) + del data_shape, x, last_conv_temp + + + def forward(self, x, temperature, flops_to_accumulate, params_to_accumulate, sampling_mode=None): + y = self.first(x) + # add flops from first layer + flops_to_accumulate += self.first.get_flops(x, only_flops=True) + params_to_accumulate = self.first_params + + y = self.maxpool(y) + + for mixed_op in self.stages_to_search: + y, flops_to_accumulate, params_to_accumulate = mixed_op(y, temperature, flops_to_accumulate, params_to_accumulate, sampling_mode) + y = self.last_stages(y) + + # add flops from last stage + flops_to_accumulate += self.last_stages_flops + + params_to_accumulate += self.last_stages_params + + return y, flops_to_accumulate, params_to_accumulate + + # TODO - + def get_flops(self, x, temperature): + flops_list = [] + params_list = [] + + y = self.first(x) + for mixed_op in self.stages_to_search: + y = mixed_op.get_flops(y, temperature) + + for mixed_op in self.stages_to_search: + + flops_list.append(mixed_op.flops) + params_list.append(mixed_op.get_params()) + print('flops', mixed_op.flops) + print('params', mixed_op.get_params()) + + y = self.last_stages(y) + return y, flops_list, params_list + + class SupernetLoss(nn.Module): def __init__(self, alpha, beta, reg_lambda, reg_loss_type, ref_value = 30 * 1e6, apply_flop_loss="False") : super(SupernetLoss, self).__init__() diff --git a/supernet_functions/params_lookup_table_resnet.txt b/supernet_functions/params_lookup_table_resnet.txt new file mode 100644 index 0000000..92656c2 --- /dev/null +++ b/supernet_functions/params_lookup_table_resnet.txt @@ -0,0 +1,9 @@ +r_k3 r_k5 r_k7 +1984 1216 5888 +896 576 2592 +1512 936 4440 +1720 1048 5032 +2528 1504 7456 +2528 1504 7456 +3616 2080 10592 +9152 5056 27200 diff --git a/supernet_functions/training_functions_supernet.py b/supernet_functions/training_functions_supernet.py index f6465ce..043046c 100644 --- a/supernet_functions/training_functions_supernet.py +++ b/supernet_functions/training_functions_supernet.py @@ -134,7 +134,9 @@ def train_loop(self, train_w_loader, train_thetas_loader, test_loader, model, sa info_for_logger="_theta_step_", sampling_mode=sampling_mode[2]) theta_list = [] - for i in range(17): + + for i in range(8): + #for i in range(17): temp_list = self.theta_optimizer.param_groups[0]['params'][i].tolist() theta_list.append(temp_list) diff --git a/supernet_main_file.py b/supernet_main_file.py index 73eb3f7..a70e3f0 100644 --- a/supernet_main_file.py +++ b/supernet_main_file.py @@ -14,7 +14,7 @@ from general_functions.utils import get_logger, weights_init, load, create_directories_from_list, \ check_tensor_in_list, writh_new_ARCH_to_fbnet_modeldef from supernet_functions.lookup_table_builder import LookUpTable -from supernet_functions.model_supernet import FBNet_Stochastic_SuperNet, SupernetLoss +from supernet_functions.model_supernet import FBNet_Stochastic_SuperNet, FBNet_Stochastic_SuperNet_ResNet, SupernetLoss from supernet_functions.training_functions_supernet import TrainerSupernet from supernet_functions.config_for_supernet import CONFIG_SUPERNET from fbnet_building_blocks.fbnet_modeldef import MODEL_ARCH @@ -22,6 +22,7 @@ import fbnet_building_blocks.fbnet_builder as fbnet_builder + parser = argparse.ArgumentParser("action") parser.add_argument('--train_or_sample', type=str, default='', \ help='train means training of the supernet, sample means sample from supernet\'s results') @@ -37,6 +38,8 @@ help='gpu number to use') parser.add_argument('--dataset', type=str, default='cifar10', \ help='using dataset') +parser.add_argument('--supernet_type', type=str, default='mobilenetv2', \ + help='supernet type') # SGD optimizer - weight parser.add_argument('--w_lr', type=float, default=0.1, \ @@ -148,8 +151,8 @@ def train_supernet(): writer = SummaryWriter(log_dir=join(save_path, 'tb')) #### lookup table consists all information about layers - lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path) - params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path) + lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path, supernet_type=args.supernet_type) + params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path, supernet_type=args.supernet_type) #### dataloading train_w_loader, train_thetas_loader = get_loaders(args.data_split, @@ -163,7 +166,10 @@ def train_supernet(): #### model if args.dataset == 'cifar10': - model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() + if args.supernet_type == 'resnet': + model = FBNet_Stochastic_SuperNet_ResNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() + else: + model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() elif args.dataset == 'cifar100': model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=100).cuda() elif args.dataset == 'tiny_imagenet': @@ -188,7 +194,12 @@ def train_supernet(): model = model.apply(weights_init) model = nn.DataParallel(model, device_ids=[0]) - # print(model) + + #print(model) + + #from torchsummary import summary + #summary(model, (3, 32, 32)) + #### loss, optimizer and scheduler criterion = SupernetLoss(reg_loss_type=args.reg_loss_type, alpha=args.alpha, beta=args.beta, reg_lambda=args.reg_lambda, ref_value=args.ref_value).cuda() @@ -221,11 +232,14 @@ def train_supernet(): def sample_architecture_from_the_supernet(unique_name_of_arch, hardsampling=True): logger = get_logger(join(curdir, 'searched_result', args.architecture_name, 'supernet_function_logs', 'logger')) - lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path) - params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path) + lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path, supernet_type=args.supernet_type) + params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path, supernet_type=args.supernet_type) if args.dataset == 'cifar10': - model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() + if args.supernet_type == 'resnet': + model = FBNet_Stochastic_SuperNet_ResNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() + else: + model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=10).cuda() elif args.dataset == 'cifar100': model = FBNet_Stochastic_SuperNet(lookup_table, params_lookup_table, cnt_classes=100).cuda() elif args.dataset == 'tiny_imagenet': @@ -256,15 +270,15 @@ def sample_architecture_from_the_supernet(unique_name_of_arch, hardsampling=True arch_operations.append(ops_names[np.random.choice(rng, p=distribution)]) logger.info("sampled architecture: " + " - ".join(arch_operations)) - writh_new_ARCH_to_fbnet_modeldef(arch_operations, my_unique_name_for_ARCH=unique_name_of_arch) + writh_new_ARCH_to_fbnet_modeldef(arch_operations, my_unique_name_for_ARCH=unique_name_of_arch, supernet_type=args.supernet_type) logger.info("congratulations! new architecture " + unique_name_of_arch \ + " was written into fbnet_building_blocks/fbnet_modeldef.py") def check_flops(): #### lookup table consists all information about layers - lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path) - params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path) + lookup_table = LookUpTable(calulate_latency=False, path=args.flops_LUT_path, supernet_type=args.supernet_type) + params_lookup_table = LookUpTable(calulate_latency=False, path=args.params_LUT_path, supernet_type=args.supernet_type) #### dataloading data_shape = [1, 3, 32, 32] @@ -288,7 +302,9 @@ def check_flops(): params_lookup_table.write_lookup_table_to_file(path_to_file=args.params_LUT_path, flops_list=params_list) - print(params_list) + #print(params_list) + + if __name__ == "__main__": # set gpu number to use diff --git a/yaml/FBNet_DoReFa_int2_120epoch.yaml b/yaml/FBNet_DoReFa_int2_120epoch.yaml new file mode 100644 index 0000000..e3bf9cf --- /dev/null +++ b/yaml/FBNet_DoReFa_int2_120epoch.yaml @@ -0,0 +1,31 @@ + +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 2 + bits_weights: 2 + + + +lr_schedulers: + training_lr: + # class: MultiStepMultiGammaLR + class: CosineAnnealingLR + T_max: 120 + eta_min: 0.001 + #milestones: [60, 120, 180] + #gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 120 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 120 + frequency: 1 + diff --git a/yaml/FBNet_DoReFa_int2_180epoch.yaml b/yaml/FBNet_DoReFa_int2_180epoch.yaml new file mode 100644 index 0000000..9f3fba2 --- /dev/null +++ b/yaml/FBNet_DoReFa_int2_180epoch.yaml @@ -0,0 +1,31 @@ + +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 2 + bits_weights: 2 + + + +lr_schedulers: + training_lr: + # class: MultiStepMultiGammaLR + class: CosineAnnealingLR + T_max: 180 + eta_min: 0.001 + #milestones: [60, 120, 180] + #gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 180 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 180 + frequency: 1 + diff --git a/yaml/FBNet_DoReFa_int2_240epoch.yaml b/yaml/FBNet_DoReFa_int2_240epoch.yaml new file mode 100644 index 0000000..49e359c --- /dev/null +++ b/yaml/FBNet_DoReFa_int2_240epoch.yaml @@ -0,0 +1,31 @@ + +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 2 + bits_weights: 2 + + + +lr_schedulers: + training_lr: + # class: MultiStepMultiGammaLR + class: CosineAnnealingLR + T_max: 240 + eta_min: 0.001 + #milestones: [60, 120, 180] + #gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 240 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 240 + frequency: 1 + diff --git a/yaml/mb_int2_80_10.yaml b/yaml/mb_int2_80_10.yaml new file mode 100644 index 0000000..8f7663c --- /dev/null +++ b/yaml/mb_int2_80_10.yaml @@ -0,0 +1,31 @@ + +quantizers: + dorefa_quantizer: + class: DorefaQuantizer + bits_activations: 2 + bits_weights: 2 + + + +lr_schedulers: + training_lr: + # class: MultiStepMultiGammaLR + class: CosineAnnealingLR + T_max: 80 + eta_min: 0.001 + #milestones: [60, 120, 180] + #gammas: [0.1, 0.1, 0.2] + +policies: + - quantizer: + instance_name: dorefa_quantizer + starting_epoch: 0 + ending_epoch: 80 + frequency: 1 + + - lr_scheduler: + instance_name: training_lr + starting_epoch: 0 + ending_epoch: 80 + frequency: 1 +