Skip to content

batched ot.emd #392

Closed
Closed
@siddharthKatageri

Description

@siddharthKatageri

🚀 Feature

To my understanding, the current implementation of ot.emd takes only two probability distributions and a cost matrix. Is there any implementation of ot.emd that takes in batched input that I am missing?

Motivation

As ot.emd and ot.emd2 are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.

Is it possible to implement ot.emd that takes batched input? If not can you please explain why?
Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions