Skip to content

Commit

Permalink
Add parentheses to packed_bag_idx formula
Browse files Browse the repository at this point in the history
  • Loading branch information
avbokovoy committed Jan 14, 2025
1 parent a9af333 commit 1d6c81e
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4));

for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
const uint32_t packed_bag_idx = num_packed_bags > 1 ? threadIdx.x % NumUint4LoadsPerRow / uint4_loads_per_row : 0;
const uint32_t packed_bag_idx = num_packed_bags > 1 ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0;
uint32_t b = min(static_cast<uint32_t>(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_idx), static_cast<uint32_t>(B - 1));
int32_t indices_start = offsets[t * B + b];
int32_t indices_end = offsets[t * B + b + 1];
Expand Down Expand Up @@ -244,7 +244,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
continue;
}
const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]);
const int32_t packed_bag_idx = threadIdx.x / uints_per_row % num_packed_bags;
const int32_t packed_bag_idx = (threadIdx.x / uints_per_row) % num_packed_bags;
// scale and bias are at the beginning of each row.
// rationale: have scale/shift at start since these get loaded first
// and then broadcasted around so it might speed up the first cache miss.
Expand Down

0 comments on commit 1d6c81e

Please sign in to comment.