Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vectorized local reduction for p2p-based ReduceScatter overlap #1452

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

erhoo82
Copy link
Collaborator

@erhoo82 erhoo82 commented Feb 4, 2025

Description

Vectorized load/store for p2p-based ReduceScatter overlap.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Vectorize the loads/stores of local reduction after TP ReduceScatter for (1) FP8 comm > FP32 reduce > BF16 out, (2) BF16 comm > FP32 reduce > BF16 out

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@erhoo82 erhoo82 requested a review from ksivaman February 5, 2025 16:48
@timmoon10 timmoon10 changed the base branch from release_v2.0 to main February 7, 2025 23:56
Comment on lines 2621 to 2631
#pragma unroll
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale);
if (input_id == num_inputs - 1) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Correctness bug when num_inputs == 1
  • Unrolling loop over num_inputs is not necessary since it's not known at compile-time
Suggested change
#pragma unroll
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale);
if (input_id == num_inputs - 1) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}
}
}
for (int input_id = 1; input_id < num_inputs; ++input_id) {
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale);
}
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]);
}

Same issue in reduce_bf16_cuda.

}

template <typename fp8type>
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream) {
const int nvec = 32;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using this as a template arg, better to make it explicit that it is known at compile-time:

Suggested change
const int nvec = 32;
constexpr int nvec = 32;

Same issue in reduce_bf16.

transformer_engine::VectorizedLoader<fp8type, nvec, true> loader(inputs_fp8, tot_input_size);
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size);

const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where the block size doesn't neatly divide the input size? Maybe we can fix with something like:

Suggested change
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid >= num_aligned_elements_per_input) {
return;
}

Alternatively, we can change how we configure the CUDA blocks:

  size_t num_threads = MAX_THREADS / 4;
  assert(num_aligned_elements_per_input % num_threads == 0);
  size_t num_blocks = num_aligned_elements_per_input / num_threads;
  dim3 block(num_threads);
  dim3 grid(num_blocks);

Same issue in reduce_bf16_cuda.

Signed-off-by: Sangkug Lym <[email protected]>
@erhoo82
Copy link
Collaborator Author

erhoo82 commented Feb 24, 2025

@timmoon10
Can you review this one again?

BTW, regarding the lint error, nvec should be compile-time constant?

@timmoon10 timmoon10 self-requested a review February 26, 2025 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants