55import tempfile
66import uuid
77from datetime import datetime
8+ from enum import Enum
89from functools import lru_cache
910from typing import Any , Callable , Dict , Optional
1011
1112import networkx as nx
1213
13- from cubed .core .optimization import multiple_inputs_optimize_dag
14+ from cubed .core .optimization import is_input_array , multiple_inputs_optimize_dag
1415from cubed .primitive .blockwise import BlockwiseSpec
1516from cubed .primitive .types import PrimitiveOperation
16- from cubed .runtime .pipeline import visit_nodes
17+ from cubed .runtime .pipeline import visit_node_generations
1718from cubed .runtime .types import ComputeEndEvent , ComputeStartEvent , CubedPipeline
1819from cubed .storage .store import is_storage_array
1920from cubed .storage .zarr import LazyZarrArray , open_if_lazy_zarr_array
@@ -283,6 +284,12 @@ def _finalize(
283284 return FinalizedPlan (nx .freeze (dag ), self .array_names , optimize_graph )
284285
285286
287+ class ArrayRole (Enum ):
288+ INPUT = "input"
289+ INTERMEDIATE = "intermediate"
290+ OUTPUT = "output"
291+
292+
286293class FinalizedPlan :
287294 """A plan that is ready to be run.
288295
@@ -297,6 +304,217 @@ def __init__(self, dag, array_names, optimized):
297304 self .dag = dag
298305 self .array_names = array_names
299306 self .optimized = optimized
307+ self ._calculate_stats ()
308+
309+ self .input_array_names = []
310+ self .intermediate_array_names = []
311+ for name , node in self .dag .nodes (data = True ):
312+ if node .get ("type" , None ) == "array" :
313+ if is_input_array (self .dag , name ):
314+ self .input_array_names .append (name )
315+ elif name not in self .array_names :
316+ self .intermediate_array_names .append (name )
317+
318+ def _calculate_stats (self ) -> None :
319+ self ._num_stages = len (list (visit_node_generations (self .dag )))
320+
321+ self ._allowed_mem = 0
322+ self ._max_projected_mem = 0
323+ self ._num_arrays = 0
324+ self ._num_primitive_ops = 0
325+ self ._num_tasks = 0
326+ self ._total_narrays_read = 0
327+ self ._total_nbytes_read = 0
328+ self ._total_nchunks_read = 0
329+ self ._total_narrays_written = 0
330+ self ._total_nbytes_written = 0
331+ self ._total_nchunks_written = 0
332+ self ._total_input_narrays = 0
333+ self ._total_input_nbytes = 0
334+ self ._total_input_nchunks = 0
335+ self ._total_intermediate_narrays = 0
336+ self ._total_intermediate_nbytes = 0
337+ self ._total_intermediate_nchunks = 0
338+ self ._total_output_narrays = 0
339+ self ._total_output_nbytes = 0
340+ self ._total_output_nchunks = 0
341+ self ._total_narrays = 0
342+ self ._total_nbytes = 0
343+ self ._total_nchunks = 0
344+
345+ for name , node in self .dag .nodes (data = True ):
346+ node_type = node .get ("type" , None )
347+ if node_type == "op" :
348+ primitive_op = node .get ("primitive_op" , None )
349+ if primitive_op is not None :
350+ # allowed mem is the same for all ops
351+ self ._allowed_mem = primitive_op .allowed_mem
352+ self ._max_projected_mem = max (
353+ primitive_op .projected_mem , self ._max_projected_mem
354+ )
355+ self ._num_primitive_ops += 1
356+ self ._num_tasks += primitive_op .num_tasks
357+ elif node_type == "array" :
358+ array = node ["target" ]
359+ self ._num_arrays += 1
360+ if not (isinstance (array , LazyZarrArray ) or is_storage_array (array )):
361+ continue # not materialized
362+ self ._total_narrays += 1
363+ self ._total_nbytes += array .nbytes
364+ self ._total_nchunks += array .nchunks
365+ if is_input_array (self .dag , name ):
366+ self ._total_narrays_read += 1
367+ self ._total_nbytes_read += array .nbytes
368+ self ._total_nchunks_read += array .nchunks
369+ self ._total_input_narrays += 1
370+ self ._total_input_nbytes += array .nbytes
371+ self ._total_input_nchunks += array .nchunks
372+ elif isinstance (array , LazyZarrArray ) or is_storage_array (array ):
373+ self ._total_narrays_written += 1
374+ self ._total_nbytes_written += array .nbytes
375+ self ._total_nchunks_written += array .nchunks
376+ if name in self .array_names :
377+ self ._total_output_narrays += 1
378+ self ._total_output_nbytes += array .nbytes
379+ self ._total_output_nchunks += array .nchunks
380+ else :
381+ self ._total_narrays_read += 1
382+ self ._total_nbytes_read += array .nbytes
383+ self ._total_nchunks_read += array .nchunks
384+ self ._total_intermediate_narrays += 1
385+ self ._total_intermediate_nbytes += array .nbytes
386+ self ._total_intermediate_nchunks += array .nchunks
387+
388+ def array_role (self , name ) -> ArrayRole :
389+ """The role of the array in the computation: an input, intermediate or output."""
390+ if name in self .input_array_names :
391+ return ArrayRole .INPUT
392+ elif name in self .intermediate_array_names :
393+ return ArrayRole .INTERMEDIATE
394+ elif name in self .array_names :
395+ return ArrayRole .OUTPUT
396+ else :
397+ raise ValueError (f"Plan does not contain an array with name { name } " )
398+
399+ @property
400+ def allowed_mem (self ) -> int :
401+ """The total memory available to a worker for running a task, in bytes."""
402+ return self ._allowed_mem
403+
404+ @property
405+ def max_projected_mem (self ) -> int :
406+ """The maximum projected memory across all tasks to execute this plan."""
407+ return self ._max_projected_mem
408+
409+ @property
410+ def num_arrays (self ) -> int :
411+ """The number of arrays in this plan."""
412+ return self ._num_arrays
413+
414+ @property
415+ def num_stages (self ) -> int :
416+ """The number of stages in this plan (including the initial stage to create arrays)."""
417+ return self ._num_stages
418+
419+ @property
420+ def num_primitive_ops (self ) -> int :
421+ """The number of primitive operations in this plan."""
422+ return self ._num_primitive_ops
423+
424+ @property
425+ def num_tasks (self ) -> int :
426+ """The number of tasks needed to execute this plan."""
427+ return self ._num_tasks
428+
429+ @property
430+ def total_narrays_read (self ) -> int :
431+ """The total number of arrays read from for all materialized arrays in this plan."""
432+ return self ._total_narrays_read
433+
434+ @property
435+ def total_nbytes_read (self ) -> int :
436+ """The total number of bytes read for all materialized arrays in this plan."""
437+ return self ._total_nbytes_read
438+
439+ @property
440+ def total_nchunks_read (self ) -> int :
441+ """The total number of chunks read for all materialized arrays in this plan."""
442+ return self ._total_nchunks_read
443+
444+ @property
445+ def total_narrays_written (self ) -> int :
446+ """The total number of arrays written to for all materialized arrays in this plan."""
447+ return self ._total_narrays_written
448+
449+ @property
450+ def total_nbytes_written (self ) -> int :
451+ """The total number of bytes written for all materialized arrays in this plan."""
452+ return self ._total_nbytes_written
453+
454+ @property
455+ def total_nchunks_written (self ) -> int :
456+ """The total number of chunks written for all materialized arrays in this plan."""
457+ return self ._total_nchunks_written
458+
459+ @property
460+ def total_input_narrays (self ) -> int :
461+ """The total number of materialized input arrays in this plan."""
462+ return self ._total_input_narrays
463+
464+ @property
465+ def total_input_nbytes (self ) -> int :
466+ """The total number of bytes for all materialized input arrays in this plan."""
467+ return self ._total_input_nbytes
468+
469+ @property
470+ def total_input_nchunks (self ) -> int :
471+ """The total number of chunks for all materialized input arrays in this plan."""
472+ return self ._total_input_nchunks
473+
474+ @property
475+ def total_intermediate_narrays (self ) -> int :
476+ """The total number of intermediate arrays in this plan."""
477+ return self ._total_intermediate_narrays
478+
479+ @property
480+ def total_intermediate_nbytes (self ) -> int :
481+ """The total number of bytes for all intermediate arrays in this plan."""
482+ return self ._total_intermediate_nbytes
483+
484+ @property
485+ def total_intermediate_nchunks (self ) -> int :
486+ """The total number of chunks for all intermediate arrays in this plan."""
487+ return self ._total_intermediate_nchunks
488+
489+ @property
490+ def total_output_narrays (self ) -> int :
491+ """The total number of output arrays in this plan."""
492+ return self ._total_output_narrays
493+
494+ @property
495+ def total_output_nbytes (self ) -> int :
496+ """The total number of bytes for all output arrays in this plan."""
497+ return self ._total_output_nbytes
498+
499+ @property
500+ def total_output_nchunks (self ) -> int :
501+ """The total number of chunks for all output arrays in this plan."""
502+ return self ._total_output_nchunks
503+
504+ @property
505+ def total_narrays (self ) -> int :
506+ """The total number of materialized arrays in this plan."""
507+ return self ._total_narrays
508+
509+ @property
510+ def total_nbytes (self ) -> int :
511+ """The total number of bytes for all materialized arrays in this plan."""
512+ return self ._total_nbytes
513+
514+ @property
515+ def total_nchunks (self ) -> int :
516+ """The total number of chunks for all materialized arrays in this plan."""
517+ return self ._total_nchunks
300518
301519 def execute (
302520 self ,
@@ -318,7 +536,7 @@ def execute(
318536 compute_id = f"compute-{ datetime .now ().strftime ('%Y%m%dT%H%M%S.%f' )} "
319537
320538 if callbacks is not None :
321- event = ComputeStartEvent (compute_id , dag )
539+ event = ComputeStartEventWithPlan (compute_id , dag , self )
322540 [callback .on_compute_start (event ) for callback in callbacks ]
323541 executor .execute_dag (
324542 dag ,
@@ -352,9 +570,9 @@ def visualize(
352570 "rankdir" : rankdir ,
353571 "label" : (
354572 # note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
355- rf"num tasks: { self .num_tasks () } \l"
356- rf"max projected memory: { memory_repr (self .max_projected_mem () )} \l"
357- rf"total nbytes written: { memory_repr (self .total_nbytes_written () )} \l"
573+ rf"num tasks: { self .num_tasks } \l"
574+ rf"max projected memory: { memory_repr (self .max_projected_mem )} \l"
575+ rf"total nbytes written: { memory_repr (self .total_nbytes_written )} \l"
358576 rf"optimized: { self .optimized } \l"
359577 ),
360578 "labelloc" : "bottom" ,
@@ -495,37 +713,11 @@ def visualize(
495713 pass
496714 return None
497715
498- def max_projected_mem (self ) -> int :
499- """Return the maximum projected memory across all tasks to execute this plan."""
500- projected_mem_values = [
501- node ["primitive_op" ].projected_mem for _ , node in visit_nodes (self .dag )
502- ]
503- return max (projected_mem_values ) if len (projected_mem_values ) > 0 else 0
504-
505- def num_arrays (self ) -> int :
506- """Return the number of arrays in this plan."""
507- return sum (d .get ("type" ) == "array" for _ , d in self .dag .nodes (data = True ))
508716
509- def num_primitive_ops (self ) -> int :
510- """Return the number of primitive operations in this plan."""
511- return len (list (visit_nodes (self .dag )))
512-
513- def num_tasks (self ) -> int :
514- """Return the number of tasks needed to execute this plan."""
515- tasks = 0
516- for _ , node in visit_nodes (self .dag ):
517- tasks += node ["primitive_op" ].num_tasks
518- return tasks
519-
520- def total_nbytes_written (self ) -> int :
521- """Return the total number of bytes written for all materialized arrays in this plan."""
522- nbytes = 0
523- for _ , d in self .dag .nodes (data = True ):
524- if d .get ("type" ) == "array" :
525- target = d ["target" ]
526- if isinstance (target , LazyZarrArray ):
527- nbytes += target .nbytes
528- return nbytes
717+ @dataclasses .dataclass
718+ class ComputeStartEventWithPlan (ComputeStartEvent ):
719+ plan : FinalizedPlan
720+ """The plan."""
529721
530722
531723def arrays_to_dag (* arrays ):
0 commit comments