Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into macsz/plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Maciej Szankin committed Feb 5, 2024
2 parents d6d9d22 + 1b9dc9f commit c23fb25
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 144 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.10"]
python-version: ["3.8", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,5 @@ OMP_NUM_THREADS=28 mpirun \

## Legal Disclaimer and Notices

> This “research quality code” is for Non-Commercial purposes provided by Intel “As Is” without any express or implied warranty of any kind. Please see the dataset's applicable license for terms and conditions. Intel does not own the rights to this data set and does not confer any rights to it. Intel does not warrant or assume responsibility for the accuracy or completeness of any information, text, graphics, links or other items within the code. A thorough security review has not been performed on this code. Additionally, this repository may contain components that are out of date or contain known security vulnerabilities.
> This “research quality code” is for Non-Commercial purposes provided by Intel “As Is” without any express or implied warranty of any kind. Please see the dataset's applicable license for terms and conditions. Intel does not own the rights to this data set and does not confer any rights to it. Intel does not warrant or assume responsibility for the accuracy or completeness of any information, text, graphics, links or other items within the code. A thorough security review has not been performed on this code.
> ImageNet, WMT, SST2: Please see the dataset's applicable license for terms and conditions. Intel does not own the rights to this data set and does not confer any rights to it.
290 changes: 149 additions & 141 deletions dynast/search/search_tactic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pandas as pd
import torch.distributed as dist
from pymoo.core.problem import Problem

from dynast.predictors.predictor_manager import PredictorManager
from dynast.search.evolutionary import (
Expand Down Expand Up @@ -264,6 +265,154 @@ def get_best_configs(self, sort_by: str = None, ascending: bool = False, limit:
return df


class Evolutionary(NASBaseConfig):
search_manager: EvolutionaryManager = None
problem: Problem = None

def __init__(
self,
supernet: str,
optimization_metrics: list,
measurements: list,
num_evals: int,
results_path: str,
dataset_path: str = None,
verbose: bool = False,
search_algo: str = 'nsga2',
population: int = 50,
seed: int = 42,
batch_size: int = 128,
eval_batch_size: int = 128,
supernet_ckpt_path: str = None,
device: str = 'cpu',
test_fraction: float = 1.0,
mp_calibration_samples: int = 100,
dataloader_workers: int = 4,
metric_eval_fns: dict = None,
**kwargs,
):
super().__init__(
dataset_path=dataset_path,
supernet=supernet,
optimization_metrics=optimization_metrics,
measurements=measurements,
num_evals=num_evals,
results_path=results_path,
seed=seed,
population=population,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
verbose=verbose,
search_algo=search_algo,
supernet_ckpt_path=supernet_ckpt_path,
device=device,
test_fraction=test_fraction,
mp_calibration_samples=mp_calibration_samples,
dataloader_workers=dataloader_workers,
metric_eval_fns=metric_eval_fns,
**kwargs,
)

def _init_evolutionary_manager(self):
# Following sets up the algorithm based on number of objectives
# Could be refractored at the expense of readability
if self.num_objectives == 1:
self.problem = EvolutionarySingleObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'cmaes':
self.search_manager = EvolutionaryManager(
algorithm='cmaes',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_cmaes(num_evals=self.num_evals)
else:
self.search_manager = EvolutionaryManager(
algorithm='ga',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_ga(population=self.population, num_evals=self.num_evals)
elif self.num_objectives == 2:
self.problem = EvolutionaryMultiObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'age':
self.search_manager = EvolutionaryManager(
algorithm='age',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_age(population=self.population, num_evals=self.num_evals)
else:
self.search_manager = EvolutionaryManager(
algorithm='nsga2',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_nsga2(population=self.population, num_evals=self.num_evals)
elif self.num_objectives == 3:
self.problem = EvolutionaryManyObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'ctaea':
self.search_manager = EvolutionaryManager(
algorithm='ctaea',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_ctaea(num_evals=self.num_evals)
elif self.search_algo == 'moead':
self.search_manager = EvolutionaryManager(
algorithm='moead',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_moead(num_evals=self.num_evals)
else:
self.search_manager = EvolutionaryManager(
algorithm='unsga3',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
self.search_manager.configure_unsga3(population=self.population, num_evals=self.num_evals)
else:
log.error('Number of objectives not supported. Update optimization_metrics!')

def search(self):
self._init_search()
self._init_evolutionary_manager()

results = self.search_manager.run_search(self.problem)

latest_population = results.pop.get('X')

log.info("Validated model architectures in file: {}".format(self.results_path))

output = list()
for individual in latest_population:
param_individual = self.supernet_manager.translate2param(individual)
if 'bootstrapnas' in self.supernet:
param_individual = BootstrapNASEncoding.convert_subnet_config_to_bootstrapnas(param_individual)
output.append(param_individual)

return output


class LINAS(NASBaseConfig):
"""The LINAS algorithm is a bi-objective optimization approach that explores the sub-networks
optimization space by iteratively training predictors and using evolutionary algorithms to
Expand Down Expand Up @@ -579,147 +728,6 @@ def search(self):
return output


class Evolutionary(NASBaseConfig):
def __init__(
self,
supernet,
optimization_metrics,
measurements,
num_evals,
results_path,
dataset_path: str = None,
seed=42,
population=50,
batch_size: int = 128,
eval_batch_size: int = 128,
verbose=False,
search_algo='nsga2',
supernet_ckpt_path=None,
test_fraction: float = 1.0,
mp_calibration_samples: int = 100,
dataloader_workers: int = 4,
device: str = 'cpu',
**kwargs,
):
super().__init__(
dataset_path=dataset_path,
supernet=supernet,
optimization_metrics=optimization_metrics,
measurements=measurements,
num_evals=num_evals,
results_path=results_path,
seed=seed,
population=population,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
verbose=verbose,
search_algo=search_algo,
supernet_ckpt_path=supernet_ckpt_path,
device=device,
test_fraction=test_fraction,
mp_calibration_samples=mp_calibration_samples,
dataloader_workers=dataloader_workers,
**kwargs,
)

def search(self):
self._init_search()

# Following sets up the algorithm based on number of objectives
# Could be refractored at the expense of readability
if self.num_objectives == 1:
problem = EvolutionarySingleObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'cmaes':
search_manager = EvolutionaryManager(
algorithm='cmaes',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_cmaes(num_evals=self.num_evals)
else:
search_manager = EvolutionaryManager(
algorithm='ga',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_ga(population=self.population, num_evals=self.num_evals)
elif self.num_objectives == 2:
problem = EvolutionaryMultiObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'age':
search_manager = EvolutionaryManager(
algorithm='age',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_age(population=self.population, num_evals=self.num_evals)
else:
search_manager = EvolutionaryManager(
algorithm='nsga2',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_nsga2(population=self.population, num_evals=self.num_evals)
elif self.num_objectives == 3:
problem = EvolutionaryManyObjective(
evaluation_interface=self.validation_interface,
param_count=self.supernet_manager.param_count,
param_upperbound=self.supernet_manager.param_upperbound,
)
if self.search_algo == 'ctaea':
search_manager = EvolutionaryManager(
algorithm='ctaea',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_ctaea(num_evals=self.num_evals)
elif self.search_algo == 'moead':
search_manager = EvolutionaryManager(
algorithm='moead',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_moead(num_evals=self.num_evals)
else:
search_manager = EvolutionaryManager(
algorithm='unsga3',
seed=self.seed,
n_obj=self.num_objectives,
verbose=self.verbose,
)
search_manager.configure_unsga3(population=self.population, num_evals=self.num_evals)
else:
log.error('Number of objectives not supported. Update optimization_metrics!')

results = search_manager.run_search(problem)

latest_population = results.pop.get('X')

log.info("Validated model architectures in file: {}".format(self.results_path))

output = list()
for individual in latest_population:
param_individual = self.supernet_manager.translate2param(individual)
if 'bootstrapnas' in self.supernet:
param_individual = BootstrapNASEncoding.convert_subnet_config_to_bootstrapnas(param_individual)
output.append(param_individual)

return output


class RandomSearch(NASBaseConfig):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def validate_top1(self, subnet_cfg, device=None) -> float:
self.run_config,
init=False,
verbose=self.verbose,
no_gpu=False if 'cuda' in self.device else True,
)
run_manager.reset_running_statistics(net=subnet)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def validate_quantized_top1(self, subnet_cfg, device=None) -> float:
self.run_config,
init=False,
verbose=self.verbose,
no_gpu=False if 'cuda' in self.device else True,
)
run_manager.reset_running_statistics(net=subnet)

Expand Down
4 changes: 3 additions & 1 deletion tests/scripts/config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ SEED=37
RESULTS_PATH="/tmp"
DATASET_IMAGENET_PATH="/datasets/imagenet-ilsvrc2012/"
DATASET_CIFAR10_PATH="/tmp/cifar10/"
DEVICE="cpu"
DATASET_SST2_PATH="/nfs/site/home/mszankin/store/nosnap/datasets/SST-2"
BATCH_SIZE=128
TEST_FRACTION=1.0

CHECKPOINT_VIT_BASE_IMAGENET_PATH="/tmp/vit/checkpoint.pth.tar"

CHECKPOINT_BERT_BASE_SST2_PATH="/nfs/site/home/mszankin/store/nosnap/models/glue_ckpt.pt"

########################################################################################################
# SHORT runs config. Shoud use random search tactic to allow for a very limited number of evaluations. #
Expand Down
18 changes: 18 additions & 0 deletions tests/scripts/run_bert_base_sst2_evolutionary_long.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env bash

source $( dirname -- "$0"; )/config.sh

time ${RUN_COMMAND} \
--results_path ${RESULTS_PATH}/results_bert_base_sst2_evolutionary_long.csv \
--supernet bert_base_sst2 \
--supernet_ckpt_path ${CHECKPOINT_BERT_BASE_SST2_PATH} \
--dataset_path ${DATASET_SST2_PATH} \
--search_tactic evolutionary \
--population ${LONG_LINAS_POPULATION} \
--batch_size ${BATCH_SIZE} \
--seed ${SEED} \
--measurements macs accuracy_sst2 latency params \
--optimization_metrics macs accuracy_sst2 \
--num_evals ${LONG_LINAS_NUM_EVALS} \
--device ${DEVICE} \
--test_fraction ${TEST_FRACTION}
16 changes: 16 additions & 0 deletions tests/scripts/run_ofambv3_evolutionary_long.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env bash

source $( dirname -- "$0"; )/config.sh

time ${RUN_COMMAND} \
--results_path ${RESULTS_PATH}/results_ofambv3_evolutionary_long.csv \
--supernet ofa_mbv3_d234_e346_k357_w1.0 \
--dataset_path ${DATASET_IMAGENET_PATH} \
--search_tactic evolutionary \
--population ${LONG_LINAS_POPULATION} \
--batch_size ${BATCH_SIZE} \
--seed ${SEED} \
--measurements macs accuracy_top1 \
--num_evals ${LONG_LINAS_NUM_EVALS} \
--device ${DEVICE} \
--test_fraction ${TEST_FRACTION}

0 comments on commit c23fb25

Please sign in to comment.