Open
Description
Hello, I am using functorch for my project. It's a good substitute for Jax and makes life easy. However it does not have GPU support for M1 chips which is a time-constraint for me because I have an M1-based system. Now that PyTorch support M1-GPU I think functorch should also make this feature available.