Closed
Description
🚀 Feature
To my understanding, the current implementation of ot.emd
takes only two probability distributions and a cost matrix. Is there any implementation of ot.emd
that takes in batched input that I am missing?
Motivation
As ot.emd
and ot.emd2
are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.
Is it possible to implement ot.emd
that takes batched input? If not can you please explain why?
Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?
Thanks.