We should follow https://pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.html to implement `enable_gqa`