Skip to content

Commit

Permalink
dist unet example now generates data in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
rhewett committed Jul 28, 2021
1 parent 4923108 commit bf0d78a
Showing 1 changed file with 79 additions and 37 deletions.
116 changes: 79 additions & 37 deletions examples/distributed_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,45 +71,87 @@
n_batch = 1
batch_size = 1

scatter = distdl.nn.DistributedTranspose(P_root, P_unet)
from distdl.utilities.tensor_decomposition import compute_subtensor_shapes_balanced
from distdl.utilities.tensor_decomposition import compute_subtensor_start_indices
from distdl.utilities.tensor_decomposition import compute_subtensor_stop_indices
from distdl.utilities.torch import TensorStructure

global_input_tensor_structure = TensorStructure()
global_input_tensor_structure.shape = input_features
subtensor_shapes = compute_subtensor_shapes_balanced(global_input_tensor_structure, P_unet.shape[2:])
subtensor_starts = compute_subtensor_start_indices(subtensor_shapes)
subtensor_stops = compute_subtensor_stop_indices(subtensor_shapes)
_slice = tuple([slice(i, i+1) for i in P_unet.index[2:]])
my_start = subtensor_starts[_slice].squeeze()
my_stop = subtensor_stops[_slice].squeeze()

MPI.COMM_WORLD.Barrier()
with torch.no_grad():
if P_root.active:

sample_spacing = [np.linspace(0, 1, f) for f in input_features]
sample_grid = np.meshgrid(*sample_spacing)

n_ellipses_target = 3
n_ellipses_noise = 2

timer.start("data gen")
batches = list()
for i in range(n_batch):
batch = list()
for j in range(batch_size):
# Add an image-mask tuple to the batch
batch.append(gen_data(sample_grid, n_ellipses_target, n_ellipses_noise))
img = torch.cat([im for im, ma in batch],dim=0)
img = scatter(img)
mask = torch.cat([ma for im, ma in batch],dim=0)
mask = scatter(mask)

batches.append((img, mask))
timer.stop("data gen", input_features)
else:
timer.start("data gen")
batches = list()
for i in range(n_batch):
batch = list()
for j in range(batch_size):
img = distdl.utilities.torch.zero_volume_tensor(batch_size)
img = scatter(img)
mask = distdl.utilities.torch.zero_volume_tensor(batch_size)
mask = scatter(mask)

batches.append((img, mask))
timer.stop("data gen", input_features)

t_sample_spacing = [np.linspace(0, 1, f) for f in input_features]
sample_spacing = []
for d in range(len(input_features)):
sample_spacing.append(t_sample_spacing[d][my_start[d]:my_stop[d]])
sample_grid = np.meshgrid(*sample_spacing)

# Ellipses are created from random parameters, each rank will generate the
# same sequence of parameters, so each ellipse can be evaluated in parallel
np.random.seed(0)

n_ellipses_target = 3
n_ellipses_noise = 2

timer.start("data gen")
batches = list()
for i in range(n_batch):
batch = list()
for j in range(batch_size):
# Add an image-mask tuple to the batch
batch.append(gen_data(sample_grid, n_ellipses_target, n_ellipses_noise))
img = torch.cat([im for im, ma in batch],dim=0)
mask = torch.cat([ma for im, ma in batch],dim=0)
batches.append((img, mask))
timer.stop("data gen", input_features)

# Leave the demo where we generate the code in one rank and scatter it...for posterity
#
# scatter = distdl.nn.DistributedTranspose(P_root, P_unet)
#
# with torch.no_grad():
# if P_root.active:

# sample_spacing = [np.linspace(0, 1, f) for f in input_features]
# sample_grid = np.meshgrid(*sample_spacing)

# n_ellipses_target = 3
# n_ellipses_noise = 2

# timer.start("data gen")
# batches = list()
# for i in range(n_batch):
# batch = list()
# for j in range(batch_size):
# # Add an image-mask tuple to the batch
# batch.append(gen_data(sample_grid, n_ellipses_target, n_ellipses_noise))
# img = torch.cat([im for im, ma in batch],dim=0)
# img = scatter(img)
# mask = torch.cat([ma for im, ma in batch],dim=0)
# mask = scatter(mask)

# batches.append((img, mask))
# timer.stop("data gen", input_features)
# else:
# timer.start("data gen")
# batches = list()
# for i in range(n_batch):
# batch = list()
# for j in range(batch_size):
# img = distdl.utilities.torch.zero_volume_tensor(batch_size)
# img = scatter(img)
# mask = distdl.utilities.torch.zero_volume_tensor(batch_size)
# mask = scatter(mask)

# batches.append((img, mask))
# timer.stop("data gen", input_features)

MPI.COMM_WORLD.Barrier()

Expand Down

0 comments on commit bf0d78a

Please sign in to comment.