Skip to content

vmap kwarg handling #244

Open
Open
@zou3519

Description

@zou3519

Our kwarg handling seems inconsistent from JAX's:

import jax
import jax.numpy as jnp

def f(x, y):
    return x * y

x = jnp.ones([2])
jax.vmap(f)(x, y=x).shape # returns (2,)
import torch
import functorch

def f(x, y):
    return x * y

x = torch.ones(2)
functorch.vmap(f)(x, y=x).shape # returns (2, 2)

We should figure out what we want the semantics of kwarg handling to be. Right now it looks like we treat all kwargs as unbatched (and it is documented like that). There are a couple of options:

  • we treat all kwargs as batched (this is kind of awkward for values that aren't Tensors)
  • we treat all kwargs as unbatched (the current state)
  • some mix of both? E.g., if an argument is a Tensor we treat it as batched but if it is not a Tensor (and not a pytree of Tensors) we treat it as unbatched.

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs designThere is ambiguity around design

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions