diff --git a/diffusion/planners/__init__.py b/diffusion/planners/__init__.py
new file mode 100644
index 00000000..efafbb0a
--- /dev/null
+++ b/diffusion/planners/__init__.py
@@ -0,0 +1,8 @@
+# Copyright 2022 MosaicML Diffusion authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Composer checkpointing planners."""
+
+from diffusion.planners.lora_planner import LoraPlanner
+
+__all__ = ['LoraPlanner']
diff --git a/diffusion/planners/lora_planner.py b/diffusion/planners/lora_planner.py
new file mode 100644
index 00000000..3f856bd1
--- /dev/null
+++ b/diffusion/planners/lora_planner.py
@@ -0,0 +1,58 @@
+# Copyright 2022 MosaicML Diffusion authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""LoRA Planner."""
+from torch.distributed.checkpoint._nested_dict import flatten_state_dict
+from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
+
+__all__ = ['LoraPlanner']
+
+
+class LoraPlanner(DefaultLoadPlanner):
+    """Takes a Composer checkpoint and converts it to LoRA Checkpoint."""
+
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        """Sets up the planner for converting Composer to LoRA Checkpoint.
+
+        Takes all targeted modules and checks whether they have been LoRA processed. If not,
+        changes names of weights appropriately. If yes, doesn't change anything for autoresume
+        compatibility.
+
+        Args:
+            state_dict (STATE_DICT_TYPE): Original torch state dict.
+            metadata (METADATA): Any metadata associated with the state dict.
+            is_coordinator (bool): Whether the machine this is running on is the coordinator of loading.
+        """
+        if 'state' not in state_dict:
+            super().set_up_planner(state_dict, metadata, is_coordinator)
+            return
+
+        self.original_state_dict = state_dict
+
+        state_dict = dict(state_dict.items())
+        state_dict['state'] = dict(state_dict['state'].items())
+        target_modules = ['to_k', 'to_v', 'to_q', 'to_out.0']
+
+        for key in state_dict['state']['model'].keys():
+            for mod in target_modules:
+                if f'{mod}.weight' in key:
+                    new_key = key.replace(mod, mod + '.base_layer')
+                    state_dict['state']['model'][new_key] = state_dict['state']['model'].pop(key)
+                    break
+
+        if self.flatten_sharded_tensors:
+            state_dict = _flatten_sharded_tensors(state_dict)
+
+        if self.flatten_state_dict:
+            state_dict, self.mappings = flatten_state_dict(state_dict)
+
+        self.state_dict = state_dict
+        self.metadata = metadata
+        self.is_coordinator = is_coordinator
diff --git a/diffusion/train.py b/diffusion/train.py
index becff0f1..3cc7392c 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -21,6 +21,7 @@
 
 from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder
 from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
+from diffusion.planners import LoraPlanner
 
 
 def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer:
@@ -206,19 +207,28 @@ def train(config: DictConfig) -> None:
                 print(f'Instantiating callbacks <{call_conf._target_}>')
                 callbacks.append(hydra.utils.instantiate(call_conf))
 
+    if 'fsdp_config' in config.trainer:
+        fsdp_config = dict(config.trainer.fsdp_config)
+        config.trainer.__delattr__("fsdp_config")
+    else:
+        fsdp_config = None
+
+    if 'lora_rank' in config.model:
+        assert fsdp_config is not None
+        fsdp_config['load_planner'] = LoraPlanner
+
     scheduler = hydra.utils.instantiate(config.scheduler)
 
-    trainer: Trainer = hydra.utils.instantiate(
-        config.trainer,
-        train_dataloader=train_dataloader,
-        eval_dataloader=eval_set,
-        optimizers=optimizer,
-        model=model,
-        loggers=logger,
-        algorithms=algorithms,
-        schedulers=scheduler,
-        callbacks=callbacks,
-    )
+    trainer: Trainer = hydra.utils.instantiate(config.trainer,
+                                               train_dataloader=train_dataloader,
+                                               eval_dataloader=eval_set,
+                                               optimizers=optimizer,
+                                               model=model,
+                                               loggers=logger,
+                                               algorithms=algorithms,
+                                               schedulers=scheduler,
+                                               callbacks=callbacks,
+                                               fsdp_config=fsdp_config)
 
     def eval_and_then_train():
         if config.get('eval_first', True):