You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
frompytorch_lightning.pluginsimportDDPPluginimporttorch_ortfrompytorch_lightning.overridesimportLightningDistributedModulefromtorch.nnimportModulefromtorch.nn.parallel.distributedimportDistributedDataParallelclassORTPlugin(DDPPlugin):
def_setup_model(self, model: Module) ->DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""fromonnxruntime.training.ortmodule.torch_cpp_extensionsimportinstallasortmodule_installortmodule_install.build_torch_cpp_extensions()
model.module.model=torch_ort.ORTModule(model.module.model)
returnDistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
If you have a better way to fix it please let me know.
Do you know if ort works with Pytorch lightning?
I am trying it but I am getting:
onnxruntime.training.ortmodule._fallback.ORTModuleTorchModelException: ORTModule does not support adding modules to it.
Also is there a way to configure ort automatically when you install the package with conda?
currently I have to call this from my code:
from onnxruntime.training.ortmodule.torch_cpp_extensions import install as ortmodule_install
ortmodule_install.build_torch_cpp_extensions()
Does it work with pytorch 10 and cuda 11?
Thanks!
The text was updated successfully, but these errors were encountered: