Skip to content

Commit 24f7575

Browse files
authored
Plan stats (#828)
* Plan stats * Fix tensorstore tests
1 parent 701a22f commit 24f7575

File tree

11 files changed

+354
-88
lines changed

11 files changed

+354
-88
lines changed

cubed/array_api/array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _repr_html_(self):
6565
grid=grid,
6666
nbytes=nbytes,
6767
cbytes=cbytes,
68-
arrs_in_plan=f"{self._plan._finalize().num_arrays()} arrays in Plan",
68+
arrs_in_plan=f"{self.plan().num_arrays} arrays in Plan",
6969
arrtype="np.ndarray",
7070
)
7171

cubed/core/array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def nbytes(self) -> int:
103103
"""Number of bytes in array"""
104104
return self.size * itemsize(self.dtype)
105105

106+
@property
107+
def nchunks(self) -> int:
108+
"""Number of chunks in array"""
109+
return reduce(mul, self.numblocks, 1)
110+
106111
@property
107112
def itemsize(self) -> int:
108113
"""Length of one array element in bytes"""

cubed/core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array":
5858
chunks, x.shape, dtype=x.dtype, previous_chunks=previous_chunks
5959
)
6060

61-
if isinstance(x, zarr.Array): # zarr fast path
61+
if is_storage_array(x): # zarr fast path
6262
from cubed.array_api import Array
6363

6464
name = gensym()

cubed/core/optimization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ def is_primitive_op(node_dict):
134134
return "primitive_op" in node_dict
135135

136136

137+
def is_input_array(dag, name):
138+
"""Return True if a node is an array that is an input to the graph (i.e. has no predecessor arrays)"""
139+
nodes = dict(dag.nodes(data=True))
140+
pre_list = list(predecessors_unordered(dag, name))
141+
assert len(pre_list) == 1 # each array is produced by a single op
142+
pre = pre_list[0]
143+
return not is_primitive_op(nodes[pre])
144+
145+
137146
def is_fusable_with_predecessors(node_dict):
138147
"""Return True if a node is a primitive op and can be fused with its predecessors."""
139148
return (

cubed/core/plan.py

Lines changed: 228 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import tempfile
66
import uuid
77
from datetime import datetime
8+
from enum import Enum
89
from functools import lru_cache
910
from typing import Any, Callable, Dict, Optional
1011

1112
import networkx as nx
1213

13-
from cubed.core.optimization import multiple_inputs_optimize_dag
14+
from cubed.core.optimization import is_input_array, multiple_inputs_optimize_dag
1415
from cubed.primitive.blockwise import BlockwiseSpec
1516
from cubed.primitive.types import PrimitiveOperation
16-
from cubed.runtime.pipeline import visit_nodes
17+
from cubed.runtime.pipeline import visit_node_generations
1718
from cubed.runtime.types import ComputeEndEvent, ComputeStartEvent, CubedPipeline
1819
from cubed.storage.store import is_storage_array
1920
from cubed.storage.zarr import LazyZarrArray, open_if_lazy_zarr_array
@@ -283,6 +284,12 @@ def _finalize(
283284
return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)
284285

285286

287+
class ArrayRole(Enum):
288+
INPUT = "input"
289+
INTERMEDIATE = "intermediate"
290+
OUTPUT = "output"
291+
292+
286293
class FinalizedPlan:
287294
"""A plan that is ready to be run.
288295
@@ -297,6 +304,217 @@ def __init__(self, dag, array_names, optimized):
297304
self.dag = dag
298305
self.array_names = array_names
299306
self.optimized = optimized
307+
self._calculate_stats()
308+
309+
self.input_array_names = []
310+
self.intermediate_array_names = []
311+
for name, node in self.dag.nodes(data=True):
312+
if node.get("type", None) == "array":
313+
if is_input_array(self.dag, name):
314+
self.input_array_names.append(name)
315+
elif name not in self.array_names:
316+
self.intermediate_array_names.append(name)
317+
318+
def _calculate_stats(self) -> None:
319+
self._num_stages = len(list(visit_node_generations(self.dag)))
320+
321+
self._allowed_mem = 0
322+
self._max_projected_mem = 0
323+
self._num_arrays = 0
324+
self._num_primitive_ops = 0
325+
self._num_tasks = 0
326+
self._total_narrays_read = 0
327+
self._total_nbytes_read = 0
328+
self._total_nchunks_read = 0
329+
self._total_narrays_written = 0
330+
self._total_nbytes_written = 0
331+
self._total_nchunks_written = 0
332+
self._total_input_narrays = 0
333+
self._total_input_nbytes = 0
334+
self._total_input_nchunks = 0
335+
self._total_intermediate_narrays = 0
336+
self._total_intermediate_nbytes = 0
337+
self._total_intermediate_nchunks = 0
338+
self._total_output_narrays = 0
339+
self._total_output_nbytes = 0
340+
self._total_output_nchunks = 0
341+
self._total_narrays = 0
342+
self._total_nbytes = 0
343+
self._total_nchunks = 0
344+
345+
for name, node in self.dag.nodes(data=True):
346+
node_type = node.get("type", None)
347+
if node_type == "op":
348+
primitive_op = node.get("primitive_op", None)
349+
if primitive_op is not None:
350+
# allowed mem is the same for all ops
351+
self._allowed_mem = primitive_op.allowed_mem
352+
self._max_projected_mem = max(
353+
primitive_op.projected_mem, self._max_projected_mem
354+
)
355+
self._num_primitive_ops += 1
356+
self._num_tasks += primitive_op.num_tasks
357+
elif node_type == "array":
358+
array = node["target"]
359+
self._num_arrays += 1
360+
if not (isinstance(array, LazyZarrArray) or is_storage_array(array)):
361+
continue # not materialized
362+
self._total_narrays += 1
363+
self._total_nbytes += array.nbytes
364+
self._total_nchunks += array.nchunks
365+
if is_input_array(self.dag, name):
366+
self._total_narrays_read += 1
367+
self._total_nbytes_read += array.nbytes
368+
self._total_nchunks_read += array.nchunks
369+
self._total_input_narrays += 1
370+
self._total_input_nbytes += array.nbytes
371+
self._total_input_nchunks += array.nchunks
372+
elif isinstance(array, LazyZarrArray) or is_storage_array(array):
373+
self._total_narrays_written += 1
374+
self._total_nbytes_written += array.nbytes
375+
self._total_nchunks_written += array.nchunks
376+
if name in self.array_names:
377+
self._total_output_narrays += 1
378+
self._total_output_nbytes += array.nbytes
379+
self._total_output_nchunks += array.nchunks
380+
else:
381+
self._total_narrays_read += 1
382+
self._total_nbytes_read += array.nbytes
383+
self._total_nchunks_read += array.nchunks
384+
self._total_intermediate_narrays += 1
385+
self._total_intermediate_nbytes += array.nbytes
386+
self._total_intermediate_nchunks += array.nchunks
387+
388+
def array_role(self, name) -> ArrayRole:
389+
"""The role of the array in the computation: an input, intermediate or output."""
390+
if name in self.input_array_names:
391+
return ArrayRole.INPUT
392+
elif name in self.intermediate_array_names:
393+
return ArrayRole.INTERMEDIATE
394+
elif name in self.array_names:
395+
return ArrayRole.OUTPUT
396+
else:
397+
raise ValueError(f"Plan does not contain an array with name {name}")
398+
399+
@property
400+
def allowed_mem(self) -> int:
401+
"""The total memory available to a worker for running a task, in bytes."""
402+
return self._allowed_mem
403+
404+
@property
405+
def max_projected_mem(self) -> int:
406+
"""The maximum projected memory across all tasks to execute this plan."""
407+
return self._max_projected_mem
408+
409+
@property
410+
def num_arrays(self) -> int:
411+
"""The number of arrays in this plan."""
412+
return self._num_arrays
413+
414+
@property
415+
def num_stages(self) -> int:
416+
"""The number of stages in this plan (including the initial stage to create arrays)."""
417+
return self._num_stages
418+
419+
@property
420+
def num_primitive_ops(self) -> int:
421+
"""The number of primitive operations in this plan."""
422+
return self._num_primitive_ops
423+
424+
@property
425+
def num_tasks(self) -> int:
426+
"""The number of tasks needed to execute this plan."""
427+
return self._num_tasks
428+
429+
@property
430+
def total_narrays_read(self) -> int:
431+
"""The total number of arrays read from for all materialized arrays in this plan."""
432+
return self._total_narrays_read
433+
434+
@property
435+
def total_nbytes_read(self) -> int:
436+
"""The total number of bytes read for all materialized arrays in this plan."""
437+
return self._total_nbytes_read
438+
439+
@property
440+
def total_nchunks_read(self) -> int:
441+
"""The total number of chunks read for all materialized arrays in this plan."""
442+
return self._total_nchunks_read
443+
444+
@property
445+
def total_narrays_written(self) -> int:
446+
"""The total number of arrays written to for all materialized arrays in this plan."""
447+
return self._total_narrays_written
448+
449+
@property
450+
def total_nbytes_written(self) -> int:
451+
"""The total number of bytes written for all materialized arrays in this plan."""
452+
return self._total_nbytes_written
453+
454+
@property
455+
def total_nchunks_written(self) -> int:
456+
"""The total number of chunks written for all materialized arrays in this plan."""
457+
return self._total_nchunks_written
458+
459+
@property
460+
def total_input_narrays(self) -> int:
461+
"""The total number of materialized input arrays in this plan."""
462+
return self._total_input_narrays
463+
464+
@property
465+
def total_input_nbytes(self) -> int:
466+
"""The total number of bytes for all materialized input arrays in this plan."""
467+
return self._total_input_nbytes
468+
469+
@property
470+
def total_input_nchunks(self) -> int:
471+
"""The total number of chunks for all materialized input arrays in this plan."""
472+
return self._total_input_nchunks
473+
474+
@property
475+
def total_intermediate_narrays(self) -> int:
476+
"""The total number of intermediate arrays in this plan."""
477+
return self._total_intermediate_narrays
478+
479+
@property
480+
def total_intermediate_nbytes(self) -> int:
481+
"""The total number of bytes for all intermediate arrays in this plan."""
482+
return self._total_intermediate_nbytes
483+
484+
@property
485+
def total_intermediate_nchunks(self) -> int:
486+
"""The total number of chunks for all intermediate arrays in this plan."""
487+
return self._total_intermediate_nchunks
488+
489+
@property
490+
def total_output_narrays(self) -> int:
491+
"""The total number of output arrays in this plan."""
492+
return self._total_output_narrays
493+
494+
@property
495+
def total_output_nbytes(self) -> int:
496+
"""The total number of bytes for all output arrays in this plan."""
497+
return self._total_output_nbytes
498+
499+
@property
500+
def total_output_nchunks(self) -> int:
501+
"""The total number of chunks for all output arrays in this plan."""
502+
return self._total_output_nchunks
503+
504+
@property
505+
def total_narrays(self) -> int:
506+
"""The total number of materialized arrays in this plan."""
507+
return self._total_narrays
508+
509+
@property
510+
def total_nbytes(self) -> int:
511+
"""The total number of bytes for all materialized arrays in this plan."""
512+
return self._total_nbytes
513+
514+
@property
515+
def total_nchunks(self) -> int:
516+
"""The total number of chunks for all materialized arrays in this plan."""
517+
return self._total_nchunks
300518

301519
def execute(
302520
self,
@@ -318,7 +536,7 @@ def execute(
318536
compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"
319537

320538
if callbacks is not None:
321-
event = ComputeStartEvent(compute_id, dag)
539+
event = ComputeStartEventWithPlan(compute_id, dag, self)
322540
[callback.on_compute_start(event) for callback in callbacks]
323541
executor.execute_dag(
324542
dag,
@@ -352,9 +570,9 @@ def visualize(
352570
"rankdir": rankdir,
353571
"label": (
354572
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
355-
rf"num tasks: {self.num_tasks()}\l"
356-
rf"max projected memory: {memory_repr(self.max_projected_mem())}\l"
357-
rf"total nbytes written: {memory_repr(self.total_nbytes_written())}\l"
573+
rf"num tasks: {self.num_tasks}\l"
574+
rf"max projected memory: {memory_repr(self.max_projected_mem)}\l"
575+
rf"total nbytes written: {memory_repr(self.total_nbytes_written)}\l"
358576
rf"optimized: {self.optimized}\l"
359577
),
360578
"labelloc": "bottom",
@@ -495,37 +713,11 @@ def visualize(
495713
pass
496714
return None
497715

498-
def max_projected_mem(self) -> int:
499-
"""Return the maximum projected memory across all tasks to execute this plan."""
500-
projected_mem_values = [
501-
node["primitive_op"].projected_mem for _, node in visit_nodes(self.dag)
502-
]
503-
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0
504-
505-
def num_arrays(self) -> int:
506-
"""Return the number of arrays in this plan."""
507-
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))
508716

509-
def num_primitive_ops(self) -> int:
510-
"""Return the number of primitive operations in this plan."""
511-
return len(list(visit_nodes(self.dag)))
512-
513-
def num_tasks(self) -> int:
514-
"""Return the number of tasks needed to execute this plan."""
515-
tasks = 0
516-
for _, node in visit_nodes(self.dag):
517-
tasks += node["primitive_op"].num_tasks
518-
return tasks
519-
520-
def total_nbytes_written(self) -> int:
521-
"""Return the total number of bytes written for all materialized arrays in this plan."""
522-
nbytes = 0
523-
for _, d in self.dag.nodes(data=True):
524-
if d.get("type") == "array":
525-
target = d["target"]
526-
if isinstance(target, LazyZarrArray):
527-
nbytes += target.nbytes
528-
return nbytes
717+
@dataclasses.dataclass
718+
class ComputeStartEventWithPlan(ComputeStartEvent):
719+
plan: FinalizedPlan
720+
"""The plan."""
529721

530722

531723
def arrays_to_dag(*arrays):

0 commit comments

Comments
 (0)