Skip to content

Conversation

shuds13
Copy link
Member

@shuds13 shuds13 commented Jan 27, 2025

For comparison with libE.

Works in parallel on Perlmutter and uses GPUs on one node, but does not work across nodes.

Main loop

    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()

            # Synchronize gradients across processes (libE would do this)
            for param in model.parameters():
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= world_size

            optimizer.step()

@shuds13 shuds13 self-assigned this Jan 27, 2025
torch.tensor(calc_in["grads_in"][offset:offset + grad_size].reshape(grad_shape),
device=param.grad.device)
)
offset += grad_size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If need to average above, then need to do it here also.

param.grad /= group_size # SH averaging?

# SH TODO - consider inheriting dist (torch.distributed)
# and overriding dist.all_reduce() to do libE stuff
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.distributed is a module not a class, but could extract these lines to some provided function.

loss = nn.CrossEntropyLoss()(output, target)
loss.backward()

# Synchronize gradients across processes
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#if do optimizer step on gen gen does

optimizer.zero_grad()
gradients.sum()
optimizer.step()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant