Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 22, 2025
1 parent 0341de7 commit c699139
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 450 deletions.
152 changes: 98 additions & 54 deletions tests/pytorch/fused_attn/test_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@

model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16),
#"infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6),
"infer_0": ModelConfig(
4, 16, 16, 64, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
# "infer_1": ModelConfig(2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6),
}

qkv_formats = ["bshd", "sbhd", "thd"]
Expand All @@ -49,9 +51,11 @@
def to_pretty_string(x: torch.Tensor):
return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]"


def round_up(a: int, b: int):
return b * math.ceil(a / b)


class Simulation:
def __init__(
self,
Expand All @@ -71,32 +75,32 @@ def __init__(
self.max_gen_len = max_seq_len - self.max_ctx_len

# simulate sequence ids in monotonically increasing fashion
self.seq_ids = torch.range(0, total_requests-1, dtype=torch.int32, device="cpu")
self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")

# simulate context lengths in Uniform distribution
self.context_lens = torch.randint(
1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu"
)
#self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu")
# self.context_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu")

# simulate gen lengths in Exponential distribution
gen_dist = Exponential(1 / self.max_gen_len)
gen_lens = gen_dist.sample((total_requests,))
gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to(
dtype=torch.int32, device="cpu"
)
self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(
dtype=torch.int32, device="cpu"
)
#self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu")
self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu")
# self.gen_lens = 4 * torch.ones(total_requests, dtype=torch.int32, device="cpu")

# simulate arrival times in Poisson distribution
if poisson_rate is None:
self.poisson_rate = torch.randint(1, max_batch_size, [1]).item()
interval_dist = Exponential(self.poisson_rate)
arrival_intervals = interval_dist.sample((total_requests,))
self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(dtype=torch.int32, device="cpu")
#self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu")
self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(
dtype=torch.int32, device="cpu"
)
# self.arrival_times = torch.zeros(total_requests, dtype=torch.int32, device="cpu")
self.last_arrival = self.arrival_times.max().item()

# initialize tensors
Expand Down Expand Up @@ -144,10 +148,10 @@ def print_step(self, logger):
def print_summary(self, logger):
logger.info("Summary:")
logger.info(" {:<18s}: {}".format("total steps taken", self.t))
logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times)))
logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times)))
logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens)))
logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times)))
logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times)))
logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times)))
logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens)))
logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times)))

def add_new_seqs(self, new_seq_ids):
# get ctx_lens for new seqs
Expand Down Expand Up @@ -194,11 +198,11 @@ def step(self, dynamic_fill: bool = True):
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens


@pytest.mark.parametrize("dtype", [torch.float16])#param_types)
@pytest.mark.parametrize("dtype", [torch.float16]) # param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats)
@pytest.mark.parametrize("qkv_format", ["thd"]) # qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
reset_rng_states()
Expand Down Expand Up @@ -253,7 +257,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
# generate data for all requests
assert (
config.max_seqlen_q == config.max_seqlen_kv
), "This test only simulates max_seqlen_q = max_seqlen_kv."
), "This test only simulates max_seqlen_q = max_seqlen_kv."
q = 0.1 * torch.randn(
(config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk),
dtype=dtype,
Expand Down Expand Up @@ -297,7 +301,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
max_ctx_len=config.max_ctx_len,
max_batch_size=max_batch_size,
poisson_rate=2,
)
)
sim.print_setup(logger)

# initialize inference_params
Expand All @@ -322,41 +326,45 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
if is_cuda_graph:
t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu")
step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu")
step_dict = OrderedDict(
zip(t_seq_ids.tolist(), step_lens.tolist())
)
step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist()))
inference_params.pre_step(step_dict)

if qkv_format == "bshd":
shape = [ config.batch_size, config.max_ctx_len]
shape = [config.batch_size, config.max_ctx_len]
if qkv_format == "sbhd":
shape = [ config.max_ctx_len, config.batch_size]
shape = [config.max_ctx_len, config.batch_size]
if qkv_format == "thd":
shape = [ config.batch_size * config.max_ctx_len]
shape = [config.batch_size * config.max_ctx_len]

def gen_data():
return [torch.ones(
*shape,
config.num_heads,
config.head_dim_qk,
device="cuda",
dtype=dtype,
) for _ in range(3)]
return [
torch.ones(
*shape,
config.num_heads,
config.head_dim_qk,
device="cuda",
dtype=dtype,
)
for _ in range(3)
]

