Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch lightning #101

Open
javier-alvarez opened this issue Jan 31, 2022 · 2 comments
Open

Pytorch lightning #101

javier-alvarez opened this issue Jan 31, 2022 · 2 comments

Comments

@javier-alvarez
Copy link

Do you know if ort works with Pytorch lightning?

I am trying it but I am getting:

raise new_exception(raised_exception) from raised_exception

onnxruntime.training.ortmodule._fallback.ORTModuleTorchModelException: ORTModule does not support adding modules to it.

Also is there a way to configure ort automatically when you install the package with conda?

currently I have to call this from my code:

from onnxruntime.training.ortmodule.torch_cpp_extensions import install as ortmodule_install
ortmodule_install.build_torch_cpp_extensions()

Does it work with pytorch 10 and cuda 11?

Thanks!

@javier-alvarez
Copy link
Author

Ok the issue here is sync_batchnorm=True,

if I switch it to False. It works.

Any ideas about how to convert the model to ORT after sync_batchnorm conversion happens?

@javier-alvarez
Copy link
Author

javier-alvarez commented Jan 31, 2022

I have managed to hack it with this:

from pytorch_lightning.plugins import DDPPlugin
import torch_ort
from pytorch_lightning.overrides import LightningDistributedModule
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
class ORTPlugin(DDPPlugin):
    def _setup_model(self, model: Module) -> DistributedDataParallel:
        """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
        from onnxruntime.training.ortmodule.torch_cpp_extensions import install as ortmodule_install
        ortmodule_install.build_torch_cpp_extensions()
        model.module.model = torch_ort.ORTModule(model.module.model)
        return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

If you have a better way to fix it please let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant