Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate cases in which subgroup size 16 is noticeable slower #1371

Closed
victor-eds opened this issue Jun 17, 2024 · 5 comments
Closed

Investigate cases in which subgroup size 16 is noticeable slower #1371

victor-eds opened this issue Jun 17, 2024 · 5 comments

Comments

@victor-eds
Copy link
Contributor

victor-eds commented Jun 17, 2024

After running huggingface benchmarks with subgroup size 16 (https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9499093865), we saw some cases in which subgroup size 16 reported worse performance:

  • huggingface amp_bf16 inference PLBartForCausalLM: 3.58x vs 2.36x speedups
  • huggingface bf16 inferece BartForCausalLM: 2.52x vs 1.19x speedups
  • huggingface f32 training T5Small: 0.65x vs 0.50x speedups

Investigate and create followup issues if needed or write report in this issue.

Env:

TIMESTAMP=20240613123916
JOB_NAME=huggingface-inference-amp_bf16
GITHUB_RUN_ID=9499093865
GITHUB_RUN_NUMBER=102
GITHUB_RUN_ATTEMPT=1
PYTHON_VERSION=3.10
PYTORCH_REPO=Stonepia/pytorch
PYTORCH_COMMIT_ID=22ce6c6508d1d13b263d4c8b1fd6b98505983e92
IPEX_REPO=
IPEX_COMMIT_ID=
LLVM_REPO=llvm/llvm-project
LLVM_COMMIT_ID=765206e050453018e861637a08a4520f29238074
BENCHMARK_REPO=weishi-deng/benchmark.git
BENCHMARK_COMMIT_ID=e90564f719a8df7daf0cff4a245404c435b7693a
TRITON_REPO=intel/intel-xpu-backend-for-triton
TRITON_COMMIT_ID=b25bf436956b8fce31bdd8dba68a9aa3959917e6
TORCHVISION_COMMIT_ID=
TORCHTEXT_COMMIT_ID=
TORCHAUDIO_COMMIT_ID=
TRANSFORMERS_VERSION=4.27.4
TIMM_COMMIT_ID=b9d43c7dcac1fe05e851dd7be7187b108af593d2
LIBIGC1_VERSION=1.0.16510.19-881
LEVEL_ZERO_VERSION=1.3.29138.29-881
GPU_DEVICE=Intel(R) Data Center GPU Max 1550
AGAMA_VERSION=881
@sommerlukas
Copy link
Contributor

So far, we haven't been able to reproduce this behavior on the current llvm-target branch. Will investigate further.

@sommerlukas
Copy link
Contributor

Work on this ticket had been blocked by #1647 for most of last week. We were able to obtain data from a CI run end of last week and will present a summary offline in the meeting.

@sommerlukas
Copy link
Contributor

As discussed in chat offline, varying performance between different CI runs (based on the same commit) are still a problem for this investigation.

We'll try to still identify a consistent outlier in the performance comparison with the different subgroups sizes and then investigate what causes that performance difference.

@sommerlukas
Copy link
Contributor

After comparing CI runs 4 and 5, three outlier models could be identified where sub-group size 16 consistently across both runs provided worse performance than sub-group size 32:

  • AllenaiLongformerBase training with float16 and amp_fp16
  • XLNetLMHeadModel training with float16 and amp_fp16
  • BlenderbotSmallForCausalLM inference with amp_fp16

For XLNetLMHeadModel, neither sub-group size provides a speedup over Pytorch execution, so the model was excluded from further investigation.

AllenaiLongformerBase

Comparing device timing from unitrace for both SG-sizes shows two Triton kernels that are among the GPU kernels that take up the most time.

For SG-size 32:

                                   Kernel,        Calls,            Time (ns),     Time (%),         Average (ns),             Min (ns),             Max (ns)
"triton_poi_fused_index_add_new_zeros_13",          179,            480605600,     8.665852,              2684947,                28800,              2943360
"triton_poi_fused_index_add_new_zeros_25",          178,            443019040,     7.988125,              2488871,                 6080,              2830400

For SG-size 16:

                                   Kernel,        Calls,            Time (ns),     Time (%),         Average (ns),             Min (ns),             Max (ns)
"triton_poi_fused_index_add_new_zeros_13",          178,            631492960,    10.590560,              3547713,                 2720,              4099520
"triton_poi_fused_index_add_new_zeros_25",          178,            629485440,    10.556891,              3536435,                 6080,              4015200

Whereas these two kernels each take up ~8% of the overall execution time of the model with SG-size 32, they take up 10.5% with SG-size 16.

The average execution time of the kernel also increases from 2.6ms to 3.5ms, a 1.3x slowdown.

The kernels are attached to this comment. They are both rather simple, but both use tl.atomic_add. @chengjunlu had confirmed in a chat offline that performance for atomic operations is a known issue, so that the performance difference potentially stems from this operation.

BlenderbotSmallForCausalLM

Comparing device timing from unitrace for both SG-sizes shows one Triton kernel that is among the GPU kernels that take up the most time.

For SG-size 32:

                                   Kernel,        Calls,            Time (ns),     Time (%),         Average (ns),             Min (ns),             Max (ns)
"triton_red_fused__log_softmax__to_copy_view_8",           15,             42188320,     6.827563,              2812554,              2708160,              2900960

For SG-size 16:

                                   Kernel,        Calls,            Time (ns),     Time (%),         Average (ns),             Min (ns),             Max (ns)
"triton_red_fused__log_softmax__to_copy_view_8",           15,             60684640,     9.484119,              4045642,              3979680,              4196800

Whereas the kernel takes up less than 7% of the overall execution time of the model with SG-size 32, it takes up 9.5% with SG-size 16.

The average execution time of the kernel also increases from 2.8ms to 4ms, a 1.43x slowdown.

The kernel is attached to this comment and is also rather simple. It however uses a reduction. From previous investigations by @victor-eds, it is known that the current pattern used for reductions in the XPU backend is less efficient for SG-size 16, as it generates significantly more assembly instructions, so the performance difference most likely stems from this known issue.

add_new_zeros_25.txt
add_new_zeros_13.txt
softmax_to_copy_view_8.txt

@sommerlukas
Copy link
Contributor

sommerlukas commented Aug 13, 2024

I filed #1867 and #1868 as follow-up to investigate the root cause of the two outliers. That investigation currently has lower priority.

whitneywhtsang pushed a commit that referenced this issue Aug 15, 2024
Add a script to easily compare two runs of the "E2E performance" CI
workflow.

The script compares the speedup over Pytorch eager yielded by the two
different CI runs, prints an evaluation and is also able to visualize
the data as a boxplot.

For more details on the usage of the script, see the accompanying
README.

This script was written and used for #1370 and #1371.

Closes #1848.

---------

Signed-off-by: Lukas Sommer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants