diff --git a/README.md b/README.md index 53ec852..04fb5c9 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # POET -By Shishir G. Patil, Paras Jain, Prabal Dutta, Ion Stoica, and Joseph E. Gonzalez ([Project Website](https://poet.cs.berkeley.edu/)) + +By Shishir G. Patil, Paras Jain, Prabal Dutta, Ion Stoica, and Joseph E. Gonzalez ([Project Website](https://poet.cs.berkeley.edu/)) ![](https://github.com/ShishirPatil/poet/blob/gh-pages/assets/img/logo.png) @@ -12,7 +13,6 @@ ResNets on smartphones and tiny ARM Cortex-M devices :muscle: Reach out to us at [sgp@berkeley.edu](mailto:sgp@berkeley.edu), if you have large models that you are trying to train - be it on GPUs, or your commodity edge devices such as laptops, smartphones, raspberry-pis, ARM Cortex M and A class, fitbits, etc. - ## Get Started [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iup_edJd9zB1tfVBHXLmkWOT5yoSmXzz?usp=sharing) ### Installation @@ -35,13 +35,13 @@ If you are affiliated with an academic institution, you can acquire a free Gurob 1. Create a free Gurobi account [here](https://pages.gurobi.com/registration). Make sure to specify the Academic user option. - image +image 2. Complete the rest of the Gurobi account creation process, which will include creating a password and verifying your email address. 3. Login to the Gurobi [Web License Manager](https://license.gurobi.com/) using your new account. 4. Create and download a new Web License file. It will be called `gurobi.lic`. - image +image 5. Move the `gurobi.lic` file to your home directory (i.e. to `~/gurobi.lic` on MacOS/Linux, or `C:\Users\YOUR_USERNAME\gurobi.lic` on Windows). @@ -71,10 +71,10 @@ solve( ) ``` - ## Key ideas From our [paper at ICML 2022](https://arxiv.org/abs/2207.07697): + ```text In this work, we show that paging and rematerialization are highly complementary. By carefully rematerializing cheap operations while paging results of expensive operations diff --git a/poet-server/README.md b/poet-server/README.md index dba92c2..37a5164 100644 --- a/poet-server/README.md +++ b/poet-server/README.md @@ -41,7 +41,6 @@ docker run -p 80:80 -v ~/gurobi.lic:/opt/gurobi/gurobi.lic public.ecr.aws/shishi Or, you can build the docker container yourself following the steps below. - 1. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed. 2. Clone this repository by running `git clone https://github.com/ShishirPatil/poet`. 3. If using Gurobi, move the `gurobi.lic` file you downloaded in the previous step to the `poet-server` directory of this repository (i.e. to `poet-server/gurobi.lic`). @@ -53,7 +52,6 @@ Or, you can build the docker container yourself following the steps below. Ensure that you have moved the `gurobi.lic` file (if you want to use the Gurobi optimizer) you downloaded earlier to the EC2 instance. Ensure that Port 80 is open for ingress traffic. - ## Making Requests To issue requests to the POET server, you can use the following Python code. Here, we use the demo POET-server hosted at IP `35.184.186.64`: @@ -70,6 +68,3 @@ response = requests.get("http://35.184.186.64/solve", { print(response.json()) ``` - - - diff --git a/poet/pareto.py b/poet/pareto.py new file mode 100644 index 0000000..45304e2 --- /dev/null +++ b/poet/pareto.py @@ -0,0 +1,104 @@ +import os +from concurrent.futures import ProcessPoolExecutor +from typing import Literal + +import matplotlib.pyplot as plt +import numpy as np + +from poet import solve +from poet.util import get_chipset_and_net, get_net_costs, make_dfgraph_costs + + +def solve_wrapper(params): + return solve(**params) + + +def pareto( + model: Literal[ + "linear", + "vgg16", + "vgg16_cifar", + "resnet18", + "resnet50", + "resnet18_cifar", + "bert", + "transformer", + ], + platform: Literal["m0", "a72", "a72nocache", "m4", "jetsontx2"], + # ram_budget: float, + runtime_budget: float = 1.4, + mem_power_scale=1.0, + batch_size=1, + ram_budget_samples: int = 20, + solver: Literal["gurobipy", "pulp-gurobi", "pulp-cbc"] = "gurobipy", + time_limit_s: float = 1e100, + solve_threads: int = 4, + total_threads: int = os.cpu_count(), + filename: str = "pareto.png", +): + plt.ion() + + chipset, net, ram_budget_start, ram_budget_end = get_chipset_and_net( + platform=platform, + model=model, + batch_size=batch_size, + mem_power_scale=mem_power_scale, + ) + + base_memory = max(get_net_costs(net=net, device=chipset)["memory_bytes"]) + print( + base_memory, + ram_budget_start, + ram_budget_end, + ram_budget_start / base_memory, + ram_budget_end / base_memory, + ) + + ram_budget_range = np.linspace(ram_budget_start, ram_budget_end, ram_budget_samples) + + g, *_ = make_dfgraph_costs(net=net, device=chipset) + total_runtime = sum(g.cost_cpu.values()) + total_ram = sum(g.cost_ram[i] for i in g.vfwd) + print(f"Total runtime of graph (forward + backward) = {total_runtime}") + print(f"Total RAM consumption of forward pass = {total_ram}") + print(f"### --- ### Total RAM consumption of forward pass = {total_ram}") + print(total_threads // solve_threads) + + with ProcessPoolExecutor(max_workers=total_threads // solve_threads) as executor: + for result in executor.map( + solve_wrapper, + [ + dict( + model=model, + platform=platform, + ram_budget=ram_budget, + runtime_budget=runtime_budget, + mem_power_scale=mem_power_scale, + batch_size=batch_size, + solver=solver, + time_limit_s=time_limit_s, + solve_threads=solve_threads, + ) + for ram_budget in ram_budget_range + ], + ): + print( + result.total_power_cost_cpu, + result.total_power_cost_page, + result.ram_budget, + ) + plt.plot( + result.ram_budget, + -1 if result.total_power_cost_cpu is None else result.total_power_cost_cpu + result.total_power_cost_page, + "r.", + ) + plt.draw() + plt.pause(0.1) + + print("Done!") + plt.savefig(filename) + plt.show(block=True) + + +if __name__ == "__main__": + pareto(model="resnet18_cifar", platform="m4", runtime_budget=1.2, time_limit_s=600) diff --git a/poet/solve.py b/poet/solve.py index 7a74c7a..0b02fea 100644 --- a/poet/solve.py +++ b/poet/solve.py @@ -2,12 +2,18 @@ from typing import Literal, Optional import numpy as np +from gurobipy import GRB, GurobiError from poet import solve from poet.poet_solver import POETSolver from poet.poet_solver_gurobi import POETSolverGurobi -from poet.util import get_chipset_and_net, make_dfgraph_costs, plot_dfgraph, print_result, POETResult -from gurobipy import GRB, GurobiError +from poet.util import ( + POETResult, + get_chipset_and_net, + make_dfgraph_costs, + plot_dfgraph, + print_result, +) def solve( @@ -48,7 +54,7 @@ def solve( :param time_limit_s: The time limit for solving in seconds. :param solve_threads: The number of threads to use for solving. """ - chipset, net = get_chipset_and_net( + chipset, net, *_ = get_chipset_and_net( platform=platform, model=model, batch_size=batch_size, @@ -142,9 +148,23 @@ def solve( "--model", type=str, required=True, - choices=["vgg16", "vgg16_cifar", "resnet18", "resnet50", "resnet18_cifar", "bert", "transformer", "linear"], + choices=[ + "vgg16", + "vgg16_cifar", + "resnet18", + "resnet50", + "resnet18_cifar", + "bert", + "transformer", + "linear", + ], + ) + parser.add_argument( + "--platform", + type=str, + required=True, + choices=["m0", "a72", "a72nocache", "m4", "jetsontx2"], ) - parser.add_argument("--platform", type=str, required=True, choices=["m0", "a72", "a72nocache", "m4", "jetsontx2"]) parser.add_argument("--ram-budget", type=int, required=True) parser.add_argument("--runtime-budget", type=float, required=True) parser.add_argument("--batch-size", type=int, default=1) @@ -153,7 +173,12 @@ def solve( parser.add_argument("--remat", action="store_true", default=True) parser.add_argument("--time-limit-s", type=int, default=1e100) parser.add_argument("--solve-threads", type=int, default=4) - parser.add_argument("--solver", type=str, default="gurobipy", choices=["gurobipy", "pulp-gurobi", "pulp-cbc"]) + parser.add_argument( + "--solver", + type=str, + default="gurobipy", + choices=["gurobipy", "pulp-gurobi", "pulp-cbc"], + ) parser.add_argument("--use-actual-gurobi", action="store_true", default=False) parser.add_argument("--print-power-costs", action="store_true", default=False) parser.add_argument("--print-graph-info", action="store_true", default=True) diff --git a/poet/util.py b/poet/util.py index acaabc4..a83d6e4 100644 --- a/poet/util.py +++ b/poet/util.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass import pickle +from dataclasses import dataclass from pathlib import Path from typing import List @@ -52,7 +52,12 @@ def make_dfgraph_costs(net, device): for idx, (layer, specs) in enumerate(zip(net, per_layer_specs)): layer_name = "layer{}_{}".format(idx, layer.__class__.__name__) layer_names[layer] = layer_name - gb.add_node(layer_name, cpu_cost=specs["runtime_ms"], ram_cost=specs["memory_bytes"], backward=isinstance(layer, GradientLayer)) + gb.add_node( + layer_name, + cpu_cost=specs["runtime_ms"], + ram_cost=specs["memory_bytes"], + backward=isinstance(layer, GradientLayer), + ) gb.set_parameter_cost(gb.parameter_cost + specs["param_memory_bytes"]) page_in_cost_dict[layer_name] = specs["pagein_cost_joules"] page_out_cost_dict[layer_name] = specs["pageout_cost_joules"] @@ -91,36 +96,61 @@ def get_chipset_and_net(platform: str, model: str, batch_size: int, mem_power_sc elif platform == "jetsontx2": chipset = JetsonTX2 else: - raise NotImplementedError() + raise NotImplementedError(f"Platform {platform} not implemented.") chipset["MEMORY_POWER"] *= mem_power_scale if model == "linear": net = make_linear_network() + # TODO: these were randomly picked + ram_budget_start = 1.0e07 + ram_budget_end = 1.0e08 elif model == "vgg16": net = vgg16(batch_size) + ram_budget_start = 2.57e07 + ram_budget_end = 1.15e08 elif model == "vgg16_cifar": net = vgg16(batch_size, 10, (3, 32, 32)) + ram_budget_start = 2.57e07 / 49 + ram_budget_end = 1.15e08 / 49 elif model == "resnet18": net = resnet18(batch_size) + ram_budget_start = 6.42e06 + ram_budget_end = 2.85e07 elif model == "resnet50": net = resnet50(batch_size) + ram_budget_start = 6.97e06 + ram_budget_end = 1.27e08 elif model == "resnet18_cifar": net = resnet18_cifar(batch_size, 10, (3, 32, 32)) + ram_budget_start = 196608 + ram_budget_end = 2339408 elif model == "bert": net = BERTBase(SEQ_LEN=512, HIDDEN_DIM=768, I=64, HEADS=12, NUM_TRANSFORMER_BLOCKS=12) + # TODO: this is very broken + ram_budget_start = 1e6 + ram_budget_end = 1e9 elif model == "transformer": net = BERTBase(SEQ_LEN=512, HIDDEN_DIM=768, I=64, HEADS=12, NUM_TRANSFORMER_BLOCKS=1) + ram_budget_start = 1e5 + ram_budget_end = 7e7 else: - raise NotImplementedError() + raise NotImplementedError(f"Model {model} not implemented.") - return chipset, net + return chipset, net, ram_budget_start, ram_budget_end def plot_network( - platform: str, model: str, directory: str, batch_size: int = 1, mem_power_scale: float = 1.0, format="pdf", quiet=True, name="" + platform: str, + model: str, + directory: str, + batch_size: int = 1, + mem_power_scale: float = 1.0, + format="pdf", + quiet=True, + name="", ): - chipset, net = get_chipset_and_net(platform, model, batch_size, mem_power_scale) + chipset, net, *_ = get_chipset_and_net(platform, model, batch_size, mem_power_scale) g, *_ = make_dfgraph_costs(net, chipset) plot_dfgraph(g, directory, format, quiet, name) diff --git a/poet/utils/checkmate/core/enum_strategy.py b/poet/utils/checkmate/core/enum_strategy.py new file mode 100644 index 0000000..1b80206 --- /dev/null +++ b/poet/utils/checkmate/core/enum_strategy.py @@ -0,0 +1,81 @@ +from enum import Enum + + +class SolveStrategy(Enum): + NOT_SPECIFIED = "NOT_SPECIFIED" + CHEN_SQRTN = "CHEN_SQRTN" + CHEN_GREEDY = "CHEN_GREEDY" + CHEN_SQRTN_NOAP = "CHEN_SQRTN_NOAP" + CHEN_GREEDY_NOAP = "CHEN_GREEDY_NOAP" + OPTIMAL_ILP_GC = "OPTIMAL_ILP_GC" + CHECKPOINT_LAST_NODE = "CHECKPOINT_LAST_NODE" + CHECKPOINT_ALL = "CHECKPOINT_ALL" + CHECKPOINT_ALL_AP = "CHECKPOINT_ALL_AP" + GRIEWANK_LOGN = "GRIEWANK_LOGN" + APPROX_DET_ROUND_LP_SWEEP = "APPROX_DET_ROUND_LP_SWEEP" + APPROX_DET_ROUND_LP_05_THRESH = "APPROX_DET_ROUND_LP_05_THRESH" + APPROX_DET_RANDOM_THRESH_ROUND_LP = "APPROX_DET_RANDOM_THRESH_ROUND_LP" + APPROX_RANDOMIZED_ROUND = "APPROX_RANDOMIZED_ROUND" + LB_LP = "LB_LP" + SIMRD = "SIMRD" + SIMRD_MSPS = "SIMRD_MSPS" + + @classmethod + def get_description(cls, val, model_name=None): + is_linear = model_name in ("VGG16", "VGG19", "MobileNet") + return { + cls.CHEN_SQRTN: "AP $\\sqrt{n}$", + cls.CHEN_GREEDY: "AP greedy", + cls.CHEN_SQRTN_NOAP: "Generalized $\\sqrt{n}$" if not is_linear else "Chen et al. $\\sqrt{n}$", + cls.CHEN_GREEDY_NOAP: "Generalized greedy", + cls.OPTIMAL_ILP_GC: "Optimal MILP (proposed)", + cls.CHECKPOINT_LAST_NODE: "Checkpoint last node", + cls.CHECKPOINT_ALL: "Checkpoint all (ideal)", + cls.CHECKPOINT_ALL_AP: "Checkpoint all APs", + cls.GRIEWANK_LOGN: "Griewank et al. $\\log~n$" if is_linear else "AP $\\log~n$", + cls.APPROX_DET_ROUND_LP_SWEEP: "Approximation via deterministic rounding of LP relaxation w/ threshold sweep", + cls.APPROX_DET_RANDOM_THRESH_ROUND_LP: "Approximation via deterministic rounding of LP relaxation with random thresholds", + cls.APPROX_DET_ROUND_LP_05_THRESH: "Approximation via deterministic rounding of LP relaxation w/ 0.5 threshold", + cls.APPROX_RANDOMIZED_ROUND: "Approximation via randomized rounding of LP relaxation", + cls.LB_LP: "Lower bound via LP relaxation", + cls.SIMRD: "Dynamic Tensor Rematerialization", + cls.SIMRD_MSPS: "Capuchin MSPS heuristic from DTR", + }[val] + + # todo move this to experiments codebase + @classmethod + def get_plot_params(cls, val): + from matplotlib import rcParams + + fullsize = rcParams["lines.markersize"] + halfsize = fullsize / 2 + bigger = fullsize * 1.5 + mapping = { + cls.CHEN_SQRTN: ("c", "D", halfsize), + cls.CHEN_SQRTN_NOAP: ("c", "^", halfsize), + cls.CHEN_GREEDY: ("g", ".", fullsize), + cls.CHEN_GREEDY_NOAP: ("g", "+", fullsize), + cls.CHECKPOINT_ALL: ("k", "*", bigger), + cls.CHECKPOINT_ALL_AP: ("b", "x", fullsize), + cls.GRIEWANK_LOGN: ("m", "p", fullsize), + cls.OPTIMAL_ILP_GC: ("r", "s", halfsize), + cls.APPROX_DET_ROUND_LP_SWEEP: ("r", "*", fullsize), + cls.APPROX_DET_ROUND_LP_05_THRESH: ("r", "^", halfsize), + cls.APPROX_DET_RANDOM_THRESH_ROUND_LP: ("r", "x", fullsize), + cls.APPROX_RANDOMIZED_ROUND: ("r", "+", fullsize), + cls.LB_LP: ("r", "p", fullsize), + cls.SIMRD: ("r", ".", fullsize), + cls.SIMRD_MSPS: ("m", ".", fullsize), + } + if val in mapping: + return mapping[val] + raise NotImplementedError("No plotting parameters for strategy {}".format(val)) + + +class ImposedSchedule(Enum): + COVER_LAST_NODE = "COVER_LAST_NODE" + COVER_ALL_NODES = "COVER_ALL_NODES" + FULL_SCHEDULE = "FULL_SCHEDULE" + + def __str__(self): + return self.value