sample_kwargs = {}
sample_kwargs["cu_seqlens_q"] = torch.linspace( 0,
sample_kwargs["cu_seqlens_q"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size+1,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["cu_seqlens_kv"] = torch.linspace( 0,
sample_kwargs["cu_seqlens_kv"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size+1,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["inference_params"] = inference_params
sample_kwargs["attn_mask_type"] = "padding" #_causal"
sample_kwargs["attn_mask_type"] = "padding" # _causal"
sample_kwargs["max_seqlen_q"] = config.max_ctx_len
sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv
sample_kwargs["qkv_format"] = qkv_format
Expand Down Expand Up @@ -386,7 +394,7 @@ def gen_data():
max_tokens = config.batch_size * config.max_ctx_len
while True:
# prepare batch for the current step
dynamic_fill = True #inference_params.is_paged
dynamic_fill = True # inference_params.is_paged
sim.step(dynamic_fill=dynamic_fill)
sim.print_step(logger)

Expand Down Expand Up @@ -427,9 +435,47 @@ def gen_data():
dim=0,
)
if is_cuda_graph:
incremental_q = torch.cat([incremental_q, torch.zeros([max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk], dtype=dtype, device=incremental_q.device)], dim=0)
incremental_k = torch.cat([incremental_k, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_k.device)], dim=0)
incremental_v = torch.cat([incremental_v, torch.zeros([max_tokens - sum(sim.step_lens), config.num_gqa_groups, config.head_dim_v], dtype=dtype, device=incremental_v.device)], dim=0)
incremental_q = torch.cat(
[
incremental_q,
torch.zeros(
[max_tokens - sum(sim.step_lens), config.num_heads, config.head_dim_qk],
dtype=dtype,
device=incremental_q.device,
),
],
dim=0,
)
incremental_k = torch.cat(
[
incremental_k,
torch.zeros(
[
max_tokens - sum(sim.step_lens),
config.num_gqa_groups,
config.head_dim_v,
],
dtype=dtype,
device=incremental_k.device,
),
],
dim=0,
)
incremental_v = torch.cat(
[
incremental_v,
torch.zeros(
[
max_tokens - sum(sim.step_lens),
config.num_gqa_groups,
config.head_dim_v,
],
dtype=dtype,
device=incremental_v.device,
),
],
dim=0,
)
else:
incremental_q = torch.zeros(
batch_size,
Expand Down Expand Up @@ -472,9 +518,7 @@ def gen_data():
cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0)
cu_seqlens_kv = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv[1 : sim.t_batch_size + 1] = torch.cumsum(sim.t_total_lens, dim=0)
step_dict = OrderedDict(
zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist())
)
step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()))
inference_params.pre_step(step_dict)
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
Expand All @@ -485,7 +529,7 @@ def gen_data():
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
inference_params=inference_params,
attn_mask_type="padding", #_causal",
attn_mask_type="padding", # _causal",
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
qkv_format=qkv_format,
Expand All @@ -508,29 +552,29 @@ def gen_data():
token_index = -1 if inference_params.is_output_right_aligned else sim.step_lens[i] - 1
if qkv_format == "bshd":
torch.testing.assert_close(
#full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
#line_output[:sim.step_lens[i] - 1, i, :],
# full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
# line_output[:sim.step_lens[i] - 1, i, :],
full_output[seq, sim.t_total_lens[i] - 1, :],
line_output[i, token_index, :],
atol=tols[dtype],
rtol=tols[dtype],
)
if qkv_format == "sbhd":
torch.testing.assert_close(
#full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
#line_output[:sim.step_lens[i] - 1, i, :],
# full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
# line_output[:sim.step_lens[i] - 1, i, :],
full_output[seq, sim.t_total_lens[i] - 1, :],
line_output[token_index, i, :],
atol=tols[dtype],
rtol=tols[dtype],
)
if qkv_format == "thd":
#print('i ', i, seq, cu_seqlens_q)
#print(full_output[seq, sim.t_total_lens[i] - 1, :4])
#print(line_output[cu_seqlens_q[i + 1] - 1, :4])
# print('i ', i, seq, cu_seqlens_q)
# print(full_output[seq, sim.t_total_lens[i] - 1, :4])
# print(line_output[cu_seqlens_q[i + 1] - 1, :4])
torch.testing.assert_close(
#full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
#line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :],
# full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
# line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :],
full_output[seq, sim.t_total_lens[i] - 1, :],
line_output[cu_seqlens_q[i + 1] - 1, :],
atol=tols[dtype],
Expand Down
33 changes: 17 additions & 16 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
cudnn_runtime_version >= 90700)) &&
// sliding window
// pre-9.2: full attn, causal
Expand Down Expand Up @@ -538,16 +540,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
}
// NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
Expand Down Expand Up @@ -637,10 +638,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream,
handle);
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
Expand Down
Loading

0 comments on commit c699139

Please sign in to comment.