diff --git a/KernelBench/level1/97_ScaledDotProductAttention.py b/KernelBench/level1/97_ScaledDotProductAttention.py index feff1490..ba4bd02b 100644 --- a/KernelBench/level1/97_ScaledDotProductAttention.py +++ b/KernelBench/level1/97_ScaledDotProductAttention.py @@ -15,9 +15,9 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Te embedding_dimension = 1024 def get_inputs(): - Q = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension, device='cuda', dtype=torch.float16) - K = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension, device='cuda', dtype=torch.float16) - V = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension, device='cuda', dtype=torch.float16) + Q = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension) + K = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension) + V = torch.rand(batch_size, num_heads, sequence_length, embedding_dimension) return [Q, K, V] def get_init_inputs():