| 
1 | 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.  | 
2 | 2 | #  | 
3 | 3 | # See LICENSE for license information.  | 
4 |  | -from typing import Callable, Tuple, Union  | 
 | 4 | +from typing import Callable, Tuple, Union, List  | 
5 | 5 | import math  | 
6 | 6 | import torch  | 
7 | 7 | import pytest  | 
8 | 8 | from transformer_engine.pytorch.attention.rope import (  | 
9 | 9 |     RotaryPositionEmbedding,  | 
10 | 10 |     apply_rotary_pos_emb,  | 
 | 11 | +    apply_fused_qkv_rotary_pos_emb,  | 
11 | 12 | )  | 
12 | 13 | 
 
  | 
13 | 14 | 
 
  | 
14 | 15 | # Gradient is a broadcasted scalar  | 
15 |  | -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:  | 
16 |  | -    return output.sum() * 2  | 
 | 16 | +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:  | 
 | 17 | +    if isinstance(output, List):  | 
 | 18 | +        return sum(t.sum() * 2 for t in output)  | 
 | 19 | +    else:  | 
 | 20 | +        return output.sum() * 2  | 
17 | 21 | 
 
  | 
18 | 22 | 
 
  | 
19 | 23 | # Gradient is a full tensor  | 
20 |  | -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:  | 
21 |  | -    t = torch.ones_like(output)  | 
22 |  | -    return torch.sum(output * t)  | 
 | 24 | +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:  | 
 | 25 | +    if isinstance(output, List):  | 
 | 26 | +        return sum(torch.sum(t * torch.ones_like(t)) for t in output)  | 
 | 27 | +    else:  | 
 | 28 | +        t = torch.ones_like(output)  | 
 | 29 | +        return torch.sum(output * t)  | 
23 | 30 | 
 
  | 
24 | 31 | 
 
  | 
25 | 32 | @pytest.mark.parametrize("start_positions", [True, False])  | 
@@ -238,3 +245,131 @@ def test_fused_rope_thd(  | 
238 | 245 |             torch.testing.assert_close(grad_fused, grad_unfused)  | 
239 | 246 | 
 
  | 
240 | 247 |         assert output_fused.is_contiguous()  | 
 | 248 | + | 
 | 249 | + | 
 | 250 | +@pytest.mark.parametrize("start_positions", [True, False])  | 
 | 251 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])  | 
 | 252 | +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])  | 
 | 253 | +@pytest.mark.parametrize("hidden_size", [64, 128, 256])  | 
 | 254 | +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])  | 
 | 255 | +@pytest.mark.parametrize("margin", [0, 10])  | 
 | 256 | +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])  | 
 | 257 | +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])  | 
 | 258 | +@pytest.mark.parametrize("cp_size", [1, 2])  | 
 | 259 | +@pytest.mark.parametrize("interleaved", [True, False])  | 
 | 260 | +def test_fused_qkv_rope(  | 
 | 261 | +    dtype: torch.dtype,  | 
 | 262 | +    seq_length: int,  | 
 | 263 | +    hidden_size: int,  | 
 | 264 | +    rotary_percent: float,  | 
 | 265 | +    margin: int,  | 
 | 266 | +    tensor_format: str,  | 
 | 267 | +    loss_func: Callable,  | 
 | 268 | +    cp_size: int,  | 
 | 269 | +    interleaved: bool,  | 
 | 270 | +    start_positions: bool,  | 
 | 271 | +) -> None:  | 
 | 272 | +    if margin == 0 and start_positions == True:  | 
 | 273 | +        # This makes sure that the `start_positions` offsets being applied  | 
 | 274 | +        # are with the maximum length of the rope embeddings.  | 
 | 275 | +        pytest.skip("Skipping test with margin=0 and start_positions=True")  | 
 | 276 | + | 
 | 277 | +    if start_positions == True and cp_size > 1:  | 
 | 278 | +        # `start_positions` is only supported for `cp_size=1` and inference.  | 
 | 279 | +        pytest.skip("Skipping test with cp_size>1 and start_positions=True")  | 
 | 280 | + | 
 | 281 | +    if seq_length - margin < 0:  | 
 | 282 | +        pytest.skip("Skipping test with seq_length - margin < 0")  | 
 | 283 | + | 
 | 284 | +    device = torch.device("cuda:0")  | 
 | 285 | +    batch_size, head_num = 2, 64  | 
 | 286 | + | 
 | 287 | +    t = torch.rand(  | 
 | 288 | +        (seq_length - margin, batch_size, head_num, hidden_size * 6),  | 
 | 289 | +        dtype=dtype,  | 
 | 290 | +        device=device,  | 
 | 291 | +    )  | 
 | 292 | + | 
 | 293 | +    # Get arbitrary offsets to be used with RoPE for all the sequences  | 
 | 294 | +    start_positions = (  | 
 | 295 | +        torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)  | 
 | 296 | +        if start_positions  | 
 | 297 | +        else None  | 
 | 298 | +    )  | 
 | 299 | + | 
 | 300 | +    if tensor_format == "bshd":  | 
 | 301 | +        t = t.transpose(0, 1).contiguous()  | 
 | 302 | +    t.requires_grad = True  | 
 | 303 | + | 
 | 304 | +    rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)  | 
 | 305 | +    emb_q = rotary_pos_emb_q(seq_length * cp_size)  | 
 | 306 | +    rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)  | 
 | 307 | +    emb_k = rotary_pos_emb_k(seq_length * cp_size)  | 
 | 308 | + | 
 | 309 | +    for cp_rank in range(cp_size):  | 
 | 310 | +        # unfused  | 
 | 311 | +        # The fused kernel computes in float32 internally, so we force the unfused func to use float32  | 
 | 312 | +        # for more accurate comparison  | 
 | 313 | + | 
 | 314 | +        t_clone = t.clone()  | 
 | 315 | +        (query, key, value) = torch.split(  | 
 | 316 | +            t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3  | 
 | 317 | +        )  | 
 | 318 | +        query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)  | 
 | 319 | + | 
 | 320 | +        query_unfused = apply_rotary_pos_emb(  | 
 | 321 | +            query,  | 
 | 322 | +            emb_q,  | 
 | 323 | +            tensor_format=tensor_format,  | 
 | 324 | +            start_positions=start_positions,  | 
 | 325 | +            interleaved=interleaved,  | 
 | 326 | +            fused=True,  | 
 | 327 | +            cp_size=cp_size,  | 
 | 328 | +            cp_rank=cp_rank,  | 
 | 329 | +        ).to(dtype)  | 
 | 330 | + | 
 | 331 | +        key_unfused = apply_rotary_pos_emb(  | 
 | 332 | +            key,  | 
 | 333 | +            emb_k,  | 
 | 334 | +            tensor_format=tensor_format,  | 
 | 335 | +            start_positions=start_positions,  | 
 | 336 | +            interleaved=interleaved,  | 
 | 337 | +            fused=True,  | 
 | 338 | +            cp_size=cp_size,  | 
 | 339 | +            cp_rank=cp_rank,  | 
 | 340 | +        ).to(dtype)  | 
 | 341 | + | 
 | 342 | +        value_unfused = value  | 
 | 343 | +        loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])  | 
 | 344 | + | 
 | 345 | +        if not isinstance(start_positions, torch.Tensor):  | 
 | 346 | +            loss_unfused.backward()  | 
 | 347 | +            grad_unfused = t.grad.detach().clone()  | 
 | 348 | + | 
 | 349 | +        t.grad = None  | 
 | 350 | + | 
 | 351 | +        # fused  | 
 | 352 | +        query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(  | 
 | 353 | +            t,  | 
 | 354 | +            emb_q,  | 
 | 355 | +            emb_k,  | 
 | 356 | +            tensor_format=tensor_format,  | 
 | 357 | +            start_positions=start_positions,  | 
 | 358 | +            interleaved=interleaved,  | 
 | 359 | +            cp_size=cp_size,  | 
 | 360 | +            cp_rank=cp_rank,  | 
 | 361 | +            qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],  | 
 | 362 | +        )  | 
 | 363 | +        loss_fused = loss_func([query_fused, key_fused, value_fused])  | 
 | 364 | + | 
 | 365 | +        if not isinstance(start_positions, torch.Tensor):  | 
 | 366 | +            loss_fused.backward()  | 
 | 367 | +            grad_fused = t.grad.detach().clone()  | 
 | 368 | +        t.grad = None  | 
 | 369 | + | 
 | 370 | +        torch.testing.assert_close(query_fused, query_unfused)  | 
 | 371 | +        torch.testing.assert_close(key_fused, key_unfused)  | 
 | 372 | +        torch.testing.assert_close(value_fused, value_unfused)  | 
 | 373 | + | 
 | 374 | +        if not isinstance(start_positions, torch.Tensor):  | 
 | 375 | +            torch.testing.assert_close(grad_fused, grad_unfused)  | 
0 commit comments