Skip to content

Commit

Permalink
kv_dq zero initialization to avoid NaNs from FA3 (#3632)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#708

Pull Request resolved: #3632

Running evals with FP8 KV gives NaNs due to issues in FA3.  For more context: D68708685

To reproduce:
> sh ai_codesign/gen_ai/disagg_generator_launcher/start_server_moe.sh -m 17b_text_sft -a " --ffn_quantize_mode=fp8_rowwise --attn_quantize_mode=fp8_rowwise --kv_cache_quantization=8 "

Mitigating these issues, change dequantize_fp8_cache initialization of output buffers from at::empty to at::zeros

Reviewed By: jasonjk-park

Differential Revision: D68574038

fbshipit-source-id: 3f3f5573d13f1b4046e6880363533eb1c2dfa268
  • Loading branch information
ayaIbrah authored and facebook-github-bot committed Jan 31, 2025
1 parent 4965f35 commit 3266957
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,11 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
// correct block_tables. (2) From outside, keep a persistent buffer that has a
// matching shape with the original paged KV and feed the same buffer
// into this function at every layer to reuse it and prevent allocation.
auto cache_K_dq = at::empty(
// FIXME: T213958042
auto cache_K_dq = at::zeros(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::empty(
auto cache_V_dq = at::zeros(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
if (B == 0) {
Expand Down

0 comments on commit 3266957

Please sign in to comment.