Description
The backward formula for operators like diag
call in-place operations. We cannot vmap over these so we want to develop some way to functionalize the backward formula. However, "conditional functionalization" (#235) right now only works on operators: we can say, "given this native_functions operator, we will apply the functionalization pass to it".
The goal of this issue is to figure out how to use functionalization to functionalize the backward pass of certain operations when vmap is active.
Approach 1
Just register the backward formula as a composite operator. This will be useful for things in the future like dynamic shape tracing where we need the backward formula to be available in Python.
Tradeoffs: There is resistance to making backward formulas into operators because (1) it increases the size of our operator library and (2) the native_functions operator library is potentially a FC/BC surface. Perhaps we would have to couple this approach with some BE improvements like being able to mark a native_functions operator as "never serializable"?
Approach 2
Use Autograd hooks. For Autograd hooks to work, we would need to:
- be able to enable functionalization before the autograd node gets called
- be able to unconditionally disable functionalization after the autograd node is done executing. This needs to happen even if an error happens during the backward pass.
It's unclear how much work this approach is.
Test Cases
Here are some test cases for where we would like to conditionally functionalize the backward pass of these operators:
- torch.masked_select
- torch.clamp
- torch.diag
- torch.index_select
- torch.topk
- torch.prod
- torch.mode
- torch.sort