You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
To my understanding, the current implementation of ot.emd takes only two probability distributions and a cost matrix. Is there any implementation of ot.emd that takes in batched input that I am missing?
Motivation
As ot.emd and ot.emd2 are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.
Is it possible to implement ot.emd that takes batched input? If not can you please explain why?
Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?
Thanks.
The text was updated successfully, but these errors were encountered:
Actually as detailed in the function documentation ot.emd and emd2 can take GPU tensors but the solver in CPU bound so there is a memory copy overhead when on GPU. It is relatively small on large problems but can be quite limiting when calling several small problems often.
There is an openMP implementation of the solver that can benefit from multiple CPU cores or one can call in parallel the solvers on multiple problems in practice. But there is no way at the moment to do a batch exact OT solver since no network flow solver on GPU is available yet. For batch implementation regularized OT is indeed best.
🚀 Feature
To my understanding, the current implementation of
ot.emd
takes only two probability distributions and a cost matrix. Is there any implementation ofot.emd
that takes in batched input that I am missing?Motivation
As
ot.emd
andot.emd2
are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.Is it possible to implement
ot.emd
that takes batched input? If not can you please explain why?Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?
Thanks.
The text was updated successfully, but these errors were encountered: