Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Coerce paged attention args' dtype to avoid mismatch #994

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

sogartar
Copy link
Contributor

With the introduction of the KV cache dtype config option we may encounter configurations that would mix dtypes of attention's op arguments. For example if the KV cache is stored in lower precision.

With this change all attention args have their dtype converteted.

With the introduction of the KV cache dtype config option we may
encounter configurations that would mix dtypes of attention's op
arguments. For example if the KV cache is stored in lower precision.

With this change all attention args have their dtype converteted.
@sogartar
Copy link
Contributor Author

Fixes #896 (comment).

@AmosLewis
Copy link

(.venv) ➜  32 python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa  \
--output-mlir=/sharedfile/32/fp8_32.mlir \
--output-config=/sharedfile/32/config_32.json \
--bs=1 --attention-kernel torch \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 \
--use-hf
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.from_numpy(wrapper)
Exporting prefill_bs1
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py:520: UserWarning: Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3691.)
  return func(*args, **kwargs)
Exporting decode_bs1
GENERATED!
Exporting
Saving to '/sharedfile/32/fp8_32.mlir'
(.venv) ➜  32 iree-compile /sharedfile/32/fp8_32.mlir \
  --iree-hip-target=gfx942 \
  -o=/sharedfile/32/fp8_32.vmfb \
  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions
LLVM ERROR: unhandled type for getConstantWithGivenDtypeAndValue
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libIREECompiler.so 0x000077899f9bb488 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 40
1  libIREECompiler.so 0x000077899f9b922e llvm::sys::RunSignalHandlers() + 238
LLVM ERROR: unhandled type for getConstantWithGivenDtypeAndValue

@sogartar
Copy link
Contributor Author

This fixes the mismatch of dtypes.
We still need to specify the correct attention dtype.
Now the attention-dtype is actually what it is meant to be. Before it was driving the KV cache dtype and the dtype when creating an attention mask. You are essentially trying now to do attention in f8, which before we actually never did that, but it was done in bf16, even though the you had --attention-dtype=float8_e4m3fnuz.

Try

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa  \
  --output-mlir=/sharedfile/32/fp8_32.mlir \
  --output-config=/sharedfile/32/config_32.json \
  --bs=1 --attention-kernel torch \
  --attention-dtype=bfloat16 \
  --kv-cache-dtype=float8_e4m3fnuz \
  --activation-dtype=bfloat16 \
  --use-hf

@sogartar
Copy link
Contributor Author

Or better yet

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa  \
  --output-mlir=/sharedfile/32/fp8_32.mlir \
  --output-config=/sharedfile/32/config_32.json \
  --bs=1 --attention-kernel torch \
  --attention-dtype=bfloat16 \
  --kv-cache-dtype=bfloat16 \
  --activation-dtype=bfloat16 \
  --use-hf

If you don't have the change iree-org/iree#20005

@AmosLewis
Copy link

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa  \
  --output-mlir=/sharedfile/32/fp8_32.mlir \
  --output-config=/sharedfile/32/config_32.json \
  --bs=1 --attention-kernel torch \
  --attention-dtype=bfloat16 \
  --kv-cache-dtype=float8_e4m3fnuz \
  --activation-dtype=bfloat16 \
  --use-hf

This flag work.

@AmosLewis
Copy link

Would it have an impact on the flag we use in #907? I guess we should also set --attention-dtype=bfloat16 --kv-cache-dtype=float8_e4m3fnuz in it.

@AmosLewis
Copy link

AmosLewis commented Feb 24, 2025

  --attention-dtype=bfloat16 \
  --kv-cache-dtype=bfloat16 \

This export flag not work when iree compile. https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/fp8_32_kv16.mlir
llama_fp8_kv16_compile_bug.txt

iree-compile /sharedfile/32/fp8_32.mlir \
  --iree-hip-target=gfx942 \
  -o=/sharedfile/32/fp8_32.vmfb \
  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions

/sharedfile/32/fp8_32_kv16.mlir:9019:13: error: 'func.func' op failed on workgroup distribution verification
    %3349 = torch.aten.index_put %3347, %3348, %3343, %false_2832 : !torch.vtensor<[?,32,8,128],bf16>, !torch.list<optional<vtensor>>, !torch.vtensor<[?,32,8,128],bf16>, !torch.bool -> !torch.vtensor<[?,32,8,128],bf16>
            ^
/sharedfile/32/fp8_32_kv16.mlir:9019:13: note: see current operation:

/sharedfile/32/fp8_32_kv16.mlir:28291:14: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %10763 = torch.aten.index_put %10761, %10762, %10757, %false_11018 : !torch.vtensor<[?,32,8,128],bf16>, !torch.list<optional<vtensor>>, !torch.vtensor<[?,32,8,128],bf16>, !torch.bool -> !torch.vtensor<[?,32,8,128],bf16>

@sogartar
Copy link
Contributor Author

This change in particular should not introduce a problem. But #907 may have merge conflicts with #896.

@sogartar
Copy link
Contributor Author

sogartar commented Feb 25, 2025

The failure is unrelated to this PR as it fails also here.

@sogartar sogartar merged commit 7889447 into nod-ai:main Feb 25, 2025
35 of 36 checks passed
@AmosLewis
Copy link

AmosLewis commented Feb 25, 2025

This fixes the mismatch of dtypes. We still need to specify the correct attention dtype. Now the attention-dtype is actually what it is meant to be. Before it was driving the KV cache dtype and the dtype when creating an attention mask. You are essentially trying now to do attention in f8, which before we actually never did that, but it was done in bf16, even though the you had --attention-dtype=float8_e4m3fnuz.

Try

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa  \
  --output-mlir=/sharedfile/32/fp8_32.mlir \
  --output-config=/sharedfile/32/config_32.json \
  --bs=1 --attention-kernel torch \
  --attention-dtype=bfloat16 \
  --kv-cache-dtype=float8_e4m3fnuz \
  --activation-dtype=bfloat16 \
  --use-hf

I feel so confused when you say You are essentially trying now to do attention in f8 but you suggest to set --attention-dtype=bfloat16 . The attention in f8 I assume you mean kv cache in f8? Then why do we need the attention-dtype?
Some much ambiguity in attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants