diff --git a/pyproject.toml b/pyproject.toml
index adb75a3d..818d9fe0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,7 +39,8 @@ dependencies = [
   "lightly>=1.4.25",
   "h5py>=3.10.0",
   "geobench>=1.0.0",
-  "mlflow>=2.12.1"
+  "mlflow>=2.12.1",
+  "lightning<=2.2.5"
 ]
 
 [project.optional-dependencies]
diff --git a/src/terratorch/models/backbones/__init__.py b/src/terratorch/models/backbones/__init__.py
index 9f91a5ea..50c59bd5 100644
--- a/src/terratorch/models/backbones/__init__.py
+++ b/src/terratorch/models/backbones/__init__.py
@@ -1,6 +1,3 @@
 # import so they get registered
-from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder
-
-__all__ = ["TemporalViTEncoder"]
-__all__ = ["TemporalViTEncoder"]
-__all__ = ["TemporalViTEncoder"]
+import terratorch.models.backbones.prithvi_vit
+import terratorch.models.backbones.prithvi_swin
\ No newline at end of file
diff --git a/src/terratorch/models/backbones/prithvi_swin.py b/src/terratorch/models/backbones/prithvi_swin.py
index 71aea43a..d83edac0 100644
--- a/src/terratorch/models/backbones/prithvi_swin.py
+++ b/src/terratorch/models/backbones/prithvi_swin.py
@@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict:
         **kwargs,
     }
 
-default_cfgs = generate_default_cfgs(
-    {
-        "prithvi_swin_90_us": {
-            "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
-            "hf_hub_filename": "Prithvi_100M.pt"
-        }
-    }
-)
-
 def convert_weights_swin2mmseg(ckpt):
     # from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
     new_ckpt = OrderedDict()
@@ -215,37 +206,6 @@ def prepare_features_for_image_model(x):
     return model
 
 
-@register_model
-def prithvi_swin_90_us(
-    pretrained: bool = False,  # noqa: FBT002, FBT001
-    pretrained_bands: list[HLSBands] | None = None,
-    bands: list[int] | None = None,
-    **kwargs,
-) -> MMSegSwinTransformer:
-    """Prithvi Swin 90M"""
-    if pretrained_bands is None:
-        pretrained_bands = PRETRAINED_BANDS
-    if bands is None:
-        bands = pretrained_bands
-        logging.info(
-            f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
-            Pretrained patch_embed layer may be misaligned with current bands"
-        )
-
-    model_args = {
-        "patch_size": 4,
-        "window_size": 7,
-        "embed_dim": 128,
-        "depths": (2, 2, 18, 2),
-        "in_chans": 6,
-        "num_heads": (4, 8, 16, 32),
-    }
-    transformer = _create_swin_mmseg_transformer(
-        "prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
-    )
-    return transformer
-
-
 @register_model
 def prithvi_swin_B(
     pretrained: bool = False,  # noqa: FBT002, FBT001
diff --git a/tests/test_backbones.py b/tests/test_backbones.py
index 17b3f6ca..6b8c36b9 100644
--- a/tests/test_backbones.py
+++ b/tests/test_backbones.py
@@ -28,7 +28,7 @@ def input_386():
     return torch.ones((1, NUM_CHANNELS, 386, 386))
 
 
-@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"])
+@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
 @pytest.mark.parametrize("test_input", ["input_224", "input_512"])
 def test_can_create_backbones_from_timm(model_name, test_input, request):
     backbone = timm.create_model(model_name, pretrained=False)
@@ -36,7 +36,7 @@ def test_can_create_backbones_from_timm(model_name, test_input, request):
     backbone(input_tensor)
 
 
-@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
+@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
 @pytest.mark.parametrize("test_input", ["input_224", "input_512"])
 def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
     backbone = timm.create_model(model_name, pretrained=False, features_only=True)
diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py
index 66a911eb..ca5280a3 100644
--- a/tests/test_prithvi_tasks.py
+++ b/tests/test_prithvi_tasks.py
@@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
     return torch.ones((1, NUM_CHANNELS, 224, 224))
 
 
-@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
+@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
 @pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
 def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
     )
 
 
-@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
+@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
 @pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
 def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
     )
 
 
-@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
+@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
 @pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
 def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):