diff --git a/README.md b/README.md
index 53ec852..04fb5c9 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
# POET
-By Shishir G. Patil, Paras Jain, Prabal Dutta, Ion Stoica, and Joseph E. Gonzalez ([Project Website](https://poet.cs.berkeley.edu/))
+
+By Shishir G. Patil, Paras Jain, Prabal Dutta, Ion Stoica, and Joseph E. Gonzalez ([Project Website](https://poet.cs.berkeley.edu/))
![](https://github.com/ShishirPatil/poet/blob/gh-pages/assets/img/logo.png)
@@ -12,7 +13,6 @@ ResNets on smartphones and tiny ARM Cortex-M devices :muscle:
Reach out to us at [sgp@berkeley.edu](mailto:sgp@berkeley.edu), if you have large models that you are trying to train - be it on GPUs, or your commodity edge devices such as laptops, smartphones, raspberry-pis, ARM Cortex M and A class, fitbits, etc.
-
## Get Started [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iup_edJd9zB1tfVBHXLmkWOT5yoSmXzz?usp=sharing)
### Installation
@@ -35,13 +35,13 @@ If you are affiliated with an academic institution, you can acquire a free Gurob
1. Create a free Gurobi account [here](https://pages.gurobi.com/registration). Make sure to specify the Academic user option.
-
+
2. Complete the rest of the Gurobi account creation process, which will include creating a password and verifying your email address.
3. Login to the Gurobi [Web License Manager](https://license.gurobi.com/) using your new account.
4. Create and download a new Web License file. It will be called `gurobi.lic`.
-
+
5. Move the `gurobi.lic` file to your home directory (i.e. to `~/gurobi.lic` on MacOS/Linux, or `C:\Users\YOUR_USERNAME\gurobi.lic` on Windows).
@@ -71,10 +71,10 @@ solve(
)
```
-
## Key ideas
From our [paper at ICML 2022](https://arxiv.org/abs/2207.07697):
+
```text
In this work, we show that paging and rematerialization are highly complementary.
By carefully rematerializing cheap operations while paging results of expensive operations
diff --git a/poet-server/README.md b/poet-server/README.md
index dba92c2..37a5164 100644
--- a/poet-server/README.md
+++ b/poet-server/README.md
@@ -41,7 +41,6 @@ docker run -p 80:80 -v ~/gurobi.lic:/opt/gurobi/gurobi.lic public.ecr.aws/shishi
Or, you can build the docker container yourself following the steps below.
-
1. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed.
2. Clone this repository by running `git clone https://github.com/ShishirPatil/poet`.
3. If using Gurobi, move the `gurobi.lic` file you downloaded in the previous step to the `poet-server` directory of this repository (i.e. to `poet-server/gurobi.lic`).
@@ -53,7 +52,6 @@ Or, you can build the docker container yourself following the steps below.
Ensure that you have moved the `gurobi.lic` file (if you want to use the Gurobi optimizer) you downloaded earlier to the EC2 instance. Ensure that Port 80 is open for ingress traffic.
-
## Making Requests
To issue requests to the POET server, you can use the following Python code. Here, we use the demo POET-server hosted at IP `35.184.186.64`:
@@ -70,6 +68,3 @@ response = requests.get("http://35.184.186.64/solve", {
print(response.json())
```
-
-
-
diff --git a/poet/pareto.py b/poet/pareto.py
new file mode 100644
index 0000000..45304e2
--- /dev/null
+++ b/poet/pareto.py
@@ -0,0 +1,104 @@
+import os
+from concurrent.futures import ProcessPoolExecutor
+from typing import Literal
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from poet import solve
+from poet.util import get_chipset_and_net, get_net_costs, make_dfgraph_costs
+
+
+def solve_wrapper(params):
+ return solve(**params)
+
+
+def pareto(
+ model: Literal[
+ "linear",
+ "vgg16",
+ "vgg16_cifar",
+ "resnet18",
+ "resnet50",
+ "resnet18_cifar",
+ "bert",
+ "transformer",
+ ],
+ platform: Literal["m0", "a72", "a72nocache", "m4", "jetsontx2"],
+ # ram_budget: float,
+ runtime_budget: float = 1.4,
+ mem_power_scale=1.0,
+ batch_size=1,
+ ram_budget_samples: int = 20,
+ solver: Literal["gurobipy", "pulp-gurobi", "pulp-cbc"] = "gurobipy",
+ time_limit_s: float = 1e100,
+ solve_threads: int = 4,
+ total_threads: int = os.cpu_count(),
+ filename: str = "pareto.png",
+):
+ plt.ion()
+
+ chipset, net, ram_budget_start, ram_budget_end = get_chipset_and_net(
+ platform=platform,
+ model=model,
+ batch_size=batch_size,
+ mem_power_scale=mem_power_scale,
+ )
+
+ base_memory = max(get_net_costs(net=net, device=chipset)["memory_bytes"])
+ print(
+ base_memory,
+ ram_budget_start,
+ ram_budget_end,
+ ram_budget_start / base_memory,
+ ram_budget_end / base_memory,
+ )
+
+ ram_budget_range = np.linspace(ram_budget_start, ram_budget_end, ram_budget_samples)
+
+ g, *_ = make_dfgraph_costs(net=net, device=chipset)
+ total_runtime = sum(g.cost_cpu.values())
+ total_ram = sum(g.cost_ram[i] for i in g.vfwd)
+ print(f"Total runtime of graph (forward + backward) = {total_runtime}")
+ print(f"Total RAM consumption of forward pass = {total_ram}")
+ print(f"### --- ### Total RAM consumption of forward pass = {total_ram}")
+ print(total_threads // solve_threads)
+
+ with ProcessPoolExecutor(max_workers=total_threads // solve_threads) as executor:
+ for result in executor.map(
+ solve_wrapper,
+ [
+ dict(
+ model=model,
+ platform=platform,
+ ram_budget=ram_budget,
+ runtime_budget=runtime_budget,
+ mem_power_scale=mem_power_scale,
+ batch_size=batch_size,
+ solver=solver,
+ time_limit_s=time_limit_s,
+ solve_threads=solve_threads,
+ )
+ for ram_budget in ram_budget_range
+ ],
+ ):
+ print(
+ result.total_power_cost_cpu,
+ result.total_power_cost_page,
+ result.ram_budget,
+ )
+ plt.plot(
+ result.ram_budget,
+ -1 if result.total_power_cost_cpu is None else result.total_power_cost_cpu + result.total_power_cost_page,
+ "r.",
+ )
+ plt.draw()
+ plt.pause(0.1)
+
+ print("Done!")
+ plt.savefig(filename)
+ plt.show(block=True)
+
+
+if __name__ == "__main__":
+ pareto(model="resnet18_cifar", platform="m4", runtime_budget=1.2, time_limit_s=600)
diff --git a/poet/solve.py b/poet/solve.py
index 7a74c7a..0b02fea 100644
--- a/poet/solve.py
+++ b/poet/solve.py
@@ -2,12 +2,18 @@
from typing import Literal, Optional
import numpy as np
+from gurobipy import GRB, GurobiError
from poet import solve
from poet.poet_solver import POETSolver
from poet.poet_solver_gurobi import POETSolverGurobi
-from poet.util import get_chipset_and_net, make_dfgraph_costs, plot_dfgraph, print_result, POETResult
-from gurobipy import GRB, GurobiError
+from poet.util import (
+ POETResult,
+ get_chipset_and_net,
+ make_dfgraph_costs,
+ plot_dfgraph,
+ print_result,
+)
def solve(
@@ -48,7 +54,7 @@ def solve(
:param time_limit_s: The time limit for solving in seconds.
:param solve_threads: The number of threads to use for solving.
"""
- chipset, net = get_chipset_and_net(
+ chipset, net, *_ = get_chipset_and_net(
platform=platform,
model=model,
batch_size=batch_size,
@@ -142,9 +148,23 @@ def solve(
"--model",
type=str,
required=True,
- choices=["vgg16", "vgg16_cifar", "resnet18", "resnet50", "resnet18_cifar", "bert", "transformer", "linear"],
+ choices=[
+ "vgg16",
+ "vgg16_cifar",
+ "resnet18",
+ "resnet50",
+ "resnet18_cifar",
+ "bert",
+ "transformer",
+ "linear",
+ ],
+ )
+ parser.add_argument(
+ "--platform",
+ type=str,
+ required=True,
+ choices=["m0", "a72", "a72nocache", "m4", "jetsontx2"],
)
- parser.add_argument("--platform", type=str, required=True, choices=["m0", "a72", "a72nocache", "m4", "jetsontx2"])
parser.add_argument("--ram-budget", type=int, required=True)
parser.add_argument("--runtime-budget", type=float, required=True)
parser.add_argument("--batch-size", type=int, default=1)
@@ -153,7 +173,12 @@ def solve(
parser.add_argument("--remat", action="store_true", default=True)
parser.add_argument("--time-limit-s", type=int, default=1e100)
parser.add_argument("--solve-threads", type=int, default=4)
- parser.add_argument("--solver", type=str, default="gurobipy", choices=["gurobipy", "pulp-gurobi", "pulp-cbc"])
+ parser.add_argument(
+ "--solver",
+ type=str,
+ default="gurobipy",
+ choices=["gurobipy", "pulp-gurobi", "pulp-cbc"],
+ )
parser.add_argument("--use-actual-gurobi", action="store_true", default=False)
parser.add_argument("--print-power-costs", action="store_true", default=False)
parser.add_argument("--print-graph-info", action="store_true", default=True)
diff --git a/poet/util.py b/poet/util.py
index acaabc4..a83d6e4 100644
--- a/poet/util.py
+++ b/poet/util.py
@@ -1,5 +1,5 @@
-from dataclasses import dataclass
import pickle
+from dataclasses import dataclass
from pathlib import Path
from typing import List
@@ -52,7 +52,12 @@ def make_dfgraph_costs(net, device):
for idx, (layer, specs) in enumerate(zip(net, per_layer_specs)):
layer_name = "layer{}_{}".format(idx, layer.__class__.__name__)
layer_names[layer] = layer_name
- gb.add_node(layer_name, cpu_cost=specs["runtime_ms"], ram_cost=specs["memory_bytes"], backward=isinstance(layer, GradientLayer))
+ gb.add_node(
+ layer_name,
+ cpu_cost=specs["runtime_ms"],
+ ram_cost=specs["memory_bytes"],
+ backward=isinstance(layer, GradientLayer),
+ )
gb.set_parameter_cost(gb.parameter_cost + specs["param_memory_bytes"])
page_in_cost_dict[layer_name] = specs["pagein_cost_joules"]
page_out_cost_dict[layer_name] = specs["pageout_cost_joules"]
@@ -91,36 +96,61 @@ def get_chipset_and_net(platform: str, model: str, batch_size: int, mem_power_sc
elif platform == "jetsontx2":
chipset = JetsonTX2
else:
- raise NotImplementedError()
+ raise NotImplementedError(f"Platform {platform} not implemented.")
chipset["MEMORY_POWER"] *= mem_power_scale
if model == "linear":
net = make_linear_network()
+ # TODO: these were randomly picked
+ ram_budget_start = 1.0e07
+ ram_budget_end = 1.0e08
elif model == "vgg16":
net = vgg16(batch_size)
+ ram_budget_start = 2.57e07
+ ram_budget_end = 1.15e08
elif model == "vgg16_cifar":
net = vgg16(batch_size, 10, (3, 32, 32))
+ ram_budget_start = 2.57e07 / 49
+ ram_budget_end = 1.15e08 / 49
elif model == "resnet18":
net = resnet18(batch_size)
+ ram_budget_start = 6.42e06
+ ram_budget_end = 2.85e07
elif model == "resnet50":
net = resnet50(batch_size)
+ ram_budget_start = 6.97e06
+ ram_budget_end = 1.27e08
elif model == "resnet18_cifar":
net = resnet18_cifar(batch_size, 10, (3, 32, 32))
+ ram_budget_start = 196608
+ ram_budget_end = 2339408
elif model == "bert":
net = BERTBase(SEQ_LEN=512, HIDDEN_DIM=768, I=64, HEADS=12, NUM_TRANSFORMER_BLOCKS=12)
+ # TODO: this is very broken
+ ram_budget_start = 1e6
+ ram_budget_end = 1e9
elif model == "transformer":
net = BERTBase(SEQ_LEN=512, HIDDEN_DIM=768, I=64, HEADS=12, NUM_TRANSFORMER_BLOCKS=1)
+ ram_budget_start = 1e5
+ ram_budget_end = 7e7
else:
- raise NotImplementedError()
+ raise NotImplementedError(f"Model {model} not implemented.")
- return chipset, net
+ return chipset, net, ram_budget_start, ram_budget_end
def plot_network(
- platform: str, model: str, directory: str, batch_size: int = 1, mem_power_scale: float = 1.0, format="pdf", quiet=True, name=""
+ platform: str,
+ model: str,
+ directory: str,
+ batch_size: int = 1,
+ mem_power_scale: float = 1.0,
+ format="pdf",
+ quiet=True,
+ name="",
):
- chipset, net = get_chipset_and_net(platform, model, batch_size, mem_power_scale)
+ chipset, net, *_ = get_chipset_and_net(platform, model, batch_size, mem_power_scale)
g, *_ = make_dfgraph_costs(net, chipset)
plot_dfgraph(g, directory, format, quiet, name)
diff --git a/poet/utils/checkmate/core/enum_strategy.py b/poet/utils/checkmate/core/enum_strategy.py
new file mode 100644
index 0000000..1b80206
--- /dev/null
+++ b/poet/utils/checkmate/core/enum_strategy.py
@@ -0,0 +1,81 @@
+from enum import Enum
+
+
+class SolveStrategy(Enum):
+ NOT_SPECIFIED = "NOT_SPECIFIED"
+ CHEN_SQRTN = "CHEN_SQRTN"
+ CHEN_GREEDY = "CHEN_GREEDY"
+ CHEN_SQRTN_NOAP = "CHEN_SQRTN_NOAP"
+ CHEN_GREEDY_NOAP = "CHEN_GREEDY_NOAP"
+ OPTIMAL_ILP_GC = "OPTIMAL_ILP_GC"
+ CHECKPOINT_LAST_NODE = "CHECKPOINT_LAST_NODE"
+ CHECKPOINT_ALL = "CHECKPOINT_ALL"
+ CHECKPOINT_ALL_AP = "CHECKPOINT_ALL_AP"
+ GRIEWANK_LOGN = "GRIEWANK_LOGN"
+ APPROX_DET_ROUND_LP_SWEEP = "APPROX_DET_ROUND_LP_SWEEP"
+ APPROX_DET_ROUND_LP_05_THRESH = "APPROX_DET_ROUND_LP_05_THRESH"
+ APPROX_DET_RANDOM_THRESH_ROUND_LP = "APPROX_DET_RANDOM_THRESH_ROUND_LP"
+ APPROX_RANDOMIZED_ROUND = "APPROX_RANDOMIZED_ROUND"
+ LB_LP = "LB_LP"
+ SIMRD = "SIMRD"
+ SIMRD_MSPS = "SIMRD_MSPS"
+
+ @classmethod
+ def get_description(cls, val, model_name=None):
+ is_linear = model_name in ("VGG16", "VGG19", "MobileNet")
+ return {
+ cls.CHEN_SQRTN: "AP $\\sqrt{n}$",
+ cls.CHEN_GREEDY: "AP greedy",
+ cls.CHEN_SQRTN_NOAP: "Generalized $\\sqrt{n}$" if not is_linear else "Chen et al. $\\sqrt{n}$",
+ cls.CHEN_GREEDY_NOAP: "Generalized greedy",
+ cls.OPTIMAL_ILP_GC: "Optimal MILP (proposed)",
+ cls.CHECKPOINT_LAST_NODE: "Checkpoint last node",
+ cls.CHECKPOINT_ALL: "Checkpoint all (ideal)",
+ cls.CHECKPOINT_ALL_AP: "Checkpoint all APs",
+ cls.GRIEWANK_LOGN: "Griewank et al. $\\log~n$" if is_linear else "AP $\\log~n$",
+ cls.APPROX_DET_ROUND_LP_SWEEP: "Approximation via deterministic rounding of LP relaxation w/ threshold sweep",
+ cls.APPROX_DET_RANDOM_THRESH_ROUND_LP: "Approximation via deterministic rounding of LP relaxation with random thresholds",
+ cls.APPROX_DET_ROUND_LP_05_THRESH: "Approximation via deterministic rounding of LP relaxation w/ 0.5 threshold",
+ cls.APPROX_RANDOMIZED_ROUND: "Approximation via randomized rounding of LP relaxation",
+ cls.LB_LP: "Lower bound via LP relaxation",
+ cls.SIMRD: "Dynamic Tensor Rematerialization",
+ cls.SIMRD_MSPS: "Capuchin MSPS heuristic from DTR",
+ }[val]
+
+ # todo move this to experiments codebase
+ @classmethod
+ def get_plot_params(cls, val):
+ from matplotlib import rcParams
+
+ fullsize = rcParams["lines.markersize"]
+ halfsize = fullsize / 2
+ bigger = fullsize * 1.5
+ mapping = {
+ cls.CHEN_SQRTN: ("c", "D", halfsize),
+ cls.CHEN_SQRTN_NOAP: ("c", "^", halfsize),
+ cls.CHEN_GREEDY: ("g", ".", fullsize),
+ cls.CHEN_GREEDY_NOAP: ("g", "+", fullsize),
+ cls.CHECKPOINT_ALL: ("k", "*", bigger),
+ cls.CHECKPOINT_ALL_AP: ("b", "x", fullsize),
+ cls.GRIEWANK_LOGN: ("m", "p", fullsize),
+ cls.OPTIMAL_ILP_GC: ("r", "s", halfsize),
+ cls.APPROX_DET_ROUND_LP_SWEEP: ("r", "*", fullsize),
+ cls.APPROX_DET_ROUND_LP_05_THRESH: ("r", "^", halfsize),
+ cls.APPROX_DET_RANDOM_THRESH_ROUND_LP: ("r", "x", fullsize),
+ cls.APPROX_RANDOMIZED_ROUND: ("r", "+", fullsize),
+ cls.LB_LP: ("r", "p", fullsize),
+ cls.SIMRD: ("r", ".", fullsize),
+ cls.SIMRD_MSPS: ("m", ".", fullsize),
+ }
+ if val in mapping:
+ return mapping[val]
+ raise NotImplementedError("No plotting parameters for strategy {}".format(val))
+
+
+class ImposedSchedule(Enum):
+ COVER_LAST_NODE = "COVER_LAST_NODE"
+ COVER_ALL_NODES = "COVER_ALL_NODES"
+ FULL_SCHEDULE = "FULL_SCHEDULE"
+
+ def __str__(self):
+ return self.value