1
1
from abc import ABCMeta , abstractproperty , abstractmethod
2
2
import numpy as np
3
+ import math
3
4
from . import crevolve as cr
4
5
from .compression import init_compression as init
5
6
from .schedulers import CRevolve , HRevolve , Action , Architecture
@@ -69,11 +70,12 @@ def __init__(
69
70
fwd_operator ,
70
71
rev_operator ,
71
72
n_checkpoints ,
72
- n_timesteps ,
73
+ op_timesteps ,
73
74
storage_list = None ,
74
75
scheduler = None ,
75
76
timings = None ,
76
77
profiler = None ,
78
+ block_size = 1 ,
77
79
):
78
80
"""
79
81
Initialises checkpointer for a given forward- and reverse operator, a
@@ -82,10 +84,10 @@ def __init__(
82
84
methods and a scheduler object must be provided as well. Otherwise
83
85
NumpyStorage and CRevolve are used as default
84
86
"""
85
- if n_timesteps is None :
87
+ if op_timesteps is None :
86
88
raise Exception (
87
89
"Online checkpointing not yet supported. Specify \
88
- number of time steps!"
90
+ number of Operator time steps!"
89
91
)
90
92
91
93
if profiler is None :
@@ -100,12 +102,15 @@ def __init__(
100
102
101
103
self .checkpoint = checkpoint
102
104
self .n_checkpoints = n_checkpoints
103
- self .n_timesteps = n_timesteps
105
+ self .block_size = block_size
106
+ self .op_timesteps = op_timesteps
104
107
self .timings = timings
105
108
self .fwd_operator = fwd_operator
106
109
self .rev_operator = rev_operator
107
110
self .scheduler = scheduler
108
111
112
+ self .cp_timesteps = int (math .ceil (self .op_timesteps / self .block_size ))
113
+
109
114
def addStorage (self , new_storage ):
110
115
self .storage_list .append (new_storage )
111
116
@@ -171,6 +176,20 @@ def addByteStorage(self, compression_params):
171
176
def makespan (self ):
172
177
return 0
173
178
179
+ @property
180
+ def op_old_capo (self ):
181
+ return self .scheduler .old_capo * self .block_size
182
+
183
+ @property
184
+ def op_capo (self ):
185
+ _op_capo = self .scheduler .capo * self .block_size
186
+ return _op_capo if _op_capo < self .op_timesteps else self .op_timesteps
187
+
188
+ @property
189
+ def next_op_capo (self ):
190
+ _op_capo = (self .scheduler .capo + 1 ) * self .block_size
191
+ return _op_capo if _op_capo < self .op_timesteps else self .op_timesteps
192
+
174
193
def apply_forward (self ):
175
194
"""Executes only the forward computation while storing checkpoints,
176
195
then returns."""
@@ -181,7 +200,7 @@ def apply_forward(self):
181
200
# advance forward computation
182
201
with self .profiler .get_timer ("forward" , "advance" ):
183
202
self .fwd_operator .apply (
184
- t_start = self .scheduler . old_capo , t_end = self .scheduler . capo
203
+ t_start = self .op_old_capo , t_end = self .op_capo
185
204
)
186
205
elif action .type == Action .TAKESHOT :
187
206
# take a snapshot: copy from workspace into storage
@@ -195,7 +214,7 @@ def apply_forward(self):
195
214
# final step in the forward computation
196
215
with self .profiler .get_timer ("forward" , "lastfw" ):
197
216
self .fwd_operator .apply (
198
- t_start = self .scheduler . old_capo , t_end = self .n_timesteps
217
+ t_start = self .op_old_capo , t_end = self .op_timesteps
199
218
)
200
219
break
201
220
elif action .type == Action .REVERSE :
@@ -220,10 +239,10 @@ def apply_reverse(self):
220
239
# advance adjoint computation by a single step
221
240
with self .profiler .get_timer ("reverse" , "reverse" ):
222
241
self .fwd_operator .apply (
223
- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
242
+ t_start = self .op_capo , t_end = self .next_op_capo
224
243
)
225
244
self .rev_operator .apply (
226
- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
245
+ t_start = self .op_capo , t_end = self .next_op_capo
227
246
)
228
247
elif action .type == Action .REVSTART :
229
248
"""Sets the rev_operator to 'nt' only if its not already there.
@@ -232,7 +251,7 @@ def apply_reverse(self):
232
251
"""
233
252
with self .profiler .get_timer ("reverse" , "reverse" ):
234
253
self .rev_operator .apply (
235
- t_start = self .scheduler . capo , t_end = self .scheduler . capo + 1
254
+ t_start = self .op_capo , t_end = self .next_op_capo
236
255
)
237
256
elif action .type == Action .TAKESHOT :
238
257
# take a snapshot: copy from workspace into storage
@@ -242,7 +261,7 @@ def apply_reverse(self):
242
261
# advance forward computation
243
262
with self .profiler .get_timer ("reverse" , "advance" ):
244
263
self .fwd_operator .apply (
245
- t_start = self .scheduler . old_capo , t_end = self .scheduler . capo
264
+ t_start = self .op_old_capo , t_end = self .op_capo
246
265
)
247
266
elif action .type == Action .RESTORE :
248
267
# restore a snapshot: copy from storage into workspace
@@ -294,13 +313,14 @@ def __init__(
294
313
fwd_operator ,
295
314
rev_operator ,
296
315
n_checkpoints ,
297
- n_timesteps ,
316
+ op_timesteps ,
298
317
timings = None ,
299
318
profiler = None ,
300
319
compression_params = None ,
301
320
diskstorage = False ,
302
321
filedir = "./" ,
303
322
singlefile = True ,
323
+ block_size = 1 ,
304
324
):
305
325
"""
306
326
Initializes a single-level Revolver
@@ -309,7 +329,7 @@ def __init__(
309
329
fwd_operator: forward operator
310
330
rev_operator: backward operator
311
331
n_checkpoints: number of checkpoints
312
- n_timesteps: number of timesteps
332
+ op_timesteps: number of timesteps
313
333
timings: timings
314
334
profiler: Profiler
315
335
compression_params: compression scheme
@@ -322,20 +342,21 @@ def __init__(
322
342
fwd_operator ,
323
343
rev_operator ,
324
344
n_checkpoints ,
325
- n_timesteps ,
345
+ op_timesteps ,
326
346
timings = timings ,
327
347
profiler = profiler ,
348
+ block_size = block_size ,
328
349
)
329
350
330
351
self .filedir = filedir
331
352
self .singlefile = singlefile
332
353
333
354
if n_checkpoints is None :
334
- self .n_checkpoints = cr .adjust (n_timesteps )
355
+ self .n_checkpoints = cr .adjust (self . cp_timesteps )
335
356
else :
336
357
self .n_checkpoints = n_checkpoints
337
358
338
- self .scheduler = CRevolve (self .n_checkpoints , self .n_timesteps )
359
+ self .scheduler = CRevolve (self .n_checkpoints , self .cp_timesteps )
339
360
340
361
# remove storage list to avoid memory overflow
341
362
self .resetStorageList ()
@@ -369,13 +390,14 @@ def __init__(
369
390
checkpoint ,
370
391
fwd_operator ,
371
392
rev_operator ,
372
- n_timesteps ,
393
+ op_timesteps ,
373
394
storage_list ,
374
395
timings = None ,
375
396
profiler = None ,
376
397
uf = 1 ,
377
398
ub = 1 ,
378
399
up = 1 ,
400
+ block_size = 1 ,
379
401
):
380
402
"""
381
403
Initializes a multi-level Revolver using HRevolve
@@ -390,7 +412,7 @@ def __init__(
390
412
fwd_operator: forward operator
391
413
rev_operator: backward operator
392
414
n_checkpoints: number of checkpoints
393
- n_timesteps : number of timesteps
415
+ op_timesteps : number of timesteps
394
416
timings: timings
395
417
profiler: profiler
396
418
storage_list: list of storage objects
@@ -402,12 +424,14 @@ def __init__(
402
424
checkpoint ,
403
425
fwd_operator ,
404
426
rev_operator ,
405
- n_timesteps ,
406
- n_timesteps ,
427
+ op_timesteps ,
428
+ op_timesteps ,
407
429
storage_list = storage_list ,
408
430
timings = timings ,
409
431
profiler = profiler ,
432
+ block_size = block_size ,
410
433
)
434
+
411
435
self .uf = uf # forward cost (default=1)
412
436
self .ub = ub # backward cost (default=1)
413
437
self .up = up # turn cost (default=1)
@@ -431,7 +455,8 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
431
455
self .up = up
432
456
self .arch = Architecture (self .storage_list )
433
457
self .scheduler = HRevolve (
434
- self .n_checkpoints , self .n_timesteps , self .arch , self .uf , self .ub , self .up
458
+ self .n_checkpoints , self .cp_timesteps , self .arch , self .uf , self .ub ,
459
+ self .up
435
460
)
436
461
else :
437
462
raise ValueError (
@@ -471,21 +496,23 @@ def __init__(
471
496
fwd_operator ,
472
497
rev_operator ,
473
498
n_checkpoints ,
474
- n_timesteps ,
499
+ op_timesteps ,
475
500
timings = None ,
476
501
profiler = None ,
477
502
compression_params = None ,
503
+ block_size = 1 ,
478
504
):
479
505
super ().__init__ (
480
506
checkpoint ,
481
507
fwd_operator ,
482
508
rev_operator ,
483
509
n_checkpoints ,
484
- n_timesteps ,
510
+ op_timesteps ,
485
511
timings = timings ,
486
512
profiler = profiler ,
487
513
compression_params = compression_params ,
488
514
diskstorage = False ,
515
+ block_size = block_size ,
489
516
)
490
517
491
518
@@ -510,23 +537,25 @@ def __init__(
510
537
fwd_operator ,
511
538
rev_operator ,
512
539
n_checkpoints ,
513
- n_timesteps ,
540
+ op_timesteps ,
514
541
timings = None ,
515
542
profiler = None ,
516
543
filedir = "./" ,
517
544
singlefile = True ,
545
+ block_size = 1 ,
518
546
):
519
547
super ().__init__ (
520
548
checkpoint ,
521
549
fwd_operator ,
522
550
rev_operator ,
523
551
n_checkpoints ,
524
- n_timesteps ,
552
+ op_timesteps ,
525
553
timings = timings ,
526
554
profiler = profiler ,
527
555
diskstorage = True ,
528
556
filedir = filedir ,
529
557
singlefile = singlefile ,
558
+ block_size = block_size ,
530
559
)
531
560
532
561
0 commit comments