Open
Description
Our kwarg handling seems inconsistent from JAX's:
import jax
import jax.numpy as jnp
def f(x, y):
return x * y
x = jnp.ones([2])
jax.vmap(f)(x, y=x).shape # returns (2,)
import torch
import functorch
def f(x, y):
return x * y
x = torch.ones(2)
functorch.vmap(f)(x, y=x).shape # returns (2, 2)
We should figure out what we want the semantics of kwarg handling to be. Right now it looks like we treat all kwargs as unbatched (and it is documented like that). There are a couple of options:
- we treat all kwargs as batched (this is kind of awkward for values that aren't Tensors)
- we treat all kwargs as unbatched (the current state)
- some mix of both? E.g., if an argument is a Tensor we treat it as batched but if it is not a Tensor (and not a pytree of Tensors) we treat it as unbatched.