Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the performace of GroupedLinear #1499

Open
XLzed opened this issue Feb 20, 2025 · 1 comment
Open

Question about the performace of GroupedLinear #1499

XLzed opened this issue Feb 20, 2025 · 1 comment

Comments

@XLzed
Copy link

XLzed commented Feb 20, 2025

Issue

When testing the performance of DeepSeek-v2-lite using Megatron and TransformerEngine, I encountered an issue where GroupedLinear exhibits unusually high duration. The TEGroupedLinear forward operation typically takes about 1ms as observed in the nsys timeline, but there are anomalous events that exceed 200ms. What could be causing this issue?

environment

  • megatron-core r0.10.0
  • transformerEngine 1.13.0+e5edd6c
  • image: nvcr.io/nvidia/pytorch:24.07-py3 + cudnn-9.5.1
  • 8xH800

Duration of GroupedLinear event

I cannot provide the timeline for some reason. The table following provides the duration of abnormal events and normal events which were extracted from nsys timeline. Why is there such a large difference in duration between nvte_multi_stream_cublas_gemm and TERowParallelGroupedLinear? and why does the start time of the abnormal event nvte_multi_stream_cublas_gemm lag behind the start time of TERowParallelGroupedLinear by about 200ms?

And if I directly use TEGroupedLinear and input tensors of the same shape for microBenchmark, the time consumption returns to normal, is the training workflow affecting the execution efficiency of the kernel?

<style> </style>
Name Start Duration TID
#TEGroupLinear forward 3.67097s 206.707 ms 179103
##TERowParallelGroupedLinear forward 3.67099s 206.660 ms 179103
###nvte_multi_stream_cublas_gemm 3.87704s 387.909 μs 179103
#TEGroupLinear forward 1.4103s 3.373 ms 179103
##TERowParallelGroupedLinear forward 1.41032s 3.327 ms 179103
###nvte_multi_stream_cublas_gemm 1.41077s 1.008 ms 179103
#TEGroupLinear forward 2.58523s 3.103 ms 179103
##TERowParallelGroupedLinear forward 2.58525s 3.055 ms 179103
###nvte_multi_stream_cublas_gemm 2.58579s 1.128 ms 179103

Optimization of hopper?

And is there a plan to optimize GroupedLinear for Hopper architecture? Based on the parameters of DeepSeek-v2, the tflops of H800 compared to A800 did not improve significantly, and overall performance is quite poor. The test results and code are as follows:

# H800
Average execution time: 0.0011188620805740357 s, tflops: 253.35369430928066
Average execution time: 0.001063387517929077 s, tflops: 133.2852966376957

# A800
Average execution time: 0.0018983731222152712 s, tflops: 149.32145752527958
Average execution time: 0.0013353574371337891 s, tflops: 106.13931283613297
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.parallel_state import set_tensor_model_parallel_world_size
import torch
from typing import Callable

def benchmark(benchmark_func: Callable, warmup_times = 10, benchmark_times = 50):
    warmup_elapsed_time_list = []
    benchmark_elapsed_time_list = []
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    for _ in range(warmup_times):
        start_event.record()
        benchmark_func()
        end_event.record()
        torch.cuda.synchronize()
        warmup_elapsed_time_list.append(start_event.elapsed_time(end_event))

    for _ in range(benchmark_times):
        start_event.record()
        benchmark_func()        
        end_event.record()
        torch.cuda.synchronize()
        benchmark_elapsed_time_list.append(start_event.elapsed_time(end_event))
    
    # print(f"warmup_elapsed_time: {warmup_elapsed_time_list}, benchmark_elapsed_time_list: {benchmark_elapsed_time_list}")
    avg_secs = sum(benchmark_elapsed_time_list) / benchmark_times / 1e3
    return avg_secs

def test_dsv2_lite_ep1():
    num_local_experts = 64
    hidden_size = 2048
    moe_ffn_hidden_size = 1408
    topk = 6
    seqlen = 4096
    # m_splits = [8, 83, 945, 0, 162, 4, 3, 0, 251, 510, 24, 140, 37, 0, 33, 10, 1, 0, 5, 0, 115, 0, 1, 1, 1, 188, 43, 1, 7, 0, 12, 0, 324, 5, 88, 0, 0, 58, 558, 219, 1296, 1155, 2, 1102, 6, 0, 115, 0, 106, 0, 10, 0, 698, 2, 594, 221, 351, 0, 1, 2, 2040, 0, 7, 743]
    num_even_tokens = seqlen * topk // num_local_experts
    m_splits = [num_even_tokens for _ in range(num_local_experts)]

    config = TransformerConfig(num_attention_heads=1, num_layers=1)
    config.params_dtype = torch.bfloat16
    config.use_cpu_initialization = False
    config.add_bias_linear = False
    config.gradient_accumulation_fusion = True
    
    # hack parallel states
    set_tensor_model_parallel_world_size(1)
    
    # column up linear
    linear_fc1 = TEColumnParallelGroupedLinear(
        num_gemms=num_local_experts,
        input_size=hidden_size,
        output_size=moe_ffn_hidden_size*2,
        config=config,
        init_method=config.init_method,
        bias=config.add_bias_linear,
        skip_bias_add=True,
        is_expert=True,
        tp_comm_buffer_name='fc1',
    )

    linear_fc2 = TEColumnParallelGroupedLinear(
        num_gemms=num_local_experts,
        input_size=moe_ffn_hidden_size,
        output_size=hidden_size,
        config=config,
        init_method=config.init_method,
        bias=config.add_bias_linear,
        skip_bias_add=True,
        is_expert=True,
        tp_comm_buffer_name='fc1',
    )

    up_inputs = torch.randn((topk*seqlen, hidden_size), dtype=torch.bfloat16, device='cuda')
    down_inputs = torch.randn((topk*seqlen, moe_ffn_hidden_size), dtype=torch.bfloat16, device='cuda')

    def up_linear():
        linear_fc1(up_inputs, m_splits)

    def down_linear():
        linear_fc2(down_inputs, m_splits)

    avg_secs = benchmark(up_linear)
    tflops = 2 * seqlen * topk * hidden_size * 2 * moe_ffn_hidden_size / avg_secs / 1e12
    print(f"Average execution time: {avg_secs} s, tflops: {tflops}")

    avg_secs = benchmark(down_linear)
    tflops = 2 * seqlen * topk * moe_ffn_hidden_size * hidden_size / avg_secs / 1e12
    print(f"Average execution time: {avg_secs} s, tflops: {tflops}")

test_dsv2_lite_ep1()
@yaox12
Copy link
Collaborator

yaox12 commented Feb 24, 2025

Can you summary your questions into the following two?

  1. Q: Why is there such a large difference in duration between nvte_multi_stream_cublas_gemm and TERowParallelGroupedLinear?
    A: It's the CPU overheads of PyTorch ops, such as torch.split(), and torch.empty() (2xnum_gemms calls) under fused_multi_cast_transpose. You can capture them in Nsys by adding the context with torch.autograd.profiler.emit_nvtx(enabled=True) to your code during profiling. It's not trivial to eliminate these overheads.
  2. Q: These are some abnormal iterations that consume much more time than usual, while in micro benchmark, there is no problem.
    A: I have no idea of this issue. Maybe you can enable nvtx for torch ops using the context mentioned above and see what it is actually doing there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants