Skip to content

Commit 28af91b

Browse files
committed
ckp: implement blocked time stepping
1 parent f695c7c commit 28af91b

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

pyrevolve/pyrevolve.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABCMeta, abstractproperty, abstractmethod
22
import numpy as np
3+
import math
34
from . import crevolve as cr
45
from .compression import init_compression as init
56
from .schedulers import CRevolve, HRevolve, Action, Architecture
@@ -69,11 +70,12 @@ def __init__(
6970
fwd_operator,
7071
rev_operator,
7172
n_checkpoints,
72-
n_timesteps,
73+
op_timesteps,
7374
storage_list=None,
7475
scheduler=None,
7576
timings=None,
7677
profiler=None,
78+
block_size=1,
7779
):
7880
"""
7981
Initialises checkpointer for a given forward- and reverse operator, a
@@ -82,10 +84,10 @@ def __init__(
8284
methods and a scheduler object must be provided as well. Otherwise
8385
NumpyStorage and CRevolve are used as default
8486
"""
85-
if n_timesteps is None:
87+
if op_timesteps is None:
8688
raise Exception(
8789
"Online checkpointing not yet supported. Specify \
88-
number of time steps!"
90+
number of Operator time steps!"
8991
)
9092

9193
if profiler is None:
@@ -100,12 +102,15 @@ def __init__(
100102

101103
self.checkpoint = checkpoint
102104
self.n_checkpoints = n_checkpoints
103-
self.n_timesteps = n_timesteps
105+
self.block_size = block_size
106+
self.op_timesteps = op_timesteps
104107
self.timings = timings
105108
self.fwd_operator = fwd_operator
106109
self.rev_operator = rev_operator
107110
self.scheduler = scheduler
108111

112+
self.cp_timesteps = int(math.ceil(self.op_timesteps / self.block_size))
113+
109114
def addStorage(self, new_storage):
110115
self.storage_list.append(new_storage)
111116

@@ -171,6 +176,20 @@ def addByteStorage(self, compression_params):
171176
def makespan(self):
172177
return 0
173178

179+
@property
180+
def op_old_capo(self):
181+
return self.scheduler.old_capo * self.block_size
182+
183+
@property
184+
def op_capo(self):
185+
_op_capo = self.scheduler.capo * self.block_size
186+
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps
187+
188+
@property
189+
def next_op_capo(self):
190+
_op_capo = (self.scheduler.capo + 1) * self.block_size
191+
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps
192+
174193
def apply_forward(self):
175194
"""Executes only the forward computation while storing checkpoints,
176195
then returns."""
@@ -181,7 +200,7 @@ def apply_forward(self):
181200
# advance forward computation
182201
with self.profiler.get_timer("forward", "advance"):
183202
self.fwd_operator.apply(
184-
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
203+
t_start=self.op_old_capo, t_end=self.op_capo
185204
)
186205
elif action.type == Action.TAKESHOT:
187206
# take a snapshot: copy from workspace into storage
@@ -195,7 +214,7 @@ def apply_forward(self):
195214
# final step in the forward computation
196215
with self.profiler.get_timer("forward", "lastfw"):
197216
self.fwd_operator.apply(
198-
t_start=self.scheduler.old_capo, t_end=self.n_timesteps
217+
t_start=self.op_old_capo, t_end=self.op_timesteps
199218
)
200219
break
201220
elif action.type == Action.REVERSE:
@@ -220,10 +239,10 @@ def apply_reverse(self):
220239
# advance adjoint computation by a single step
221240
with self.profiler.get_timer("reverse", "reverse"):
222241
self.fwd_operator.apply(
223-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
242+
t_start=self.op_capo, t_end=self.next_op_capo
224243
)
225244
self.rev_operator.apply(
226-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
245+
t_start=self.op_capo, t_end=self.next_op_capo
227246
)
228247
elif action.type == Action.REVSTART:
229248
"""Sets the rev_operator to 'nt' only if its not already there.
@@ -232,7 +251,7 @@ def apply_reverse(self):
232251
"""
233252
with self.profiler.get_timer("reverse", "reverse"):
234253
self.rev_operator.apply(
235-
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
254+
t_start=self.op_capo, t_end=self.next_op_capo
236255
)
237256
elif action.type == Action.TAKESHOT:
238257
# take a snapshot: copy from workspace into storage
@@ -242,7 +261,7 @@ def apply_reverse(self):
242261
# advance forward computation
243262
with self.profiler.get_timer("reverse", "advance"):
244263
self.fwd_operator.apply(
245-
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
264+
t_start=self.op_old_capo, t_end=self.op_capo
246265
)
247266
elif action.type == Action.RESTORE:
248267
# restore a snapshot: copy from storage into workspace
@@ -294,13 +313,14 @@ def __init__(
294313
fwd_operator,
295314
rev_operator,
296315
n_checkpoints,
297-
n_timesteps,
316+
op_timesteps,
298317
timings=None,
299318
profiler=None,
300319
compression_params=None,
301320
diskstorage=False,
302321
filedir="./",
303322
singlefile=True,
323+
block_size=1,
304324
):
305325
"""
306326
Initializes a single-level Revolver
@@ -309,7 +329,7 @@ def __init__(
309329
fwd_operator: forward operator
310330
rev_operator: backward operator
311331
n_checkpoints: number of checkpoints
312-
n_timesteps: number of timesteps
332+
op_timesteps: number of timesteps
313333
timings: timings
314334
profiler: Profiler
315335
compression_params: compression scheme
@@ -322,20 +342,21 @@ def __init__(
322342
fwd_operator,
323343
rev_operator,
324344
n_checkpoints,
325-
n_timesteps,
345+
op_timesteps,
326346
timings=timings,
327347
profiler=profiler,
348+
block_size=block_size,
328349
)
329350

330351
self.filedir = filedir
331352
self.singlefile = singlefile
332353

333354
if n_checkpoints is None:
334-
self.n_checkpoints = cr.adjust(n_timesteps)
355+
self.n_checkpoints = cr.adjust(self.cp_timesteps)
335356
else:
336357
self.n_checkpoints = n_checkpoints
337358

338-
self.scheduler = CRevolve(self.n_checkpoints, self.n_timesteps)
359+
self.scheduler = CRevolve(self.n_checkpoints, self.cp_timesteps)
339360

340361
# remove storage list to avoid memory overflow
341362
self.resetStorageList()
@@ -369,13 +390,14 @@ def __init__(
369390
checkpoint,
370391
fwd_operator,
371392
rev_operator,
372-
n_timesteps,
393+
op_timesteps,
373394
storage_list,
374395
timings=None,
375396
profiler=None,
376397
uf=1,
377398
ub=1,
378399
up=1,
400+
block_size=1,
379401
):
380402
"""
381403
Initializes a multi-level Revolver using HRevolve
@@ -390,7 +412,7 @@ def __init__(
390412
fwd_operator: forward operator
391413
rev_operator: backward operator
392414
n_checkpoints: number of checkpoints
393-
n_timesteps: number of timesteps
415+
op_timesteps: number of timesteps
394416
timings: timings
395417
profiler: profiler
396418
storage_list: list of storage objects
@@ -402,12 +424,14 @@ def __init__(
402424
checkpoint,
403425
fwd_operator,
404426
rev_operator,
405-
n_timesteps,
406-
n_timesteps,
427+
op_timesteps,
428+
op_timesteps,
407429
storage_list=storage_list,
408430
timings=timings,
409431
profiler=profiler,
432+
block_size=block_size,
410433
)
434+
411435
self.uf = uf # forward cost (default=1)
412436
self.ub = ub # backward cost (default=1)
413437
self.up = up # turn cost (default=1)
@@ -431,7 +455,8 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
431455
self.up = up
432456
self.arch = Architecture(self.storage_list)
433457
self.scheduler = HRevolve(
434-
self.n_checkpoints, self.n_timesteps, self.arch, self.uf, self.ub, self.up
458+
self.n_checkpoints, self.cp_timesteps, self.arch, self.uf, self.ub,
459+
self.up
435460
)
436461
else:
437462
raise ValueError(
@@ -471,21 +496,23 @@ def __init__(
471496
fwd_operator,
472497
rev_operator,
473498
n_checkpoints,
474-
n_timesteps,
499+
op_timesteps,
475500
timings=None,
476501
profiler=None,
477502
compression_params=None,
503+
block_size=1,
478504
):
479505
super().__init__(
480506
checkpoint,
481507
fwd_operator,
482508
rev_operator,
483509
n_checkpoints,
484-
n_timesteps,
510+
op_timesteps,
485511
timings=timings,
486512
profiler=profiler,
487513
compression_params=compression_params,
488514
diskstorage=False,
515+
block_size=block_size,
489516
)
490517

491518

@@ -510,23 +537,25 @@ def __init__(
510537
fwd_operator,
511538
rev_operator,
512539
n_checkpoints,
513-
n_timesteps,
540+
op_timesteps,
514541
timings=None,
515542
profiler=None,
516543
filedir="./",
517544
singlefile=True,
545+
block_size=1,
518546
):
519547
super().__init__(
520548
checkpoint,
521549
fwd_operator,
522550
rev_operator,
523551
n_checkpoints,
524-
n_timesteps,
552+
op_timesteps,
525553
timings=timings,
526554
profiler=profiler,
527555
diskstorage=True,
528556
filedir=filedir,
529557
singlefile=singlefile,
558+
block_size=block_size,
530559
)
531560

532561

0 commit comments

Comments
 (0)