We should check performance of vmap x grad on pytorch/opacus' examples: - https://github.com/pytorch/opacus/tree/main/examples