diff --git a/.github/workflows/discord_release_notification.yml b/.github/workflows/discord_release_notification.yml
new file mode 100644
index 000000000..47367714b
--- /dev/null
+++ b/.github/workflows/discord_release_notification.yml
@@ -0,0 +1,37 @@
+name: Discord Release Notification
+
+on:
+ release:
+ types: [published]
+
+jobs:
+ notify-discord:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Send Notification to Discord
+ env:
+ DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }}
+ # We truncate the description at the models section (starting with ### Models)
+ # to keep the message short.
+ # We have also have to format the release description for it to be valid JSON.
+ # This is done by piping the description to jq.
+ run: |
+ DESCRIPTION=$(echo '${{ github.event.release.body }}' | awk '/### Models/{exit}1' | jq -aRs .)
+ curl -H "Content-Type: application/json" \
+ -X POST \
+ -d @- \
+ "${DISCORD_WEBHOOK}" << EOF
+ {
+ "username": "Lightly",
+ "avatar_url": "https://avatars.githubusercontent.com/u/50146475",
+ "content": "Lightly ${{ github.event.release.tag_name }} has been released!",
+ "embeds": [
+ {
+ "title": "${{ github.event.release.name }}",
+ "url": "${{ github.event.release.html_url }}",
+ "color": 5814783,
+ "description": $DESCRIPTION
+ }
+ ]
+ }
+ EOF
diff --git a/.github/workflows/test_code_format.yml b/.github/workflows/test_code_format.yml
index 504756df3..61eb862fb 100644
--- a/.github/workflows/test_code_format.yml
+++ b/.github/workflows/test_code_format.yml
@@ -9,7 +9,6 @@ jobs:
test:
name: Check
runs-on: ubuntu-latest
-
steps:
- name: Checkout Code
uses: actions/checkout@v3
@@ -21,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.7"
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
@@ -30,5 +29,7 @@ jobs:
run: pip install -e '.[all]'
- name: Run Format Check
run: |
- export LIGHTLY_SERVER_LOCATION="localhost:-1"
make format-check
+ - name: Run Type Check
+ run: |
+ make type-check
diff --git a/.gitignore b/.gitignore
index 031ae8795..83d4d1fa2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,7 @@ lightning_logs/
**lightning_logs/
**/__MACOSX
datasets/
+dist/
docs/source/tutorials/package/*
docs/source/tutorials/platform/*
docs/source/tutorials_source/platform/data
diff --git a/Makefile b/Makefile
index ead2c748d..fcb2ab17b 100644
--- a/Makefile
+++ b/Makefile
@@ -63,8 +63,15 @@ test:
test-fast:
pytest tests
-# run format checks and tests
-all-checks: format-check test
+## check typing
+type-check:
+ mypy lightly tests
+
+## run format checks
+static-checks: format-check type-check
+
+## run format checks and tests
+all-checks: static-checks test
## build source and wheel package
dist: clean
diff --git a/README.md b/README.md
index 7938b21e8..321b09125 100644
--- a/README.md
+++ b/README.md
@@ -1,29 +1,31 @@
-![Lightly Logo](docs/logos/lightly_logo_crop.png)
+![Lightly SSL self-supervised learning Logo](docs/logos/lightly_SSL_logo_crop.png)
![GitHub](https://img.shields.io/github/license/lightly-ai/lightly)
![Unit Tests](https://github.com/lightly-ai/lightly/workflows/Unit%20Tests/badge.svg)
[![PyPI](https://img.shields.io/pypi/v/lightly)](https://pypi.org/project/lightly/)
-[![Downloads](https://pepy.tech/badge/lightly)](https://pepy.tech/project/lightly)
+[![Downloads](https://static.pepy.tech/badge/lightly)](https://pepy.tech/project/lightly)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+[![Discord](https://img.shields.io/discord/752876370337726585?logo=discord&logoColor=white&label=discord&color=7289da)](https://discord.gg/xvNJW94)
-Lightly is a computer vision framework for self-supervised learning.
-> We, at [Lightly](https://www.lightly.ai), are passionate engineers who want to make deep learning more efficient. That's why - together with our community - we want to popularize the use of self-supervised methods to understand and curate raw image data. Our solution can be applied before any data annotation step and the learned representations can be used to visualize and analyze datasets. This allows to select the best set of samples for model training through advanced filtering.
+Lightly SSL is a computer vision framework for self-supervised learning.
-- [Homepage](https://www.lightly.ai)
-- [Web-App](https://app.lightly.ai)
- [Documentation](https://docs.lightly.ai/self-supervised-learning/)
-- [Lightly Solution Documentation (Lightly Worker & API)](https://docs.lightly.ai/)
- [Github](https://github.com/lightly-ai/lightly)
- [Discord](https://discord.gg/xvNJW94) (We have weekly paper sessions!)
+We've also built a whole platform on top, with additional features for active learning
+and [data curation](https://docs.lightly.ai/docs/what-is-lightly). If you're interested in the
+Lightly Worker Solution to easily process millions of samples and run [powerful algorithms](https://docs.lightly.ai/docs/selection)
+on your data, check out [lightly.ai](https://www.lightly.ai). It's free to get started!
+
## Features
-Lightly offers features like
+This self-supervised learning framework offers the following features:
-- Modular framework which exposes low-level building blocks such as loss functions and
+- Modular framework, which exposes low-level building blocks such as loss functions and
model heads.
- Easy to use and written in a PyTorch like style.
- Supports custom backbone models for self-supervised pre-training.
@@ -66,17 +68,6 @@ Want to jump to the tutorials and see Lightly in action?
- [Use Lightly with Custom Augmentations](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_custom_augmentations.html)
- [Pre-train a Detectron2 Backbone with Lightly](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_pretrain_detectron2.html)
-Tutorials for the Lightly Solution (Lightly Worker & API):
-
-- [General Docs of Lightly Solution](https://docs.lightly.ai)
-- [Active Learning Using YOLOv7 and Comma10k](https://docs.lightly.ai/docs/active-learning-yolov7)
-- [Active Learning for Driveable Area Segmentation Using Cityscapes](https://docs.lightly.ai/docs/active-learning-for-driveable-area-segmentation-using-cityscapes)
-- [Active Learning for Transactions of Images](https://docs.lightly.ai/docs/active-learning-for-transactions-of-images)
-- [Improving YOLOv8 using Active Learning on Videos](https://docs.lightly.ai/docs/active-learning-yolov8-video)
-- [Assertion-based Active Learning with YOLOv8](https://docs.lightly.ai/docs/assertion-based-active-learning-tutorial)
-- and more ...
-
-
Community and partner projects:
- [On-Device Deep Learning with Lightly on an ARM microcontroller](https://github.com/ARM-software/EndpointAI/tree/master/ProofOfConcepts/Vision/OpenMvMaskDefaults)
@@ -105,9 +96,6 @@ pip3 install lightly
We strongly recommend that you install Lightly in a dedicated virtualenv, to avoid
conflicting with your system packages.
-If you only want to install the API client without torch and torchvision dependencies
-follow the docs on [how to install the Lightly Python Client](https://docs.lightly.ai/docs/install-lightly#install-the-lightly-python-client).
-
### Lightly in Action
@@ -274,75 +262,48 @@ We provide multi-GPU training examples with distributed gather and synchronized
## Benchmarks
Implemented models and their performance on various datasets. Hyperparameters are not
-tuned for maximum accuracy. For detailed results and more info about the benchmarks click
+tuned for maximum accuracy. For detailed results and more information about the benchmarks click
[here](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html).
-### Imagenet
+### ImageNet1k
+
+[ImageNet1k benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet1k)
-> **Note**: Evaluation settings are based on these papers:
-> * Linear: [SimCLR](https://arxiv.org/abs/2002.05709)
-> * Finetune: [SimCLR](https://arxiv.org/abs/2002.05709)
-> * KNN: [InstDisc](https://arxiv.org/abs/1805.01978)
->
-> See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
+**Note**: Evaluation settings are based on these papers:
+ * Linear: [SimCLR](https://arxiv.org/abs/2002.05709)
+ * Finetune: [SimCLR](https://arxiv.org/abs/2002.05709)
+ * KNN: [InstDisc](https://arxiv.org/abs/1805.01978)
+
+See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.
-| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | KNN Top1 | Tensorboard | Checkpoint |
+
+| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | kNN Top1 | Tensorboard | Checkpoint |
|----------------|----------|------------|--------|-------------|---------------|----------|-------------|------------|
+| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) |
| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
+| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
*\*We use square root learning rate scaling instead of linear scaling as it yields
-better results for smaller batch sizes. See Appendix B.1 in [SimCLR paper](https://arxiv.org/abs/2002.05709).*
-
-
-
-### ImageNette
-
-| Model | Backbone | Batch Size | Epochs | KNN Top1 |
-|-------------|----------|------------|--------|----------|
-| BarlowTwins | Res18 | 256 | 800 | 0.852 |
-| BYOL | Res18 | 256 | 800 | 0.887 |
-| DCL | Res18 | 256 | 800 | 0.861 |
-| DCLW | Res18 | 256 | 800 | 0.865 |
-| DINO | Res18 | 256 | 800 | 0.888 |
-| FastSiam | Res18 | 256 | 800 | 0.873 |
-| MAE | ViT-S | 256 | 800 | 0.610 |
-| MSN | ViT-S | 256 | 800 | 0.828 |
-| Moco | Res18 | 256 | 800 | 0.874 |
-| NNCLR | Res18 | 256 | 800 | 0.884 |
-| PMSN | ViT-S | 256 | 800 | 0.822 |
-| SimCLR | Res18 | 256 | 800 | 0.889 |
-| SimMIM | ViT-B32 | 256 | 800 | 0.343 |
-| SimSiam | Res18 | 256 | 800 | 0.872 |
-| SwaV | Res18 | 256 | 800 | 0.902 |
-| SwaVQueue | Res18 | 256 | 800 | 0.890 |
-| SMoG | Res18 | 256 | 800 | 0.788 |
-| TiCo | Res18 | 256 | 800 | 0.856 |
-| VICReg | Res18 | 256 | 800 | 0.845 |
-| VICRegL | Res18 | 256 | 800 | 0.778 |
-
-
-### Cifar10
-
-| Model | Backbone | Batch Size | Epochs | KNN Top1 |
-|-------------|----------|------------|--------|----------|
-| BarlowTwins | Res18 | 512 | 800 | 0.859 |
-| BYOL | Res18 | 512 | 800 | 0.910 |
-| DCL | Res18 | 512 | 800 | 0.874 |
-| DCLW | Res18 | 512 | 800 | 0.871 |
-| DINO | Res18 | 512 | 800 | 0.848 |
-| FastSiam | Res18 | 512 | 800 | 0.902 |
-| Moco | Res18 | 512 | 800 | 0.899 |
-| NNCLR | Res18 | 512 | 800 | 0.892 |
-| SimCLR | Res18 | 512 | 800 | 0.879 |
-| SimSiam | Res18 | 512 | 800 | 0.904 |
-| SwaV | Res18 | 512 | 800 | 0.884 |
-| SMoG | Res18 | 512 | 800 | 0.800 |
+better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709).*
+
+### ImageNet100
+[ImageNet100 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet100)
+
+
+### Imagenette
+
+[Imagenette benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenette)
+
+
+### CIFAR-10
+
+[CIFAR-10 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#cifar-10)
## Terminology
@@ -355,7 +316,7 @@ The terms in bold are explained in more detail in our [documentation](https://do
### Next Steps
-Head to the [documentation](https://docs.lightly.ai) and see the things you can achieve with Lightly!
+Head to the [documentation](https://docs.lightly.ai/self-supervised-learning/) and see the things you can achieve with Lightly!
## Development
@@ -438,6 +399,15 @@ make format
- [Decoupled Contrastive Learning, 2021](https://arxiv.org/abs/2110.06848)
- [solo-learn: A Library of Self-supervised Methods for Visual Representation Learning, 2021](https://www.jmlr.org/papers/volume23/21-1155/21-1155.pdf)
+## Company behind this Open Source Framework
+[Lightly](https://www.lightly.ai) is a spin-off from ETH Zurich that helps companies
+build efficient active learning pipelines to select the most relevant data for their models.
+
+You can find out more about the company and it's services by following the links below:
+
+- [Homepage](https://www.lightly.ai)
+- [Web-App](https://app.lightly.ai)
+- [Lightly Solution Documentation (Lightly Worker & API)](https://docs.lightly.ai/)
## BibTeX
If you want to cite the framework feel free to use this:
diff --git a/benchmarks/imagenet/resnet50/barlowtwins.py b/benchmarks/imagenet/resnet50/barlowtwins.py
new file mode 100644
index 000000000..ec2e5f72e
--- /dev/null
+++ b/benchmarks/imagenet/resnet50/barlowtwins.py
@@ -0,0 +1,112 @@
+import copy
+from typing import List, Tuple
+
+import torch
+from pytorch_lightning import LightningModule
+from torch import Tensor
+from torch.nn import Identity
+from torchvision.models import resnet50
+
+from lightly.loss import BarlowTwinsLoss
+from lightly.models.modules import BarlowTwinsProjectionHead
+from lightly.models.utils import get_weight_decay_parameters
+from lightly.transforms import BYOLTransform
+from lightly.utils.benchmarking import OnlineLinearClassifier
+from lightly.utils.lars import LARS
+from lightly.utils.scheduler import CosineWarmupScheduler
+
+
+class BarlowTwins(LightningModule):
+ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
+ super().__init__()
+ self.save_hyperparameters()
+ self.batch_size_per_device = batch_size_per_device
+
+ resnet = resnet50()
+ resnet.fc = Identity() # Ignore classification head
+ self.backbone = resnet
+ self.projection_head = BarlowTwinsProjectionHead()
+ self.criterion = BarlowTwinsLoss(lambda_param=5e-3, gather_distributed=True)
+
+ self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
+
+ def training_step(
+ self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ # Forward pass and loss calculation.
+ views, targets = batch[0], batch[1]
+ features = self.forward(torch.cat(views)).flatten(start_dim=1)
+ z = self.projection_head(features)
+ z0, z1 = z.chunk(len(views))
+ loss = self.criterion(z0, z1)
+
+ self.log(
+ "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
+ )
+
+ # Online linear evaluation.
+ cls_loss, cls_log = self.online_classifier.training_step(
+ (features.detach(), targets.repeat(len(views))), batch_idx
+ )
+ self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
+ return loss + cls_loss
+
+ def validation_step(
+ self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ images, targets = batch[0], batch[1]
+ features = self.forward(images).flatten(start_dim=1)
+ cls_loss, cls_log = self.online_classifier.validation_step(
+ (features.detach(), targets), batch_idx
+ )
+ self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
+ return cls_loss
+
+ def configure_optimizers(self):
+ lr_factor = self.batch_size_per_device * self.trainer.world_size / 256
+
+ # Don't use weight decay for batch norm, bias parameters, and classification
+ # head to improve performance.
+ params, params_no_weight_decay = get_weight_decay_parameters(
+ [self.backbone, self.projection_head]
+ )
+ optimizer = LARS(
+ [
+ {"name": "barlowtwins", "params": params},
+ {
+ "name": "barlowtwins_no_weight_decay",
+ "params": params_no_weight_decay,
+ "weight_decay": 0.0,
+ "lr": 0.0048 * lr_factor,
+ },
+ {
+ "name": "online_classifier",
+ "params": self.online_classifier.parameters(),
+ "weight_decay": 0.0,
+ },
+ ],
+ lr=0.2 * lr_factor,
+ momentum=0.9,
+ weight_decay=1.5e-6,
+ )
+
+ scheduler = {
+ "scheduler": CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_epochs=int(
+ self.trainer.estimated_stepping_batches
+ / self.trainer.max_epochs
+ * 10
+ ),
+ max_epochs=int(self.trainer.estimated_stepping_batches),
+ ),
+ "interval": "step",
+ }
+ return [optimizer], [scheduler]
+
+
+# BarlowTwins uses same transform as BYOL.
+transform = BYOLTransform()
diff --git a/benchmarks/imagenet/resnet50/byol.py b/benchmarks/imagenet/resnet50/byol.py
index bfc2f1080..27d550dc6 100644
--- a/benchmarks/imagenet/resnet50/byol.py
+++ b/benchmarks/imagenet/resnet50/byol.py
@@ -10,7 +10,7 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import get_weight_decay_parameters, update_momentum
-from lightly.transforms import SimCLRTransform
+from lightly.transforms import BYOLTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
@@ -132,17 +132,19 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]
-# BYOL uses same transform as SimCLR.
-transform = SimCLRTransform()
+# BYOL uses a slight modification of the SimCLR transforms.
+# It uses asymmetric augmentation and solarize.
+# Check table 6 in the BYOL paper for more info.
+transform = BYOLTransform()
diff --git a/benchmarks/imagenet/resnet50/dcl.py b/benchmarks/imagenet/resnet50/dcl.py
index 42c66c7c5..f08fdb157 100644
--- a/benchmarks/imagenet/resnet50/dcl.py
+++ b/benchmarks/imagenet/resnet50/dcl.py
@@ -97,12 +97,12 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
diff --git a/benchmarks/imagenet/resnet50/dclw.py b/benchmarks/imagenet/resnet50/dclw.py
index bcae95d6e..6f0bb7e54 100644
--- a/benchmarks/imagenet/resnet50/dclw.py
+++ b/benchmarks/imagenet/resnet50/dclw.py
@@ -97,12 +97,12 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
diff --git a/benchmarks/imagenet/resnet50/dino.py b/benchmarks/imagenet/resnet50/dino.py
index d7924c9b5..a440ce147 100644
--- a/benchmarks/imagenet/resnet50/dino.py
+++ b/benchmarks/imagenet/resnet50/dino.py
@@ -136,12 +136,12 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
diff --git a/benchmarks/imagenet/resnet50/finetune_eval.py b/benchmarks/imagenet/resnet50/finetune_eval.py
index 466f73f1c..9db6f0f88 100644
--- a/benchmarks/imagenet/resnet50/finetune_eval.py
+++ b/benchmarks/imagenet/resnet50/finetune_eval.py
@@ -11,6 +11,7 @@
from lightly.data import LightlyDataset
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import LinearClassifier, MetricCallback
+from lightly.utils.dist import print_rank_zero
from lightly.utils.scheduler import CosineWarmupScheduler
@@ -63,7 +64,7 @@ def finetune_eval(
References:
- [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709
"""
- print("Running fine-tune evaluation...")
+ print_rank_zero("Running fine-tune evaluation...")
# Setup training data.
train_transform = T.Compose(
@@ -81,7 +82,7 @@ def finetune_eval(
shuffle=True,
num_workers=num_workers,
drop_last=True,
- persistent_workers=True,
+ persistent_workers=False,
)
# Setup validation data.
@@ -99,7 +100,7 @@ def finetune_eval(
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
- persistent_workers=True,
+ persistent_workers=False,
)
# Train linear classifier.
@@ -116,6 +117,7 @@ def finetune_eval(
logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
+ num_sanity_val_steps=0,
)
classifier = FinetuneLinearClassifier(
model=model,
@@ -130,4 +132,6 @@ def finetune_eval(
val_dataloaders=val_dataloader,
)
for metric in ["val_top1", "val_top5"]:
- print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
+ print_rank_zero(
+ f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}"
+ )
diff --git a/benchmarks/imagenet/resnet50/knn_eval.py b/benchmarks/imagenet/resnet50/knn_eval.py
index 51017c0c1..62d7501a8 100644
--- a/benchmarks/imagenet/resnet50/knn_eval.py
+++ b/benchmarks/imagenet/resnet50/knn_eval.py
@@ -10,6 +10,7 @@
from lightly.data import LightlyDataset
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import KNNClassifier, MetricCallback
+from lightly.utils.dist import print_rank_zero
def knn_eval(
@@ -34,7 +35,7 @@ def knn_eval(
References:
- [0]: InstDict, 2018, https://arxiv.org/abs/1805.01978
"""
- print("Running KNN evaluation...")
+ print_rank_zero("Running KNN evaluation...")
# Setup training data.
transform = T.Compose(
@@ -81,6 +82,7 @@ def knn_eval(
metric_callback,
],
strategy="ddp_find_unused_parameters_true",
+ num_sanity_val_steps=0,
)
trainer.fit(
model=classifier,
@@ -88,4 +90,4 @@ def knn_eval(
val_dataloaders=val_dataloader,
)
for metric in ["val_top1", "val_top5"]:
- print(f"knn {metric}: {max(metric_callback.val_metrics[metric])}")
+ print_rank_zero(f"knn {metric}: {max(metric_callback.val_metrics[metric])}")
diff --git a/benchmarks/imagenet/resnet50/linear_eval.py b/benchmarks/imagenet/resnet50/linear_eval.py
index 21a4aef48..08af6b8c0 100644
--- a/benchmarks/imagenet/resnet50/linear_eval.py
+++ b/benchmarks/imagenet/resnet50/linear_eval.py
@@ -10,6 +10,7 @@
from lightly.data import LightlyDataset
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import LinearClassifier, MetricCallback
+from lightly.utils.dist import print_rank_zero
def linear_eval(
@@ -40,7 +41,7 @@ def linear_eval(
References:
- [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709
"""
- print("Running linear evaluation...")
+ print_rank_zero("Running linear evaluation...")
# Setup training data.
train_transform = T.Compose(
@@ -58,7 +59,7 @@ def linear_eval(
shuffle=True,
num_workers=num_workers,
drop_last=True,
- persistent_workers=True,
+ persistent_workers=False,
)
# Setup validation data.
@@ -76,7 +77,7 @@ def linear_eval(
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
- persistent_workers=True,
+ persistent_workers=False,
)
# Train linear classifier.
@@ -93,6 +94,7 @@ def linear_eval(
logger=TensorBoardLogger(save_dir=str(log_dir), name="linear_eval"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
+ num_sanity_val_steps=0,
)
classifier = LinearClassifier(
model=model,
@@ -107,4 +109,6 @@ def linear_eval(
val_dataloaders=val_dataloader,
)
for metric in ["val_top1", "val_top5"]:
- print(f"max linear {metric}: {max(metric_callback.val_metrics[metric])}")
+ print_rank_zero(
+ f"max linear {metric}: {max(metric_callback.val_metrics[metric])}"
+ )
diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py
index ce1e9e40b..505d4412c 100644
--- a/benchmarks/imagenet/resnet50/main.py
+++ b/benchmarks/imagenet/resnet50/main.py
@@ -3,6 +3,7 @@
from pathlib import Path
from typing import Sequence, Union
+import barlowtwins
import byol
import dcl
import dclw
@@ -14,6 +15,7 @@
import simclr
import swav
import torch
+import vicreg
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import (
DeviceStatsMonitor,
@@ -27,6 +29,7 @@
from lightly.data import LightlyDataset
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import MetricCallback
+from lightly.utils.dist import print_rank_zero
parser = ArgumentParser("ImageNet ResNet50 Benchmarks")
parser.add_argument("--train-dir", type=Path, default="/datasets/imagenet/train")
@@ -38,6 +41,7 @@
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--precision", type=str, default="16-mixed")
+parser.add_argument("--ckpt-path", type=Path, default=None)
parser.add_argument("--compile-model", action="store_true")
parser.add_argument("--methods", type=str, nargs="+")
parser.add_argument("--num-classes", type=int, default=1000)
@@ -46,6 +50,10 @@
parser.add_argument("--skip-finetune-eval", action="store_true")
METHODS = {
+ "barlowtwins": {
+ "model": barlowtwins.BarlowTwins,
+ "transform": barlowtwins.transform,
+ },
"byol": {"model": byol.BYOL, "transform": byol.transform},
"dcl": {"model": dcl.DCL, "transform": dcl.transform},
"dclw": {"model": dclw.DCLW, "transform": dclw.transform},
@@ -53,6 +61,7 @@
"mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
+ "vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
}
@@ -72,6 +81,7 @@ def main(
skip_knn_eval: bool,
skip_linear_eval: bool,
skip_finetune_eval: bool,
+ ckpt_path: Union[Path, None],
) -> None:
torch.set_float32_matmul_precision("high")
@@ -87,11 +97,13 @@ def main(
if compile_model and hasattr(torch, "compile"):
# Compile model if PyTorch supports it.
- print("Compiling model...")
+ print_rank_zero("Compiling model...")
model = torch.compile(model)
if epochs <= 0:
- print("Epochs <= 0, skipping pretraining.")
+ print_rank_zero("Epochs <= 0, skipping pretraining.")
+ if ckpt_path is not None:
+ model.load_state_dict(torch.load(ckpt_path)["state_dict"])
else:
pretrain(
model=model,
@@ -105,10 +117,11 @@ def main(
accelerator=accelerator,
devices=devices,
precision=precision,
+ ckpt_path=ckpt_path,
)
if skip_knn_eval:
- print("Skipping KNN eval.")
+ print_rank_zero("Skipping KNN eval.")
else:
knn_eval.knn_eval(
model=model,
@@ -123,7 +136,7 @@ def main(
)
if skip_linear_eval:
- print("Skipping linear eval.")
+ print_rank_zero("Skipping linear eval.")
else:
linear_eval.linear_eval(
model=model,
@@ -139,7 +152,7 @@ def main(
)
if skip_finetune_eval:
- print("Skipping fine-tune eval.")
+ print_rank_zero("Skipping fine-tune eval.")
else:
finetune_eval.finetune_eval(
model=model,
@@ -167,8 +180,9 @@ def pretrain(
accelerator: str,
devices: int,
precision: str,
+ ckpt_path: Union[Path, None],
) -> None:
- print(f"Running pretraining for {method}...")
+ print_rank_zero(f"Running pretraining for {method}...")
# Setup training data.
train_transform = METHODS[method]["transform"]
@@ -179,7 +193,7 @@ def pretrain(
shuffle=True,
num_workers=num_workers,
drop_last=True,
- persistent_workers=True,
+ persistent_workers=False,
)
# Setup validation data.
@@ -197,7 +211,7 @@ def pretrain(
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
- persistent_workers=True,
+ persistent_workers=False,
)
# Train model.
@@ -217,14 +231,17 @@ def pretrain(
precision=precision,
strategy="ddp_find_unused_parameters_true",
sync_batchnorm=True,
+ num_sanity_val_steps=0,
)
+
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
+ ckpt_path=ckpt_path,
)
for metric in ["val_online_cls_top1", "val_online_cls_top5"]:
- print(f"max {metric}: {max(metric_callback.val_metrics[metric])}")
+ print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")
if __name__ == "__main__":
diff --git a/benchmarks/imagenet/resnet50/simclr.py b/benchmarks/imagenet/resnet50/simclr.py
index de6958638..6562025ae 100644
--- a/benchmarks/imagenet/resnet50/simclr.py
+++ b/benchmarks/imagenet/resnet50/simclr.py
@@ -96,12 +96,12 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
diff --git a/benchmarks/imagenet/resnet50/swav.py b/benchmarks/imagenet/resnet50/swav.py
index e3117e4f8..96eb7a1fb 100644
--- a/benchmarks/imagenet/resnet50/swav.py
+++ b/benchmarks/imagenet/resnet50/swav.py
@@ -156,12 +156,12 @@ def configure_optimizers(self):
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
- warmup_epochs=(
+ warmup_epochs=int(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 10
),
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
end_value=0.0006
* (self.batch_size_per_device * self.trainer.world_size)
/ 256,
diff --git a/benchmarks/imagenet/resnet50/vicreg.py b/benchmarks/imagenet/resnet50/vicreg.py
new file mode 100644
index 000000000..0da61677e
--- /dev/null
+++ b/benchmarks/imagenet/resnet50/vicreg.py
@@ -0,0 +1,127 @@
+from typing import List, Tuple
+
+import torch
+from pytorch_lightning import LightningModule
+from torch import Tensor
+from torch.nn import Identity
+from torchvision.models import resnet50
+
+from lightly.loss.vicreg_loss import VICRegLoss
+from lightly.models.modules.heads import VICRegProjectionHead
+from lightly.models.utils import get_weight_decay_parameters
+from lightly.transforms.vicreg_transform import VICRegTransform
+from lightly.utils.benchmarking import OnlineLinearClassifier
+from lightly.utils.lars import LARS
+from lightly.utils.scheduler import CosineWarmupScheduler
+
+
+class VICReg(LightningModule):
+ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
+ super().__init__()
+ self.save_hyperparameters()
+ self.batch_size_per_device = batch_size_per_device
+
+ resnet = resnet50()
+ resnet.fc = Identity() # Ignore classification head
+ self.backbone = resnet
+ self.projection_head = VICRegProjectionHead(num_layers=2)
+ self.criterion = VICRegLoss()
+
+ self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.backbone(x)
+
+ def training_step(
+ self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ views, targets = batch[0], batch[1]
+ features = self.forward(torch.cat(views)).flatten(start_dim=1)
+ z = self.projection_head(features)
+ z_a, z_b = z.chunk(len(views))
+ loss = self.criterion(z_a=z_a, z_b=z_b)
+ self.log(
+ "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
+ )
+
+ # Online linear evaluation.
+ cls_loss, cls_log = self.online_classifier.training_step(
+ (features.detach(), targets.repeat(len(views))), batch_idx
+ )
+
+ self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
+ return loss + cls_loss
+
+ def validation_step(
+ self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
+ images, targets = batch[0], batch[1]
+ features = self.forward(images).flatten(start_dim=1)
+ cls_loss, cls_log = self.online_classifier.validation_step(
+ (features.detach(), targets), batch_idx
+ )
+ self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
+ return cls_loss
+
+ def configure_optimizers(self):
+ # Don't use weight decay for batch norm, bias parameters, and classification
+ # head to improve performance.
+ params, params_no_weight_decay = get_weight_decay_parameters(
+ [self.backbone, self.projection_head]
+ )
+ global_batch_size = self.batch_size_per_device * self.trainer.world_size
+ base_lr = _get_base_learning_rate(global_batch_size=global_batch_size)
+ optimizer = LARS(
+ [
+ {"name": "vicreg", "params": params},
+ {
+ "name": "vicreg_no_weight_decay",
+ "params": params_no_weight_decay,
+ "weight_decay": 0.0,
+ },
+ {
+ "name": "online_classifier",
+ "params": self.online_classifier.parameters(),
+ "weight_decay": 0.0,
+ },
+ ],
+ # Linear learning rate scaling with a base learning rate of 0.2.
+ # See https://arxiv.org/pdf/2105.04906.pdf for details.
+ lr=base_lr * global_batch_size / 256,
+ momentum=0.9,
+ weight_decay=1e-6,
+ )
+ scheduler = {
+ "scheduler": CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_epochs=(
+ self.trainer.estimated_stepping_batches
+ / self.trainer.max_epochs
+ * 10
+ ),
+ max_epochs=self.trainer.estimated_stepping_batches,
+ end_value=0.01, # Scale base learning rate from 0.2 to 0.002.
+ ),
+ "interval": "step",
+ }
+ return [optimizer], [scheduler]
+
+
+# VICReg transform
+transform = VICRegTransform()
+
+
+def _get_base_learning_rate(global_batch_size: int) -> float:
+ """Returns the base learning rate for training 100 epochs with a given batch size.
+
+ This follows section C.4 in https://arxiv.org/pdf/2105.04906.pdf.
+
+ """
+ if global_batch_size == 128:
+ return 0.8
+ elif global_batch_size == 256:
+ return 0.5
+ elif global_batch_size == 512:
+ return 0.4
+ else:
+ return 0.3
diff --git a/docs/Makefile b/docs/Makefile
index 95f708852..5900e0fe6 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -60,20 +60,6 @@ download-noplot:
unzip $(ZIPOPTS) $(DATADIR)/resources.zip -d $(DOCKERSOURCE);\
unzip $(ZIPOPTS) $(DATADIR)/resources.zip -d $(DOCKER_ARCHIVE_SOURCE); \
- # pizza dataset
- @if [ ! -d $(PLATFORMSOURCE)/pizzas/salami ]; then \
- wget -N https://storage.googleapis.com/datasets_boris/pizzas.zip -P $(DATADIR);\
- unzip $(ZIPOPTS) $(DATADIR)/pizzas.zip -d $(PLATFORMSOURCE);\
- mkdir -p $(PLATFORMSOURCE)/pizzas/margherita;\
- mkdir -p $(PLATFORMSOURCE)/pizzas/salami;\
- mv $(PLATFORMSOURCE)/pizzas/margherita*.jpg $(PLATFORMSOURCE)/pizzas/margherita;\
- mv $(PLATFORMSOURCE)/pizzas/salami*.jpg $(PLATFORMSOURCE)/pizzas/salami;\
- fi
-
- # sunflowers dataset
- @if [ ! -f $(DATADIR)/Sunflowers.zip ]; then \
- wget -N https://storage.googleapis.com/datasets_boris/Sunflowers.zip -P $(DATADIR);\
- fi
# Download also the datasets needed for the tutorials
download: download-noplot
diff --git a/docs/logos/lightly_SSL_logo_crop.png b/docs/logos/lightly_SSL_logo_crop.png
new file mode 100644
index 000000000..62028eaf2
Binary files /dev/null and b/docs/logos/lightly_SSL_logo_crop.png differ
diff --git a/docs/logos/lightly_SSL_logo_crop_white_text.png b/docs/logos/lightly_SSL_logo_crop_white_text.png
new file mode 100644
index 000000000..90733d7d2
Binary files /dev/null and b/docs/logos/lightly_SSL_logo_crop_white_text.png differ
diff --git a/docs/source/_templates/footer.html b/docs/source/_templates/footer.html
index 86060f5ef..217e11623 100644
--- a/docs/source/_templates/footer.html
+++ b/docs/source/_templates/footer.html
@@ -26,7 +26,13 @@
{%- else %}
{% set copyright = copyright|e %}
- © {% trans %}Copyright{% endtrans %} {{ copyright_year }}, {{ copyright }}
+ © {% trans %}Copyright{% endtrans %} {{ copyright_year }}
+ | {{ copyright }}
+ |
+
+ Source Code
+
+ | Lightly Worker Solution documentation
{%- endif %}
{%- endif %}
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
index bcf32e9f5..e39f8042f 100644
--- a/docs/source/_templates/layout.html
+++ b/docs/source/_templates/layout.html
@@ -8,6 +8,36 @@
We need this to override the footer
-->
{%- block content %}
+
+
+
Looking to easily do active learning on millions of samples? See our Lighly Worker docs.
+
{% if theme_style_external_links|tobool %}
{% else %}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 0d649486c..eb2721c57 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -22,10 +22,10 @@
# -- Project information -----------------------------------------------------
project = "lightly"
-copyright_year = "2020"
+copyright_year = "2020-"
copyright = "Lightly AG"
website_url = "https://www.lightly.ai/"
-author = "Philipp Wirth, Igor Susmelj"
+author = "Lightly Team"
# The full version, including alpha/beta/rc tags
release = lightly.__version__
@@ -98,7 +98,7 @@
html_favicon = "favicon.png"
-html_logo = "../logos/lightly_logo_crop_white_text.png"
+html_logo = "../logos/lightly_SSL_logo_crop_white_text.png"
# Exposes variables so that they can be used by django
html_context = {
diff --git a/docs/source/docker/advanced/datapool.rst b/docs/source/docker/advanced/datapool.rst
index 6a1b6badb..389af3336 100644
--- a/docs/source/docker/advanced/datapool.rst
+++ b/docs/source/docker/advanced/datapool.rst
@@ -3,7 +3,7 @@
Datapool
=================
-Lightly has been designed in a way that you can incrementally build up a
+Lightly Worker has been designed in a way that you can incrementally build up a
dataset for your project. The software automatically keeps track of the
representations of previously selected samples and uses this information
to pick new samples in order to maximize the quality of the final dataset.
@@ -43,7 +43,7 @@ has the following advantages:
If you want to search all data in your bucket for new samples
instead of only newly added data,
then set :code:`'datasource.process_all': True` in your worker config. This has the
-same effect as creating a new Lightly dataset and running the Lightly Worker from scratch
+same effect as creating a new dataset and running the Lightly Worker from scratch
on the full dataset. We process all data instead of only the newly added ones.
@@ -67,7 +67,7 @@ first time.
|-- passageway1-c1.avi
`-- terrace1-c0.avi
-Let's create a Lightly dataset which uses that bucket (choose your tab - S3, GCS or Azure):
+Let's create a dataset which uses that bucket (choose your tab - S3, GCS or Azure):
.. tabs::
.. tab:: AWS S3 Datasource
diff --git a/docs/source/docker/advanced/datasource_metadata.rst b/docs/source/docker/advanced/datasource_metadata.rst
index d16b4356d..da7288393 100644
--- a/docs/source/docker/advanced/datasource_metadata.rst
+++ b/docs/source/docker/advanced/datasource_metadata.rst
@@ -3,7 +3,7 @@
Add Metadata to a Datasource
===============================
-Lightly can make use of metadata collected alongside your images or videos. Provided,
+Lightly Worker can make use of metadata collected alongside your images or videos. Provided,
metadata can be used to steer the selection process and to analyze the selected dataset
in the Lightly Platform.
@@ -45,7 +45,7 @@ Metadata Schema
The schema defines the format of the metadata and helps the Lightly Platform to correctly identify
and display different types of metadata.
-You can provide this information to Lightly by adding a `schema.json` to the
+You can provide this information to Lightly Worker by adding a `schema.json` to the
`.lightly/metadata` directory. The `schema.json` file must contain a list of
configuration entries. Each of the entries is a dictionary with the following keys:
@@ -105,9 +105,9 @@ of the images we have collected. A possible schema could look like this:
Metadata Files
--------------
-Lightly requires a single metadata file per image or video. If an image or video has no corresponding metadata file,
-Lightly assumes the default value from the `schema.json`. If a metadata file is provided for a full video,
-Lightly assumes that the metadata is valid for all frames in that video.
+Lightly Worker requires a single metadata file per image or video. If an image or video has no corresponding metadata file,
+Lightly Worker assumes the default value from the `schema.json`. If a metadata file is provided for a full video,
+Lightly Worker assumes that the metadata is valid for all frames in that video.
To provide metadata for an image or a video, place a metadata file with the same name
as the image or video in the `.lightly/metadata` directory but change the file extension to
@@ -130,8 +130,8 @@ as the image or video in the `.lightly/metadata` directory but change the file e
When working with videos it's also possible to provide metadata on a per-frame basis.
-Then, Lightly requires a metadata file per frame. If a frame has no corresponding metadata file,
-Lightly assumes the default value from the `schema.json`. Lightly uses a naming convention to
+Then, Lightly Worker requires a metadata file per frame. If a frame has no corresponding metadata file,
+Lightly Worker assumes the default value from the `schema.json`. Lightly Worker uses a naming convention to
identify frames: The filename of a frame consists of the video filename, the frame number
(padded to the length of the number of frames in the video), the video format separated
by hyphens. For example, for a video with 200 frames, the frame number will be padded
diff --git a/docs/source/docker/advanced/datasource_predictions.rst b/docs/source/docker/advanced/datasource_predictions.rst
index ae2160525..584bc4d93 100644
--- a/docs/source/docker/advanced/datasource_predictions.rst
+++ b/docs/source/docker/advanced/datasource_predictions.rst
@@ -3,9 +3,9 @@
Add Predictions to a Datasource
===============================
-Lightly can not only use images you provided in a datasource, but also predictions of a ML model on your images.
+Lightly Worker can not only use images you provided in a datasource, but also predictions of a ML model on your images.
They are used for active learning for selecting images based on the objects in them.
-Furthermore, object detection predictions can be used running Lightly on object level.
+Furthermore, object detection predictions can be used running Lightly Worker on object level.
By providing the predictions in the datasource,
you have full control over them and they scale well to millions of samples.
Furthermore, if you add new samples to your datasource, you can simultaneously
@@ -62,8 +62,8 @@ and an object detection task). All of the files are explained in the next sectio
Prediction Tasks
----------------
-To let Lightly know what kind of prediction tasks you want to work with, Lightly
-needs to know their names. It's very easy to let Lightly know which tasks exist:
+To let the Lightly Worker know what kind of prediction tasks you want to work with, the Lightly Worker
+needs to know their names. It's very easy to let the Lightly Worker know which tasks exist:
simply add a `tasks.json` in your lightly bucket stored at the subdirectory `.lightly/predictions/`.
The `tasks.json` file must include a list of your task names which must match name
@@ -116,7 +116,7 @@ we can specify which subfolders contain relevant predictions in the `tasks.json`
Prediction Schema
-----------------
-For Lightly it's required to store a prediction schema. The schema defines the
+It's required to store a prediction schema. The schema defines the
format of the predictions and helps the Lightly Platform to correctly identify
and display classes. It also helps to prevent errors as all predictions which
are loaded are validated against this schema.
@@ -127,7 +127,7 @@ all the categories and their corresponding ids. For other tasks, such as keypoin
detection, it can be useful to store additional information like which keypoints
are connected with each other by an edge.
-You can provide all this information to Lightly by adding a `schema.json` to the
+You can provide all this information to the Lightly Worker by adding a `schema.json` to the
directory of the respective task. The schema.json file must have a key `categories`
with a corresponding list of categories following the COCO annotation format.
It must also have a key `task_type` indicating the type of the predictions.
@@ -167,10 +167,10 @@ The three classes are sunny, clouded, and rainy.
Prediction Files
----------------
-Lightly requires a **single prediction file per image**. The file should be a .json
+The Lightly Worker requires a **single prediction file per image**. The file should be a .json
following the format defined under :ref:`prediction-format` and stored in the subdirectory
`.lightly/predictions/${TASK_NAME}` in the storage bucket the dataset was configured with.
-In order to make sure Lightly can match the predictions to the correct source image,
+In order to make sure the Lightly Worker can match the predictions to the correct source image,
it's necessary to follow the naming convention:
.. code-block:: bash
@@ -189,7 +189,7 @@ it's necessary to follow the naming convention:
Prediction Files for Videos
---------------------------
-When working with videos, Lightly requires a prediction file per frame. Lightly
+When working with videos, the Lightly Worker requires a prediction file per frame. Lightly
uses a naming convention to identify frames: The filename of a frame consists of
the video filename, the video format, and the frame number (padded to the length
of the number of frames in the video) separated by hyphens. For example, for a
@@ -363,7 +363,7 @@ belonging to that category. Optionally, a list of probabilities can be provided
containing a probability for each category, indicating the likeliness that the
segment belongs to that category.
-To kickstart using Lightly with semantic segmentation predictions we created an
+To kickstart using the Lightly Worker with semantic segmentation predictions we created an
example script that takes model predictions and converts them to the correct
format :download:`semantic_segmentation_inference.py
`
@@ -403,13 +403,13 @@ following function:
Segmentation models oftentimes output a probability for each pixel and category.
Storing such probabilities can quickly result in large file sizes if the input
-images have a high resolution. To reduce storage requirements, Lightly expects
+images have a high resolution. To reduce storage requirements, Lightly Worker expects
only a single score or probability per segmentation. If you have scores or
probabilities for each pixel in the image, you have to first aggregate them
into a single score/probability. We recommend to take either the median or mean
score/probability over all pixels within the segmentation mask. The example
below shows how pixelwise segmentation predictions can be converted to the
-format required by Lightly.
+format required by the Lightly Worker.
.. code-block:: python
@@ -522,7 +522,7 @@ Don't forget to change these 2 parameters at the top of the script.
Creating Prediction Files for Videos
-------------------------------------
-Lightly expects one prediction file per frame in a video. Predictions can be
+The Lightly Worker expects one prediction file per frame in a video. Predictions can be
created following the Python example code below. Make sure that `PyAV `_
is installed on your system for it to work correctly.
diff --git a/docs/source/docker/advanced/load_model_from_checkpoint.rst b/docs/source/docker/advanced/load_model_from_checkpoint.rst
index 09e6f4ed7..58c28c5e4 100644
--- a/docs/source/docker/advanced/load_model_from_checkpoint.rst
+++ b/docs/source/docker/advanced/load_model_from_checkpoint.rst
@@ -3,8 +3,8 @@
Load Model from Checkpoint
==========================
-The Lightly worker can be used to :ref:`train a self-supervised model on your data. `
-Lightly saves the weights of the model after training to a checkpoint file in
+The Lightly Worker can be used to :ref:`train a self-supervised model on your data. `
+Lightly Worker saves the weights of the model after training to a checkpoint file in
:code:`output_dir/lightly_epoch_X.ckpt`. This checkpoint can then be further
used to, for example, train a classifier model on your dataset. The code below
demonstrates how the checkpoint can be loaded:
diff --git a/docs/source/docker/advanced/object_level.rst b/docs/source/docker/advanced/object_level.rst
index d211ed85a..5b96289ec 100644
--- a/docs/source/docker/advanced/object_level.rst
+++ b/docs/source/docker/advanced/object_level.rst
@@ -2,7 +2,7 @@
Object Level
============
-Lightly does not only work on full images but also on an object level. This
+The Lightly Worker does not only work on full images but also on an object level. This
workflow is especially useful for datasets containing small objects or multiple
objects in each image and provides the following benefits over the full image
workflow:
@@ -21,7 +21,7 @@ workflow:
Prerequisites
-------------
-In order to use the object level workflow with Lightly, you will need the
+In order to use the object level workflow with the Lightly Worker, you will need the
following things:
- The installed Lightly Worker (see :ref:`docker-setup`)
@@ -31,13 +31,13 @@ following things:
.. note::
- If you don't have any predictions available, you can use the Lightly pretagging
+ If you don't have any predictions available, you can use the Lightly Worker pretagging
model. See :ref:`Pretagging ` for more information.
Predictions
-----------
-Lightly needs to know which objects to process. This information is provided
+The Lightly Worker needs to know which objects to process. This information is provided
by uploading a set of object predictions to the datasource (see :ref:`docker-datasource-predictions`).
Let's say we are working with a dataset containing different types of vehicles
and used an object detection model to find possible vehicle objects in the
@@ -170,7 +170,7 @@ code to sping up a Lightly Worker
Padding
-------
-Lightly makes it possible to add a padding around your bounding boxes. This allows
+The Lightly Worker makes it possible to add a padding around your bounding boxes. This allows
for better visualization of the cropped images in the web-app and can improve the
embeddings of the objects as the embedding model sees the objects in context. To add
padding, simply specify `object_level.padding=X` where `X` is the padding relative
@@ -239,9 +239,9 @@ properties of your dataset and reveal things like:
These hidden biases are hard to find in a dataset if you only rely on full
images or the coarse vehicle type predicted by the object detection model.
-Lightly helps you to identify them quickly and assists you in monitoring and
+The Lightly Worker helps you to identify them quickly and assists you in monitoring and
improving the quality of your dataset. After an initial exploration you can now
-take further steps to enhance the dataset using one of the workflows Lightly
+take further steps to enhance the dataset using one of the workflows the Lightly Worker
provides:
- Select a subset of your data using our :ref:`Sampling Algorithms `
@@ -252,7 +252,7 @@ provides:
Multiple Object Level Runs
--------------------------
You can run multiple object level workflows using the same dataset. To start a
-new run, please select your original full image dataset in the Lightly Web App
+new run, please select your original full image dataset in the Lightly Platform
and schedule a new run from there. If you are running the Lightly Worker from Python or
over the API, you have to set the `dataset_id` configuration option to the id of
the original full image dataset. In both cases make sure that the run is *not*
@@ -261,7 +261,7 @@ started from the crops dataset as this is not supported!
You can control to which crops dataset the newly selected object crops are
uploaded by setting the `object_level.crop_dataset_name` configuration option.
By default this option is not set and if you did not specify it in the first run,
-you can also omit it in future runs. In this case Lightly will automatically
+you can also omit it in future runs. In this case the Lightly Worker will automatically
find the existing crops dataset and add the new crops to it. If you want to
upload the crops to a new dataset or have set a custom crop dataset name in a
previous run, then set the `object_level.crop_dataset_name` option to a new
diff --git a/docs/source/docker/advanced/overview.rst b/docs/source/docker/advanced/overview.rst
index 14e592240..3057c5448 100644
--- a/docs/source/docker/advanced/overview.rst
+++ b/docs/source/docker/advanced/overview.rst
@@ -1,6 +1,6 @@
Advanced
===================================
-Here you learn more advanced usage patterns of Lightly Docker.
+Here you learn more advanced usage patterns of Lightly Worker.
.. toctree::
diff --git a/docs/source/docker/configuration/configuration.rst b/docs/source/docker/configuration/configuration.rst
index a8908dfb6..241e4766f 100644
--- a/docs/source/docker/configuration/configuration.rst
+++ b/docs/source/docker/configuration/configuration.rst
@@ -38,7 +38,7 @@ The following are parameters which can be passed to the container:
token: ''
worker:
- # If specified, the docker is started as a worker on the Lightly platform.
+ # If specified, the docker is started as a worker on the Lightly Platform.
worker_id: ''
# If True, the worker notifies that it is online even though another worker
# with the same worker_id is already online.
@@ -89,12 +89,12 @@ The following are parameters which can be passed to the container:
# shortest edge to x or to resize the image to (height, width), use =-1 for no
# resizing (default). This only affects the output size of the images dumped to
# the output folder with dump_dataset=True. To change the size of images
- # uploaded to the lightly platform or your cloud bucket please use the
+ # uploaded to the lightly Platform or your cloud bucket please use the
# lightly.resize option instead.
output_image_size: -1
output_image_format: 'png'
- # Upload the dataset to the Lightly platform.
+ # Upload the dataset to the Lightly Platform.
upload_dataset: False
# pretagging
@@ -134,14 +134,14 @@ The following are parameters which can be passed to the container:
name:
# If True keeps backup of all previous data pool states.
keep_history: True
- # Dataset id from Lightly platform where the datapool should be hosted.
+ # Dataset id from Lightly Platform where the datapool should be hosted.
dataset_id:
# datasource
# By default only new samples in the datasource are processed. Set process_all
# to True to reprocess all samples in the datasource.
datasource:
- # Dataset id from the Lightly platform.
+ # Dataset id from the Lightly Platform.
dataset_id:
# Set to True to reprocess all samples in the datasource.
process_all: False
@@ -192,7 +192,7 @@ The following are parameters which can be passed to the container:
# optional deterministic unique output subdirectory for run, in place of timestamp
run_directory:
-To get an overview of all possible configuration parameters of Lightly,
+To get an overview of all possible configuration parameters of the Lightly Worker,
please check out :ref:`ref-cli-config-default`
Choosing the Right Parameters
diff --git a/docs/source/docker/examples/datasets_in_the_wild.rst b/docs/source/docker/examples/datasets_in_the_wild.rst
index cfea4327f..90d2b1c3a 100644
--- a/docs/source/docker/examples/datasets_in_the_wild.rst
+++ b/docs/source/docker/examples/datasets_in_the_wild.rst
@@ -213,7 +213,7 @@ can process the video directly so we require only 6.4 MBytes of storage. This me
* - Metric
- ffmpeg extracted frames
- - Lightly using video
+ - Lightly Worker using video
- Reduction
* - Storage Consumption
- 447 MBytes + 6.4 MBytes
diff --git a/docs/source/docker/getting_started/first_steps.rst b/docs/source/docker/getting_started/first_steps.rst
index 715b55a40..e471bb2a4 100644
--- a/docs/source/docker/getting_started/first_steps.rst
+++ b/docs/source/docker/getting_started/first_steps.rst
@@ -26,8 +26,8 @@ The Lightly Worker follows a train, embed, select workflow:
The Lightly Worker can be easily triggered from your Python code. There are various parameters you can
-configure and we also expose the full configuration of the lightly self-supervised learning framework.
-You can use the Lightly Worker to train a self-supervised model instead of using the Lightly Python framework.
+configure and we also expose the full configuration of the Lightly self-supervised learning framework.
+You can use the Lightly Worker to train a self-supervised model instead of using the Lightly SSL framework.
Using Docker
-------------
@@ -56,8 +56,8 @@ Here, we quickly explain the most important parts of the typical **docker run**
Start the Lightly Worker Docker
--------------------------------
-Before we jump into the details of how to submit jobs, we need to start the Lightly image in
-worker mode (as outlined in :ref:`docker-setup`).
+Before we jump into the details of how to submit jobs, we need to start the
+Lightly Worker docker container in worker mode (as outlined in :ref:`docker-setup`).
**This is how you start your Lightly Worker:**
@@ -115,7 +115,7 @@ make sure to specify the `dataset_id` in the constructor.
INPUT bucket
^^^^^^^^^^^^
-The `INPUT` bucket is where Lightly reads your input data from. You must specify it and you must provide Lightly `LIST` and `READ` access to it.
+The `INPUT` bucket is where the Lightly Worker reads your input data from. You must specify it and you must provide Lightly `LIST` and `READ` access to it.
LIGHTLY bucket
^^^^^^^^^^^^^^
@@ -129,7 +129,7 @@ The `LIGHTLY` bucket is used for many purposes:
- Saving thumbnails of images for a more responsive Lightly Platform.
- Saving images of cropped out objects, if you use the object-level workflow. See also :ref:`docker-object-level`.
- Saving frames of videos, if your input consists of videos.
-- Providing the relevant filenames file if you want to to run the lightly worker only on a subset of input files: See also :ref:`specifying_relevant_files`.
+- Providing the relevant filenames file if you want to to run the Lightly Worker only on a subset of input files: See also :ref:`specifying_relevant_files`.
- Providing predictions for running the object level workflow or as additional information for the selection process. See also :ref:`docker-datasource-predictions`.
- Providing metadata as additional information for the selection process. See also :ref:`docker-datasource-metadata`.
@@ -351,8 +351,9 @@ epochs on the input images before embedding the images and selecting from them.
)
You may not always want to train for exactly 100 epochs with the default settings.
-The Lightly worker is a wrapper around the lightly Python package.
-Hence, for training and embedding the user can access all the settings from the lightly command-line tool.
+The Lightly Worker is a wrapper around the Lightly SSL Python package.
+Hence, for training and embedding the user can access and set all the settings
+known from the Lightly SSL Python package.
Here are some of the most common parameters for the **lightly_config**
you might want to change:
@@ -364,7 +365,7 @@ you might want to change:
.. code-block:: python
:emphasize-lines: 24, 35
- :caption: Accessing the lightly parameters from Python
+ :caption: Setting the Lightly SSL parameters from Python
scheduled_run_id = client.schedule_compute_worker_run(
worker_config={
diff --git a/docs/source/docker/getting_started/hardware_recommendations.rst b/docs/source/docker/getting_started/hardware_recommendations.rst
index 67edbda1e..a5d58b05b 100644
--- a/docs/source/docker/getting_started/hardware_recommendations.rst
+++ b/docs/source/docker/getting_started/hardware_recommendations.rst
@@ -3,7 +3,7 @@
Hardware recommendations
========================
-Lightly worker is usually run on dedicated hardware
+The Lightly Worker is usually run on dedicated hardware
or in the cloud on a compute instance
which is specifically spun up to run Lightly Worker standalone.
Our recommendations on the hardware requirements of this compute instance are
@@ -42,7 +42,7 @@ Finding the compute speed bottleneck
------------------------------------
Usually, the compute speed is limited by one of three potential bottlenecks.
-Different steps of the Lightly worker use these resources to a different extent.
+Different steps of the Lightly Worker use these resources to a different extent.
Thus the bottleneck changes throughout the run. The bottlenecks are:
- data read speed: I/O
diff --git a/docs/source/docker/getting_started/selection.rst b/docs/source/docker/getting_started/selection.rst
index 8df3c7186..12bc742ff 100644
--- a/docs/source/docker/getting_started/selection.rst
+++ b/docs/source/docker/getting_started/selection.rst
@@ -3,7 +3,7 @@
Selection
=========
-Lightly allows you to specify the subset to be selected based on several objectives.
+The Lightly Worker allows you to specify the subset to be selected based on several objectives.
E.g. you can specify that the images in the subset should be visually diverse, be images the model struggles with (active learning),
should only be sharp images, or have a certain distribution of classes, e.g. be 50% from sunny, 30% from cloudy and 20% from rainy weather.
@@ -13,12 +13,12 @@ Each of these objectives is defined by a `strategy`. A strategy consists of two
- The :code:`input` defines which data the objective is defined on. This data is either a scalar number or a vector for each sample in the dataset.
- The :code:`strategy` itself defines the objective to apply on the input data.
-Lightly allows you to specify several objectives at the same time. The algorithms try to fulfil all objectives simultaneously.
+The Lightly Worker allows you to specify several objectives at the same time. The algorithms try to fulfil all objectives simultaneously.
-Lightly's data selection algorithms support four types of input:
+Lightly Worker's data selection algorithms support four types of input:
- **Embeddings** computed using `our open source framework for self-supervised learning `_
-- **Lightly metadata** are metadata of images like the sharpness and computed out of the images themselves by Lightly.
+- **Lightly metadata** are metadata of images like the sharpness and computed out of the images themselves by the Lightly Worker.
- (Optional) :ref:`Model predictions ` such as classifications, object detections or segmentations
- (Optional) :ref:`Custom metadata ` can be anything you can encode in a json file (from numbers to categorical strings)
@@ -96,7 +96,7 @@ The input can be one of the following:
.. tab:: EMBEDDINGS
- The `lightly OSS framework for self supervised learning `_ is used to compute the embeddings.
+ The `Lightly OSS framework for self supervised learning `_ is used to compute the embeddings.
They are a vector of numbers for each sample.
You can define embeddings as input using:
@@ -213,7 +213,7 @@ The input can be one of the following:
- **Numerical** vs. **Categorical** values
- Not all metadata types can be used in all selection strategies. Lightly differentiates between numerical and categorical metadata.
+ Not all metadata types can be used in all selection strategies. The Lightly Worker differentiates between numerical and categorical metadata.
**Numerical** metadata are numbers (int, float), e.g. `lightly.sharpness` or `weather.temperature`. It is usually real-valued.
@@ -539,7 +539,7 @@ In the next step, all other strategies are applied in parallel.
from "my_weather_classification_task" for one strategy combined with predictions from
"my_object_detection_task" from another strategy.
-The Lightly optimizer tries to fulfil all strategies as good as possible.
+The Lightly Worker optimizer tries to fulfil all strategies as good as possible.
**Potential reasons why your objectives were not satisfied:**
- **Tradeoff between different objectives.**
@@ -558,12 +558,12 @@ The Lightly optimizer tries to fulfil all strategies as good as possible.
Selection on object level
-------------------------
-Lightly supports doing selection on :ref:`docker-object-level`.
+The Lightly Worker supports doing selection on :ref:`docker-object-level`.
While embeddings are fully available, there are some limitations regarding the usage of METADATA and predictions for SCORES and PREDICTIONS as input:
- When using the object level workflow, the object detections used to create the object crops out of the images are available and can be used for both the SCORES and PREDICTIONS input. However, predictions from other tasks are NOT available at the moment.
-- Lightly metadata is generated on the fly for the object crops and can thus be used for selection. However, other metadata is on image level and thus NOT available at the moment.
+- The Lightly Worker generates metadata on the fly for the object crops and can thus be used for selection. However, other metadata is on image level and thus NOT available at the moment.
If your use case would profit from using image-level data for object-level selection, please reach out to us.
diff --git a/docs/source/docker/getting_started/setup.rst b/docs/source/docker/getting_started/setup.rst
index 387a046eb..c91d1f0fd 100644
--- a/docs/source/docker/getting_started/setup.rst
+++ b/docs/source/docker/getting_started/setup.rst
@@ -7,7 +7,7 @@ Setup
Analytics
^^^^^^^^^
-The Lightly worker currently reports usage metrics to our analytics software
+The Lightly Worker currently reports usage metrics to our analytics software
(we use mixpanel) which uses https encrypted GET and POST requests to https://api.mixpanel.com.
The transmitted data includes information about crashes and the number of samples
that have been filtered. However, **the data does not include input / output samples**,
@@ -22,7 +22,7 @@ The licensing and account management is done through the :ref:`ref-authenticatio
obtained from the Lightly Platform (https://app.lightly.ai).
The token will be used to authenticate your account.
-The authentication happens at every run of the worker. Make sure the Lightly worker
+The authentication happens at every run of the worker. Make sure the Lightly Worker
has a working internet connection and has access to https://api.lightly.ai.
@@ -78,9 +78,9 @@ In short, installing the Docker image consists of the following steps:
a :code:`container-credentials.json` file from your account manager.
2. Authenticate your docker account
- To be able to download docker images of Lightly you need to log in with these credentials.
+ To be able to download docker images of the Lightly Worker you need to log in with these credentials.
- The following command will authenticate yourself to gain access to the Lightly docker images.
+ The following command will authenticate yourself to gain access to the Lightly Worker docker images.
We assume :code:`container-credentials.json` is in your current directory.
.. code-block:: console
@@ -123,7 +123,7 @@ In short, installing the Docker image consists of the following steps:
Update the Lightly Worker
^^^^^^^^^^^^^^^^^^^^^^^^^
-To update the Lightly worker we simply need to pull the latest docker image.
+To update the Lightly Worker we simply need to pull the latest docker image.
.. code-block:: console
@@ -140,7 +140,7 @@ Don't forget to tag the image again after pulling it.
instead of `latest`. We follow semantic versioning standards.
-Furthermore, we always recommend using the latest version of the lightly pip package
+Furthermore, we always recommend using the latest version of the Lightly SSL python package
alongside the latest version of the Lightly Worker. You can update the
pip package using the following command.
@@ -153,7 +153,7 @@ pip package using the following command.
Sanity Check
^^^^^^^^^^^^
-**Next**, verify that the Lightly worker is installed correctly by running the following command:
+**Next**, verify that the Lightly Worker is installed correctly by running the following command:
.. code-block:: console
@@ -172,7 +172,7 @@ You should see an output similar to this one:
Register the Lightly Worker
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-**Finally**, start the Lightly worker in waiting mode. In this mode, the worker will long-poll
+**Finally**, start the Lightly Worker in waiting mode. In this mode, the worker will long-poll
the Lightly API for new jobs to process. To do so, a worker first needs to be registered.
diff --git a/docs/source/docker/integration/overview.rst b/docs/source/docker/integration/overview.rst
index ef403acca..81e509a84 100644
--- a/docs/source/docker/integration/overview.rst
+++ b/docs/source/docker/integration/overview.rst
@@ -1,6 +1,6 @@
Integration
===================================
-Here you learn how to integrate the Lightly worker into data pre-processing pipelines.
+Here you learn how to integrate the Lightly Worker into data pre-processing pipelines.
.. toctree::
diff --git a/docs/source/docker/known_issues_faq.rst b/docs/source/docker/known_issues_faq.rst
index 3fbc8ef9c..175f7b311 100644
--- a/docs/source/docker/known_issues_faq.rst
+++ b/docs/source/docker/known_issues_faq.rst
@@ -143,7 +143,7 @@ workers for data fetching :code:`lightly.loader.num_workers` there might be not
To solve this problem we need to reduce the number of workers or
increase the shared memory for the docker runtime.
-Lightly determines the number of CPU cores available and sets the number
+Lightly Worker determines the number of CPU cores available and sets the number
of workers to the same number. If you have a machine with many cores but not so much
memory (e.g. less than 2 GB of memory per core) it can happen that you run out
of memory and you rather want to reduce
@@ -298,7 +298,7 @@ a section about the `credHelpers` they might overrule the authentication.
The `credHelpers` can overrule the key for certain URLs. This can lead to
permission errors pulling the docker image.
-The Lightly docker images are hosted in the European location. Therefore,
+The Lightly Worker docker images are hosted in the European location. Therefore,
it's important that pulling from the `eu.gcr.io` domain is using
the provided credentials.
@@ -314,7 +314,7 @@ There are two ways to solve the problem:
cat container-credentials.json | docker login -u _json_key --password-stdin https://eu.gcr.io
- You can work with two configs. We recommend creating a dedicated folder
- for the Lightly docker config.
+ for the Lightly Worker docker config.
.. code-block:: console
@@ -324,5 +324,5 @@ There are two ways to solve the problem:
docker --config ~/.docker_lightly/ pull eu.gcr.io/boris-250909/lightly/worker:latest
-Whenever you're pulling a new image (e.g. updating Lightly) you would need to
+Whenever you're pulling a new image (e.g. updating Lightly Worker) you would need to
pass it the corresponding config using the `--config` parameter.
\ No newline at end of file
diff --git a/docs/source/docker/overview.rst b/docs/source/docker/overview.rst
index 0bf10e992..60f523cc9 100644
--- a/docs/source/docker/overview.rst
+++ b/docs/source/docker/overview.rst
@@ -39,7 +39,7 @@ We worked hard to make this happen and are very proud to present you with the fo
* Check for exact duplicates and report them
- * We expose the full lightly OSS framework config
+ * We expose the full Lightly SSL OSS framework config
* Automated reporting of the datasets for each run
diff --git a/docs/source/docker_archive/configuration/configuration.rst b/docs/source/docker_archive/configuration/configuration.rst
index 71d26c1a9..ba0409af9 100644
--- a/docs/source/docker_archive/configuration/configuration.rst
+++ b/docs/source/docker_archive/configuration/configuration.rst
@@ -9,7 +9,7 @@ Configuration
The old workflow described in these docs will not be supported with new Lightly Worker versions above 2.6.
Please switch to our `new documentation page `_ instead.
-As the lightly framework the docker solution can be configured using Hydra.
+As the Lightly SSL framework the docker solution can be configured using Hydra.
The example below shows how the `token` parameter can be set when running the docker container.
diff --git a/docs/source/docker_archive/getting_started/first_steps.rst b/docs/source/docker_archive/getting_started/first_steps.rst
index ec6283449..1f5da7b40 100644
--- a/docs/source/docker_archive/getting_started/first_steps.rst
+++ b/docs/source/docker_archive/getting_started/first_steps.rst
@@ -37,7 +37,7 @@ them.
The docker solution can be used as a command-line interface. You run the container, tell it where to find data, and where to store the result. That's it.
-There are various parameters you can pass to the container. We put a lot of effort to also expose the full lightly framework configuration.
+There are various parameters you can pass to the container. We put a lot of effort to also expose the full Lightly SSL framework configuration.
You could use the docker solution to train a self-supervised model instead of using the Python framework.
Before jumping into the detail let's have a look at some basics.
diff --git a/docs/source/docker_archive/known_issues_faq.rst b/docs/source/docker_archive/known_issues_faq.rst
index 8dcb4939b..350fcfef3 100644
--- a/docs/source/docker_archive/known_issues_faq.rst
+++ b/docs/source/docker_archive/known_issues_faq.rst
@@ -43,7 +43,7 @@ Try to install `nvidia-docker` following the guide
`here `_.
-Shared Memory Error when running Lightly Docker
+Shared Memory Error when running Lightly Worker
-----------------------------------------------
The following error message appears when the docker runtime has not enough
diff --git a/docs/source/docker_archive/overview.rst b/docs/source/docker_archive/overview.rst
index 2d6872898..968332729 100644
--- a/docs/source/docker_archive/overview.rst
+++ b/docs/source/docker_archive/overview.rst
@@ -8,7 +8,7 @@ Docker Archive
Please switch to our `new documentation page `_ instead.
We all know that sometimes when working with ML data we deal with really BIG datasets. The cloud solution is great for exploration, prototyping
-and an easy way to work with lightly. But there is more!
+and an easy way to work with Lightly. But there is more!
.. figure:: images/lightly_docker_overview.png
:align: center
@@ -50,7 +50,7 @@ We worked hard to make this happen and are very proud to present you with the fo
* Check for exact duplicates and report them
- * We expose the full lightly framework config
+ * We expose the full Lightly SSL framework config
* Automated reporting of the datasets for each run
diff --git a/docs/source/getting_started/advanced.rst b/docs/source/getting_started/advanced.rst
index 8a10f6f13..5be80439c 100644
--- a/docs/source/getting_started/advanced.rst
+++ b/docs/source/getting_started/advanced.rst
@@ -3,7 +3,7 @@
Advanced Concepts in Self-Supervised Learning
=============================================
-In this section, we will have a look at some more advanced topics around Lightly.
+In this section, we will have a look at some more advanced topics around Lightly SSL.
Augmentations
-------------
@@ -76,8 +76,8 @@ Some interesting papers regarding invariances in self-supervised learning:
Transforms
^^^^^^^^^^
-Lightly uses `Torchvision transforms `_
-to apply augmentations to images. The Lightly :py:mod:`~lightly.transforms` module
+Lightly SSL uses `Torchvision transforms `_
+to apply augmentations to images. The Lightly SSL :py:mod:`~lightly.transforms` module
exposes transforms for common self-supervised learning methods.
The most important difference compared to transforms for other tasks, such as
@@ -95,9 +95,9 @@ while :ref:`dino` uses two global and multiple, smaller local views per image.
Custom Transforms
^^^^^^^^^^^^^^^^^
-There are three ways how you can customize augmentations in Lightly:
+There are three ways how you can customize augmentations in Lightly SSL:
-1. Modify the parameters of the :py:mod:`~lightly.transforms` provided by Lightly:
+1. Modify the parameters of the :py:mod:`~lightly.transforms` provided by Lightly SSL:
.. code-block:: python
@@ -171,7 +171,7 @@ Previewing Augmentations
It often can be very useful to understand how the image augmentations we pick affect
the input dataset. We provide a few helper methods that make it very easy to
-preview augmentations using Lightly.
+preview augmentations using Lightly SSL.
.. literalinclude:: code_examples/plot_image_augmentations.py
@@ -212,7 +212,7 @@ our DINO model would see during training.
Models
------
-See the :ref:`models` section for a list of models that are available in Lightly.
+See the :ref:`models` section for a list of models that are available in Lightly SSL.
Do you know a model that should be on this list? Please add an `issue `_
on GitHub :)
@@ -222,14 +222,14 @@ other vision model. When creating a self-supervised learning model you pass it a
backbone. You need to make sure the backbone output dimension matches the input
dimension of the head component for the respective self-supervised model.
-Lightly has a built-in generator for ResNets. However, the model architecture slightly
+Lightly SSL has a built-in generator for ResNets. However, the model architecture slightly
differs from the official ResNet implementation. The difference is in the first few
-layers. Whereas the official ResNet starts with a 7x7 convolution the one from Lightly
+layers. Whereas the official ResNet starts with a 7x7 convolution the one from Lightly SSL
has a 3x3 convolution.
* The 3x3 convolution variant is more efficient (fewer parameters and faster
processing) and is better suited for small input images (32x32 pixels or 64x64 pixels).
- We recommend using the Lightly variant for cifar10 or running the model on a microcontroller
+ We recommend using the Lightly SSL variant for cifar10 or running the model on a microcontroller
(see https://github.com/ARM-software/EndpointAI/tree/master/ProofOfConcepts/Vision/OpenMvMaskDefaults)
* However, the 7x7 convolution variant is better suited for larger images
since the number of features is smaller due to the stride and additional
@@ -241,7 +241,7 @@ has a 3x3 convolution.
from torch import nn
- # Create a Lightly ResNet.
+ # Create a Lightly SSL ResNet.
from lightly.models import ResNetGenerator
resnet = ResNetGenerator('resnet-18')
# Ignore the classification layer as we want the features as output.
@@ -267,7 +267,7 @@ has a 3x3 convolution.
resnet_simclr = SimCLR(backbone, hidden_dim=512, out_dim=128)
-You can also use **custom backbones** with Lightly. We provide a
+You can also use **custom backbones** with Lightly SSL. We provide a
`colab notebook to show how you can use torchvision or timm models
`_.
diff --git a/docs/source/getting_started/benchmarks.rst b/docs/source/getting_started/benchmarks.rst
index 585a6a514..1947e6e9c 100644
--- a/docs/source/getting_started/benchmarks.rst
+++ b/docs/source/getting_started/benchmarks.rst
@@ -1,160 +1,197 @@
Benchmarks
===================================
-We show benchmarks of the different models for self-supervised learning
-and their performance on public datasets.
+Implemented models and their performance on various datasets. Hyperparameters are not tuned for maximum accuracy.
+List of available benchmarks:
-We have benchmarks we regularly update for these datasets:
-
-- `Imagenet`_
-- `Imagenet100`_
-- `ImageNette`_
+- `ImageNet1k`_
+- `ImageNet100`_
+- `Imagenette`_
- `CIFAR-10`_
-ImageNet
---------
+ImageNet1k
+----------
-We use the ImageNet1k ILSVRC2012 split provided here: https://image-net.org/download.php.
+- `Dataset `_
+- `Code `_
-Self-supervised training of a SimCLR model for 100 epochs with total batch size 256
-takes about two days on two GeForce RTX 4090 GPUs. You can reproduce the results with
-the code at `benchmarks/imagenet/resnet50 `_.
+The following experiments have been conducted on a system with 2x4090 GPUs.
+Training a model takes around four days for 100 epochs (35 min per epoch), including kNN, linear probing, and fine-tuning evaluation.
-Evaluation settings are based on these papers:
+Evaluation settings are based on the following papers:
- Linear: `SimCLR `_
- Finetune: `SimCLR `_
-- KNN: `InstDisc `_
-
-See the `benchmarking scripts `_ for details.
-
+- kNN: `InstDisc `_
.. csv-table:: Imagenet benchmark results.
- :header: "Model", "Backbone", "Batch Size", "Epochs", "Linear Top1", "Linear Top5", "Finetune Top1", "Finetune Top5", "KNN Top1", "KNN Top5", "Tensorboard", "Checkpoint"
+ :header: "Model", "Backbone", "Batch Size", "Epochs", "Linear Top1", "Linear Top5", "Finetune Top1", "Finetune Top5", "kNN Top1", "kNN Top5", "Tensorboard", "Checkpoint"
:widths: 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20
+ "BarlowTwins", "Res50", "256", "100", "62.9", "84.3", "72.6", "90.9", "45.6", "73.9", "`link `_", "`link `_"
"BYOL", "Res50", "256", "100", "62.4", "84.7", "74.0", "91.9", "45.6", "74.8", "`link `_", "`link `_"
"DINO", "Res50", "128", "100", "68.2", "87.9", "72.5", "90.8", "49.9", "78.7", "`link `_", "`link `_"
"SimCLR*", "Res50", "256", "100", "63.2", "85.2", "73.9", "91.9", "44.8", "73.9", "`link `_", "`link `_"
"SimCLR* + DCL", "Res50", "256", "100", "65.1", "86.2", "73.5", "91.7", "49.6", "77.5", "`link `_", "`link `_"
"SimCLR* + DCLW", "Res50", "256", "100", "64.5", "86.0", "73.2", "91.5", "48.5", "76.8", "`link `_", "`link `_"
"SwAV", "Res50", "256", "100", "67.2", "88.1", "75.4", "92.7", "49.5", "78.6", "`link `_", "`link `_"
+ "VICReg", "Res50", "256", "100", "63.0", "85.4", "73.7", "91.9", "46.3", "75.2", "`link `_", "`link `_"
-*\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in SimCLR paper.*
+*\*We use square root learning rate scaling instead of linear scaling as it yields better results for smaller batch sizes. See Appendix B.1 in the SimCLR paper.*
+Found a missing model? Track the progress of our planned benchmarks on `GitHub `_.
-ImageNette
------------------------------------
+ImageNet100
+-----------
-We use the ImageNette dataset provided here: https://github.com/fastai/imagenette
+- `Dataset `_
+- :download:`Code `
-For our benchmarks we use the 160px version and resize the input images to 128 pixels.
-Training a single model for 800 epochs on a A6000 GPU takes about 3-5 hours.
+Imagenet100 is a subset of the popular ImageNet1k dataset. It consists of 100 classes
+with 1300 training and 50 validation images per class. We train the
+self-supervised models from scratch on the training data. At the end of every
+epoch we embed all training images and use the features for a kNN classifier
+with k=20 on the test set. The reported kNN Top 1 is the max accuracy
+over all epochs the model reached. All experiments use the same ResNet-18 backbone and
+the default ImageNet1k training parameters from the respective papers.
+
+The following experiments have been conducted on a system with single A6000 GPU.
+Training a model takes between 20 and 30 hours, including kNN evaluation.
+
+.. csv-table:: Imagenet100 benchmark results
+ :header: "Model", "Backbone", "Batch Size", "Epochs", "kNN Top 1", "Runtime", "GPU Memory"
+ :widths: 20, 20, 20, 20, 20, 20, 20
+
+ "BarlowTwins", "Res18", "256", "200", "0.465", "1319.3 Min", "11.3 GByte"
+ "BYOL", "Res18", "256", "200", "0.439", "1315.4 Min", "12.9 GByte"
+ "DINO", "Res18", "256", "200", "0.518", "1868.5 Min", "17.4 GByte"
+ "FastSiam", "Res18", "256", "200", "0.559", "1856.2 Min", "22.0 GByte"
+ "Moco", "Res18", "256", "200", "0.560", "1314.2 Min", "13.1 GByte"
+ "NNCLR", "Res18", "256", "200", "0.453", "1198.6 Min", "11.8 GByte"
+ "SimCLR", "Res18", "256", "200", "0.469", "1207.7 Min", "11.3 GByte"
+ "SimSiam", "Res18", "256", "200", "0.534", "1175.0 Min", "11.1 GByte"
+ "SwaV", "Res18", "256", "200", "0.678", "1569.2 Min", "16.9 GByte"
-.. csv-table:: ImageNette benchmark results using kNN evaluation on the test set using 128x128 input resolution.
- :header: "Model", "Batch Size", "Epochs", "KNN Test Accuracy", "Runtime", "GPU Memory"
- :widths: 20, 20, 20, 20, 20, 20
+Imagenette
+----------
+
+- `Dataset `_
+- :download:`Code `
- "BarlowTwins", "256", "800", "0.852", "298.5 Min", "4.0 GByte"
- "BYOL", "256", "800", "0.887", "214.8 Min", "4.3 GByte"
- "DCL", "256", "800", "0.861", "189.1 Min", "3.7 GByte"
- "DCLW", "256", "800", "0.865", "192.2 Min", "3.7 GByte"
- "DINO (Res18)", "256", "800", "0.888", "312.3 Min", "6.6 GByte"
- "FastSiam", "256", "800", "0.873", "299.6 Min", "7.3 GByte"
- "MAE (ViT-S)", "256", "800", "0.610", "248.2 Min", "4.4 GByte"
- "MSN (ViT-S)", "256", "800", "0.828", "515.5 Min", "14.7 GByte"
- "Moco", "256", "800", "0.874", "231.7 Min", "4.3 GByte"
- "NNCLR", "256", "800", "0.884", "212.5 Min", "3.8 GByte"
- "PMSN (ViT-S)", "256", "800", "0.822", "505.8 Min", "14.7 GByte"
- "SimCLR", "256", "800", "0.889", "193.5 Min", "3.7 GByte"
- "SimMIM (ViT-B32)", "256", "800", "0.343", "446.5 Min", "9.7 GByte"
- "SimSiam", "256", "800", "0.872", "206.4 Min", "3.9 GByte"
- "SwaV", "256", "800", "0.902", "283.2 Min", "6.4 GByte"
- "SwaVQueue", "256", "800", "0.890", "282.7 Min", "6.4 GByte"
- "SMoG", "256", "800", "0.788", "232.1 Min", "2.6 GByte"
- "TiCo", "256", "800", "0.856", "177.8 Min", "2.5 GByte"
- "VICReg", "256", "800", "0.845", "205.6 Min", "4.0 GByte"
- "VICRegL", "256", "800", "0.778", "218.7 Min", "4.0 GByte"
+Imagenette is a subset of 10 easily classified classes from ImageNet.
+For our benchmarks we use the 160px version of the Imagenette dataset and
+resize the input images to 128 pixels during training.
+We train the self-supervised models from scratch on the training data. At the end of every
+epoch we embed all training images and use the features for a kNN classifier
+with k=20 on the test set. The reported kNN Top 1 is the max accuracy
+over all epochs the model reached. All experiments use the same ResNet-18 backbone and
+the default ImageNet1k training parameters from the respective papers.
+
+The following experiments have been conducted on a system with single A6000 GPU.
+Training a model takes three to five hours, including kNN evaluation.
+
+
+.. csv-table:: Imagenette benchmark results
+ :header: "Model", "Backbone", "Batch Size", "Epochs", "kNN Top 1", "Runtime", "GPU Memory"
+ :widths: 20, 20, 20, 20, 20, 20, 20
+
+ "BarlowTwins", "Res18", "256", "800", "0.852", "298.5 Min", "4.0 GByte"
+ "BYOL", "Res18", "256", "800", "0.887", "214.8 Min", "4.3 GByte"
+ "DCL", "Res18", "256", "800", "0.861", "189.1 Min", "3.7 GByte"
+ "DCLW", "Res18", "256", "800", "0.865", "192.2 Min", "3.7 GByte"
+ "DINO", "Res18", "256", "800", "0.888", "312.3 Min", "6.6 GByte"
+ "FastSiam", "Res18", "256", "800", "0.873", "299.6 Min", "7.3 GByte"
+ "MAE", "ViT-S", "256", "800", "0.610", "248.2 Min", "4.4 GByte"
+ "MSN", "ViT-S", "256", "800", "0.828", "515.5 Min", "14.7 GByte"
+ "Moco", "Res18", "256", "800", "0.874", "231.7 Min", "4.3 GByte"
+ "NNCLR", "Res18", "256", "800", "0.884", "212.5 Min", "3.8 GByte"
+ "PMSN", "ViT-S", "256", "800", "0.822", "505.8 Min", "14.7 GByte"
+ "SimCLR", "Res18", "256", "800", "0.889", "193.5 Min", "3.7 GByte"
+ "SimMIM", "ViT-B32", "256", "800", "0.343", "446.5 Min", "9.7 GByte"
+ "SimSiam", "Res18", "256", "800", "0.872", "206.4 Min", "3.9 GByte"
+ "SwaV", "Res18", "256", "800", "0.902", "283.2 Min", "6.4 GByte"
+ "SwaVQueue", "Res18", "256", "800", "0.890", "282.7 Min", "6.4 GByte"
+ "SMoG", "Res18", "256", "800", "0.788", "232.1 Min", "2.6 GByte"
+ "TiCo", "Res18", "256", "800", "0.856", "177.8 Min", "2.5 GByte"
+ "VICReg", "Res18", "256", "800", "0.845", "205.6 Min", "4.0 GByte"
+ "VICRegL", "Res18", "256", "800", "0.778", "218.7 Min", "4.0 GByte"
-You can reproduce the benchmarks using the following script:
-:download:`imagenette_benchmark.py `
CIFAR-10
------------------------------------
+--------
+
+- `Dataset `_
+- :download:`Code `
-Cifar10 consists of 50k training images and 10k testing images. We train the
+CIFAR-10 consists of 50k training images and 10k testing images. We train the
self-supervised models from scratch on the training data. At the end of every
epoch we embed all training images and use the features for a kNN classifier
-with k=200 on the test set. The reported kNN test accuracy is the max accuracy
+with k=200 on the test set. The reported kNN Top 1 is the max accuracy
over all epochs the model reached.
All experiments use the same ResNet-18 backbone and we disable the gaussian blur
augmentation due to the small image sizes.
.. note:: The ResNet-18 backbone in this benchmark is slightly different from
the torchvision variant as it starts with a 3x3 convolution and has no
- stride and no `MaxPool2d`. This is a typical variation used for cifar10
+ stride and no `MaxPool2d`. This is a typical variation used for CIFAR-10
benchmarks of SSL methods.
.. role:: raw-html(raw)
:format: html
-.. csv-table:: Cifar10 benchmark results showing kNN test accuracy, runtime and peak GPU memory consumption for different training setups.
- :header: "Model", "Batch Size", "Epochs", "KNN Test Accuracy", "Runtime", "GPU Memory"
- :widths: 20, 20, 20, 30, 20, 20
-
- "BarlowTwins", "128", "200", "0.842", "375.9 Min", "1.7 GByte"
- "BYOL", "128", "200", "0.869", "121.9 Min", "1.6 GByte"
- "DCL", "128", "200", "0.844", "102.2 Min", "1.5 GByte"
- "DCLW", "128", "200", "0.833", "100.4 Min", "1.5 GByte"
- "DINO", "128", "200", "0.840", "120.3 Min", "1.6 GByte"
- "FastSiam", "128", "200", "0.906", "164.0 Min", "2.7 GByte"
- "Moco", "128", "200", "0.838", "128.8 Min", "1.7 GByte"
- "NNCLR", "128", "200", "0.834", "101.5 Min", "1.5 GByte"
- "SimCLR", "128", "200", "0.847", "97.7 Min", "1.5 GByte"
- "SimSiam", "128", "200", "0.819", "97.3 Min", "1.6 GByte"
- "SwaV", "128", "200", "0.812", "99.6 Min", "1.5 GByte"
- "SMoG", "128", "200", "0.743", "192.2 Min", "1.2 GByte"
- "BarlowTwins", "512", "200", "0.819", "153.3 Min", "5.1 GByte"
- "BYOL", "512", "200", "0.868", "108.3 Min", "5.6 GByte"
- "DCL", "512", "200", "0.840", "88.2 Min", "4.9 GByte"
- "DCLW", "512", "200", "0.824", "87.9 Min", "4.9 GByte"
- "DINO", "512", "200", "0.813", "108.6 Min", "5.0 GByte"
- "FastSiam", "512", "200", "0.788", "146.9 Min", "9.5 GByte"
- "Moco (*)", "512", "200", "0.847", "112.2 Min", "5.6 GByte"
- "NNCLR (*)", "512", "200", "0.815", "88.1 Min", "5.0 GByte"
- "SimCLR", "512", "200", "0.848", "87.1 Min", "4.9 GByte"
- "SimSiam", "512", "200", "0.764", "87.8 Min", "5.0 GByte"
- "SwaV", "512", "200", "0.842", "88.7 Min", "4.9 GByte"
- "SMoG", "512", "200", "0.686", "110.0 Min", "3.4 GByte"
- "BarlowTwins", "512", "800", "0.859", "517.5 Min", "7.9 GByte"
- "BYOL", "512", "800", "0.910", "400.9 Min", "5.4 GByte"
- "DCL", "512", "800", "0.874", "334.6 Min", "4.9 GByte"
- "DCLW", "512", "800", "0.871", "333.3 Min", "4.9 GByte"
- "DINO", "512", "800", "0.848", "405.2 Min", "5.0 GByte"
- "FastSiam", "512", "800", "0.902", "582.0 Min", "9.5 GByte"
- "Moco (*)", "512", "800", "0.899", "417.8 Min", "5.4 GByte"
- "NNCLR (*)", "512", "800", "0.892", "335.0 Min", "5.0 GByte"
- "SimCLR", "512", "800", "0.879", "331.1 Min", "4.9 GByte"
- "SimSiam", "512", "800", "0.904", "333.7 Min", "5.1 GByte"
- "SwaV", "512", "800", "0.884", "330.5 Min", "5.0 GByte"
- "SMoG", "512", "800", "0.800", "415.6 Min", "3.2 GByte"
-
-(*): Increased size of memory bank from 4096 to 8192 to avoid too quickly
-changing memory bank due to larger batch size.
+.. csv-table:: CIFAR-10 benchmark results
+ :header: "Model", "Backbone", "Batch Size", "Epochs", "kNN Top 1", "Runtime", "GPU Memory"
+ :widths: 20, 20, 20, 20, 30, 20, 20
+
+ "BarlowTwins", "Res18", "128", "200", "0.842", "375.9 Min", "1.7 GByte"
+ "BYOL", "Res18", "128", "200", "0.869", "121.9 Min", "1.6 GByte"
+ "DCL", "Res18", "128", "200", "0.844", "102.2 Min", "1.5 GByte"
+ "DCLW", "Res18", "128", "200", "0.833", "100.4 Min", "1.5 GByte"
+ "DINO", "Res18", "128", "200", "0.840", "120.3 Min", "1.6 GByte"
+ "FastSiam", "Res18", "128", "200", "0.906", "164.0 Min", "2.7 GByte"
+ "Moco", "Res18", "128", "200", "0.838", "128.8 Min", "1.7 GByte"
+ "NNCLR", "Res18", "128", "200", "0.834", "101.5 Min", "1.5 GByte"
+ "SimCLR", "Res18", "128", "200", "0.847", "97.7 Min", "1.5 GByte"
+ "SimSiam", "Res18", "128", "200", "0.819", "97.3 Min", "1.6 GByte"
+ "SwaV", "Res18", "128", "200", "0.812", "99.6 Min", "1.5 GByte"
+ "SMoG", "Res18", "128", "200", "0.743", "192.2 Min", "1.2 GByte"
+ "BarlowTwins", "Res18", "512", "200", "0.819", "153.3 Min", "5.1 GByte"
+ "BYOL", "Res18", "512", "200", "0.868", "108.3 Min", "5.6 GByte"
+ "DCL", "Res18", "512", "200", "0.840", "88.2 Min", "4.9 GByte"
+ "DCLW", "Res18", "512", "200", "0.824", "87.9 Min", "4.9 GByte"
+ "DINO", "Res18", "512", "200", "0.813", "108.6 Min", "5.0 GByte"
+ "FastSiam", "Res18", "512", "200", "0.788", "146.9 Min", "9.5 GByte"
+ "Moco*", "Res18", "512", "200", "0.847", "112.2 Min", "5.6 GByte"
+ "NNCLR*", "Res18", "512", "200", "0.815", "88.1 Min", "5.0 GByte"
+ "SimCLR", "Res18", "512", "200", "0.848", "87.1 Min", "4.9 GByte"
+ "SimSiam", "Res18", "512", "200", "0.764", "87.8 Min", "5.0 GByte"
+ "SwaV", "Res18", "512", "200", "0.842", "88.7 Min", "4.9 GByte"
+ "SMoG", "Res18", "512", "200", "0.686", "110.0 Min", "3.4 GByte"
+ "BarlowTwins", "Res18", "512", "800", "0.859", "517.5 Min", "7.9 GByte"
+ "BYOL", "Res18", "512", "800", "0.910", "400.9 Min", "5.4 GByte"
+ "DCL", "Res18", "512", "800", "0.874", "334.6 Min", "4.9 GByte"
+ "DCLW", "Res18", "512", "800", "0.871", "333.3 Min", "4.9 GByte"
+ "DINO", "Res18", "512", "800", "0.848", "405.2 Min", "5.0 GByte"
+ "FastSiam", "Res18", "512", "800", "0.902", "582.0 Min", "9.5 GByte"
+ "Moco*", "Res18", "512", "800", "0.899", "417.8 Min", "5.4 GByte"
+ "NNCLR*", "Res18", "512", "800", "0.892", "335.0 Min", "5.0 GByte"
+ "SimCLR", "Res18", "512", "800", "0.879", "331.1 Min", "4.9 GByte"
+ "SimSiam", "Res18", "512", "800", "0.904", "333.7 Min", "5.1 GByte"
+ "SwaV", "Res18", "512", "800", "0.884", "330.5 Min", "5.0 GByte"
+ "SMoG", "Res18", "512", "800", "0.800", "415.6 Min", "3.2 GByte"
+
+*\*Increased size of memory bank from 4096 to 8192 to avoid
+changing the memory bank too quickly due to larger batch size.*
We make the following observations running the benchmark:
- Self-Supervised models benefit from larger batch sizes and longer training.
-- All models need around 3-4h to complete the 200 epoch benchmark and 11-13h
- for the 800 epoch benchmark.
-- Memory consumption is roughly the same for all models.
-- Some models, like MoCo or SwaV, learn quickly in the beginning and then
- plateau. Other models, like SimSiam or NNCLR, take longer to warm up but then
- catch up when training for 800 epochs. This can also be seen in the
- figure below.
-
+- Training time is roughly the same for all methods (three to four hours for 200 epochs).
+- Memory consumption is roughly the same for all methods.
+- MoCo and SwaV learn quickly in the beginning and then plateau.
+- SimSiam or NNCLR take longer to warm up but then catch up when training for 800 epochs.
.. figure:: images/cifar10_benchmark_knn_accuracy_800_epochs.png
:align: center
@@ -166,48 +203,9 @@ We make the following observations running the benchmark:
Interactive plots of the 800 epoch accuracy and training loss are hosted on
`tensorboard `__.
-You can reproduce the benchmarks using the following script:
-:download:`cifar10_benchmark.py `
-
-
-Imagenet100
------------
-
-Imagenet100 is a subset of the popular ImageNet-1k dataset. It consists of 100 classes
-with 1300 training and 50 validation images per class. We train the
-self-supervised models from scratch on the training data. At the end of every
-epoch we embed all training images and use the features for a kNN classifier
-with k=20 on the test set. The reported kNN test accuracy is the max accuracy
-over all epochs the model reached. All experiments use the same ResNet-18 backbone and
-with the default ImageNet-1k training parameters from the respective papers.
-
-
-.. csv-table:: Imagenet100 benchmark results showing kNN test accuracy, runtime and peak GPU memory consumption for different training setups.
- :header: "Model", "Batch Size", "Epochs", "KNN Test Accuracy", "Runtime", "GPU Memory"
- :widths: 20, 20, 20, 20, 20, 20
-
- "BarlowTwins", "256", "200", "0.465", "1319.3 Min", "11.3 GByte"
- "BYOL", "256", "200", "0.439", "1315.4 Min", "12.9 GByte"
- "DINO", "256", "200", "0.518", "1868.5 Min", "17.4 GByte"
- "FastSiam", "256", "200", "0.559", "1856.2 Min", "22.0 GByte"
- "Moco", "256", "200", "0.560", "1314.2 Min", "13.1 GByte"
- "NNCLR", "256", "200", "0.453", "1198.6 Min", "11.8 GByte"
- "SimCLR", "256", "200", "0.469", "1207.7 Min", "11.3 GByte"
- "SimSiam", "256", "200", "0.534", "1175.0 Min", "11.1 GByte"
- "SwaV", "256", "200", "0.678", "1569.2 Min", "16.9 GByte"
-
-You can reproduce the benchmarks using the following script:
-:download:`imagenet100_benchmark.py `
-
Next Steps
----------
-Now that you understand the performance of the different lightly methods how about
-looking into a tutorial to implement your favorite model?
-
-- :ref:`input-structure-label`
-- :ref:`lightly-moco-tutorial-2`
-- :ref:`lightly-simclr-tutorial-3`
-- :ref:`lightly-simsiam-tutorial-4`
-- :ref:`lightly-custom-augmentation-5`
\ No newline at end of file
+Train your own self-supervised model following our :ref:`examples ` or
+check out our :ref:`tutorials `.
\ No newline at end of file
diff --git a/docs/source/getting_started/benchmarks/cifar10_benchmark.py b/docs/source/getting_started/benchmarks/cifar10_benchmark.py
index c1db40af0..c10777f8f 100644
--- a/docs/source/getting_started/benchmarks/cifar10_benchmark.py
+++ b/docs/source/getting_started/benchmarks/cifar10_benchmark.py
@@ -83,6 +83,9 @@
from lightly.models import ResNetGenerator, modules, utils
from lightly.models.modules import heads, memory_bank
from lightly.transforms import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
DINOTransform,
FastSiamTransform,
SimCLRTransform,
@@ -160,6 +163,12 @@
path_to_train = "/datasets/cifar10/train/"
path_to_test = "/datasets/cifar10/test/"
+# Use BYOL augmentations
+byol_transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
+
# Use SimCLR augmentations
simclr_transform = SimCLRTransform(
input_size=32,
@@ -228,8 +237,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
- BarlowTwinsModel: simclr_transform,
- BYOLModel: simclr_transform,
+ BarlowTwinsModel: byol_transform,
+ BYOLModel: byol_transform,
DCL: simclr_transform,
DCLW: simclr_transform,
DINOModel: dino_transform,
diff --git a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
index e4fa787cb..3dcc1beb2 100644
--- a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
+++ b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py
@@ -50,6 +50,7 @@
from lightly.models import modules, utils
from lightly.models.modules import heads
from lightly.transforms import (
+ BYOLTransform,
DINOTransform,
FastSiamTransform,
SimCLRTransform,
@@ -108,14 +109,17 @@
path_to_train = "/datasets/imagenet100/train/"
path_to_test = "/datasets/imagenet100/val/"
+# Use BYOL augmentations
+byol_transform = BYOLTransform()
+
# Use SimCLR augmentations
-simclr_transform = SimCLRTransform(input_size=input_size)
+simclr_transform = SimCLRTransform()
# Use SimSiam augmentations
-simsiam_transform = SimSiamTransform(input_size=input_size)
+simsiam_transform = SimSiamTransform()
# Multi crop augmentation for FastSiam
-fast_siam_transform = FastSiamTransform(input_size=input_size)
+fast_siam_transform = FastSiamTransform()
# Multi crop augmentation for SwAV
swav_transform = SwaVTransform()
@@ -155,8 +159,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
- BarlowTwinsModel: simclr_transform,
- BYOLModel: simclr_transform,
+ BarlowTwinsModel: byol_transform,
+ BYOLModel: byol_transform,
DINOModel: dino_transform,
FastSiamModel: fast_siam_transform,
MocoModel: simclr_transform,
diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py
index 1345a7a88..4eb2e073f 100644
--- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py
+++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py
@@ -89,6 +89,9 @@
from lightly.models import modules, utils
from lightly.models.modules import heads, masked_autoencoder, memory_bank
from lightly.transforms import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
DINOTransform,
FastSiamTransform,
MAETransform,
@@ -153,6 +156,12 @@
path_to_train = "/datasets/imagenette2-160/train/"
path_to_test = "/datasets/imagenette2-160/val/"
+# Use BYOL augmentations
+byol_transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=input_size),
+ view_2_transform=BYOLView2Transform(input_size=input_size),
+)
+
# Use SimCLR augmentations
simclr_transform = SimCLRTransform(
input_size=input_size,
@@ -243,8 +252,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
- BarlowTwinsModel: simclr_transform,
- BYOLModel: simclr_transform,
+ BarlowTwinsModel: byol_transform,
+ BYOLModel: byol_transform,
DCL: simclr_transform,
DCLW: simclr_transform,
DINOModel: dino_transform,
@@ -260,7 +269,7 @@ def create_dataset_train_ssl(model):
SwaVModel: swav_transform,
SwaVQueueModel: swav_transform,
SMoGModel: smog_transform,
- TiCoModel: simclr_transform,
+ TiCoModel: byol_transform,
VICRegModel: vicreg_transform,
VICRegLModel: vicregl_transform,
}
diff --git a/docs/source/getting_started/command_line_tool.rst b/docs/source/getting_started/command_line_tool.rst
index bf336ff35..93c6d4fcc 100644
--- a/docs/source/getting_started/command_line_tool.rst
+++ b/docs/source/getting_started/command_line_tool.rst
@@ -3,7 +3,7 @@
Command-line tool
=================
-The Lightly framework provides you with a command-line interface (CLI) to train
+The Lightly SSL framework provides you with a command-line interface (CLI) to train
self-supervised models and create embeddings without having to write a single
line of code.
@@ -24,16 +24,16 @@ the CLI.
-Check the installation of lightly
------------------------------------
-To see if the lightly command-line tool was installed correctly, you can run the
-following command which will print the installed lightly version:
+Check the installation of Lightly SSL
+-------------------------------------
+To see if the Lightly SSL command-line tool was installed correctly, you can run the
+following command which will print the version of the installed Lightly SSL package:
.. code-block:: bash
lightly-version
-If lightly was installed correctly, you should see something like this:
+If Lightly SSL was installed correctly, you should see something like this:
.. code-block:: bash
diff --git a/docs/source/getting_started/distributed_training.rst b/docs/source/getting_started/distributed_training.rst
index e1b140c28..30daefd7b 100644
--- a/docs/source/getting_started/distributed_training.rst
+++ b/docs/source/getting_started/distributed_training.rst
@@ -3,7 +3,7 @@
Distributed Training
====================
-Lightly supports training your model on multiple GPUs using Pytorch Lightning
+Lightly SSL supports training your model on multiple GPUs using Pytorch Lightning
and Distributed Data Parallel (DDP) training. You can find reference
implementations for all our models in the :ref:`models` section.
@@ -12,7 +12,7 @@ Training with multiple gpus is also available from the command line: :ref:`cli-t
For details on distributed training we recommend the following pages:
- `Pytorch Distributed Overview `_
-- `Pytorch Lightning Multi-GPU Training `_
+- `Pytorch Lightning Multi-GPU Training `_
There are different levels of synchronization for distributed training. One can
diff --git a/docs/source/getting_started/install.rst b/docs/source/getting_started/install.rst
index e05f61fc8..1cb8ff670 100644
--- a/docs/source/getting_started/install.rst
+++ b/docs/source/getting_started/install.rst
@@ -4,24 +4,24 @@ Installation
Supported Python versions
-------------------------
-Lightly requires Python 3.6+. We recommend installing Lighlty in a Linux or OSX environment.
+Lightly SSL requires Python 3.6+. We recommend installing Lightly SSL in a Linux or OSX environment.
.. _rst-installing:
-Installing Lightly
-------------------
+Installing Lightly SSL
+----------------------
-You can install Lightly and its dependencies from PyPi with:
+You can install Lightly SSL and its dependencies from PyPi with:
.. code-block:: bash
pip install lightly
-We strongly recommend that you install Lightly in a dedicated virtualenv, to avoid conflicting with your system packages.
+We strongly recommend that you install Lightly SSL in a dedicated virtualenv, to avoid conflicting with your system packages.
Dependencies
------------
-Lightly currently uses `PyTorch `_ as the underlying deep learning framework.
+Lightly SSL currently uses `PyTorch `_ as the underlying deep learning framework.
On top of PyTorch we use `Hydra `_ for managing configurations and
`PyTorch Lightning `_ for training models.
@@ -35,4 +35,4 @@ If you want to work with video files you need to additionally install
Next Steps
------------
-Check out our tutorial: :ref:`lightly-tutorials`
+Start with one of our tutorials: :ref:`input-structure-label`
diff --git a/docs/source/getting_started/lightly_at_a_glance.rst b/docs/source/getting_started/lightly_at_a_glance.rst
index e4b4af2f6..78361da69 100644
--- a/docs/source/getting_started/lightly_at_a_glance.rst
+++ b/docs/source/getting_started/lightly_at_a_glance.rst
@@ -3,14 +3,14 @@
Self-supervised learning
========================
-Lightly is a computer vision framework for training deep learning models using self-supervised learning.
+Lightly SSL is a computer vision framework for training deep learning models using self-supervised learning.
The framework can be used for a wide range of useful applications such as finding the nearest
neighbors, similarity search, transfer learning, or data analytics.
-How Lightly Works
------------------
-The flexible design of Lightly makes it easy to integrate in your Python code. Lightly is built
+How Lightly SSL Works
+---------------------
+The flexible design of Lightly SSL makes it easy to integrate in your Python code. Lightly SSL is built
completely around PyTorch and the different pieces can be put together to fit *your* requirements.
Data and Transformations
@@ -219,12 +219,12 @@ Furthermore, the ResNet backbone can be used for transfer and few-shot learning.
Self-supervised learning does not require labels for a model to be trained on. Lightly,
however, supports the use of additional labels. For example, if you train a model
on a folder 'cats' with subfolders 'Maine Coon', 'Bengal' and 'British Shorthair'
- Lightly automatically returns the enumerated labels as a list.
+ Lightly SSL automatically returns the enumerated labels as a list.
-Lightly in Three Lines
-----------------------------------------
+Lightly SSL in Three Lines
+--------------------------
-Lightly also offers an easy-to-use interface. The following lines show how the package can
+Lightly SSL also offers an easy-to-use interface. The following lines show how the package can
be used to train a model with self-supervision and create embeddings with only three lines
of code.
diff --git a/docs/source/getting_started/main_concepts.rst b/docs/source/getting_started/main_concepts.rst
index a891539b3..8c2c561f0 100644
--- a/docs/source/getting_started/main_concepts.rst
+++ b/docs/source/getting_started/main_concepts.rst
@@ -6,18 +6,18 @@ Main Concepts
Self-Supervised Learning
------------------------
-The figure below shows an overview of the different concepts used by the Lightly package
+The figure below shows an overview of the different concepts used by the Lightly SSL package
and a schema of how they interact. The expressions in **bold** are explained further
below.
.. figure:: images/lightly_overview.png
- :align: center
- :alt: Lightly Overview
+ :align: center
+ :alt: Lightly SSL Overview
- Overview of the different concepts used by the Lightly package and how they interact.
+ Overview of the different concepts used by the Lightly SSL package and how they interact.
* **Dataset**
- In Lightly, datasets are accessed through :py:class:`~lightly.data.dataset.LightlyDataset`.
+ In Lightly SSL, datasets are accessed through :py:class:`~lightly.data.dataset.LightlyDataset`.
You can create a :py:class:`~lightly.data.dataset.LightlyDataset` from a directory of
images or videos, or directly from a `torchvision dataset `_.
You can learn more about this in our tutorial:
@@ -36,7 +36,7 @@ below.
* **Collate Function**
The collate function aggregates the views of multiple images into a single batch.
- You can use the default collate function. Lightly also provides a
+ You can use the default collate function. Lightly SSL also provides a
:py:class:`~lightly.data.multi_view_collate.MultiViewCollate`
* **Dataloader**
@@ -56,7 +56,7 @@ below.
They project the outputs of the backbone, commonly called *embeddings*,
*representations*, or *features*, into a new space in which the loss is calculated.
This has been found to be hugely beneficial instead of directly calculating the loss
- on the embeddings. Lightly provides common :py:mod:`~lightly.models.modules.heads`
+ on the embeddings. Lightly SSL provides common :py:mod:`~lightly.models.modules.heads`
that can be added to any backbone.
* **Model**
@@ -71,26 +71,25 @@ below.
* :ref:`sphx_glr_tutorials_package_tutorial_simsiam_esa.py`
* **Loss**
- The loss function plays a crucial role in self-supervised learning. Lightly provides
+ The loss function plays a crucial role in self-supervised learning. Lightly SSL provides
common loss functions in the :py:mod:`~lightly.loss` module.
* **Optimizer**
- With Lightly, you can use any `PyTorch optimizer `_
+ With Lightly SSL, you can use any `PyTorch optimizer `_
to train your model.
* **Training**
The model can either be trained using a plain `PyTorch training loop `_
or with a dedicated framework such as `PyTorch Lightning `_.
- Lightly lets you choose what is best for you. Check out our :ref:`models ` and
- `tutorials `_
- sections on how to train models with PyTorch or PyTorch Lightning.
+ Lightly SSL lets you choose what is best for you. Check out our :ref:`models ` and
+ :ref:`tutorials ` sections on how to train models with PyTorch
+ or PyTorch Lightning.
* **Image Embeddings**
During the training process, the model learns to create compact embeddings from images.
The embeddings, also often called representations or features, can then be used for
tasks such as identifying similar images or creating a diverse subset from your data:
- * :ref:`lightly-tutorial-sunflowers`
* :ref:`lightly-simsiam-tutorial-4`
* **Pre-Trained Backbone**
diff --git a/docs/source/index.rst b/docs/source/index.rst
index c4e02bf77..81949504a 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -4,36 +4,37 @@
contain the root `toctree` directive.
-.. image:: ../logos/lightly_logo_crop.png
- :width: 600
- :alt: Lightly
+.. image:: ../logos/lightly_SSL_logo_crop.png
+ :width: 600
+ :align: center
+ :alt: Lightly SSL Self-Supervised Learning
Documentation
===================================
.. note:: These pages document the Lightly self-supervised learning library.
- If you are looking for Lightly Worker Solution to easily process millions
- of samples and run powerful active learning algorithms on your data
- please follow
- `Lightly Worker documentation `_.
+ If you are looking for the Lightly Worker Solution with
+ advanced `active learning algorithms `_ and
+ `selection strategies `_ to select the best samples
+ within millions of unlabeled images or video frames stored in your cloud storage or locally,
+ please follow our `Lightly Worker documentation `_.
-Lightly is a computer vision framework for self-supervised learning.
+Lightly SSL is a computer vision framework for self-supervised learning.
-With Lightly you can train deep learning models using self-supervision.
+With Lightly SSL you can train deep learning models using self-supervision.
This means, that you don’t require any labels to train a model.
-Lightly has been built to help you understand and work with large unlabeled
+Lightly SSL has been built to help you understand and work with large unlabeled
datasets. It is built on top of PyTorch and therefore fully compatible with
other frameworks such as Fast.ai.
-Lightly
--------
+Lightly AI
+----------
- `Homepage `_
-- `Web-App `_
-- `Documentation `_
-- `Lightly Solution Documentation (Lightly Worker & API) `_
+- `Lightly Worker Solution Documentation `_
+- `Lightly Platform `_
- `Github `_
- `Discord `_ (We have weekly paper sessions!)
@@ -58,8 +59,12 @@ Lightly
:maxdepth: 1
:caption: Tutorials
- tutorials/package.rst
- tutorials/platform.rst
+ tutorials/structure_your_input.rst
+ tutorials/package/tutorial_moco_memory_bank.rst
+ tutorials/package/tutorial_simclr_clothing.rst
+ tutorials/package/tutorial_simsiam_esa.rst
+ tutorials/package/tutorial_custom_augmentations.rst
+ tutorials/package/tutorial_pretrain_detectron2.rst
.. toctree::
:maxdepth: 1
diff --git a/docs/source/lightly.cli.rst b/docs/source/lightly.cli.rst
index 62996119e..44861d86d 100644
--- a/docs/source/lightly.cli.rst
+++ b/docs/source/lightly.cli.rst
@@ -35,7 +35,7 @@ lightly.cli
.config.config.yaml
-------------------
-The default settings for all command line tools in the lightly Python package are stored in a YAML config file.
+The default settings for all command line tools in the Lightly SSL Python package are stored in a YAML config file.
The config file is distributed along with the Python package and can be adapted to fit custom requirements.
The arguments are grouped into namespaces. For example, everything related to the embedding model is grouped under
diff --git a/docs/source/tutorials/package.rst b/docs/source/tutorials/package.rst
deleted file mode 100644
index 1b0e2904f..000000000
--- a/docs/source/tutorials/package.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-.. _lightly-tutorials:
-
-Python Package
-===================================
-
-With the lightly framework you can use the power of self-supervised learning
-for computervision with ease. Here we show you tutorials to help you work with
-the Python library.
-
-Since lightly is built on top of `PyTorch `_
-and `PyTorch Lightning `_
-you might want to have a look at the two frameworks to understand basic concepts.
-
-.. toctree::
- :maxdepth: 1
-
- structure_your_input.rst
- package/tutorial_moco_memory_bank.rst
- package/tutorial_simclr_clothing.rst
- package/tutorial_simsiam_esa.rst
- package/tutorial_custom_augmentations.rst
- package/tutorial_pretrain_detectron2.rst
diff --git a/docs/source/tutorials/platform.rst b/docs/source/tutorials/platform.rst
deleted file mode 100644
index b7b499457..000000000
--- a/docs/source/tutorials/platform.rst
+++ /dev/null
@@ -1,30 +0,0 @@
-.. _platform-tutorials-label:
-
-Platform
-===================================
-
-.. warning::
- **Tutorials are outdated**
-
- These tutorials use a deprecated workflow of the Lightly Solution and will be removed in the future.
- Please refer to the `new documentation and tutorials `_ instead.
-
-Lightly is more than just a framework for self-supervised learning. We built a complete data curation platform on top.
-Use the embeddings generated using the lightly framework and use them to curate your dataset. Collaborate with your friends
-and share the curated data with your favorite data labeling partner.
-
-In this tutorial series, you will learn how to get the most out of the platform.
-
-.. toctree::
- :maxdepth: 1
-
- platform/tutorial_pizza_filter.rst
- platform/tutorial_sunflowers.rst
- platform/tutorial_active_learning.rst
- platform/tutorial_active_learning_detectron2.rst
- platform/tutorial_aquarium_custom_metadata.rst
- platform/tutorial_cropped_objects_metadata.rst
- Tutorial 7: Active Learning with Nvidia TLT
- Tutorial 8: Integration with LabelStudio for Active Learning
- Tutorial 9: Embedded COVID mask detection
- platform/tutorial_label_studio_export.rst
\ No newline at end of file
diff --git a/docs/source/tutorials/structure_your_input.rst b/docs/source/tutorials/structure_your_input.rst
index 7994f7b11..976bba9bf 100644
--- a/docs/source/tutorials/structure_your_input.rst
+++ b/docs/source/tutorials/structure_your_input.rst
@@ -3,23 +3,32 @@
Tutorial 1: Structure Your Input
================================
-The `lightly Python package `_ can process image datasets to generate embeddings
-or to upload data to the `Lightly platform `_. In this tutorial you will learn how to structure
-your image dataset such that it is understood by our framework.
-
-You can also skip this tutorial and jump right into training a model:
+If you are familiar with torch-like image dataset, you can skip this tutorial and
+jump right into training a model:
- :ref:`lightly-moco-tutorial-2`
- :ref:`lightly-simclr-tutorial-3`
+- :ref:`lightly-simsiam-tutorial-4`
+- :ref:`lightly-custom-augmentation-5`
+- :ref:`lightly-detectron-tutorial-6`
+
+If you are looking for a use case that's not covered by the above tutorials please
+let us know by `creating an issue `_
+for it.
+
Supported File Types
--------------------
+By default, the `Lightly SSL Python package `_
+can process images or videos for self-supervised learning or for generating embeddings.
+You can always write your own torch-like dataset to use other file types.
+
Images
^^^^^^^^^^^^^^^^^^^^^
-Since lightly uses `Pillow `_
-for image loading it also supports all the image formats supported by
+Since Lightly SSL uses `Pillow `_
+for image loading, it also supports all the image formats supported by
`Pillow `_.
- .jpg, .png, .tiff and
@@ -28,14 +37,13 @@ for image loading it also supports all the image formats supported by
Videos
^^^^^^^^^^^^^^^^^^^^^
-To load videos directly lightly uses
+To load videos directly, Lightly SSL uses
`torchvision `_ and
`PyAV `_. The following formats are supported.
- .mov, .mp4 and .avi
-
Image Folder Datasets
---------------------
@@ -46,7 +54,7 @@ Flat Directory Containing Images
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
You can store all images of interest in a single folder without additional hierarchy. For example below,
-lightly will load all filenames and images in the directory `data/`. Additionally, it will assign all images
+Lightly SSL will load all filenames and images in the directory `data/`. Additionally, it will assign all images
a placeholder label.
.. code-block:: bash
@@ -58,7 +66,7 @@ a placeholder label.
...
+--- img-N.jpg
-For the structure above, lightly will understand the input as follows:
+For the structure above, Lightly SSL will understand the input as follows:
.. code-block:: python
@@ -80,7 +88,7 @@ Directory with Subdirectories Containing Images
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
You can give structure to your input directory by collecting the input images in subdirectories. In this case,
-the filenames loaded by lightly are with respect to the "root directory" `data/`. Furthermore, lightly assigns
+the filenames loaded by Lightly SSL are with respect to the "root directory" `data/`. Furthermore, Lightly SSL assigns
each image a so-called "weak-label" indicating to which subdirectory it belongs.
.. code-block:: bash
@@ -106,7 +114,7 @@ each image a so-called "weak-label" indicating to which subdirectory it belongs.
...
+-- img-N10.jpg
-For the structure above, lightly will understand the input as follows:
+For the structure above, Lightly SSL will understand the input as follows:
.. code-block:: python
@@ -136,9 +144,9 @@ For the structure above, lightly will understand the input as follows:
Video Folder Datasets
---------------------
-The lightly Python package allows you to work `directly` on video data, without having
+The Lightly SSL Python package allows you to work `directly` on video data, without having
to exctract the frames first. This can save a lot of disk space as video files are
-typically strongly compressed. Using lightly on video data is as simple as pointing
+typically strongly compressed. Using Lightly SSL on video data is as simple as pointing
the software at an input directory where one or more videos are stored. The package will
automatically detect all video files and index them so that each frame can be accessed.
@@ -154,175 +162,10 @@ An example for an input directory with videos could look like this:
+-- my_video_4.avi
We assign a weak label to each video.
-To upload the three videos from above to the platform, you can use
-.. code-block:: bash
-
- lightly-upload token='123' new_dataset_name='my_video_dataset' input_dir='data/'
-
-All other operations (like training a self-supervised model and embedding the frames individually)
-also work on video data. Give it a try!
.. note::
Randomly accessing video frames is slower compared to accessing the extracted frames on disk. However,
by working directly on video files, one can save a lot of disk space because the frames do not have to
be extracted beforehand.
-
-
-Embedding Files
----------------
-
-Embeddings generated by the lightly Python package are typically stored in a `.csv` file and can then be uploaded to the
-Lightly platform from the command line. If the embeddings were generated with the lightly command-line tool, they have
-the correct format already.
-
-You can also save your own embeddings in a `.csv` file to upload them. In that case, make sure the file meets the format
-requirements: Use the `save_embeddings` function from `lightly.utils.io` to convert your embeddings, weak-labels, and
-filenames to the right shape.
-
-.. code-block:: python
-
- import lightly.utils.io as io
-
- # embeddings:
- # embeddings are stored as an n_samples x dim numpy array
- embeddings = np.array([[0.1, 0.5],
- [0.2, 0.2],
- [0.1, 0.9],
- [0.3, 0.2]])
-
- # weak-labels
- # a list of integers carrying meta-information about the images
- labels = [0, 1, 1, 0]
-
- # filenames
- # list of strings containing the filenames of the images w.r.t the input directory
- filenames = [
- 'weak-label-0/img-1.jpg',
- 'weak-label-1/img-1.jpg',
- 'weak-label-1/img-2.jpg',
- 'weak-label-0/img-2.jpg',
- ]
-
- io.save_embeddings('my_embeddings_file.csv', embeddings, labels, filenames)
-
-The code shown above will produce the following `.csv` file:
-
-.. list-table:: my_embeddings_file.csv
- :widths: 50 50 50 50
- :header-rows: 1
-
- * - filenames
- - embedding_0
- - embedding_1
- - labels
- * - weak-label-0/img-1.jpg
- - 0.1
- - 0.5
- - 0
- * - weak-label-1/img-1.jpg
- - 0.2
- - 0.2
- - 1
- * - weak-label-1/img-2.jpg
- - 0.1
- - 0.9
- - 1
- * - weak-label-0/img-2.jpg
- - 0.3
- - 0.2
- - 0
-
-.. note:: Note that lightly automatically creates "weak" labels for datasets
- with subfolders. Each subfolder corresponds to one weak label.
- The labels are called "weak" since they might not be used for a task
- you want to solve with ML directly but still can be relevant to group
- the data into buckets.
-
-
-Advanced usage of Embeddings
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-In some cases you want to enrich the embeddings with additional information.
-The lightly csv scheme is very simple and can be easily extended.
-For example, you can add your own embeddings to the existing embeddings. This
-could be useful if you have additional meta information about each sample.
-
-.. _lightly-custom-labels:
-
-Add Custom Embeddings
-""""""""""""""""""""""""""""""
-
-To add custom embeddings you need to add more embedding columns to the .csv file.
-Make sure you keep the enumeration of the embeddings in correct order.
-
-
-Here you see an embedding from lightly with a 2-dimensional embedding vector.
-
-.. list-table:: lightly_embeddings.csv
- :widths: 50 50 50 50
- :header-rows: 1
-
- * - filenames
- - embedding_0
- - embedding_1
- - labels
- * - img-1.jpg
- - 0.1
- - 0.5
- - 0
- * - img-2.jpg
- - 0.2
- - 0.2
- - 0
- * - img-3.jpg
- - 0.1
- - 0.9
- - 1
-
-We can now append our embedding vector to the .csv file.
-
-.. list-table:: lightly_with_custom_embeddings.csv
- :widths: 50 50 50 50 50 50
- :header-rows: 1
-
- * - filenames
- - embedding_0
- - embedding_1
- - embedding_2
- - embedding_3
- - labels
- * - img-1.jpg
- - 0.1
- - 0.5
- - 0.2
- - 0.7
- - 0
- * - img-2.jpg
- - 0.2
- - -0.2
- - 1.1
- - -0.4
- - 0
- * - img-3.jpg
- - 0.1
- - 0.9
- - -0.2
- - 0.5
- - 1
-
-.. note:: The embedding columns must be grouped together. You can not have
- another column between two embedding columns.
-
-
-Next Steps
------------------
-
-Now that you understand the various data formats lightly supports you can
-start training a model:
-
-- :ref:`lightly-moco-tutorial-2`
-- :ref:`lightly-simclr-tutorial-3`
-- :ref:`lightly-simsiam-tutorial-4`
-- :ref:`lightly-custom-augmentation-5`
\ No newline at end of file
diff --git a/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg b/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg
deleted file mode 100644
index 7e976c820..000000000
Binary files a/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg and /dev/null differ
diff --git a/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg b/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg
deleted file mode 100644
index 589b47933..000000000
Binary files a/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg and /dev/null differ
diff --git a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py b/docs/source/tutorials_source/platform/tutorial_pizza_filter.py
deleted file mode 100644
index 3500afeed..000000000
--- a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-This documentation accompanies the video tutorial: `youtube link `_
-
-##############################################################################
-
-.. _lightly-tutorial-pizza-filter:
-
-Tutorial 1: Curate Pizza Images
-===============================
-
-.. warning::
- **Tutorial is outdated**
-
- This tutorial uses a deprecated workflow of the Lightly Solution and will be removed in the future.
- Please refer to the `new documentation and tutorials `_ instead.
-
-In this tutorial, you will learn how to upload a dataset to the Lightly platform,
-curate the data, and finally use the curated data to train a model.
-
-What you will learn
--------------------
-
-* Create and upload a new dataset
-* Curate a dataset using simple image metrics such as Width, Height, Sharpness, Signal-to-Noise ratio, File Size
-* Download images based on a tag from a dataset
-* Train an image classifier with the filtered dataset
-
-
-Requirements
-------------
-
-You can use your dataset or use the one we provide with this tutorial:
-:download:`pizzas.zip <../../../_data/pizzas.zip>`.
-If you use your dataset, please make sure the images are smaller than
-2048 pixels with width and height, and you use less than 1000 images.
-
-.. note:: For this tutorial, we provide you with a small dataset of pizza images.
- We chose a small dataset because it's easy to ship and train.
-
-Upload the data
----------------
-
-We start by uploading the dataset to the `Lightly Platform `_.
-
-Create a new account if you do not have one yet.
-Go to your user Preferences and copy your API token.
-
-Now install lightly if you haven't already, and upload your dataset.
-
-.. code-block:: console
-
- # install Lightly
- pip3 install lightly
-
- # upload your DATA directory
- lightly-upload token=MY_TOKEN new_dataset_name='NEW_DATASET_NAME' input_dir='DATA/'
-
-
-Filter the dataset using metadata
----------------------------------
-
-Once the dataset is created and the
-images uploaded, you can head to 'Metadata' under the 'Analyze & Filter' menu.
-
-Move the sliders below the histograms to define filter rules for the dataset.
-Once you are satisfied with the filtered dataset, create a new tag using the tag menu
-on the left side.
-
-Download the curated dataset
-----------------------------
-
-We have filtered the dataset and want to download it now to train a model.
-Therefore, click on the download menu on the left.
-
-We can now download the filtered images by clicking on the 'DOWNLOAD IMAGES' button.
-In our case, the images are stored in the 'pizzas' folder. We now have to
-annotate the images. We can do this by moving the individual images to
-subfolders corresponding to the class. E.g. we move salami pizza images to the
-'salami' folder and Margherita pizza images to the 'margherita' folder.
-
-##############################################################################
-
-Training a model using the curated data
----------------------------------------
-
-"""
-
-
-# %%
-# Now we can start training our model using PyTorch Lightning
-# We start by importing the necessary dependencies
-import os
-
-import pytorch_lightning as pl
-import torch
-import torchmetrics
-from torch.utils.data import DataLoader, random_split
-from torchvision import transforms
-from torchvision.datasets import ImageFolder
-from torchvision.models import resnet18
-
-# %%
-# We use a small batch size to make sure we can run the training on all kinds
-# of machines. Feel free to adjust the value to one that works on your machine.
-batch_size = 8
-seed = 42
-
-# %%
-# Set the seed to make the experiment reproducible
-pl.seed_everything(seed)
-
-# %%
-# Let's set up the augmentations for the train and the test data.
-train_transform = transforms.Compose(
- [
- transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0)),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ]
-)
-
-# we don't do any resizing or mirroring for the test data
-test_transform = transforms.Compose(
- [
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ]
-)
-
-
-# %%
-# We load our data and split it into train/test with a 70/30 ratio.
-
-# Please make sure the data folder contains subfolders for each class
-#
-# pizzas
-# L salami
-# L margherita
-dset = ImageFolder("pizzas", transform=train_transform)
-
-# to use the random_split method we need to obtain the length
-# of the train and test set
-full_len = len(dset)
-train_len = int(full_len * 0.7)
-test_len = int(full_len - train_len)
-dataset_train, dataset_test = random_split(dset, [train_len, test_len])
-dataset_test.transforms = test_transform
-
-print("Training set consists of {} images".format(len(dataset_train)))
-print("Test set consists of {} images".format(len(dataset_test)))
-
-# %%
-# We can create our data loaders to fetch the data from the training and test
-# set and pack them into batches.
-dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
-dataloader_test = DataLoader(dataset_test, batch_size=batch_size)
-
-
-# %%
-# PyTorch Lightning allows us to pack the loss as well as the
-# optimizer into a single module.
-class MyModel(pl.LightningModule):
- def __init__(self, num_classes=2):
- super().__init__()
- self.save_hyperparameters()
-
- # load a pretrained resnet from torchvision
- self.model = resnet18(pretrained=True)
-
- # add new linear output layer (transfer learning)
- num_ftrs = self.model.fc.in_features
- self.model.fc = torch.nn.Linear(num_ftrs, 2)
-
- self.accuracy = torchmetrics.Accuracy()
-
- def forward(self, x):
- return self.model(x)
-
- def training_step(self, batch, batch_idx):
- x, y = batch
- y_hat = self(x)
- loss = torch.nn.functional.cross_entropy(y_hat, y)
- self.log("train_loss", loss, prog_bar=True)
- return loss
-
- def validation_step(self, batch, batch_idx):
- x, y = batch
- y_hat = self(x)
- loss = torch.nn.functional.cross_entropy(y_hat, y)
- y_hat = torch.nn.functional.softmax(y_hat, dim=1)
- self.accuracy(y_hat, y)
- self.log("val_loss", loss, on_epoch=True, prog_bar=True)
- self.log("val_acc", self.accuracy.compute(), on_epoch=True, prog_bar=True)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.model.fc.parameters(), lr=0.001, momentum=0.9)
-
-
-# %%
-# Finally, we can create the model and use the Trainer
-# to train our model.
-model = MyModel()
-trainer = pl.Trainer(max_epochs=4, devices=1)
-trainer.fit(model, dataloader_train, dataloader_test)
diff --git a/docs/source/tutorials_source/platform/tutorial_sunflowers.py b/docs/source/tutorials_source/platform/tutorial_sunflowers.py
deleted file mode 100644
index 41c2c0cd8..000000000
--- a/docs/source/tutorials_source/platform/tutorial_sunflowers.py
+++ /dev/null
@@ -1,136 +0,0 @@
-"""
-
-.. _lightly-tutorial-sunflowers:
-
-Tutorial 2: Diversify the Sunflowers Dataset
-=============================================
-
-.. warning::
- **Tutorial is outdated**
-
- This tutorial uses a deprecated workflow of the Lightly Solution and will be removed in the future.
- Please refer to the `new documentation and tutorials `_ instead.
-
-This tutorial highlights the basic functionality of selecting a subset in the web-app.
-You can use the CORESET selection strategy to choose a diverse subset of your dataset.
-This can be useful many purposes, e.g. for having a good subset of data to label
-or for creating a validation or test dataset that covers the complete sample space.
-Removing duplicate images can also help you in reducing bias and imbalances in your dataset.
-
-What you will learn
---------------------
-
-* Upload images and embeddings to the web-app via the Python package
-* Sample a diverse subset of your original dataset in the web-app
-* Download the filenames of the subset and use it to create a new local dataset folder.
-
-Requirements
--------------
-You can use your own dataset or the one we provide for this tutorial. The dataset
-we provide consists of 734 images of sunflowers. You can
-download it here :download:`Sunflowers.zip <../../../_data/Sunflowers.zip>`.
-
-To use the Lightly platform, we need to upload the dataset with embeddings to it.
-The first step for this is to train a self-supervised embedding model.
-Then, embed your dataset and lastly, upload the dataset and embeddings to the Lightly platform.
-These three steps can be done using a single terminal command from the lightly pip package: lightly-magic
-But first, we need to install lightly from the Python package index.
-
-.. code-block:: bash
-
- # Install lightly as a pip package
- pip install lightly
-
-.. code-block:: bash
-
- # The lightly-magic command first needs the input directory of your dataset.
- # Then it needs the information for how many epochs to train an embedding model on it.
- # If you want to use our pretrained model instead, set trainer.max_epochs=0.
- # Next, the embedding model is used to embed all images in the input directory
- # and saves the embeddings in a csv file. Last, a new dataset with the specified name
- # is created on the Lightly platform.
-
- lightly-magic input_dir="./Sunflowers" trainer.max_epochs=0 token=YOUR_TOKEN
- new_dataset_name="sunflowers_dataset"
-
-.. note::
-
- The lightly-magic command with prefilled parameters is displayed in the web-app when you
- create a new dataset. `Head over there and try it! `_
- For more information on the CLI commands refer to :ref:`lightly-command-line-tool` and :ref:`lightly-at-a-glance`.
-
-Create a Selection
-------------------
-
-Now, you have everything you need to create a selection of your dataset. For this,
-head to the *Embedding* page of your dataset. You should see a two-dimensional
-scatter plot of your embeddings. If you hover over the images, their thumbnails
-will appear. Can you find clusters of similar images?
-
-.. figure:: ../../tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg
- :align: center
- :alt: Alt text
- :figclass: align-center
-
- You should see a two-dimensional scatter plot of your dataset as shown above.
- Hover over an image to view a thumbnail of it.
- There are also features like selecting and browsing some images and creating
- a tag from it.
-
-.. note::
-
- We reduce the dimensionality of the embeddings to 2 dimensions before plotting them.
- You can switch between the PCA, tSNE and UMAP dimensionality reduction methods.
-
-Right above the scatter plot you should see a button "Create Sampling". Click on it to
-create a selection. You will need to configure the following settings:
-
-* **Embedding:** Choose the embedding to use for the selection.
-* **Sampling Strategy:** Choose the selection strategy to use. This will be one of:
-
- * CORESET: Selects samples which are diverse.
- * CORAL: Combines CORESET with uncertainty scores to do active learning.
- * RANDOM: Selects samples uniformly at random.
-* **Stopping Condition:** Indicate how many samples you want to keep.
-* **Name:** Give your selection a name. A new tag will be created under this name.
-
-.. figure:: ../../tutorials_source/platform/images/selection_create_request.png
- :align: center
- :alt: Alt text
- :figclass: align-center
- :figwidth: 400px
-
- Example of a filled out selection request in the web-app.
-
-After confirming your settings, a worker will start processing your request. Once
-it's done, the page switches to the new tag. You can see how the scatter plot now shows
-selected images and discarded images in a different color. Play around with the different selection strategies
-to see differences between the results.
-
-.. figure:: ../../tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg
- :align: center
- :alt: Alt text
- :figclass: align-center
-
- After the selection you can see which samples were selected and which ones were discarded.
- Here, the green dots are part of the new tag while the gray ones are left away. Notice
- how the CORESET selection strategy selects an evenly spaced subset of images.
-
-.. note::
-
- The CORESET selection strategy chooses the samples evenly spaced out in the 32-dimensional space.
- This does not necessarily translate into being evenly spaced out after the dimensionality
- reduction to 2 dimensions.
-
-
-Download selected samples
--------------------------
-
-Now you can use this diverse subset for your machine learning project.
-Just head over to the *Download* tag to see the different download options.
-Apart from downloading the filenames or the images directly, you can also
-use the lightly-download command to copy the files in the subset from your existing
-to a new directory. The CLI command with prefilled arguments is already provided.
-
-
-"""
diff --git a/examples/pytorch/barlowtwins.py b/examples/pytorch/barlowtwins.py
index 8bf4524dc..c6631518d 100644
--- a/examples/pytorch/barlowtwins.py
+++ b/examples/pytorch/barlowtwins.py
@@ -8,7 +8,11 @@
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
class BarlowTwins(nn.Module):
@@ -30,7 +34,12 @@ def forward(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
-transform = SimCLRTransform(input_size=32)
+# BarlowTwins uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch/byol.py b/examples/pytorch/byol.py
index cf80fdf11..4abcdf0a0 100644
--- a/examples/pytorch/byol.py
+++ b/examples/pytorch/byol.py
@@ -11,7 +11,11 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -49,7 +53,11 @@ def forward_momentum(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
-transform = SimCLRTransform(input_size=32)
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py
index bbda91bd1..82274f3f9 100644
--- a/examples/pytorch/tico.py
+++ b/examples/pytorch/tico.py
@@ -11,7 +11,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -47,7 +51,12 @@ def forward_momentum(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
-transform = SimCLRTransform(input_size=32)
+# TiCo uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch/vicreg.py b/examples/pytorch/vicreg.py
index a3221846c..39968cb08 100644
--- a/examples/pytorch/vicreg.py
+++ b/examples/pytorch/vicreg.py
@@ -7,7 +7,7 @@
## The projection head is the same as the Barlow Twins one
from lightly.loss.vicreg_loss import VICRegLoss
-from lightly.models.modules import BarlowTwinsProjectionHead
+from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform
@@ -15,7 +15,12 @@ class VICReg(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
- self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
+ self.projection_head = VICRegProjectionHead(
+ input_dim=512,
+ hidden_dim=2048,
+ output_dim=2048,
+ num_layers=2,
+ )
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
diff --git a/examples/pytorch_lightning/barlowtwins.py b/examples/pytorch_lightning/barlowtwins.py
index 4053fe65d..86d069e5a 100644
--- a/examples/pytorch_lightning/barlowtwins.py
+++ b/examples/pytorch_lightning/barlowtwins.py
@@ -9,7 +9,11 @@
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
class BarlowTwins(pl.LightningModule):
@@ -39,7 +43,12 @@ def configure_optimizers(self):
model = BarlowTwins()
-transform = SimCLRTransform(input_size=32)
+# BarlowTwins uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning/byol.py b/examples/pytorch_lightning/byol.py
index 2b227d34d..a87feb4df 100644
--- a/examples/pytorch_lightning/byol.py
+++ b/examples/pytorch_lightning/byol.py
@@ -12,7 +12,11 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -62,7 +66,11 @@ def configure_optimizers(self):
model = BYOL()
-transform = SimCLRTransform(input_size=32)
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py
index b55246549..48f9d0db1 100644
--- a/examples/pytorch_lightning/tico.py
+++ b/examples/pytorch_lightning/tico.py
@@ -8,7 +8,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -56,7 +60,12 @@ def configure_optimizers(self):
model = TiCo()
-transform = SimCLRTransform(input_size=32)
+# TiCo uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning/vicreg.py b/examples/pytorch_lightning/vicreg.py
index 5fa5c89fd..2770309e7 100644
--- a/examples/pytorch_lightning/vicreg.py
+++ b/examples/pytorch_lightning/vicreg.py
@@ -10,7 +10,7 @@
from lightly.loss.vicreg_loss import VICRegLoss
## The projection head is the same as the Barlow Twins one
-from lightly.models.modules import BarlowTwinsProjectionHead
+from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform
@@ -19,7 +19,12 @@ def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
- self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
+ self.projection_head = VICRegProjectionHead(
+ input_dim=512,
+ hidden_dim=2048,
+ output_dim=2048,
+ num_layers=2,
+ )
self.criterion = VICRegLoss()
def forward(self, x):
diff --git a/examples/pytorch_lightning_distributed/barlowtwins.py b/examples/pytorch_lightning_distributed/barlowtwins.py
index c837dde04..876d7e5f2 100644
--- a/examples/pytorch_lightning_distributed/barlowtwins.py
+++ b/examples/pytorch_lightning_distributed/barlowtwins.py
@@ -9,7 +9,11 @@
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
class BarlowTwins(pl.LightningModule):
@@ -42,7 +46,12 @@ def configure_optimizers(self):
model = BarlowTwins()
-transform = SimCLRTransform(input_size=32)
+# BarlowTwins uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning_distributed/byol.py b/examples/pytorch_lightning_distributed/byol.py
index 5dd49a16f..6fb9d88fe 100644
--- a/examples/pytorch_lightning_distributed/byol.py
+++ b/examples/pytorch_lightning_distributed/byol.py
@@ -13,7 +13,11 @@
from lightly.models.modules import BYOLProjectionHead
from lightly.models.modules.heads import BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -63,7 +67,11 @@ def configure_optimizers(self):
model = BYOL()
-transform = SimCLRTransform(input_size=32)
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py
index 0ba938bd5..5ed02bad2 100644
--- a/examples/pytorch_lightning_distributed/tico.py
+++ b/examples/pytorch_lightning_distributed/tico.py
@@ -8,7 +8,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
-from lightly.transforms.simclr_transform import SimCLRTransform
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.utils.scheduler import cosine_schedule
@@ -56,7 +60,12 @@ def configure_optimizers(self):
model = TiCo()
-transform = SimCLRTransform(input_size=32)
+# TiCo uses BYOL augmentations.
+# We disable resizing and gaussian blur for cifar10.
+transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
+ view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
+)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
diff --git a/examples/pytorch_lightning_distributed/vicreg.py b/examples/pytorch_lightning_distributed/vicreg.py
index a30c726e5..bc8aa38ad 100644
--- a/examples/pytorch_lightning_distributed/vicreg.py
+++ b/examples/pytorch_lightning_distributed/vicreg.py
@@ -10,7 +10,7 @@
from lightly.loss import VICRegLoss
## The projection head is the same as the Barlow Twins one
-from lightly.models.modules import BarlowTwinsProjectionHead
+from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform
@@ -19,7 +19,12 @@ def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
- self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
+ self.projection_head = VICRegProjectionHead(
+ input_dim=512,
+ hidden_dim=2048,
+ output_dim=2048,
+ num_layers=2,
+ )
# enable gather_distributed to gather features from all gpus
# before calculating the loss
diff --git a/lightly/__init__.py b/lightly/__init__.py
index 50929e040..cd41a557d 100644
--- a/lightly/__init__.py
+++ b/lightly/__init__.py
@@ -75,46 +75,31 @@
# All Rights Reserved
__name__ = "lightly"
-__version__ = "1.4.12"
+__version__ = "1.4.22"
+
import os
+# see if torchvision vision transformer is available
try:
- # See (https://github.com/PyTorchLightning/pytorch-lightning)
- # This variable is injected in the __builtins__ by the build
- # process. It used to enable importing subpackages of skimage when
- # the binaries are not built
- __LIGHTLY_SETUP__
-except NameError:
- __LIGHTLY_SETUP__ = False
-
-
-if __LIGHTLY_SETUP__:
- # setting up lightly
- msg = f"Partial import of {__name__}=={__version__} during build process."
- print(msg)
-else:
- # see if torchvision vision transformer is available
- try:
- import torchvision.models.vision_transformer
-
- _torchvision_vit_available = True
- except (
- RuntimeError, # Different CUDA versions for torch and torchvision
- OSError, # Different CUDA versions for torch and torchvision (old)
- ImportError, # No installation or old version of torchvision
- ):
- _torchvision_vit_available = False
-
- if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False":
- os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True"
- from multiprocessing import current_process
-
- if current_process().name == "MainProcess":
- from lightly.api.version_checking import is_latest_version
-
- try:
- is_latest_version(current_version=__version__)
- except Exception:
- # Version check should never break the package.
- pass
+ import torchvision.models.vision_transformer
+
+ _torchvision_vit_available = True
+except (
+ RuntimeError, # Different CUDA versions for torch and torchvision
+ OSError, # Different CUDA versions for torch and torchvision (old)
+ ImportError, # No installation or old version of torchvision
+):
+ _torchvision_vit_available = False
+
+
+if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False":
+ os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True"
+ import multiprocessing
+
+ if multiprocessing.current_process().name == "MainProcess":
+ from lightly.api import _version_checking
+
+ _version_checking.check_is_latest_version_in_background(
+ current_version=__version__
+ )
diff --git a/lightly/api/_version_checking.py b/lightly/api/_version_checking.py
new file mode 100644
index 000000000..3f370732a
--- /dev/null
+++ b/lightly/api/_version_checking.py
@@ -0,0 +1,73 @@
+from threading import Thread
+
+from lightly.api import utils
+from lightly.api.swagger_api_client import LightlySwaggerApiClient
+from lightly.openapi_generated.swagger_client.api import VersioningApi
+from lightly.utils import version_compare
+
+# Default timeout for API version verification requests in seconds.
+DEFAULT_TIMEOUT_SEC = 2
+
+
+def is_latest_version(current_version: str) -> bool:
+ """Returns True if package version is latest released version."""
+ latest_version = get_latest_version(current_version)
+ return version_compare.version_compare(current_version, latest_version) >= 0
+
+
+def is_compatible_version(current_version: str) -> bool:
+ """Returns True if package version is compatible with API."""
+ minimum_version = get_minimum_compatible_version()
+ return version_compare.version_compare(current_version, minimum_version) >= 0
+
+
+def get_latest_version(
+ current_version: str, timeout_sec: float = DEFAULT_TIMEOUT_SEC
+) -> str:
+ """Returns the latest package version."""
+ versioning_api = _get_versioning_api()
+ version_number: str = versioning_api.get_latest_pip_version(
+ current_version=current_version,
+ _request_timeout=timeout_sec,
+ )
+ return version_number
+
+
+def get_minimum_compatible_version(
+ timeout_sec: float = DEFAULT_TIMEOUT_SEC,
+) -> str:
+ """Returns minimum package version that is compatible with the API."""
+ versioning_api = _get_versioning_api()
+ version_number: str = versioning_api.get_minimum_compatible_pip_version(
+ _request_timeout=timeout_sec
+ )
+ return version_number
+
+
+def check_is_latest_version_in_background(current_version: str) -> None:
+ """Checks if the current version is the latest version in a background thread."""
+
+ def _check_version_in_background(current_version: str) -> None:
+ try:
+ is_latest_version(current_version=current_version)
+ except Exception:
+ # Ignore failed check.
+ pass
+
+ thread = Thread(
+ target=_check_version_in_background,
+ kwargs=dict(current_version=current_version),
+ daemon=True,
+ )
+ thread.start()
+
+
+def _get_versioning_api() -> VersioningApi:
+ configuration = utils.get_api_client_configuration(
+ raise_if_no_token_specified=False,
+ )
+ # Set retries to 0 to avoid waiting for retries in case of a timeout.
+ configuration.retries = 0
+ api_client = LightlySwaggerApiClient(configuration=configuration)
+ versioning_api = VersioningApi(api_client=api_client)
+ return versioning_api
diff --git a/lightly/api/api_workflow_artifacts.py b/lightly/api/api_workflow_artifacts.py
index 5ea55244b..6bea0447e 100644
--- a/lightly/api/api_workflow_artifacts.py
+++ b/lightly/api/api_workflow_artifacts.py
@@ -1,4 +1,5 @@
import os
+import warnings
from lightly.api import download
from lightly.openapi_generated.swagger_client.models import (
@@ -145,6 +146,10 @@ def download_compute_worker_run_report_json(
) -> None:
"""Download the report in json format from a run.
+ DEPRECATED: This method is deprecated and will be removed in the future. Use
+ download_compute_worker_run_report_v2_json to download the new report_v2.json
+ instead.
+
Args:
run:
Run from which to download the report.
@@ -171,6 +176,13 @@ def download_compute_worker_run_report_json(
>>> client.download_compute_worker_run_report_json(run=run, output_path="report.json")
"""
+ warnings.warn(
+ DeprecationWarning(
+ "This method downloads the deprecated report.json file and will be "
+ "removed in the future. Use download_compute_worker_run_report_v2_json "
+ "to download the new report_v2.json file instead."
+ )
+ )
self._download_compute_worker_run_artifact_by_type(
run=run,
artifact_type=DockerRunArtifactType.REPORT_JSON,
@@ -178,6 +190,47 @@ def download_compute_worker_run_report_json(
timeout=timeout,
)
+ def download_compute_worker_run_report_v2_json(
+ self,
+ run: DockerRunData,
+ output_path: str,
+ timeout: int = 60,
+ ) -> None:
+ """Download the report in json format from a run.
+
+ Args:
+ run:
+ Run from which to download the report.
+ output_path:
+ Path where report will be saved.
+ timeout:
+ Timeout in seconds after which download is interrupted.
+
+ Raises:
+ ArtifactNotExist:
+ If the run has no report artifact or the report has not yet been
+ uploaded.
+
+ Examples:
+ >>> # schedule run
+ >>> scheduled_run_id = client.schedule_compute_worker_run(...)
+ >>>
+ >>> # wait until run completed
+ >>> for run_info in client.compute_worker_run_info_generator(scheduled_run_id=scheduled_run_id):
+ >>> pass
+ >>>
+ >>> # download checkpoint
+ >>> run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)
+ >>> client.download_compute_worker_run_report_v2_json(run=run, output_path="report_v2.json")
+
+ """
+ self._download_compute_worker_run_artifact_by_type(
+ run=run,
+ artifact_type=DockerRunArtifactType.REPORT_V2_JSON,
+ output_path=output_path,
+ timeout=timeout,
+ )
+
def download_compute_worker_run_log(
self,
run: DockerRunData,
diff --git a/lightly/api/api_workflow_client.py b/lightly/api/api_workflow_client.py
index 18fa190e7..87f1193bc 100644
--- a/lightly/api/api_workflow_client.py
+++ b/lightly/api/api_workflow_client.py
@@ -6,9 +6,10 @@
import requests
from requests import Response
+from urllib3.exceptions import HTTPError
from lightly.__init__ import __version__
-from lightly.api import utils, version_checking
+from lightly.api import _version_checking, utils
from lightly.api.api_workflow_artifacts import _ArtifactsMixin
from lightly.api.api_workflow_collaboration import _CollaborationMixin
from lightly.api.api_workflow_compute_worker import _ComputeWorkerMixin
@@ -23,7 +24,6 @@
from lightly.api.api_workflow_upload_metadata import _UploadCustomMetadataMixin
from lightly.api.swagger_api_client import LightlySwaggerApiClient
from lightly.api.utils import DatasourceType
-from lightly.api.version_checking import LightlyAPITimeoutException
from lightly.openapi_generated.swagger_client.api import (
CollaborationApi,
DatasetsApi,
@@ -93,7 +93,7 @@ def __init__(
creator: str = Creator.USER_PIP,
):
try:
- if not version_checking.is_compatible_version(__version__):
+ if not _version_checking.is_compatible_version(__version__):
warnings.warn(
UserWarning(
(
@@ -104,10 +104,14 @@ def __init__(
)
)
except (
+ # Error if version compare fails.
ValueError,
+ # Any error by API client if status not in [200, 299].
ApiException,
- LightlyAPITimeoutException,
- AttributeError,
+ # Any error by urllib3 from within API client. Happens for failed requests
+ # that are not handled by API client. For example if there is no internet
+ # connection or a timeout.
+ HTTPError,
):
pass
@@ -214,6 +218,8 @@ def get_filenames(self) -> List[str]:
Returns:
Names of files in the current dataset.
+
+ :meta private: # Skip docstring generation
"""
filenames_on_server = self._mappings_api.get_sample_mappings_by_dataset_id(
dataset_id=self.dataset_id, field="fileName"
@@ -243,6 +249,7 @@ def upload_file_with_signed_url(
Returns:
The response of the put request.
+ :meta private: # Skip docstring generation
"""
# check to see if server side encryption for S3 is desired
diff --git a/lightly/api/api_workflow_compute_worker.py b/lightly/api/api_workflow_compute_worker.py
index 06c37c392..f98244256 100644
--- a/lightly/api/api_workflow_compute_worker.py
+++ b/lightly/api/api_workflow_compute_worker.py
@@ -22,10 +22,10 @@
DockerWorkerConfigV3Lightly,
DockerWorkerRegistryEntryData,
DockerWorkerType,
- SelectionConfig,
- SelectionConfigEntry,
- SelectionConfigEntryInput,
- SelectionConfigEntryStrategy,
+ SelectionConfigV3,
+ SelectionConfigV3Entry,
+ SelectionConfigV3EntryInput,
+ SelectionConfigV3EntryStrategy,
TagData,
)
from lightly.openapi_generated.swagger_client.rest import ApiException
@@ -175,7 +175,7 @@ def create_compute_worker_config(
self,
worker_config: Optional[Dict[str, Any]] = None,
lightly_config: Optional[Dict[str, Any]] = None,
- selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None,
+ selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None,
) -> str:
"""Creates a new configuration for a Lightly Worker run.
@@ -207,6 +207,8 @@ def create_compute_worker_config(
>>> config_id = client.create_compute_worker_config(
... selection_config=selection_config,
... )
+
+ :meta private: # Skip docstring generation
"""
if isinstance(selection_config, dict):
selection = selection_config_from_dict(cfg=selection_config)
@@ -267,7 +269,7 @@ def schedule_compute_worker_run(
self,
worker_config: Optional[Dict[str, Any]] = None,
lightly_config: Optional[Dict[str, Any]] = None,
- selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None,
+ selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None,
priority: str = DockerRunScheduledPriority.MID,
runs_on: Optional[List[str]] = None,
) -> str:
@@ -632,17 +634,17 @@ def get_compute_worker_run_tags(self, run_id: str) -> List[TagData]:
return tags_in_dataset
-def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig:
- """Recursively converts selection config from dict to a SelectionConfig instance."""
+def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfigV3:
+ """Recursively converts selection config from dict to a SelectionConfigV3 instance."""
strategies = []
for entry in cfg.get("strategies", []):
new_entry = copy.deepcopy(entry)
- new_entry["input"] = SelectionConfigEntryInput(**entry["input"])
- new_entry["strategy"] = SelectionConfigEntryStrategy(**entry["strategy"])
- strategies.append(SelectionConfigEntry(**new_entry))
+ new_entry["input"] = SelectionConfigV3EntryInput(**entry["input"])
+ new_entry["strategy"] = SelectionConfigV3EntryStrategy(**entry["strategy"])
+ strategies.append(SelectionConfigV3Entry(**new_entry))
new_cfg = copy.deepcopy(cfg)
new_cfg["strategies"] = strategies
- return SelectionConfig(**new_cfg)
+ return SelectionConfigV3(**new_cfg)
_T = TypeVar("_T")
diff --git a/lightly/api/api_workflow_datasets.py b/lightly/api/api_workflow_datasets.py
index 29d33d63d..2d670e76c 100644
--- a/lightly/api/api_workflow_datasets.py
+++ b/lightly/api/api_workflow_datasets.py
@@ -376,7 +376,7 @@ def create_dataset(
>>> from lightly.api import ApiWorkflowClient
>>> from lightly.openapi_generated.swagger_client.models import DatasetType
>>>
- >>> client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN")
+ >>> client = ApiWorkflowClient(token="YOUR_TOKEN")
>>> client.create_dataset('your-dataset-name', dataset_type=DatasetType.IMAGES)
>>>
>>> # or to work with videos
diff --git a/lightly/api/api_workflow_datasources.py b/lightly/api/api_workflow_datasources.py
index 90566002c..5c2033766 100644
--- a/lightly/api/api_workflow_datasources.py
+++ b/lightly/api/api_workflow_datasources.py
@@ -1,93 +1,22 @@
import time
import warnings
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
import tqdm
from lightly.openapi_generated.swagger_client.models import (
DatasourceConfig,
- DatasourceConfigVerifyDataErrors,
DatasourceProcessedUntilTimestampRequest,
DatasourceProcessedUntilTimestampResponse,
DatasourcePurpose,
DatasourceRawSamplesData,
)
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import (
+ DatasourceRawSamplesDataRow,
+)
class _DatasourcesMixin:
- def _download_raw_files(
- self,
- download_function: Union[
- "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id",
- "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
- "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
- ],
- from_: int = 0,
- to: Optional[int] = None,
- relevant_filenames_file_name: Optional[str] = None,
- use_redirected_read_url: bool = False,
- progress_bar: Optional[tqdm.tqdm] = None,
- **kwargs,
- ) -> List[Tuple[str, str]]:
- if to is None:
- to = int(time.time())
- relevant_filenames_kwargs = (
- {"relevant_filenames_file_name": relevant_filenames_file_name}
- if relevant_filenames_file_name
- else dict()
- )
-
- response: DatasourceRawSamplesData = download_function(
- dataset_id=self.dataset_id,
- var_from=from_,
- to=to,
- use_redirected_read_url=use_redirected_read_url,
- **relevant_filenames_kwargs,
- **kwargs,
- )
- cursor = response.cursor
- samples = response.data
- if progress_bar is not None:
- progress_bar.update(len(response.data))
- while response.has_more:
- response: DatasourceRawSamplesData = download_function(
- dataset_id=self.dataset_id,
- cursor=cursor,
- use_redirected_read_url=use_redirected_read_url,
- **relevant_filenames_kwargs,
- **kwargs,
- )
- cursor = response.cursor
- samples.extend(response.data)
- if progress_bar is not None:
- progress_bar.update(len(response.data))
- sample_map = {}
- for idx, s in enumerate(samples):
- if s.file_name.startswith("/"):
- warnings.warn(
- UserWarning(
- f"Absolute file paths like {s.file_name} are not supported"
- f" in relevant filenames file {relevant_filenames_file_name} due to blob storage"
- )
- )
- elif s.file_name.startswith(("./", "../")):
- warnings.warn(
- UserWarning(
- f"Using dot notation ('./', '../') like in {s.file_name} is not supported"
- f" in relevant filenames file {relevant_filenames_file_name} due to blob storage"
- )
- )
- elif s.file_name in sample_map:
- warnings.warn(
- UserWarning(
- f"Duplicate filename {s.file_name} in relevant"
- f" filenames file {relevant_filenames_file_name}"
- )
- )
- else:
- sample_map[s.file_name] = s.read_url
- return [(file_name, read_url) for file_name, read_url in sample_map.items()]
-
def download_raw_samples(
self,
from_: int = 0,
@@ -131,8 +60,10 @@ def download_raw_samples(
>>> client.set_dataset_id_by_name("my-dataset")
>>> client.download_raw_samples()
[('image-1.png', 'https://......'), ('image-2.png', 'https://......')]
+
+ :meta private: # Skip docstring generation
"""
- samples = self._download_raw_files(
+ return self._download_raw_files(
download_function=self._datasources_api.get_list_of_raw_samples_from_datasource_by_dataset_id,
from_=from_,
to=to,
@@ -140,7 +71,6 @@ def download_raw_samples(
use_redirected_read_url=use_redirected_read_url,
progress_bar=progress_bar,
)
- return samples
def download_raw_predictions(
self,
@@ -153,6 +83,36 @@ def download_raw_predictions(
use_redirected_read_url: bool = False,
progress_bar: Optional[tqdm.tqdm] = None,
) -> List[Tuple[str, str]]:
+ """Downloads all prediction filenames and read urls from the datasource.
+
+ See `download_raw_predictions_iter` for details.
+
+ :meta private: # Skip docstring generation
+ """
+ return list(
+ self.download_raw_predictions_iter(
+ task_name=task_name,
+ from_=from_,
+ to=to,
+ relevant_filenames_file_name=relevant_filenames_file_name,
+ run_id=run_id,
+ relevant_filenames_artifact_id=relevant_filenames_artifact_id,
+ use_redirected_read_url=use_redirected_read_url,
+ progress_bar=progress_bar,
+ )
+ )
+
+ def download_raw_predictions_iter(
+ self,
+ task_name: str,
+ from_: int = 0,
+ to: Optional[int] = None,
+ relevant_filenames_file_name: Optional[str] = None,
+ run_id: Optional[str] = None,
+ relevant_filenames_artifact_id: Optional[str] = None,
+ use_redirected_read_url: bool = False,
+ progress_bar: Optional[tqdm.tqdm] = None,
+ ) -> Iterator[Tuple[str, str]]:
"""Downloads prediction filenames and read urls from the datasource.
Only samples with timestamp between `from_` (inclusive) and `to` (inclusive)
@@ -188,7 +148,7 @@ def download_raw_predictions(
retrieved.
Returns:
- A list of (filename, url) tuples where each tuple represents a sample.
+ An iterator of (filename, url) tuples where each tuple represents a sample.
Examples:
>>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN")
@@ -196,9 +156,11 @@ def download_raw_predictions(
>>> # Already created some Lightly Worker runs with this dataset
>>> task_name = "object-detection"
>>> client.set_dataset_id_by_name("my-dataset")
- >>> client.download_raw_predictions(task_name=task_name)
+ >>> list(client.download_raw_predictions(task_name=task_name))
[('.lightly/predictions/object-detection/image-1.json', 'https://......'),
('.lightly/predictions/object-detection/image-2.json', 'https://......')]
+
+ :meta private: # Skip docstring generation
"""
if run_id is not None and relevant_filenames_artifact_id is None:
raise ValueError(
@@ -217,7 +179,7 @@ def download_raw_predictions(
"relevant_filenames_artifact_id"
] = relevant_filenames_artifact_id
- samples = self._download_raw_files(
+ yield from self._download_raw_files_iter(
download_function=self._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id,
from_=from_,
to=to,
@@ -227,7 +189,6 @@ def download_raw_predictions(
progress_bar=progress_bar,
**relevant_filenames_kwargs,
)
- return samples
def download_raw_metadata(
self,
@@ -241,6 +202,34 @@ def download_raw_metadata(
) -> List[Tuple[str, str]]:
"""Downloads all metadata filenames and read urls from the datasource.
+ See `download_raw_metadata_iter` for details.
+
+ :meta private: # Skip docstring generation
+ """
+ return list(
+ self.download_raw_metadata_iter(
+ from_=from_,
+ to=to,
+ run_id=run_id,
+ relevant_filenames_artifact_id=relevant_filenames_artifact_id,
+ relevant_filenames_file_name=relevant_filenames_file_name,
+ use_redirected_read_url=use_redirected_read_url,
+ progress_bar=progress_bar,
+ )
+ )
+
+ def download_raw_metadata_iter(
+ self,
+ from_: int = 0,
+ to: Optional[int] = None,
+ run_id: Optional[str] = None,
+ relevant_filenames_artifact_id: Optional[str] = None,
+ relevant_filenames_file_name: Optional[str] = None,
+ use_redirected_read_url: bool = False,
+ progress_bar: Optional[tqdm.tqdm] = None,
+ ) -> Iterator[Tuple[str, str]]:
+ """Downloads all metadata filenames and read urls from the datasource.
+
Only samples with timestamp between `from_` (inclusive) and `to` (inclusive)
will be downloaded.
@@ -272,16 +261,18 @@ def download_raw_metadata(
retrieved.
Returns:
- A list of (filename, url) tuples where each tuple represents a sample.
+ An iterator of (filename, url) tuples where each tuple represents a sample.
Examples:
>>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN")
>>>
>>> # Already created some Lightly Worker runs with this dataset
>>> client.set_dataset_id_by_name("my-dataset")
- >>> client.download_raw_metadata()
+ >>> list(client.download_raw_metadata_iter())
[('.lightly/metadata/object-detection/image-1.json', 'https://......'),
('.lightly/metadata/object-detection/image-2.json', 'https://......')]
+
+ :meta private: # Skip docstring generation
"""
if run_id is not None and relevant_filenames_artifact_id is None:
raise ValueError(
@@ -300,8 +291,8 @@ def download_raw_metadata(
"relevant_filenames_artifact_id"
] = relevant_filenames_artifact_id
- samples = self._download_raw_files(
- self._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id,
+ yield from self._download_raw_files_iter(
+ download_function=self._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id,
from_=from_,
to=to,
relevant_filenames_file_name=relevant_filenames_file_name,
@@ -309,7 +300,6 @@ def download_raw_metadata(
progress_bar=progress_bar,
**relevant_filenames_kwargs,
)
- return samples
def download_new_raw_samples(
self,
@@ -371,6 +361,8 @@ def get_processed_until_timestamp(self) -> int:
>>> client.set_dataset_id_by_name("my-dataset")
>>> client.get_processed_until_timestamp()
1684750513
+
+ :meta private: # Skip docstring generation
"""
response: DatasourceProcessedUntilTimestampResponse = self._datasources_api.get_datasource_processed_until_timestamp_by_dataset_id(
dataset_id=self.dataset_id
@@ -449,7 +441,6 @@ def set_azure_config(
purpose:
Datasource purpose, determines if datasource is read only (INPUT)
or can be written to as well (LIGHTLY, INPUT_OUTPUT).
- The latter is required when Lightly extracts frames from input videos.
"""
# TODO: Use DatasourceConfigAzure once we switch/update the api generator.
@@ -499,7 +490,6 @@ def set_gcs_config(
purpose:
Datasource purpose, determines if datasource is read only (INPUT)
or can be written to as well (LIGHTLY, INPUT_OUTPUT).
- The latter is required when Lightly extracts frames from input videos.
"""
# TODO: Use DatasourceConfigGCS once we switch/update the api generator.
@@ -519,10 +509,12 @@ def set_gcs_config(
def set_local_config(
self,
- resource_path: str,
+ relative_path: str = "",
+ web_server_location: Optional[str] = "http://localhost:3456",
thumbnail_suffix: Optional[
str
] = ".lightly/thumbnails/[filename]_thumb.[extension]",
+ purpose: str = DatasourcePurpose.INPUT_OUTPUT,
) -> None:
"""Sets the local configuration for the datasource of the current dataset.
@@ -530,22 +522,29 @@ def set_local_config(
server in our docs: https://docs.lightly.ai/getting_started/dataset_creation/dataset_creation_local_server.html
Args:
- resource_path:
- Url to your local file server, for example: "http://localhost:1234/path/to/my/data".
+ relative_path:
+ Relative path from the mount root, for example: "path/to/my/data".
+ web_server_location:
+ Location of your local file server. Defaults to "http://localhost:3456".
thumbnail_suffix:
Where to save thumbnails of the images in the dataset, for
example ".lightly/thumbnails/[filename]_thumb.[extension]".
Set to None to disable thumbnails and use the full images from the
datasource instead.
+ purpose:
+ Datasource purpose, determines if datasource is read only (INPUT)
+ or can be written to as well (LIGHTLY, INPUT_OUTPUT).
+
"""
# TODO: Use DatasourceConfigLocal once we switch/update the api generator.
self._datasources_api.update_datasource_by_dataset_id(
datasource_config=DatasourceConfig.from_dict(
{
"type": "LOCAL",
- "fullPath": resource_path,
+ "webServerLocation": web_server_location,
+ "fullPath": relative_path,
"thumbSuffix": thumbnail_suffix,
- "purpose": DatasourcePurpose.INPUT_OUTPUT,
+ "purpose": purpose,
}
),
dataset_id=self.dataset_id,
@@ -584,7 +583,6 @@ def set_s3_config(
purpose:
Datasource purpose, determines if datasource is read only (INPUT)
or can be written to as well (LIGHTLY, INPUT_OUTPUT).
- The latter is required when Lightly extracts frames from input videos.
"""
# TODO: Use DatasourceConfigS3 once we switch/update the api generator.
@@ -636,7 +634,6 @@ def set_s3_delegated_access_config(
purpose:
Datasource purpose, determines if datasource is read only (INPUT)
or can be written to as well (LIGHTLY, INPUT_OUTPUT).
- The latter is required when Lightly extracts frames from input videos.
"""
# TODO: Use DatasourceConfigS3 once we switch/update the api generator.
@@ -685,7 +682,7 @@ def set_obs_config(
purpose:
Datasource purpose, determines if datasource is read only (INPUT)
or can be written to as well (LIGHTLY, INPUT_OUTPUT).
- The latter is required when Lightly extracts frames from input videos.
+
"""
# TODO: Use DatasourceConfigOBS once we switch/update the api generator.
self._datasources_api.update_datasource_by_dataset_id(
@@ -706,7 +703,7 @@ def set_obs_config(
def get_prediction_read_url(
self,
filename: str,
- ):
+ ) -> str:
"""Returns a read-url for .lightly/predictions/{filename}.
Args:
@@ -717,6 +714,7 @@ def get_prediction_read_url(
A read-url to the file. Note that a URL will be returned even if the file does not
exist.
+ :meta private: # Skip docstring generation
"""
return self._datasources_api.get_prediction_file_read_url_from_datasource_by_dataset_id(
dataset_id=self.dataset_id,
@@ -726,7 +724,7 @@ def get_prediction_read_url(
def get_metadata_read_url(
self,
filename: str,
- ):
+ ) -> str:
"""Returns a read-url for .lightly/metadata/{filename}.
Args:
@@ -737,6 +735,7 @@ def get_metadata_read_url(
A read-url to the file. Note that a URL will be returned even if the file does not
exist.
+ :meta private: # Skip docstring generation
"""
return self._datasources_api.get_metadata_file_read_url_from_datasource_by_dataset_id(
dataset_id=self.dataset_id,
@@ -757,6 +756,7 @@ def get_custom_embedding_read_url(
A read-url to the file. Note that a URL will be returned even if the file does not
exist.
+ :meta private: # Skip docstring generation
"""
return self._datasources_api.get_custom_embedding_file_read_url_from_datasource_by_dataset_id(
dataset_id=self.dataset_id,
@@ -790,3 +790,123 @@ def list_datasource_permissions(
return self._datasources_api.verify_datasource_by_dataset_id(
dataset_id=self.dataset_id,
).to_dict()
+
+ def _download_raw_files(
+ self,
+ download_function: Union[
+ "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id",
+ "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
+ "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
+ ],
+ from_: int = 0,
+ to: Optional[int] = None,
+ relevant_filenames_file_name: Optional[str] = None,
+ use_redirected_read_url: bool = False,
+ progress_bar: Optional[tqdm.tqdm] = None,
+ **kwargs,
+ ) -> List[Tuple[str, str]]:
+ return list(
+ self._download_raw_files_iter(
+ download_function=download_function,
+ from_=from_,
+ to=to,
+ relevant_filenames_file_name=relevant_filenames_file_name,
+ use_redirected_read_url=use_redirected_read_url,
+ progress_bar=progress_bar,
+ **kwargs,
+ )
+ )
+
+ def _download_raw_files_iter(
+ self,
+ download_function: Union[
+ "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id",
+ "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
+ "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
+ ],
+ from_: int = 0,
+ to: Optional[int] = None,
+ relevant_filenames_file_name: Optional[str] = None,
+ use_redirected_read_url: bool = False,
+ progress_bar: Optional[tqdm.tqdm] = None,
+ **kwargs,
+ ) -> Iterator[Tuple[str, str]]:
+ if to is None:
+ to = int(time.time())
+ relevant_filenames_kwargs = (
+ {"relevant_filenames_file_name": relevant_filenames_file_name}
+ if relevant_filenames_file_name
+ else dict()
+ )
+
+ listed_filenames = set()
+
+ def get_samples(
+ response: DatasourceRawSamplesData,
+ ) -> Iterator[Tuple[str, str]]:
+ for sample in response.data:
+ if _sample_unseen_and_valid(
+ sample=sample,
+ relevant_filenames_file_name=relevant_filenames_file_name,
+ listed_filenames=listed_filenames,
+ ):
+ listed_filenames.add(sample.file_name)
+ yield sample.file_name, sample.read_url
+ if progress_bar is not None:
+ progress_bar.update(1)
+
+ response: DatasourceRawSamplesData = download_function(
+ dataset_id=self.dataset_id,
+ var_from=from_,
+ to=to,
+ use_redirected_read_url=use_redirected_read_url,
+ **relevant_filenames_kwargs,
+ **kwargs,
+ )
+ yield from get_samples(response=response)
+ while response.has_more:
+ response: DatasourceRawSamplesData = download_function(
+ dataset_id=self.dataset_id,
+ cursor=response.cursor,
+ use_redirected_read_url=use_redirected_read_url,
+ **relevant_filenames_kwargs,
+ **kwargs,
+ )
+ yield from get_samples(response=response)
+
+
+def _sample_unseen_and_valid(
+ sample: DatasourceRawSamplesDataRow,
+ relevant_filenames_file_name: Optional[str],
+ listed_filenames: Set[str],
+) -> bool:
+ # Note: We want to remove these checks eventually. Absolute paths and relative paths
+ # with dot notation should be handled either in the API or the Worker. Duplicate
+ # filenames should be handled in the Worker as handling it in the API would require
+ # too much memory.
+ if sample.file_name.startswith("/"):
+ warnings.warn(
+ UserWarning(
+ f"Absolute file paths like {sample.file_name} are not supported"
+ f" in relevant filenames file {relevant_filenames_file_name} due to blob storage"
+ )
+ )
+ return False
+ elif sample.file_name.startswith(("./", "../")):
+ warnings.warn(
+ UserWarning(
+ f"Using dot notation ('./', '../') like in {sample.file_name} is not supported"
+ f" in relevant filenames file {relevant_filenames_file_name} due to blob storage"
+ )
+ )
+ return False
+ elif sample.file_name in listed_filenames:
+ warnings.warn(
+ UserWarning(
+ f"Duplicate filename {sample.file_name} in relevant"
+ f" filenames file {relevant_filenames_file_name}"
+ )
+ )
+ return False
+ else:
+ return True
diff --git a/lightly/api/api_workflow_predictions.py b/lightly/api/api_workflow_predictions.py
index 3ea7b63af..258620091 100644
--- a/lightly/api/api_workflow_predictions.py
+++ b/lightly/api/api_workflow_predictions.py
@@ -44,7 +44,7 @@ def create_or_update_prediction_task_schema(
>>> )
>>> client.create_or_update_prediction_task_schema(schema=schema)
-
+ :meta private: # Skip docstring generation
"""
self._predictions_api.create_or_update_prediction_task_schema_by_dataset_id(
prediction_task_schema=schema,
@@ -71,6 +71,8 @@ def create_or_update_prediction(
prediction_singletons:
Predictions to be uploaded for the designated sample.
+
+ :meta private: # Skip docstring generation
"""
self._predictions_api.create_or_update_prediction_by_sample_id(
prediction_singleton=prediction_singletons,
diff --git a/lightly/api/api_workflow_selection.py b/lightly/api/api_workflow_selection.py
index f013c5881..760ef2d4d 100644
--- a/lightly/api/api_workflow_selection.py
+++ b/lightly/api/api_workflow_selection.py
@@ -39,6 +39,8 @@ def upload_scores(
and score arrays. The length of each score array must match samples
in the designated tag.
query_tag_id: ID of the desired tag.
+
+ :meta private: # Skip docstring generation
"""
# iterate over all available score types and upload them
for score_type, score_values in al_scores.items():
@@ -79,6 +81,7 @@ def selection(
When `initial-tag` does not exist in the dataset.
When the selection task fails.
+ :meta private: # Skip docstring generation
"""
warnings.warn(
diff --git a/lightly/api/api_workflow_tags.py b/lightly/api/api_workflow_tags.py
index c6b4780bc..be70ebe22 100644
--- a/lightly/api/api_workflow_tags.py
+++ b/lightly/api/api_workflow_tags.py
@@ -124,6 +124,8 @@ def get_filenames_in_tag(
>>> tag = client.get_tag_by_name("cool-tag")
>>> client.get_filenames_in_tag(tag_data=tag)
['image-1.png', 'image-2.png']
+
+ :meta private: # Skip docstring generation
"""
if exclude_parent_tag:
diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py
index a8b21a169..bf28c1c2f 100644
--- a/lightly/api/api_workflow_upload_embeddings.py
+++ b/lightly/api/api_workflow_upload_embeddings.py
@@ -31,7 +31,10 @@ def _get_csv_reader_from_read_url(self, read_url: str) -> None:
return reader
def set_embedding_id_to_latest(self) -> None:
- """Sets the embedding ID in the API client to the latest embedding ID in the current dataset."""
+ """Sets the embedding ID in the API client to the latest embedding ID in the current dataset.
+
+ :meta private: # Skip docstring generation
+ """
embeddings_on_server: List[
DatasetEmbeddingData
] = self._embeddings_api.get_embeddings_by_dataset_id(
@@ -104,6 +107,7 @@ def upload_embeddings(self, path_to_embeddings_csv: str, name: str) -> None:
The name of the embedding. If an embedding with such a name already exists on the server,
the upload is aborted.
+ :meta private: # Skip docstring generation
"""
io_utils.check_embeddings(
path_to_embeddings_csv, remove_additional_columns=True
@@ -185,6 +189,7 @@ def append_embeddings(self, path_to_embeddings_csv: str, embedding_id: str) -> N
If the number of columns in the local embeddings file and that of the remote
embeddings file mismatch.
+ :meta private: # Skip docstring generation
"""
# read embedding from API
@@ -254,7 +259,6 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]:
f"The filenames in the embedding file and "
f"the filenames on the server do not align"
)
- io_utils.check_filenames(filenames)
rows_without_header_ordered = self._order_list_by_filenames(
filenames, rows_without_header
diff --git a/lightly/api/api_workflow_upload_metadata.py b/lightly/api/api_workflow_upload_metadata.py
index e5859b089..87aa55f4e 100644
--- a/lightly/api/api_workflow_upload_metadata.py
+++ b/lightly/api/api_workflow_upload_metadata.py
@@ -67,6 +67,7 @@ def index_custom_metadata_by_filename(
If there are no annotations for a filename, the custom metadata
is None instead.
+ :meta private: # Skip docstring generation
"""
# The mapping is filename -> image_id -> custom_metadata
@@ -141,6 +142,7 @@ def upload_custom_metadata(
max_workers:
Maximum number of concurrent threads during upload.
+ :meta private: # Skip docstring generation
"""
self.verify_custom_metadata_format(custom_metadata)
@@ -241,7 +243,7 @@ def create_custom_metadata_config(
>>> [entry],
>>> )
-
+ :meta private: # Skip docstring generation
"""
config_set_request = ConfigurationSetRequest(name=name, configs=configs)
resp = self._metadata_configurations_api.create_meta_data_configuration(
diff --git a/lightly/api/serve.py b/lightly/api/serve.py
new file mode 100644
index 000000000..716088c05
--- /dev/null
+++ b/lightly/api/serve.py
@@ -0,0 +1,68 @@
+from http.server import HTTPServer, SimpleHTTPRequestHandler
+from pathlib import Path
+from typing import Sequence
+
+
+def get_server(
+ paths: Sequence[Path],
+ host: str,
+ port: int,
+):
+ """Returns an HTTP server that serves a local datasource.
+
+ Args:
+ paths:
+ List of paths to serve.
+ host:
+ Host to serve the datasource on.
+ port:
+ Port to serve the datasource on.
+
+ Examples:
+ >>> from lightly.api import serve
+ >>> from pathlib import Path
+ >>> serve(
+ >>> paths=[Path("/input_mount), Path("/lightly_mount)],
+ >>> host="localhost",
+ >>> port=3456,
+ >>> )
+
+ """
+
+ class _LocalDatasourceRequestHandler(SimpleHTTPRequestHandler):
+ def translate_path(self, path: str) -> str:
+ return _translate_path(path=path, directories=paths)
+
+ def do_OPTIONS(self) -> None:
+ self.send_response(204)
+ self.send_header("Access-Control-Allow-Origin", "*")
+ self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+ self.end_headers()
+
+ return HTTPServer((host, port), _LocalDatasourceRequestHandler)
+
+
+def _translate_path(path: str, directories: Sequence[Path]) -> str:
+ """Translates a relative path to a file in the local datasource.
+
+ Tries to resolve the relative path to a file in the first directory
+ and serves it if it exists. Otherwise, it tries to resolve the relative
+ path to a file in the second directory and serves it if it exists, etc.
+
+ Args:
+ path:
+ Relative path to a file in the local datasource.
+ directories:
+ List of directories to search for the file.
+
+
+ Returns:
+ Absolute path to the file in the local datasource or an empty string
+ if the file doesn't exist.
+
+ """
+ stripped_path = path.lstrip("/")
+ for directory in directories:
+ if (directory / stripped_path).exists():
+ return str(directory / stripped_path)
+ return "" # Not found.
diff --git a/lightly/api/utils.py b/lightly/api/utils.py
index 5960fa7f6..d80bb2d67 100644
--- a/lightly/api/utils.py
+++ b/lightly/api/utils.py
@@ -24,7 +24,7 @@
RETRY_MAX_RETRIES = 5
-def retry(func, *args, **kwargs):
+def retry(func, *args, **kwargs): # type: ignore
"""Repeats a function until it completes successfully or fails too often.
Args:
diff --git a/lightly/api/version_checking.py b/lightly/api/version_checking.py
deleted file mode 100644
index 175179a5f..000000000
--- a/lightly/api/version_checking.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import signal
-import warnings
-from typing import Tuple
-
-from lightly.api import utils
-from lightly.api.swagger_api_client import LightlySwaggerApiClient
-from lightly.openapi_generated.swagger_client.api import VersioningApi
-from lightly.utils import version_compare
-
-
-class LightlyAPITimeoutException(Exception):
- pass
-
-
-class TimeoutDecorator:
- def __init__(self, seconds):
- self.seconds = seconds
-
- def handle_timeout_method(self, *args, **kwargs):
- raise LightlyAPITimeoutException
-
- def __enter__(self):
- signal.signal(signal.SIGALRM, self.handle_timeout_method)
- signal.alarm(self.seconds)
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- signal.alarm(0)
-
-
-def is_latest_version(current_version: str) -> bool:
- with TimeoutDecorator(1):
- versioning_api = get_versioning_api()
- latest_version: str = versioning_api.get_latest_pip_version(
- current_version=current_version
- )
- return version_compare.version_compare(current_version, latest_version) >= 0
-
-
-def is_compatible_version(current_version: str) -> bool:
- with TimeoutDecorator(1):
- versioning_api = get_versioning_api()
- minimum_version: str = versioning_api.get_minimum_compatible_pip_version()
- return version_compare.version_compare(current_version, minimum_version) >= 0
-
-
-def get_versioning_api() -> VersioningApi:
- configuration = utils.get_api_client_configuration(
- raise_if_no_token_specified=False,
- )
- api_client = LightlySwaggerApiClient(configuration=configuration)
- versioning_api = VersioningApi(api_client=api_client)
- return versioning_api
-
-
-def get_latest_version(current_version: str) -> Tuple[None, str]:
- try:
- versioning_api = get_versioning_api()
- version_number: str = versioning_api.get_latest_pip_version(
- current_version=current_version
- )
- return version_number
- except Exception as e:
- return None
-
-
-def get_minimum_compatible_version():
- versioning_api = get_versioning_api()
- version_number: str = versioning_api.get_minimum_compatible_pip_version()
- return version_number
-
-
-def pretty_print_latest_version(current_version, latest_version, width=70):
- warning = (
- f"You are using lightly version {current_version}. "
- f"There is a newer version of the package available. "
- f"For compatability reasons, please upgrade your current version: "
- f"pip install lightly=={latest_version}"
- )
- warnings.warn(Warning(warning))
diff --git a/lightly/cli/config/lightly-serve.yaml b/lightly/cli/config/lightly-serve.yaml
new file mode 100644
index 000000000..28ee095e1
--- /dev/null
+++ b/lightly/cli/config/lightly-serve.yaml
@@ -0,0 +1,16 @@
+input_mount: '' # Path to the input directory.
+lightly_mount: '' # Path to the lightly directory.
+host: 'localhost' # Hostname for serving the data.
+port: 3456 # Port for serving the data.
+
+
+# Disable Hydra log directories.
+defaults:
+ - _self_
+ - override hydra/hydra_logging: disabled
+ - override hydra/job_logging: disabled
+
+hydra:
+ output_subdir: null
+ run:
+ dir: .
diff --git a/lightly/cli/serve_cli.py b/lightly/cli/serve_cli.py
new file mode 100644
index 000000000..285e61519
--- /dev/null
+++ b/lightly/cli/serve_cli.py
@@ -0,0 +1,48 @@
+import sys
+from pathlib import Path
+
+import hydra
+
+from lightly.api import serve
+from lightly.cli._helpers import fix_hydra_arguments
+
+
+@hydra.main(**fix_hydra_arguments(config_path="config", config_name="lightly-serve"))
+def lightly_serve(cfg):
+ """Use lightly-serve to serve your data for interactive exploration.
+
+ Command-Line Args:
+ input_mount:
+ Path to the input directory.
+ lightly_mount:
+ Path to the Lightly directory.
+ host:
+ Hostname for serving the data (defaults to localhost).
+ port:
+ Port for serving the data (defaults to 3456).
+
+ Examples:
+ >>> lightly-serve input_mount=data/ lightly_mount=lightly/ port=3456
+
+
+ """
+ if not cfg.input_mount:
+ print("Please provide a valid input mount. Use --help for more information.")
+ sys.exit(1)
+
+ if not cfg.lightly_mount:
+ print("Please provide a valid Lightly mount. Use --help for more information.")
+ sys.exit(1)
+
+ httpd = serve.get_server(
+ paths=[Path(cfg.input_mount), Path(cfg.lightly_mount)],
+ host=cfg.host,
+ port=cfg.port,
+ )
+ print(f"Starting server, listening at '{httpd.server_name}:{httpd.server_port}'")
+ print(f"Serving files in '{cfg.input_mount}' and '{cfg.lightly_mount}'")
+ httpd.serve_forever()
+
+
+def entry() -> None:
+ lightly_serve()
diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py
index 4cd1ce87d..da7fd24b1 100644
--- a/lightly/data/dataset.py
+++ b/lightly/data/dataset.py
@@ -9,13 +9,11 @@
import torchvision.datasets as datasets
from PIL import Image
-from torch._C import Value
from torchvision import transforms
from torchvision.datasets.vision import StandardTransform, VisionDataset
from lightly.data._helpers import DatasetFolder, _load_dataset_from_folder
from lightly.data._video import VideoDataset
-from lightly.utils.io import check_filenames
def _get_filename_by_index(dataset, index):
@@ -177,11 +175,6 @@ def is_valid_file(filepath: str):
if index_to_filename is not None:
self.index_to_filename = index_to_filename
- # if created from an input directory with filenames, check if they
- # are valid
- if input_dir:
- check_filenames(self.get_filenames())
-
@classmethod
def from_torch_dataset(cls, dataset, transform=None, index_to_filename=None):
"""Builds a LightlyDataset from a PyTorch (or torchvision) dataset.
diff --git a/lightly/embedding/__init__.py b/lightly/embedding/__init__.py
index f2bb98c2d..3ee7a5480 100644
--- a/lightly/embedding/__init__.py
+++ b/lightly/embedding/__init__.py
@@ -8,5 +8,6 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+
from lightly.embedding._base import BaseEmbedding
from lightly.embedding.embedding import SelfSupervisedEmbedding
diff --git a/lightly/embedding/_base.py b/lightly/embedding/_base.py
index 0a0e0e7cb..8d5086ba6 100644
--- a/lightly/embedding/_base.py
+++ b/lightly/embedding/_base.py
@@ -4,18 +4,34 @@
# All Rights Reserved
import copy
import os
+from typing import Any, List, Optional, Sequence, Tuple, Union
import omegaconf
from omegaconf import DictConfig
from pytorch_lightning import LightningModule, Trainer
-
+from pytorch_lightning.callbacks.callback import Callback
+from torch import Tensor
+from torch.nn import Module
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+
+from lightly.data.dataset import LightlyDataset
from lightly.embedding import callbacks
+from lightly.utils.benchmarking import BenchmarkModule
class BaseEmbedding(LightningModule):
"""All trainable embeddings must inherit from BaseEmbedding."""
- def __init__(self, model, criterion, optimizer, dataloader, scheduler=None):
+ def __init__(
+ self,
+ model: BenchmarkModule,
+ criterion: Module,
+ optimizer: Optimizer,
+ dataloader: DataLoader[LightlyDataset],
+ scheduler: Optional[_LRScheduler] = None,
+ ) -> None:
"""Constructor
Args:
@@ -32,30 +48,35 @@ def __init__(self, model, criterion, optimizer, dataloader, scheduler=None):
self.optimizer = optimizer
self.dataloader = dataloader
self.scheduler = scheduler
- self.checkpoint = None
+ self.checkpoint: Optional[str] = None
self.cwd = os.getcwd()
- def forward(self, x0, x1):
- return self.model(x0, x1)
+ def forward(self, x0: Tensor, x1: Tensor) -> Tensor:
+ embedding: Tensor = self.model(x0, x1)
+ return embedding
- def training_step(self, batch, batch_idx):
+ def training_step(
+ self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
+ ) -> Tensor:
# get the two image transformations
(x0, x1), _, _ = batch
# forward pass of the transformations
y0, y1 = self(x0, x1)
# calculate loss
- loss = self.criterion(y0, y1)
+ loss: Tensor = self.criterion(y0, y1)
# log loss and return
self.log("loss", loss)
return loss
- def configure_optimizers(self):
+ def configure_optimizers(
+ self,
+ ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[_LRScheduler]]]:
if self.scheduler is None:
return self.optimizer
else:
return [self.optimizer], [self.scheduler]
- def train_dataloader(self):
+ def train_dataloader(self) -> DataLoader[LightlyDataset]:
return self.dataloader
def train_embedding(
@@ -63,7 +84,7 @@ def train_embedding(
trainer_config: DictConfig,
checkpoint_callback_config: DictConfig,
summary_callback_config: DictConfig,
- ):
+ ) -> None:
"""Train the model on the provided dataset.
Args:
@@ -80,9 +101,9 @@ def train_embedding(
A trained encoder, ready for embedding datasets.
"""
- trainer_callbacks = []
+ trainer_callbacks: List[Callback] = []
- checkpoint_cb = callbacks.create_checkpoint_callback(
+ checkpoint_cb = callbacks.create_checkpoint_callback( # type: ignore[misc]
**checkpoint_callback_config
)
trainer_callbacks.append(checkpoint_cb)
@@ -100,14 +121,13 @@ def train_embedding(
if "weights_summary" in trainer_config_copy:
with omegaconf.open_dict(trainer_config_copy):
del trainer_config_copy["weights_summary"]
-
- trainer = Trainer(**trainer_config_copy, callbacks=trainer_callbacks)
+ trainer = Trainer(**trainer_config_copy, callbacks=trainer_callbacks) # type: ignore[misc]
trainer.fit(self)
if checkpoint_cb.best_model_path != "":
self.checkpoint = os.path.join(self.cwd, checkpoint_cb.best_model_path)
- def embed(self, *args, **kwargs):
+ def embed(self, *args: Any, **kwargs: Any) -> Any:
"""Must be implemented by classes which inherit from BaseEmbedding."""
raise NotImplementedError()
diff --git a/lightly/embedding/callbacks.py b/lightly/embedding/callbacks.py
index af586c23a..b04aa0ae3 100644
--- a/lightly/embedding/callbacks.py
+++ b/lightly/embedding/callbacks.py
@@ -1,4 +1,5 @@
import os
+from typing import Optional
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
@@ -7,10 +8,10 @@
def create_checkpoint_callback(
- save_last=False,
- save_top_k=0,
- monitor="loss",
- dirpath=None,
+ save_last: bool = False,
+ save_top_k: int = 0,
+ monitor: str = "loss",
+ dirpath: Optional[str] = None,
) -> ModelCheckpoint:
"""Initializes the checkpoint callback.
@@ -44,7 +45,7 @@ def create_summary_callback(
if weights_summary not in [None, "None"]:
return _create_summary_callback_deprecated(weights_summary)
else:
- return _create_summary_callback(**summary_callback_config)
+ return _create_summary_callback(summary_callback_config["max_depth"])
def _create_summary_callback(max_depth: int) -> ModelSummary:
diff --git a/lightly/embedding/embedding.py b/lightly/embedding/embedding.py
index 5439fea3f..9a1c704c8 100644
--- a/lightly/embedding/embedding.py
+++ b/lightly/embedding/embedding.py
@@ -4,14 +4,20 @@
# All Rights Reserved
import time
-from typing import List, Tuple, Union
+from typing import List, Optional, Tuple
import numpy as np
import torch
+from numpy.typing import NDArray
+from torch.nn import Module
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
from tqdm import tqdm
-import lightly
+from lightly.data import LightlyDataset
from lightly.embedding._base import BaseEmbedding
+from lightly.utils.benchmarking import BenchmarkModule
from lightly.utils.reordering import sort_items_by_keys
@@ -59,19 +65,21 @@ class SelfSupervisedEmbedding(BaseEmbedding):
def __init__(
self,
- model: torch.nn.Module,
- criterion: torch.nn.Module,
- optimizer: torch.optim.Optimizer,
- dataloader: torch.utils.data.DataLoader,
- scheduler=None,
- ):
+ model: BenchmarkModule,
+ criterion: Module,
+ optimizer: Optimizer,
+ dataloader: DataLoader[LightlyDataset],
+ scheduler: Optional[_LRScheduler] = None,
+ ) -> None:
super(SelfSupervisedEmbedding, self).__init__(
model, criterion, optimizer, dataloader, scheduler
)
def embed(
- self, dataloader: torch.utils.data.DataLoader, device: torch.device = None
- ) -> Tuple[np.ndarray, List[int], List[str]]:
+ self,
+ dataloader: DataLoader[LightlyDataset],
+ device: Optional[torch.device] = None,
+ ) -> Tuple[NDArray[np.float_], List[int], List[str]]:
"""Embeds images in a vector space.
Args:
@@ -99,15 +107,15 @@ def embed(
"""
self.model.eval()
- embeddings, labels, filenames = None, None, []
+ filenames = []
- dataset = dataloader.dataset
+ dataset: LightlyDataset = dataloader.dataset
pbar = tqdm(total=len(dataset), unit="imgs")
efficiency = 0.0
- embeddings = []
- labels = []
+ embeddings: List[NDArray[np.float_]] = []
+ labels: List[int] = []
with torch.no_grad():
start_timepoint = time.time()
for image_batch, label_batch, filename_batch in dataloader:
@@ -125,8 +133,8 @@ def embed(
embedding_batch = self.model.backbone(image_batch)
embedding_batch = embedding_batch.detach().reshape(batch_size, -1)
- embeddings.append(embedding_batch)
- labels.append(label_batch)
+ embeddings.extend(embedding_batch.cpu().numpy())
+ labels.extend(label_batch.cpu().tolist())
finished_timepoint = time.time()
@@ -140,16 +148,14 @@ def embed(
pbar.update(batch_size)
- embeddings = torch.cat(embeddings, 0)
- labels = torch.cat(labels, 0)
-
- embeddings = embeddings.cpu().numpy()
- labels = labels.cpu().numpy()
-
sorted_filenames = dataset.get_filenames()
- sorted_embeddings = sort_items_by_keys(filenames, embeddings, sorted_filenames)
- sorted_labels = sort_items_by_keys(filenames, labels, sorted_filenames)
- embeddings = np.stack(sorted_embeddings)
- labels = np.stack(sorted_labels).tolist()
+ sorted_embeddings = sort_items_by_keys(
+ keys=filenames,
+ items=embeddings,
+ sorted_keys=sorted_filenames,
+ )
+ sorted_labels = sort_items_by_keys(
+ keys=filenames, items=labels, sorted_keys=sorted_filenames
+ )
- return embeddings, labels, sorted_filenames
+ return np.stack(sorted_embeddings), sorted_labels, sorted_filenames
diff --git a/lightly/loss/__init__.py b/lightly/loss/__init__.py
index 706045fd1..613707abe 100644
--- a/lightly/loss/__init__.py
+++ b/lightly/loss/__init__.py
@@ -2,7 +2,6 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
-
from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.dino_loss import DINOLoss
diff --git a/lightly/loss/barlow_twins_loss.py b/lightly/loss/barlow_twins_loss.py
index 16a005427..54ee08496 100644
--- a/lightly/loss/barlow_twins_loss.py
+++ b/lightly/loss/barlow_twins_loss.py
@@ -1,5 +1,8 @@
+from typing import Tuple
+
import torch
import torch.distributed as dist
+import torch.nn.functional as F
class BarlowTwinsLoss(torch.nn.Module):
@@ -48,17 +51,14 @@ def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False)
)
def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
- device = z_a.device
-
# normalize repr. along the batch dimension
- z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
- z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD
+ z_a_norm, z_b_norm = _normalize(z_a, z_b)
N = z_a.size(0)
- D = z_a.size(1)
# cross-correlation matrix
- c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD
+ c = z_a_norm.T @ z_b_norm
+ c.div_(N)
# sum cross-correlation matrix between multiple gpus
if self.gather_distributed and dist.is_initialized():
@@ -67,10 +67,31 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
c = c / world_size
dist.all_reduce(c)
- # loss
- c_diff = (c - torch.eye(D, device=device)).pow(2) # DxD
- # multiply off-diagonal elems of c_diff by lambda
- c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
- loss = c_diff.sum()
+ invariance_loss = torch.diagonal(c).add_(-1).pow_(2).sum()
+ redundancy_reduction_loss = _off_diagonal(c).pow_(2).sum()
+ loss = invariance_loss + self.lambda_param * redundancy_reduction_loss
return loss
+
+
+def _normalize(
+ z_a: torch.Tensor, z_b: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Helper function to normalize tensors along the batch dimension."""
+ combined = torch.stack([z_a, z_b], dim=0) # Shape: 2 x N x D
+ normalized = F.batch_norm(
+ combined.flatten(0, 1),
+ running_mean=None,
+ running_var=None,
+ weight=None,
+ bias=None,
+ training=True,
+ ).view_as(combined)
+ return normalized[0], normalized[1]
+
+
+def _off_diagonal(x):
+ # return a flattened view of the off-diagonal elements of a square matrix
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
diff --git a/lightly/loss/pmsn_loss.py b/lightly/loss/pmsn_loss.py
index b8a66b9d2..15e97e2cb 100644
--- a/lightly/loss/pmsn_loss.py
+++ b/lightly/loss/pmsn_loss.py
@@ -68,7 +68,9 @@ def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor:
exponent=self.power_law_exponent,
device=mean_anchor_probs.device,
)
- loss = F.kl_div(input=mean_anchor_probs, target=power_dist, reduction="sum")
+ loss = F.kl_div(
+ input=mean_anchor_probs.log(), target=power_dist, reduction="sum"
+ )
return loss
@@ -139,7 +141,9 @@ def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor:
target_dist = self.target_distribution(mean_anchor_probs).to(
mean_anchor_probs.device
)
- loss = F.kl_div(input=mean_anchor_probs, target=target_dist, reduction="sum")
+ loss = F.kl_div(
+ input=mean_anchor_probs.log(), target=target_dist, reduction="sum"
+ )
return loss
diff --git a/lightly/models/_momentum.py b/lightly/models/_momentum.py
index f7cdec507..94192cc24 100644
--- a/lightly/models/_momentum.py
+++ b/lightly/models/_momentum.py
@@ -4,18 +4,23 @@
# All Rights Reserved
import copy
+from typing import Iterable, Tuple
import torch
import torch.nn as nn
+from torch import Tensor
+from torch.nn.parameter import Parameter
-def _deactivate_requires_grad(params):
+def _deactivate_requires_grad(params: Iterable[Parameter]) -> None:
"""Deactivates the requires_grad flag for all parameters."""
for param in params:
param.requires_grad = False
-def _do_momentum_update(prev_params, params, m):
+def _do_momentum_update(
+ prev_params: Iterable[Parameter], params: Iterable[Parameter], m: float
+) -> None:
"""Updates the weights of the previous parameters."""
for prev_param, param in zip(prev_params, params):
prev_param.data = prev_param.data * m + param.data * (1.0 - m)
@@ -42,7 +47,7 @@ class _MomentumEncoderMixin:
>>> # initialize momentum_backbone and momentum_projection_head
>>> self._init_momentum_encoder()
>>>
- >>> def forward(self, x: torch.Tensor):
+ >>> def forward(self, x: Tensor):
>>>
>>> # do the momentum update
>>> self._momentum_update(0.999)
@@ -59,7 +64,7 @@ class _MomentumEncoderMixin:
momentum_backbone: nn.Module
momentum_projection_head: nn.Module
- def _init_momentum_encoder(self):
+ def _init_momentum_encoder(self) -> None:
"""Initializes momentum backbone and a momentum projection head."""
assert self.backbone is not None
assert self.projection_head is not None
@@ -71,7 +76,7 @@ def _init_momentum_encoder(self):
_deactivate_requires_grad(self.momentum_projection_head.parameters())
@torch.no_grad()
- def _momentum_update(self, m: float = 0.999):
+ def _momentum_update(self, m: float = 0.999) -> None:
"""Performs the momentum update for the backbone and projection head."""
_do_momentum_update(
self.momentum_backbone.parameters(),
@@ -85,14 +90,14 @@ def _momentum_update(self, m: float = 0.999):
)
@torch.no_grad()
- def _batch_shuffle(self, batch: torch.Tensor):
+ def _batch_shuffle(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
"""Returns the shuffled batch and the indices to undo."""
batch_size = batch.shape[0]
shuffle = torch.randperm(batch_size, device=batch.device)
return batch[shuffle], shuffle
@torch.no_grad()
- def _batch_unshuffle(self, batch: torch.Tensor, shuffle: torch.Tensor):
+ def _batch_unshuffle(self, batch: Tensor, shuffle: Tensor) -> Tensor:
"""Returns the unshuffled batch."""
unshuffle = torch.argsort(shuffle)
return batch[unshuffle]
diff --git a/lightly/models/batchnorm.py b/lightly/models/batchnorm.py
index e4424bbad..7f05fe48e 100644
--- a/lightly/models/batchnorm.py
+++ b/lightly/models/batchnorm.py
@@ -3,8 +3,13 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from __future__ import annotations
+
+from typing import Any
+
import torch
import torch.nn as nn
+from torch import Tensor
class SplitBatchNorm(nn.BatchNorm2d):
@@ -21,27 +26,30 @@ class SplitBatchNorm(nn.BatchNorm2d):
"""
- def __init__(self, num_features, num_splits, **kw):
+ def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None:
super().__init__(num_features, **kw)
self.num_splits = num_splits
+ # Register buffers
self.register_buffer(
"running_mean", torch.zeros(num_features * self.num_splits)
)
self.register_buffer("running_var", torch.ones(num_features * self.num_splits))
- def train(self, mode=True):
+ def train(self, mode: bool = True) -> SplitBatchNorm:
# lazily collate stats when we are going to use them
if (self.training is True) and (mode is False):
+ assert self.running_mean is not None
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
+ assert self.running_var is not None
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
return super().train(mode)
- def forward(self, input):
+ def forward(self, input: Tensor) -> Tensor:
"""Computes the SplitBatchNorm on the input."""
# get input shape
N, C, H, W = input.shape
@@ -60,10 +68,12 @@ def forward(self, input):
self.eps,
).view(N, C, H, W)
else:
+ # We have to ignore the type errors here, because we know that running_mean
+ # and running_var are not None, but the type checker does not.
result = nn.functional.batch_norm(
input,
- self.running_mean[: self.num_features],
- self.running_var[: self.num_features],
+ self.running_mean[: self.num_features], # type: ignore[index]
+ self.running_var[: self.num_features], # type: ignore[index]
self.weight,
self.bias,
False,
@@ -74,7 +84,7 @@ def forward(self, input):
return result
-def get_norm_layer(num_features: int, num_splits: int, **kw):
+def get_norm_layer(num_features: int, num_splits: int, **kw: Any) -> nn.Module:
"""Utility to switch between BatchNorm2d and SplitBatchNorm."""
if num_splits > 0:
return SplitBatchNorm(num_features, num_splits)
diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py
index a6ab268d7..541b49a29 100644
--- a/lightly/models/modules/heads.py
+++ b/lightly/models/modules/heads.py
@@ -7,6 +7,7 @@
import torch
import torch.nn as nn
+from torch import Tensor
from lightly.models import utils
@@ -33,10 +34,10 @@ class ProjectionHead(nn.Module):
def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
- ):
+ ) -> None:
super(ProjectionHead, self).__init__()
- layers = []
+ layers: List[nn.Module] = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
@@ -46,7 +47,7 @@ def __init__(
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the projection head.
Args:
@@ -54,7 +55,8 @@ def forward(self, x: torch.Tensor):
Input of shape bsz x num_ftrs.
"""
- return self.layers(x)
+ projection: Tensor = self.layers(x)
+ return projection
class BarlowTwinsProjectionHead(ProjectionHead):
@@ -324,7 +326,7 @@ class SMoGPrototypes(nn.Module):
def __init__(
self,
- group_features: torch.Tensor,
+ group_features: Tensor,
beta: float,
):
super(SMoGPrototypes, self).__init__()
@@ -332,8 +334,8 @@ def __init__(
self.beta = beta
def forward(
- self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1
- ) -> torch.Tensor:
+ self, x: Tensor, group_features: Tensor, temperature: float = 0.1
+ ) -> Tensor:
"""Computes the logits for given model outputs and group features.
Args:
@@ -353,7 +355,7 @@ def forward(
logits = torch.mm(x, group_features.t())
return logits / temperature
- def get_updated_group_features(self, x: torch.Tensor) -> None:
+ def get_updated_group_features(self, x: Tensor) -> Tensor:
"""Performs the synchronous momentum update of the group vectors.
Args:
@@ -370,23 +372,23 @@ def get_updated_group_features(self, x: torch.Tensor) -> None:
mask = assignments == assigned_class
group_features[assigned_class] = self.beta * self.group_features[
assigned_class
- ] + (1 - self.beta) * x[mask].mean(axis=0)
+ ] + (1 - self.beta) * x[mask].mean(dim=0)
return group_features
- def set_group_features(self, x: torch.Tensor) -> None:
+ def set_group_features(self, x: Tensor) -> None:
"""Sets the group features and asserts they don't require gradient."""
self.group_features.data = x.to(self.group_features.device)
@torch.no_grad()
- def assign_groups(self, x: torch.Tensor) -> torch.LongTensor:
+ def assign_groups(self, x: Tensor) -> Tensor:
"""Assigns each representation in x to a group based on cosine similarity.
Args:
Tensor of shape bsz x dim.
Returns:
- LongTensor of shape bsz indicating group assignments.
+ Tensor of shape bsz indicating group assignments.
"""
return torch.argmax(self.forward(x, self.group_features), dim=-1)
@@ -524,19 +526,21 @@ def __init__(
)
self.n_steps_frozen_prototypes = n_steps_frozen_prototypes
- def forward(self, x, step=None) -> Union[torch.Tensor, List[torch.Tensor]]:
+ def forward(
+ self, x: Tensor, step: Optional[int] = None
+ ) -> Union[Tensor, List[Tensor]]:
self._freeze_prototypes_if_required(step)
out = []
for layer in self.heads:
out.append(layer(x))
return out[0] if self._is_single_prototype else out
- def normalize(self):
+ def normalize(self) -> None:
"""Normalizes the prototypes so that they are on the unit sphere."""
for layer in self.heads:
utils.normalize_weight(layer.weight)
- def _freeze_prototypes_if_required(self, step):
+ def _freeze_prototypes_if_required(self, step: Optional[int] = None) -> None:
if self.n_steps_frozen_prototypes > 0:
if step is None:
raise ValueError(
@@ -601,22 +605,23 @@ def __init__(
)
self.apply(self._init_weights)
self.freeze_last_layer = freeze_last_layer
- self.last_layer = nn.utils.weight_norm(
- nn.Linear(bottleneck_dim, output_dim, bias=False)
- )
- self.last_layer.weight_g.data.fill_(1)
+ self.last_layer = nn.Linear(bottleneck_dim, output_dim, bias=False)
+ self.last_layer = nn.utils.weight_norm(self.last_layer)
+ # Tell mypy this is ok because fill_ is overloaded.
+ self.last_layer.weight_g.data.fill_(1) # type: ignore
+
# Option to normalize last layer.
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
- def cancel_last_layer_gradients(self, current_epoch: int):
+ def cancel_last_layer_gradients(self, current_epoch: int) -> None:
"""Cancel last layer gradients to stabilize the training."""
if current_epoch >= self.freeze_last_layer:
return
for param in self.last_layer.parameters():
param.grad = None
- def _init_weights(self, module):
+ def _init_weights(self, module: nn.Module) -> None:
"""Initializes layers with a truncated normal distribution."""
if isinstance(module, nn.Linear):
utils._no_grad_trunc_normal(
@@ -629,7 +634,7 @@ def _init_weights(self, module):
if module.bias is not None:
nn.init.constant_(module.bias, 0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: Tensor) -> Tensor:
"""Computes one forward pass through the head."""
x = self.layers(x)
# l2 normalization
@@ -694,6 +699,37 @@ def __init__(
)
+class VICRegProjectionHead(ProjectionHead):
+ """Projection head used for VICReg.
+
+ "The projector network has three linear layers, each with 8192 output
+ units. The first two layers of the projector are followed by a batch
+ normalization layer and rectified linear units." [0]
+
+ [0]: 2022, VICReg, https://arxiv.org/pdf/2105.04906.pdf
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 2048,
+ hidden_dim: int = 8192,
+ output_dim: int = 8192,
+ num_layers: int = 3,
+ ):
+ hidden_layers = [
+ (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU())
+ for _ in range(num_layers - 2) # Exclude first and last layer.
+ ]
+ super(VICRegProjectionHead, self).__init__(
+ [
+ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
+ *hidden_layers,
+ (hidden_dim, output_dim, None, None),
+ ]
+ )
+
+
class VicRegLLocalProjectionHead(ProjectionHead):
"""Projection head used for the local head of VICRegL.
diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py
index 3eb14a247..47bfca7ad 100644
--- a/lightly/models/modules/ijepa.py
+++ b/lightly/models/modules/ijepa.py
@@ -72,7 +72,7 @@ def __init__(
self.predictor_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False
)
- predictor_pos_embed = _get_2d_sincos_pos_embed(
+ predictor_pos_embed = utils.get_2d_sincos_pos_embed(
self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False
)
self.predictor_pos_embed.data.copy_(
@@ -431,66 +431,3 @@ def images_to_tokens(
if prepend_class_token:
tokens = utils.prepend_class_token(tokens, self.class_token)
return tokens
-
-
-def _get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=float)
- grid_w = np.arange(grid_size, dtype=float)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
-
-def _get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- """
- grid_size: int of the grid length
- return:
- pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid = np.arange(grid_size, dtype=float)
- pos_embed = _get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=float)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
diff --git a/lightly/models/modules/memory_bank.py b/lightly/models/modules/memory_bank.py
index 255a38308..7eba9a4b5 100644
--- a/lightly/models/modules/memory_bank.py
+++ b/lightly/models/modules/memory_bank.py
@@ -3,7 +3,10 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import Optional, Tuple, Union
+
import torch
+from torch import Tensor
from lightly.models import utils
@@ -32,8 +35,8 @@ class MemoryBankModule(torch.nn.Module):
>>> def __init__(self, memory_bank_size: int = 2 ** 16):
>>> super(MyLossFunction, self).__init__(memory_bank_size)
>>>
- >>> def forward(self, output: torch.Tensor,
- >>> labels: torch.Tensor = None):
+ >>> def forward(self, output: Tensor,
+ >>> labels: Tensor = None):
>>>
>>> output, negatives = super(
>>> MyLossFunction, self).forward(output)
@@ -62,7 +65,7 @@ def __init__(self, size: int = 65536, gather_distributed: bool = False):
)
@torch.no_grad()
- def _init_memory_bank(self, dim: int):
+ def _init_memory_bank(self, dim: int) -> None:
"""Initialize the memory bank if it's empty
Args:
@@ -74,12 +77,12 @@ def _init_memory_bank(self, dim: int):
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
- self.bank = torch.randn(dim, self.size).type_as(self.bank)
- self.bank = torch.nn.functional.normalize(self.bank, dim=0)
- self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
+ bank: Tensor = torch.randn(dim, self.size).type_as(self.bank)
+ self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0)
+ self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr)
@torch.no_grad()
- def _dequeue_and_enqueue(self, batch: torch.Tensor):
+ def _dequeue_and_enqueue(self, batch: Tensor) -> None:
"""Dequeue the oldest batch and add the latest one
Args:
@@ -100,8 +103,11 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor):
self.bank_ptr[0] = ptr + batch_size
def forward(
- self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
- ):
+ self,
+ output: Tensor,
+ labels: Optional[Tensor] = None,
+ update: bool = False,
+ ) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]:
"""Query memory bank for additional negative samples
Args:
diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py
index 624c3cda4..777fcee0a 100644
--- a/lightly/models/modules/nn_memory_bank.py
+++ b/lightly/models/modules/nn_memory_bank.py
@@ -3,7 +3,10 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import Optional
+
import torch
+from torch import Tensor
from lightly.models.modules.memory_bank import MemoryBankModule
@@ -19,8 +22,7 @@ class NNMemoryBankModule(MemoryBankModule):
Attributes:
size:
- Number of keys the memory bank can store. If set to 0,
- memory bank is not used.
+ Number of keys the memory bank can store.
Examples:
>>> model = NNCLR(backbone)
@@ -38,9 +40,15 @@ class NNMemoryBankModule(MemoryBankModule):
"""
def __init__(self, size: int = 2**16):
+ if size <= 0:
+ raise ValueError(f"Memory bank size must be positive, got {size}.")
super(NNMemoryBankModule, self).__init__(size)
- def forward(self, output: torch.Tensor, update: bool = False):
+ def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class.
+ self,
+ output: Tensor,
+ update: bool = False,
+ ) -> Tensor:
"""Returns nearest neighbour of output tensor from memory bank
Args:
@@ -50,6 +58,7 @@ def forward(self, output: torch.Tensor, update: bool = False):
"""
output, bank = super(NNMemoryBankModule, self).forward(output, update=update)
+ assert bank is not None
bank = bank.to(output.device).t()
output_normed = torch.nn.functional.normalize(output, dim=1)
diff --git a/lightly/models/resnet.py b/lightly/models/resnet.py
index 44c6c5829..6bff53626 100644
--- a/lightly/models/resnet.py
+++ b/lightly/models/resnet.py
@@ -10,12 +10,13 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from __future__ import annotations
from typing import List
-import torch
import torch.nn as nn
import torch.nn.functional as F
+from torch import Tensor
from lightly.models.batchnorm import get_norm_layer
@@ -62,7 +63,7 @@ def __init__(
get_norm_layer(self.expansion * planes, num_splits),
)
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor) -> Tensor:
"""Forward pass through basic ResNet block.
Args:
@@ -73,7 +74,7 @@ def forward(self, x: torch.Tensor):
Tensor of shape bsz x channels x W x H
"""
- out = self.conv1(x)
+ out: Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
@@ -132,7 +133,7 @@ def __init__(
get_norm_layer(self.expansion * planes, num_splits),
)
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
"""Forward pass through bottleneck ResNet block.
Args:
@@ -143,7 +144,7 @@ def forward(self, x):
Tensor of shape bsz x channels x W x H
"""
- out = self.conv1(x)
+ out: Tensor = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
@@ -179,7 +180,7 @@ class ResNet(nn.Module):
def __init__(
self,
- block: nn.Module = BasicBlock,
+ block: type[BasicBlock] = BasicBlock,
layers: List[int] = [2, 2, 2, 2],
num_classes: int = 10,
width: float = 1.0,
@@ -208,15 +209,22 @@ def __init__(
)
self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes)
- def _make_layer(self, block, planes, layers, stride, num_splits):
- strides = [stride] + [1] * (layers - 1)
+ def _make_layer(
+ self,
+ block: type[BasicBlock],
+ planes: int,
+ num_layers: int,
+ stride: int,
+ num_splits: int,
+ ) -> nn.Sequential:
+ strides = [stride] + [1] * (num_layers - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, num_splits))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
- def forward(self, x: torch.Tensor):
+ def forward(self, x: Tensor) -> Tensor:
"""Forward pass through ResNet.
Args:
@@ -243,7 +251,7 @@ def ResNetGenerator(
width: float = 1,
num_classes: int = 10,
num_splits: int = 0,
-):
+) -> ResNet:
"""Builds and returns the specified ResNet.
Args:
@@ -286,8 +294,8 @@ def ResNetGenerator(
)
return ResNet(
- **model_params[name],
+ **model_params[name], # type: ignore # Cannot unpack dict to type "ResNet".
width=width,
num_classes=num_classes,
- num_splits=num_splits
+ num_splits=num_splits,
)
diff --git a/lightly/models/utils.py b/lightly/models/utils.py
index 7215395a3..bd7c4e26d 100644
--- a/lightly/models/utils.py
+++ b/lightly/models/utils.py
@@ -7,9 +7,11 @@
import warnings
from typing import Iterable, List, Optional, Tuple, Union
+import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
+from numpy.typing import NDArray
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
@@ -596,3 +598,119 @@ def repeat_interleave_batch(x, B, repeat):
dim=0,
)
return x
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int, grid_size: int, cls_token: bool = False
+) -> NDArray[np.float_]:
+ """Returns 2D sin-cos embeddings. Code from [0].
+
+ - [0]: https://github.com/facebookresearch/ijepa
+
+ Args:
+ embed_dim:
+ Embedding dimension.
+ grid_size:
+ Grid height and width. Should usually be set to sqrt(sequence length).
+ cls_token:
+ If True, a positional embedding for the class token is prepended to the returned
+ embeddings.
+
+ Returns:
+ Positional embeddings array with size (grid_size * grid_size, embed_dim) if cls_token is False.
+ If cls_token is True, a (1 + grid_size * grid_size, embed_dim) array is returned.
+ """
+ grid_h = np.arange(grid_size, dtype=float)
+ grid_w = np.arange(grid_size, dtype=float)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(
+ embed_dim: int, grid: NDArray[np.int_]
+) -> NDArray[np.float_]:
+ """Returns 2D sin-cos embeddings grid. Code from [0].
+
+ - [0]: https://github.com/facebookresearch/ijepa
+
+ Args:
+ embed_dim:
+ Embedding dimension.
+ grid:
+ 2-dimensional grid to embed.
+
+ Returns:
+ Positional embeddings array with size (grid_size * grid_size, embed_dim).
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed(
+ embed_dim: int, grid_size: int, cls_token: bool = False
+) -> NDArray[np.float_]:
+ """Returns 1D sin-cos embeddings. Code from [0].
+
+ - [0]: https://github.com/facebookresearch/ijepa
+
+ Args:
+ embed_dim:
+ Embedding dimension.
+ grid_size:
+ Grid height and width. Should usually be set to sqrt(sequence length).
+ cls_token:
+ If True, a positional embedding for the class token is prepended to the returned
+ embeddings.
+
+ Returns:
+ Positional embeddings array with size (grid_size, embed_dim) if cls_token is False.
+ If cls_token is True, a (1 + grid_size, embed_dim) array is returned.
+ """
+ grid = np.arange(grid_size, dtype=float)
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: NDArray[np.int_]
+) -> NDArray[np.float_]:
+ """Returns 1D sin-cos embeddings grid. Code from [0].
+
+ - [0]: https://github.com/facebookresearch/ijepa
+
+ Args:
+ embed_dim:
+ Embedding dimension.
+ pos:
+ 1-dimensional grid to embed.
+
+ Returns:
+ Positional embeddings array with size (grid_size, embed_dim).
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
diff --git a/lightly/models/zoo.py b/lightly/models/zoo.py
index 34d7dbede..34c6e1868 100644
--- a/lightly/models/zoo.py
+++ b/lightly/models/zoo.py
@@ -3,6 +3,8 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import List
+
ZOO = {
"resnet-9/simclr/d16/w0.0625": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth",
"resnet-9/simclr/d16/w0.125": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth",
@@ -13,7 +15,7 @@
}
-def checkpoints():
+def checkpoints() -> List[str]:
"""Returns the Lightly model zoo as a list of checkpoints.
Checkpoints:
diff --git a/lightly/openapi_generated/swagger_client/__init__.py b/lightly/openapi_generated/swagger_client/__init__.py
index a03a7a3b8..8d3328d74 100644
--- a/lightly/openapi_generated/swagger_client/__init__.py
+++ b/lightly/openapi_generated/swagger_client/__init__.py
@@ -50,6 +50,8 @@
# import models into sdk package
from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest
from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData
+from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data
+from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data
from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode
from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse
from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData
@@ -75,10 +77,12 @@
from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure
from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
+from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath
from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS
from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY
from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL
+from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS
from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3
@@ -213,13 +217,23 @@
from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod
from lightly.openapi_generated.swagger_client.models.sector import Sector
from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig
+from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase
from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry
from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput
from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange
from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName
from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType
from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation
from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType
+from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3
from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData
from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest
from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest
diff --git a/lightly/openapi_generated/swagger_client/__init__.py-e b/lightly/openapi_generated/swagger_client/__init__.py-e
new file mode 100644
index 000000000..c4f85819a
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/__init__.py-e
@@ -0,0 +1,275 @@
+# coding: utf-8
+
+# flake8: noqa
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+__version__ = "1.0.0"
+
+# import apis into sdk package
+from lightly.openapi_generated.swagger_client.api.collaboration_api import CollaborationApi
+from lightly.openapi_generated.swagger_client.api.datasets_api import DatasetsApi
+from lightly.openapi_generated.swagger_client.api.datasources_api import DatasourcesApi
+from lightly.openapi_generated.swagger_client.api.docker_api import DockerApi
+from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi
+from lightly.openapi_generated.swagger_client.api.embeddings2d_api import Embeddings2dApi
+from lightly.openapi_generated.swagger_client.api.jobs_api import JobsApi
+from lightly.openapi_generated.swagger_client.api.mappings_api import MappingsApi
+from lightly.openapi_generated.swagger_client.api.meta_data_configurations_api import MetaDataConfigurationsApi
+from lightly.openapi_generated.swagger_client.api.predictions_api import PredictionsApi
+from lightly.openapi_generated.swagger_client.api.profiles_api import ProfilesApi
+from lightly.openapi_generated.swagger_client.api.quota_api import QuotaApi
+from lightly.openapi_generated.swagger_client.api.samples_api import SamplesApi
+from lightly.openapi_generated.swagger_client.api.samplings_api import SamplingsApi
+from lightly.openapi_generated.swagger_client.api.scores_api import ScoresApi
+from lightly.openapi_generated.swagger_client.api.tags_api import TagsApi
+from lightly.openapi_generated.swagger_client.api.teams_api import TeamsApi
+from lightly.openapi_generated.swagger_client.api.versioning_api import VersioningApi
+
+# import ApiClient
+from lightly.openapi_generated.swagger_client.api_response import ApiResponse
+from lightly.openapi_generated.swagger_client.api_client import ApiClient
+from lightly.openapi_generated.swagger_client.configuration import Configuration
+from lightly.openapi_generated.swagger_client.exceptions import OpenApiException
+from lightly.openapi_generated.swagger_client.exceptions import ApiTypeError
+from lightly.openapi_generated.swagger_client.exceptions import ApiValueError
+from lightly.openapi_generated.swagger_client.exceptions import ApiKeyError
+from lightly.openapi_generated.swagger_client.exceptions import ApiAttributeError
+from lightly.openapi_generated.swagger_client.exceptions import ApiException
+
+# import models into sdk package
+from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest
+from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData
+from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data
+from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data
+from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode
+from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse
+from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData
+from lightly.openapi_generated.swagger_client.models.configuration_data import ConfigurationData
+from lightly.openapi_generated.swagger_client.models.configuration_entry import ConfigurationEntry
+from lightly.openapi_generated.swagger_client.models.configuration_set_request import ConfigurationSetRequest
+from lightly.openapi_generated.swagger_client.models.configuration_value_data_type import ConfigurationValueDataType
+from lightly.openapi_generated.swagger_client.models.create_cf_bucket_activity_request import CreateCFBucketActivityRequest
+from lightly.openapi_generated.swagger_client.models.create_docker_worker_registry_entry_request import CreateDockerWorkerRegistryEntryRequest
+from lightly.openapi_generated.swagger_client.models.create_entity_response import CreateEntityResponse
+from lightly.openapi_generated.swagger_client.models.create_sample_with_write_urls_response import CreateSampleWithWriteUrlsResponse
+from lightly.openapi_generated.swagger_client.models.create_team_membership_request import CreateTeamMembershipRequest
+from lightly.openapi_generated.swagger_client.models.creator import Creator
+from lightly.openapi_generated.swagger_client.models.crop_data import CropData
+from lightly.openapi_generated.swagger_client.models.dataset_create_request import DatasetCreateRequest
+from lightly.openapi_generated.swagger_client.models.dataset_creator import DatasetCreator
+from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData
+from lightly.openapi_generated.swagger_client.models.dataset_data_enriched import DatasetDataEnriched
+from lightly.openapi_generated.swagger_client.models.dataset_embedding_data import DatasetEmbeddingData
+from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType
+from lightly.openapi_generated.swagger_client.models.dataset_update_request import DatasetUpdateRequest
+from lightly.openapi_generated.swagger_client.models.datasource_config import DatasourceConfig
+from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure
+from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
+from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath
+from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS
+from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY
+from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL
+from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS
+from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3
+from lightly.openapi_generated.swagger_client.models.datasource_config_s3_all_of import DatasourceConfigS3AllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_s3_delegated_access import DatasourceConfigS3DelegatedAccess
+from lightly.openapi_generated.swagger_client.models.datasource_config_s3_delegated_access_all_of import DatasourceConfigS3DelegatedAccessAllOf
+from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data import DatasourceConfigVerifyData
+from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data_errors import DatasourceConfigVerifyDataErrors
+from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_request import DatasourceProcessedUntilTimestampRequest
+from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import DatasourceProcessedUntilTimestampResponse
+from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import DatasourceRawSamplesData
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import DatasourceRawSamplesDataRow
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_metadata_data import DatasourceRawSamplesMetadataData
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_metadata_data_row import DatasourceRawSamplesMetadataDataRow
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data import DatasourceRawSamplesPredictionsData
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data_row import DatasourceRawSamplesPredictionsDataRow
+from lightly.openapi_generated.swagger_client.models.dimensionality_reduction_method import DimensionalityReductionMethod
+from lightly.openapi_generated.swagger_client.models.docker_license_information import DockerLicenseInformation
+from lightly.openapi_generated.swagger_client.models.docker_run_artifact_create_request import DockerRunArtifactCreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_run_artifact_created_data import DockerRunArtifactCreatedData
+from lightly.openapi_generated.swagger_client.models.docker_run_artifact_data import DockerRunArtifactData
+from lightly.openapi_generated.swagger_client.models.docker_run_artifact_storage_location import DockerRunArtifactStorageLocation
+from lightly.openapi_generated.swagger_client.models.docker_run_artifact_type import DockerRunArtifactType
+from lightly.openapi_generated.swagger_client.models.docker_run_create_request import DockerRunCreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_run_data import DockerRunData
+from lightly.openapi_generated.swagger_client.models.docker_run_log_data import DockerRunLogData
+from lightly.openapi_generated.swagger_client.models.docker_run_log_entry_data import DockerRunLogEntryData
+from lightly.openapi_generated.swagger_client.models.docker_run_log_level import DockerRunLogLevel
+from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_create_request import DockerRunScheduledCreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_data import DockerRunScheduledData
+from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_priority import DockerRunScheduledPriority
+from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_state import DockerRunScheduledState
+from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_update_request import DockerRunScheduledUpdateRequest
+from lightly.openapi_generated.swagger_client.models.docker_run_state import DockerRunState
+from lightly.openapi_generated.swagger_client.models.docker_run_update_request import DockerRunUpdateRequest
+from lightly.openapi_generated.swagger_client.models.docker_task_description import DockerTaskDescription
+from lightly.openapi_generated.swagger_client.models.docker_user_stats import DockerUserStats
+from lightly.openapi_generated.swagger_client.models.docker_worker_config import DockerWorkerConfig
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_create_request import DockerWorkerConfigCreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_data import DockerWorkerConfigData
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2 import DockerWorkerConfigV2
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_create_request import DockerWorkerConfigV2CreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_data import DockerWorkerConfigV2Data
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker import DockerWorkerConfigV2Docker
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker_object_level import DockerWorkerConfigV2DockerObjectLevel
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker_stopping_condition import DockerWorkerConfigV2DockerStoppingCondition
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly import DockerWorkerConfigV2Lightly
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_collate import DockerWorkerConfigV2LightlyCollate
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_model import DockerWorkerConfigV2LightlyModel
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_trainer import DockerWorkerConfigV2LightlyTrainer
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3 import DockerWorkerConfigV3
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_create_request import DockerWorkerConfigV3CreateRequest
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_data import DockerWorkerConfigV3Data
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker import DockerWorkerConfigV3Docker
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_corruptness_check import DockerWorkerConfigV3DockerCorruptnessCheck
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_datasource import DockerWorkerConfigV3DockerDatasource
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_training import DockerWorkerConfigV3DockerTraining
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly import DockerWorkerConfigV3Lightly
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_checkpoint_callback import DockerWorkerConfigV3LightlyCheckpointCallback
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_collate import DockerWorkerConfigV3LightlyCollate
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_criterion import DockerWorkerConfigV3LightlyCriterion
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_loader import DockerWorkerConfigV3LightlyLoader
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_model import DockerWorkerConfigV3LightlyModel
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_optimizer import DockerWorkerConfigV3LightlyOptimizer
+from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_trainer import DockerWorkerConfigV3LightlyTrainer
+from lightly.openapi_generated.swagger_client.models.docker_worker_registry_entry_data import DockerWorkerRegistryEntryData
+from lightly.openapi_generated.swagger_client.models.docker_worker_state import DockerWorkerState
+from lightly.openapi_generated.swagger_client.models.docker_worker_type import DockerWorkerType
+from lightly.openapi_generated.swagger_client.models.embedding2d_create_request import Embedding2dCreateRequest
+from lightly.openapi_generated.swagger_client.models.embedding2d_data import Embedding2dData
+from lightly.openapi_generated.swagger_client.models.embedding_data import EmbeddingData
+from lightly.openapi_generated.swagger_client.models.file_name_format import FileNameFormat
+from lightly.openapi_generated.swagger_client.models.file_output_format import FileOutputFormat
+from lightly.openapi_generated.swagger_client.models.filename_and_read_url import FilenameAndReadUrl
+from lightly.openapi_generated.swagger_client.models.image_type import ImageType
+from lightly.openapi_generated.swagger_client.models.initial_tag_create_request import InitialTagCreateRequest
+from lightly.openapi_generated.swagger_client.models.job_result_type import JobResultType
+from lightly.openapi_generated.swagger_client.models.job_state import JobState
+from lightly.openapi_generated.swagger_client.models.job_status_data import JobStatusData
+from lightly.openapi_generated.swagger_client.models.job_status_data_result import JobStatusDataResult
+from lightly.openapi_generated.swagger_client.models.job_status_meta import JobStatusMeta
+from lightly.openapi_generated.swagger_client.models.job_status_upload_method import JobStatusUploadMethod
+from lightly.openapi_generated.swagger_client.models.jobs_data import JobsData
+from lightly.openapi_generated.swagger_client.models.label_box_data_row import LabelBoxDataRow
+from lightly.openapi_generated.swagger_client.models.label_box_v4_data_row import LabelBoxV4DataRow
+from lightly.openapi_generated.swagger_client.models.label_studio_task import LabelStudioTask
+from lightly.openapi_generated.swagger_client.models.label_studio_task_data import LabelStudioTaskData
+from lightly.openapi_generated.swagger_client.models.lightly_docker_selection_method import LightlyDockerSelectionMethod
+from lightly.openapi_generated.swagger_client.models.lightly_model_v2 import LightlyModelV2
+from lightly.openapi_generated.swagger_client.models.lightly_model_v3 import LightlyModelV3
+from lightly.openapi_generated.swagger_client.models.lightly_trainer_precision_v2 import LightlyTrainerPrecisionV2
+from lightly.openapi_generated.swagger_client.models.lightly_trainer_precision_v3 import LightlyTrainerPrecisionV3
+from lightly.openapi_generated.swagger_client.models.prediction_singleton import PredictionSingleton
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_base import PredictionSingletonBase
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_classification import PredictionSingletonClassification
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_classification_all_of import PredictionSingletonClassificationAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_instance_segmentation import PredictionSingletonInstanceSegmentation
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_instance_segmentation_all_of import PredictionSingletonInstanceSegmentationAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_keypoint_detection import PredictionSingletonKeypointDetection
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_keypoint_detection_all_of import PredictionSingletonKeypointDetectionAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_object_detection import PredictionSingletonObjectDetection
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_object_detection_all_of import PredictionSingletonObjectDetectionAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation import PredictionSingletonSemanticSegmentation
+from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation_all_of import PredictionSingletonSemanticSegmentationAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints_all_of import PredictionTaskSchemaCategoryKeypointsAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint import PredictionTaskSchemaKeypoint
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint_all_of import PredictionTaskSchemaKeypointAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple import PredictionTaskSchemaSimple
+from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple_all_of import PredictionTaskSchemaSimpleAllOf
+from lightly.openapi_generated.swagger_client.models.prediction_task_schemas import PredictionTaskSchemas
+from lightly.openapi_generated.swagger_client.models.profile_basic_data import ProfileBasicData
+from lightly.openapi_generated.swagger_client.models.profile_me_data import ProfileMeData
+from lightly.openapi_generated.swagger_client.models.profile_me_data_settings import ProfileMeDataSettings
+from lightly.openapi_generated.swagger_client.models.questionnaire_data import QuestionnaireData
+from lightly.openapi_generated.swagger_client.models.s3_region import S3Region
+from lightly.openapi_generated.swagger_client.models.sama_task import SamaTask
+from lightly.openapi_generated.swagger_client.models.sama_task_data import SamaTaskData
+from lightly.openapi_generated.swagger_client.models.sample_create_request import SampleCreateRequest
+from lightly.openapi_generated.swagger_client.models.sample_data import SampleData
+from lightly.openapi_generated.swagger_client.models.sample_data_modes import SampleDataModes
+from lightly.openapi_generated.swagger_client.models.sample_meta_data import SampleMetaData
+from lightly.openapi_generated.swagger_client.models.sample_partial_mode import SamplePartialMode
+from lightly.openapi_generated.swagger_client.models.sample_sort_by import SampleSortBy
+from lightly.openapi_generated.swagger_client.models.sample_type import SampleType
+from lightly.openapi_generated.swagger_client.models.sample_update_request import SampleUpdateRequest
+from lightly.openapi_generated.swagger_client.models.sample_write_urls import SampleWriteUrls
+from lightly.openapi_generated.swagger_client.models.sampling_config import SamplingConfig
+from lightly.openapi_generated.swagger_client.models.sampling_config_stopping_condition import SamplingConfigStoppingCondition
+from lightly.openapi_generated.swagger_client.models.sampling_create_request import SamplingCreateRequest
+from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod
+from lightly.openapi_generated.swagger_client.models.sector import Sector
+from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig
+from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase
+from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry
+from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput
+from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange
+from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName
+from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType
+from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation
+from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType
+from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData
+from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest
+from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest
+from lightly.openapi_generated.swagger_client.models.shared_access_config_data import SharedAccessConfigData
+from lightly.openapi_generated.swagger_client.models.shared_access_type import SharedAccessType
+from lightly.openapi_generated.swagger_client.models.tag_active_learning_scores_data import TagActiveLearningScoresData
+from lightly.openapi_generated.swagger_client.models.tag_arithmetics_operation import TagArithmeticsOperation
+from lightly.openapi_generated.swagger_client.models.tag_arithmetics_request import TagArithmeticsRequest
+from lightly.openapi_generated.swagger_client.models.tag_arithmetics_response import TagArithmeticsResponse
+from lightly.openapi_generated.swagger_client.models.tag_bit_mask_response import TagBitMaskResponse
+from lightly.openapi_generated.swagger_client.models.tag_change_data import TagChangeData
+from lightly.openapi_generated.swagger_client.models.tag_change_data_arithmetics import TagChangeDataArithmetics
+from lightly.openapi_generated.swagger_client.models.tag_change_data_initial import TagChangeDataInitial
+from lightly.openapi_generated.swagger_client.models.tag_change_data_metadata import TagChangeDataMetadata
+from lightly.openapi_generated.swagger_client.models.tag_change_data_operation_method import TagChangeDataOperationMethod
+from lightly.openapi_generated.swagger_client.models.tag_change_data_rename import TagChangeDataRename
+from lightly.openapi_generated.swagger_client.models.tag_change_data_sampler import TagChangeDataSampler
+from lightly.openapi_generated.swagger_client.models.tag_change_data_samples import TagChangeDataSamples
+from lightly.openapi_generated.swagger_client.models.tag_change_data_scatterplot import TagChangeDataScatterplot
+from lightly.openapi_generated.swagger_client.models.tag_change_data_upsize import TagChangeDataUpsize
+from lightly.openapi_generated.swagger_client.models.tag_change_entry import TagChangeEntry
+from lightly.openapi_generated.swagger_client.models.tag_create_request import TagCreateRequest
+from lightly.openapi_generated.swagger_client.models.tag_creator import TagCreator
+from lightly.openapi_generated.swagger_client.models.tag_data import TagData
+from lightly.openapi_generated.swagger_client.models.tag_update_request import TagUpdateRequest
+from lightly.openapi_generated.swagger_client.models.tag_upsize_request import TagUpsizeRequest
+from lightly.openapi_generated.swagger_client.models.task_type import TaskType
+from lightly.openapi_generated.swagger_client.models.team_basic_data import TeamBasicData
+from lightly.openapi_generated.swagger_client.models.team_data import TeamData
+from lightly.openapi_generated.swagger_client.models.team_role import TeamRole
+from lightly.openapi_generated.swagger_client.models.trigger2d_embedding_job_request import Trigger2dEmbeddingJobRequest
+from lightly.openapi_generated.swagger_client.models.update_docker_worker_registry_entry_request import UpdateDockerWorkerRegistryEntryRequest
+from lightly.openapi_generated.swagger_client.models.update_team_membership_request import UpdateTeamMembershipRequest
+from lightly.openapi_generated.swagger_client.models.user_type import UserType
+from lightly.openapi_generated.swagger_client.models.video_frame_data import VideoFrameData
+from lightly.openapi_generated.swagger_client.models.write_csv_url_data import WriteCSVUrlData
diff --git a/lightly/openapi_generated/swagger_client/api/datasources_api.py b/lightly/openapi_generated/swagger_client/api/datasources_api.py
index d409937b4..63d82eac1 100644
--- a/lightly/openapi_generated/swagger_client/api/datasources_api.py
+++ b/lightly/openapi_generated/swagger_client/api/datasources_api.py
@@ -54,10 +54,10 @@ def __init__(self, api_client=None):
self.api_client = api_client
@validate_arguments
- def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the readUrl for")], **kwargs) -> str: # noqa: E501
+ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501
"""get_custom_embedding_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501
+ Get the GET ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -66,7 +66,7 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, datas
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the csv file within the embeddings folder to get the readUrl for (required)
+ :param file_name: The name of the csv file within the embeddings folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -85,10 +85,10 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, datas
return self.get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501
@validate_arguments
- def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
+ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
"""get_custom_embedding_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501
+ Get the GET ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -97,7 +97,7 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_i
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the csv file within the embeddings folder to get the readUrl for (required)
+ :param file_name: The name of the csv file within the embeddings folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -209,9 +209,9 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_i
@validate_arguments
def get_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], purpose : Annotated[Optional[DatasourcePurpose], Field(description="Which datasource with which purpose we want to get. Defaults to INPUT_OUTPUT")] = None, **kwargs) -> DatasourceConfig: # noqa: E501
- """get_datasource_by_dataset_id # noqa: E501
+ """(Deprecated) get_datasource_by_dataset_id # noqa: E501
- Get the datasource of a dataset # noqa: E501
+ DEPRECATED - use getDatasourcesByDatasetId. Get the datasource of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -240,9 +240,9 @@ def get_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True
@validate_arguments
def get_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], purpose : Annotated[Optional[DatasourcePurpose], Field(description="Which datasource with which purpose we want to get. Defaults to INPUT_OUTPUT")] = None, **kwargs) -> ApiResponse: # noqa: E501
- """get_datasource_by_dataset_id # noqa: E501
+ """(Deprecated) get_datasource_by_dataset_id # noqa: E501
- Get the datasource of a dataset # noqa: E501
+ DEPRECATED - use getDatasourcesByDatasetId. Get the datasource of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -278,6 +278,8 @@ def get_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[con
:rtype: tuple(DatasourceConfig, status_code(int), headers(HTTPHeaderDict))
"""
+ warnings.warn("GET /v1/datasets/{datasetId}/datasource is deprecated.", DeprecationWarning)
+
_params = locals()
_all_params = [
@@ -647,6 +649,160 @@ def get_datasources_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
collection_formats=_collection_formats,
_request_auth=_params.get('_request_auth'))
+ @validate_arguments
+ def get_head_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The name of the file within the the datasource to get a HEAD readUrl or GET readURL")], **kwargs) -> str: # noqa: E501
+ """get_head_file_read_url_from_datasource_by_dataset_id # noqa: E501
+
+ Get a HEAD ReadURL of a file within datasources. Can only be used for HEAD request, no GET requests. # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_head_file_read_url_from_datasource_by_dataset_id(dataset_id, file_name, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param file_name: The name of the file within the the datasource to get a HEAD readUrl or GET readURL (required)
+ :type file_name: str
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: str
+ """
+ kwargs['_return_http_data_only'] = True
+ if '_preload_content' in kwargs:
+ raise ValueError("Error! Please call the get_head_file_read_url_from_datasource_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501
+
+ @validate_arguments
+ def get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The name of the file within the the datasource to get a HEAD readUrl or GET readURL")], **kwargs) -> ApiResponse: # noqa: E501
+ """get_head_file_read_url_from_datasource_by_dataset_id # noqa: E501
+
+ Get a HEAD ReadURL of a file within datasources. Can only be used for HEAD request, no GET requests. # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param file_name: The name of the file within the the datasource to get a HEAD readUrl or GET readURL (required)
+ :type file_name: str
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _preload_content: if False, the ApiResponse.data will
+ be set to none and raw_data will store the
+ HTTP response body without reading/decoding.
+ Default is True.
+ :type _preload_content: bool, optional
+ :param _return_http_data_only: response data instead of ApiResponse
+ object with status code, headers, etc
+ :type _return_http_data_only: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :param _request_auth: set to override the auth_settings for an a single
+ request; this effectively ignores the authentication
+ in the spec for a single request.
+ :type _request_auth: dict, optional
+ :type _content_type: string, optional: force content-type for the request
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: tuple(str, status_code(int), headers(HTTPHeaderDict))
+ """
+
+ _params = locals()
+
+ _all_params = [
+ 'dataset_id',
+ 'file_name'
+ ]
+ _all_params.extend(
+ [
+ 'async_req',
+ '_return_http_data_only',
+ '_preload_content',
+ '_request_timeout',
+ '_request_auth',
+ '_content_type',
+ '_headers'
+ ]
+ )
+
+ # validate the arguments
+ for _key, _val in _params['kwargs'].items():
+ if _key not in _all_params:
+ raise ApiTypeError(
+ "Got an unexpected keyword argument '%s'"
+ " to method get_head_file_read_url_from_datasource_by_dataset_id" % _key
+ )
+ _params[_key] = _val
+ del _params['kwargs']
+
+ _collection_formats = {}
+
+ # process the path parameters
+ _path_params = {}
+ if _params['dataset_id']:
+ _path_params['datasetId'] = _params['dataset_id']
+
+
+ # process the query parameters
+ _query_params = []
+ if _params.get('file_name') is not None: # noqa: E501
+ _query_params.append((
+ 'fileName',
+ _params['file_name'].value if hasattr(_params['file_name'], 'value') else _params['file_name']
+ ))
+
+ # process the header parameters
+ _header_params = dict(_params.get('_headers', {}))
+ # process the form parameters
+ _form_params = []
+ _files = {}
+ # process the body parameter
+ _body_params = None
+ # set the HTTP header `Accept`
+ _header_params['Accept'] = self.api_client.select_header_accept(
+ ['application/json']) # noqa: E501
+
+ # authentication setting
+ _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
+
+ _response_types_map = {
+ '200': "str",
+ '400': "ApiErrorResponse",
+ '401': "ApiErrorResponse",
+ '403': "ApiErrorResponse",
+ '404': "ApiErrorResponse",
+ }
+
+ return self.api_client.call_api(
+ '/v1/datasets/{datasetId}/datasource/fileHEAD', 'GET',
+ _path_params,
+ _query_params,
+ _header_params,
+ body=_body_params,
+ post_params=_form_params,
+ files=_files,
+ response_types_map=_response_types_map,
+ auth_settings=_auth_settings,
+ async_req=_params.get('async_req'),
+ _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501
+ _preload_content=_params.get('_preload_content', True),
+ _request_timeout=_params.get('_request_timeout'),
+ collection_formats=_collection_formats,
+ _request_auth=_params.get('_request_auth'))
+
@validate_arguments
def get_list_of_raw_samples_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], var_from : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Unix timestamp, only samples with a creation date after `from` will be returned. This parameter is ignored if `cursor` is specified. ")] = None, to : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Unix timestamp, only samples with a creation date before `to` will be returned. This parameter is ignored if `cursor` is specified. ")] = None, cursor : Annotated[Optional[StrictStr], Field(description="Cursor from previous request, encodes `from` and `to` parameters. Specify to continue reading samples from the list. ")] = None, use_redirected_read_url : Annotated[Optional[StrictBool], Field(description="By default this is set to false unless a S3DelegatedAccess is configured in which case its always true and this param has no effect. When true this will return RedirectedReadUrls instead of ReadUrls meaning that returned URLs allow for unlimited access to the file ")] = None, relevant_filenames_file_name : Annotated[Optional[constr(strict=True, min_length=4)], Field(description="The name of the file within your datasource which contains a list of relevant filenames to list. See https://docs.lightly.ai/docker/getting_started/first_steps.html#specify-relevant-files for more details ")] = None, **kwargs) -> DatasourceRawSamplesData: # noqa: E501
"""get_list_of_raw_samples_from_datasource_by_dataset_id # noqa: E501
@@ -1297,10 +1453,10 @@ def get_list_of_raw_samples_predictions_from_datasource_by_dataset_id_with_http_
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the readUrl for")], **kwargs) -> str: # noqa: E501
+ def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501
"""get_metadata_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501
+ Get the GET ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -1309,7 +1465,7 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id :
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the metadata folder to get the readUrl for (required)
+ :param file_name: The name of the file within the metadata folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -1328,10 +1484,10 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id :
return self.get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501
@validate_arguments
- def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
+ def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
"""get_metadata_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501
+ Get the GET ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -1340,7 +1496,7 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the metadata folder to get the readUrl for (required)
+ :param file_name: The name of the file within the metadata folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -1451,10 +1607,10 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> str: # noqa: E501
+ def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501
"""get_prediction_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
+ Get the GET ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -1463,7 +1619,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the prediction folder to get the readUrl for (required)
+ :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -1482,10 +1638,10 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id
return self.get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501
@validate_arguments
- def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
+ def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
"""get_prediction_file_read_url_from_datasource_by_dataset_id # noqa: E501
- Get the ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
+ Get the GET ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -1494,7 +1650,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(se
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the prediction folder to get the readUrl for (required)
+ :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -1605,7 +1761,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(se
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> str: # noqa: E501
+ def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501
"""get_prediction_file_write_url_from_datasource_by_dataset_id # noqa: E501
Get the WriteURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
@@ -1617,7 +1773,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the prediction folder to get the readUrl for (required)
+ :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -1636,7 +1792,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id
return self.get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501
@validate_arguments
- def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
+ def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501
"""get_prediction_file_write_url_from_datasource_by_dataset_id # noqa: E501
Get the WriteURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501
@@ -1648,7 +1804,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(s
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: The name of the file within the prediction folder to get the readUrl for (required)
+ :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required)
:type file_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
diff --git a/lightly/openapi_generated/swagger_client/api/docker_api.py b/lightly/openapi_generated/swagger_client/api/docker_api.py
index 633383cbc..107f16037 100644
--- a/lightly/openapi_generated/swagger_client/api/docker_api.py
+++ b/lightly/openapi_generated/swagger_client/api/docker_api.py
@@ -2150,7 +2150,7 @@ def get_docker_run_logs_by_id_with_http_info(self, run_id : Annotated[constr(str
def get_docker_run_report_read_url_by_id(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> str: # noqa: E501
"""(Deprecated) get_docker_run_report_read_url_by_id # noqa: E501
- Get the url of a specific docker runs report # noqa: E501
+ DEPRECATED, use getDockerRunArtifactReadUrlById - Get the url of a specific docker runs report # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2179,7 +2179,7 @@ def get_docker_run_report_read_url_by_id(self, run_id : Annotated[constr(strict=
def get_docker_run_report_read_url_by_id_with_http_info(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> ApiResponse: # noqa: E501
"""(Deprecated) get_docker_run_report_read_url_by_id # noqa: E501
- Get the url of a specific docker runs report # noqa: E501
+ DEPRECATED, use getDockerRunArtifactReadUrlById - Get the url of a specific docker runs report # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2295,7 +2295,7 @@ def get_docker_run_report_read_url_by_id_with_http_info(self, run_id : Annotated
def get_docker_run_report_write_url_by_id(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> str: # noqa: E501
"""(Deprecated) get_docker_run_report_write_url_by_id # noqa: E501
- Get the signed url to upload a report of a docker run # noqa: E501
+ DEPRECATED, use createDockerRunArtifact - Get the signed url to upload a report of a docker run # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2324,7 +2324,7 @@ def get_docker_run_report_write_url_by_id(self, run_id : Annotated[constr(strict
def get_docker_run_report_write_url_by_id_with_http_info(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> ApiResponse: # noqa: E501
"""(Deprecated) get_docker_run_report_write_url_by_id # noqa: E501
- Get the signed url to upload a report of a docker run # noqa: E501
+ DEPRECATED, use createDockerRunArtifact - Get the signed url to upload a report of a docker run # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2580,14 +2580,14 @@ def get_docker_run_tags_with_http_info(self, run_id : Annotated[constr(strict=Tr
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> List[DockerRunData]: # noqa: E501
+ def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> List[DockerRunData]: # noqa: E501
"""get_docker_runs # noqa: E501
Gets all docker runs for a user. # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_docker_runs(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True)
+ >>> thread = api.get_docker_runs(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True)
>>> result = thread.get()
:param page_size: pagination size/limit of the number of samples to return
@@ -2598,6 +2598,8 @@ def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=
:type get_assets_of_team: bool
:param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user
:type get_assets_of_team_inclusive_self: bool
+ :param show_archived: if this flag is true, we also get the archived assets
+ :type show_archived: bool
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -2612,17 +2614,17 @@ def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
raise ValueError("Error! Please call the get_docker_runs_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, **kwargs) # noqa: E501
+ return self.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, **kwargs) # noqa: E501
@validate_arguments
- def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""get_docker_runs # noqa: E501
Gets all docker runs for a user. # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True)
+ >>> thread = api.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True)
>>> result = thread.get()
:param page_size: pagination size/limit of the number of samples to return
@@ -2633,6 +2635,8 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s
:type get_assets_of_team: bool
:param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user
:type get_assets_of_team_inclusive_self: bool
+ :param show_archived: if this flag is true, we also get the archived assets
+ :type show_archived: bool
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -2664,7 +2668,8 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s
'page_size',
'page_offset',
'get_assets_of_team',
- 'get_assets_of_team_inclusive_self'
+ 'get_assets_of_team_inclusive_self',
+ 'show_archived'
]
_all_params.extend(
[
@@ -2719,6 +2724,12 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s
_params['get_assets_of_team_inclusive_self'].value if hasattr(_params['get_assets_of_team_inclusive_self'], 'value') else _params['get_assets_of_team_inclusive_self']
))
+ if _params.get('show_archived') is not None: # noqa: E501
+ _query_params.append((
+ 'showArchived',
+ _params['show_archived'].value if hasattr(_params['show_archived'], 'value') else _params['show_archived']
+ ))
+
# process the header parameters
_header_params = dict(_params.get('_headers', {}))
# process the form parameters
@@ -2759,20 +2770,22 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> str: # noqa: E501
+ def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> str: # noqa: E501
"""get_docker_runs_count # noqa: E501
Gets the total count of the amount of runs existing for a user # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_docker_runs_count(get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True)
+ >>> thread = api.get_docker_runs_count(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True)
>>> result = thread.get()
:param get_assets_of_team: if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user
:type get_assets_of_team: bool
:param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user
:type get_assets_of_team_inclusive_self: bool
+ :param show_archived: if this flag is true, we also get the archived assets
+ :type show_archived: bool
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -2787,23 +2800,25 @@ def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBo
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
raise ValueError("Error! Please call the get_docker_runs_count_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, **kwargs) # noqa: E501
+ return self.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, **kwargs) # noqa: E501
@validate_arguments
- def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""get_docker_runs_count # noqa: E501
Gets the total count of the amount of runs existing for a user # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True)
+ >>> thread = api.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True)
>>> result = thread.get()
:param get_assets_of_team: if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user
:type get_assets_of_team: bool
:param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user
:type get_assets_of_team_inclusive_self: bool
+ :param show_archived: if this flag is true, we also get the archived assets
+ :type show_archived: bool
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -2833,7 +2848,8 @@ def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Op
_all_params = [
'get_assets_of_team',
- 'get_assets_of_team_inclusive_self'
+ 'get_assets_of_team_inclusive_self',
+ 'show_archived'
]
_all_params.extend(
[
@@ -2876,6 +2892,12 @@ def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Op
_params['get_assets_of_team_inclusive_self'].value if hasattr(_params['get_assets_of_team_inclusive_self'], 'value') else _params['get_assets_of_team_inclusive_self']
))
+ if _params.get('show_archived') is not None: # noqa: E501
+ _query_params.append((
+ 'showArchived',
+ _params['show_archived'].value if hasattr(_params['show_archived'], 'value') else _params['show_archived']
+ ))
+
# process the header parameters
_header_params = dict(_params.get('_headers', {}))
# process the form parameters
diff --git a/lightly/openapi_generated/swagger_client/api/embeddings_api.py b/lightly/openapi_generated/swagger_client/api/embeddings_api.py
index 8f710136d..c38036884 100644
--- a/lightly/openapi_generated/swagger_client/api/embeddings_api.py
+++ b/lightly/openapi_generated/swagger_client/api/embeddings_api.py
@@ -199,7 +199,7 @@ def delete_embedding_by_id_with_http_info(self, dataset_id : Annotated[constr(st
def get_embeddings_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], **kwargs) -> List[DatasetEmbeddingData]: # noqa: E501
"""get_embeddings_by_dataset_id # noqa: E501
- Get all annotations of a dataset # noqa: E501
+ Get all embeddings of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -228,7 +228,7 @@ def get_embeddings_by_dataset_id(self, dataset_id : Annotated[constr(strict=True
def get_embeddings_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], **kwargs) -> ApiResponse: # noqa: E501
"""get_embeddings_by_dataset_id # noqa: E501
- Get all annotations of a dataset # noqa: E501
+ Get all embeddings of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
diff --git a/lightly/openapi_generated/swagger_client/api/predictions_api.py b/lightly/openapi_generated/swagger_client/api/predictions_api.py
index 057cbdcb7..c188f45fc 100644
--- a/lightly/openapi_generated/swagger_client/api/predictions_api.py
+++ b/lightly/openapi_generated/swagger_client/api/predictions_api.py
@@ -50,24 +50,24 @@ def __init__(self, api_client=None):
self.api_client = api_client
@validate_arguments
- def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_singleton : conlist(PredictionSingleton), **kwargs) -> CreateEntityResponse: # noqa: E501
+ def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_singleton : conlist(PredictionSingleton), prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501
"""create_or_update_prediction_by_sample_id # noqa: E501
- Create/Update all the prediction singletons for a sampleId in the order/index of them being discovered # noqa: E501
+ Create/Update all the prediction singletons per taskName for a sampleId in the order/index of them being discovered # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.create_or_update_prediction_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, async_req=True)
+ >>> thread = api.create_or_update_prediction_by_sample_id(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
:param sample_id: ObjectId of the sample (required)
:type sample_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
- :type prediction_uuid_timestamp: int
:param prediction_singleton: (required)
:type prediction_singleton: List[PredictionSingleton]
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -82,27 +82,27 @@ def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
raise ValueError("Error! Please call the create_or_update_prediction_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, **kwargs) # noqa: E501
+ return self.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, **kwargs) # noqa: E501
@validate_arguments
- def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_singleton : conlist(PredictionSingleton), **kwargs) -> ApiResponse: # noqa: E501
+ def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_singleton : conlist(PredictionSingleton), prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""create_or_update_prediction_by_sample_id # noqa: E501
- Create/Update all the prediction singletons for a sampleId in the order/index of them being discovered # noqa: E501
+ Create/Update all the prediction singletons per taskName for a sampleId in the order/index of them being discovered # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, async_req=True)
+ >>> thread = api.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
:param sample_id: ObjectId of the sample (required)
:type sample_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
- :type prediction_uuid_timestamp: int
:param prediction_singleton: (required)
:type prediction_singleton: List[PredictionSingleton]
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -133,8 +133,8 @@ def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : A
_all_params = [
'dataset_id',
'sample_id',
- 'prediction_uuid_timestamp',
- 'prediction_singleton'
+ 'prediction_singleton',
+ 'prediction_uuid_timestamp'
]
_all_params.extend(
[
@@ -227,22 +227,22 @@ def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : A
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_task_schema : PredictionTaskSchema, **kwargs) -> CreateEntityResponse: # noqa: E501
+ def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_task_schema : PredictionTaskSchema, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501
"""create_or_update_prediction_task_schema_by_dataset_id # noqa: E501
Creates/updates a prediction task schema with the task name # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id(dataset_id, prediction_uuid_timestamp, prediction_task_schema, async_req=True)
+ >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id(dataset_id, prediction_task_schema, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
- :type prediction_uuid_timestamp: int
:param prediction_task_schema: (required)
:type prediction_task_schema: PredictionTaskSchema
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -257,25 +257,25 @@ def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Ann
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
raise ValueError("Error! Please call the create_or_update_prediction_task_schema_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, prediction_task_schema, **kwargs) # noqa: E501
+ return self.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_task_schema, prediction_uuid_timestamp, **kwargs) # noqa: E501
@validate_arguments
- def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_task_schema : PredictionTaskSchema, **kwargs) -> ApiResponse: # noqa: E501
+ def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_task_schema : PredictionTaskSchema, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""create_or_update_prediction_task_schema_by_dataset_id # noqa: E501
Creates/updates a prediction task schema with the task name # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, prediction_task_schema, async_req=True)
+ >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_task_schema, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
- :type prediction_uuid_timestamp: int
:param prediction_task_schema: (required)
:type prediction_task_schema: PredictionTaskSchema
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -305,8 +305,8 @@ def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, d
_all_params = [
'dataset_id',
- 'prediction_uuid_timestamp',
- 'prediction_task_schema'
+ 'prediction_task_schema',
+ 'prediction_uuid_timestamp'
]
_all_params.extend(
[
@@ -395,21 +395,21 @@ def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, d
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], **kwargs) -> List[PredictionSingleton]: # noqa: E501
- """get_prediction_by_sample_id # noqa: E501
+ def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchema: # noqa: E501
+ """get_prediction_task_schema_by_task_name # noqa: E501
- Get all prediction singletons of a specific sample of a dataset # noqa: E501
+ Get a prediction task schemas named taskName for a datasetId # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True)
+ >>> thread = api.get_prediction_task_schema_by_task_name(dataset_id, task_name, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param sample_id: ObjectId of the sample (required)
- :type sample_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param task_name: The prediction task name for which one wants to list the predictions (required)
+ :type task_name: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -420,29 +420,29 @@ def get_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True)
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: List[PredictionSingleton]
+ :rtype: PredictionTaskSchema
"""
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
- raise ValueError("Error! Please call the get_prediction_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, **kwargs) # noqa: E501
+ raise ValueError("Error! Please call the get_prediction_task_schema_by_task_name_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, task_name, prediction_uuid_timestamp, **kwargs) # noqa: E501
@validate_arguments
- def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], **kwargs) -> ApiResponse: # noqa: E501
- """get_prediction_by_sample_id # noqa: E501
+ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """get_prediction_task_schema_by_task_name # noqa: E501
- Get all prediction singletons of a specific sample of a dataset # noqa: E501
+ Get a prediction task schemas named taskName for a datasetId # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True)
+ >>> thread = api.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, task_name, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param sample_id: ObjectId of the sample (required)
- :type sample_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param task_name: The prediction task name for which one wants to list the predictions (required)
+ :type task_name: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
@@ -466,14 +466,14 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: tuple(List[PredictionSingleton], status_code(int), headers(HTTPHeaderDict))
+ :rtype: tuple(PredictionTaskSchema, status_code(int), headers(HTTPHeaderDict))
"""
_params = locals()
_all_params = [
'dataset_id',
- 'sample_id',
+ 'task_name',
'prediction_uuid_timestamp'
]
_all_params.extend(
@@ -493,7 +493,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
if _key not in _all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
- " to method get_prediction_by_sample_id" % _key
+ " to method get_prediction_task_schema_by_task_name" % _key
)
_params[_key] = _val
del _params['kwargs']
@@ -505,8 +505,8 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
if _params['dataset_id']:
_path_params['datasetId'] = _params['dataset_id']
- if _params['sample_id']:
- _path_params['sampleId'] = _params['sample_id']
+ if _params['task_name']:
+ _path_params['taskName'] = _params['task_name']
# process the query parameters
@@ -532,7 +532,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
_auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
_response_types_map = {
- '200': "List[PredictionSingleton]",
+ '200': "PredictionTaskSchema",
'400': "ApiErrorResponse",
'401': "ApiErrorResponse",
'403': "ApiErrorResponse",
@@ -540,7 +540,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
}
return self.api_client.call_api(
- '/v1/datasets/{datasetId}/predictions/samples/{sampleId}', 'GET',
+ '/v1/datasets/{datasetId}/predictions/tasks/{taskName}', 'GET',
_path_params,
_query_params,
_header_params,
@@ -557,22 +557,20 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], **kwargs) -> PredictionTaskSchema: # noqa: E501
- """get_prediction_task_schema_by_task_name # noqa: E501
+ def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchemas: # noqa: E501
+ """get_prediction_task_schemas_by_dataset_id # noqa: E501
- Get a prediction task schemas named taskName for a datasetId # noqa: E501
+ Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_task_schema_by_task_name(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
+ >>> thread = api.get_prediction_task_schemas_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
- :param task_name: The prediction task name for which one wants to list the predictions (required)
- :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -582,30 +580,28 @@ def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr(
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: PredictionTaskSchema
+ :rtype: PredictionTaskSchemas
"""
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
- raise ValueError("Error! Please call the get_prediction_task_schema_by_task_name_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501
+ raise ValueError("Error! Please call the get_prediction_task_schemas_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501
@validate_arguments
- def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], **kwargs) -> ApiResponse: # noqa: E501
- """get_prediction_task_schema_by_task_name # noqa: E501
+ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """get_prediction_task_schemas_by_dataset_id # noqa: E501
- Get a prediction task schemas named taskName for a datasetId # noqa: E501
+ Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
+ >>> thread = api.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
- :param task_name: The prediction task name for which one wants to list the predictions (required)
- :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -628,15 +624,14 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: tuple(PredictionTaskSchema, status_code(int), headers(HTTPHeaderDict))
+ :rtype: tuple(PredictionTaskSchemas, status_code(int), headers(HTTPHeaderDict))
"""
_params = locals()
_all_params = [
'dataset_id',
- 'prediction_uuid_timestamp',
- 'task_name'
+ 'prediction_uuid_timestamp'
]
_all_params.extend(
[
@@ -655,7 +650,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
if _key not in _all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
- " to method get_prediction_task_schema_by_task_name" % _key
+ " to method get_prediction_task_schemas_by_dataset_id" % _key
)
_params[_key] = _val
del _params['kwargs']
@@ -667,9 +662,6 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
if _params['dataset_id']:
_path_params['datasetId'] = _params['dataset_id']
- if _params['task_name']:
- _path_params['taskName'] = _params['task_name']
-
# process the query parameters
_query_params = []
@@ -694,7 +686,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
_auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
_response_types_map = {
- '200': "PredictionTaskSchema",
+ '200': "PredictionTaskSchemas",
'400': "ApiErrorResponse",
'401': "ApiErrorResponse",
'403': "ApiErrorResponse",
@@ -702,7 +694,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
}
return self.api_client.call_api(
- '/v1/datasets/{datasetId}/predictions/tasks/{taskName}', 'GET',
+ '/v1/datasets/{datasetId}/predictions/tasks', 'GET',
_path_params,
_query_params,
_header_params,
@@ -719,20 +711,22 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchemas: # noqa: E501
- """get_prediction_task_schemas_by_dataset_id # noqa: E501
+ def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> List[List]: # noqa: E501
+ """get_predictions_by_dataset_id # noqa: E501
- Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501
+ Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_task_schemas_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True)
+ >>> thread = api.get_predictions_by_dataset_id(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
+ :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name
+ :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -742,28 +736,30 @@ def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[const
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: PredictionTaskSchemas
+ :rtype: List[List]
"""
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
- raise ValueError("Error! Please call the get_prediction_task_schemas_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501
+ raise ValueError("Error! Please call the get_predictions_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501
@validate_arguments
- def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
- """get_prediction_task_schemas_by_dataset_id # noqa: E501
+ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """get_predictions_by_dataset_id # noqa: E501
- Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501
+ Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True)
+ >>> thread = api.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
+ :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name
+ :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -786,14 +782,15 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: tuple(PredictionTaskSchemas, status_code(int), headers(HTTPHeaderDict))
+ :rtype: tuple(List[List], status_code(int), headers(HTTPHeaderDict))
"""
_params = locals()
_all_params = [
'dataset_id',
- 'prediction_uuid_timestamp'
+ 'prediction_uuid_timestamp',
+ 'task_name'
]
_all_params.extend(
[
@@ -812,7 +809,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
if _key not in _all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
- " to method get_prediction_task_schemas_by_dataset_id" % _key
+ " to method get_predictions_by_dataset_id" % _key
)
_params[_key] = _val
del _params['kwargs']
@@ -833,6 +830,12 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
_params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp']
))
+ if _params.get('task_name') is not None: # noqa: E501
+ _query_params.append((
+ 'taskName',
+ _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name']
+ ))
+
# process the header parameters
_header_params = dict(_params.get('_headers', {}))
# process the form parameters
@@ -848,7 +851,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
_auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
_response_types_map = {
- '200': "PredictionTaskSchemas",
+ '200': "List[List]",
'400': "ApiErrorResponse",
'401': "ApiErrorResponse",
'403': "ApiErrorResponse",
@@ -856,7 +859,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
}
return self.api_client.call_api(
- '/v1/datasets/{datasetId}/predictions/tasks', 'GET',
+ '/v1/datasets/{datasetId}/predictions/samples', 'GET',
_path_params,
_query_params,
_header_params,
@@ -873,22 +876,22 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id :
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> List[List]: # noqa: E501
- """get_predictions_by_dataset_id # noqa: E501
+ def get_predictions_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> List[PredictionSingleton]: # noqa: E501
+ """get_predictions_by_sample_id # noqa: E501
- Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501
+ Get all prediction singletons of all tasks for a specific sample of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_predictions_by_dataset_id(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
+ >>> thread = api.get_predictions_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param sample_id: ObjectId of the sample (required)
+ :type sample_id: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
- :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name
- :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _request_timeout: timeout setting for this request. If one
@@ -898,30 +901,30 @@ def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=Tru
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: List[List]
+ :rtype: List[PredictionSingleton]
"""
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
- raise ValueError("Error! Please call the get_predictions_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
- return self.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501
+ raise ValueError("Error! Please call the get_predictions_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_predictions_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, **kwargs) # noqa: E501
@validate_arguments
- def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> ApiResponse: # noqa: E501
- """get_predictions_by_dataset_id # noqa: E501
+ def get_predictions_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """get_predictions_by_sample_id # noqa: E501
- Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501
+ Get all prediction singletons of all tasks for a specific sample of a dataset # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
- >>> thread = api.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True)
+ >>> thread = api.get_predictions_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True)
>>> result = thread.get()
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required)
+ :param sample_id: ObjectId of the sample (required)
+ :type sample_id: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
:type prediction_uuid_timestamp: int
- :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name
- :type task_name: str
:param async_req: Whether to execute the request asynchronously.
:type async_req: bool, optional
:param _preload_content: if False, the ApiResponse.data will
@@ -944,15 +947,15 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: tuple(List[List], status_code(int), headers(HTTPHeaderDict))
+ :rtype: tuple(List[PredictionSingleton], status_code(int), headers(HTTPHeaderDict))
"""
_params = locals()
_all_params = [
'dataset_id',
- 'prediction_uuid_timestamp',
- 'task_name'
+ 'sample_id',
+ 'prediction_uuid_timestamp'
]
_all_params.extend(
[
@@ -971,7 +974,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
if _key not in _all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
- " to method get_predictions_by_dataset_id" % _key
+ " to method get_predictions_by_sample_id" % _key
)
_params[_key] = _val
del _params['kwargs']
@@ -983,6 +986,9 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
if _params['dataset_id']:
_path_params['datasetId'] = _params['dataset_id']
+ if _params['sample_id']:
+ _path_params['sampleId'] = _params['sample_id']
+
# process the query parameters
_query_params = []
@@ -992,12 +998,6 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
_params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp']
))
- if _params.get('task_name') is not None: # noqa: E501
- _query_params.append((
- 'taskName',
- _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name']
- ))
-
# process the header parameters
_header_params = dict(_params.get('_headers', {}))
# process the form parameters
@@ -1013,7 +1013,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
_auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
_response_types_map = {
- '200': "List[List]",
+ '200': "List[PredictionSingleton]",
'400': "ApiErrorResponse",
'401': "ApiErrorResponse",
'403': "ApiErrorResponse",
@@ -1021,7 +1021,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co
}
return self.api_client.call_api(
- '/v1/datasets/{datasetId}/predictions/samples', 'GET',
+ '/v1/datasets/{datasetId}/predictions/samples/{sampleId}', 'GET',
_path_params,
_query_params,
_header_params,
diff --git a/lightly/openapi_generated/swagger_client/api/samples_api.py b/lightly/openapi_generated/swagger_client/api/samples_api.py
index 53f3ab74e..85be18c03 100644
--- a/lightly/openapi_generated/swagger_client/api/samples_api.py
+++ b/lightly/openapi_generated/swagger_client/api/samples_api.py
@@ -1153,7 +1153,7 @@ def get_sample_image_write_urls_by_id_with_http_info(self, dataset_id : Annotate
_request_auth=_params.get('_request_auth'))
@validate_arguments
- def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[SampleData]: # noqa: E501
+ def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="DEPRECATED, use without and filter yourself - Filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[SampleData]: # noqa: E501
"""get_samples_by_dataset_id # noqa: E501
Get all samples of a dataset # noqa: E501
@@ -1165,7 +1165,7 @@ def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True),
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: filter the samples by filename
+ :param file_name: DEPRECATED, use without and filter yourself - Filter the samples by filename
:type file_name: str
:param sort_by: sort the samples
:type sort_by: SampleSortBy
@@ -1190,7 +1190,7 @@ def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True),
return self.get_samples_by_dataset_id_with_http_info(dataset_id, file_name, sort_by, page_size, page_offset, **kwargs) # noqa: E501
@validate_arguments
- def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="DEPRECATED, use without and filter yourself - Filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""get_samples_by_dataset_id # noqa: E501
Get all samples of a dataset # noqa: E501
@@ -1202,7 +1202,7 @@ def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr
:param dataset_id: ObjectId of the dataset (required)
:type dataset_id: str
- :param file_name: filter the samples by filename
+ :param file_name: DEPRECATED, use without and filter yourself - Filter the samples by filename
:type file_name: str
:param sort_by: sort the samples
:type sort_by: SampleSortBy
diff --git a/lightly/openapi_generated/swagger_client/api/scores_api.py b/lightly/openapi_generated/swagger_client/api/scores_api.py
index 7c43e49cf..fadd3efcc 100644
--- a/lightly/openapi_generated/swagger_client/api/scores_api.py
+++ b/lightly/openapi_generated/swagger_client/api/scores_api.py
@@ -20,12 +20,14 @@
from pydantic import validate_arguments, ValidationError
from typing_extensions import Annotated
-from pydantic import Field, constr, validator
+from pydantic import Field, conint, constr, validator
-from typing import List
+from typing import List, Optional
from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest
from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData
+from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data
+from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data
from lightly.openapi_generated.swagger_client.models.create_entity_response import CreateEntityResponse
from lightly.openapi_generated.swagger_client.models.tag_active_learning_scores_data import TagActiveLearningScoresData
@@ -51,7 +53,7 @@ def __init__(self, api_client=None):
@validate_arguments
def create_or_update_active_learning_score_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, **kwargs) -> CreateEntityResponse: # noqa: E501
- """create_or_update_active_learning_score_by_tag_id # noqa: E501
+ """(Deprecated) create_or_update_active_learning_score_by_tag_id # noqa: E501
Create or update active learning score object by tag id # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -84,7 +86,7 @@ def create_or_update_active_learning_score_by_tag_id(self, dataset_id : Annotate
@validate_arguments
def create_or_update_active_learning_score_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, **kwargs) -> ApiResponse: # noqa: E501
- """create_or_update_active_learning_score_by_tag_id # noqa: E501
+ """(Deprecated) create_or_update_active_learning_score_by_tag_id # noqa: E501
Create or update active learning score object by tag id # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -124,6 +126,8 @@ def create_or_update_active_learning_score_by_tag_id_with_http_info(self, datase
:rtype: tuple(CreateEntityResponse, status_code(int), headers(HTTPHeaderDict))
"""
+ warnings.warn("POST /v1/datasets/{datasetId}/tags/{tagId}/scores is deprecated.", DeprecationWarning)
+
_params = locals()
_all_params = [
@@ -215,9 +219,189 @@ def create_or_update_active_learning_score_by_tag_id_with_http_info(self, datase
collection_formats=_collection_formats,
_request_auth=_params.get('_request_auth'))
+ @validate_arguments
+ def create_or_update_active_learning_v2_score_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501
+ """create_or_update_active_learning_v2_score_by_dataset_id # noqa: E501
+
+ Create or update active learning score object for a dataset, taskName, predictionUUIDTimestamp # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.create_or_update_active_learning_v2_score_by_dataset_id(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param task_name: The prediction task name for which one wants to list the predictions (required)
+ :type task_name: str
+ :param active_learning_score_create_request: (required)
+ :type active_learning_score_create_request: ActiveLearningScoreCreateRequest
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: CreateEntityResponse
+ """
+ kwargs['_return_http_data_only'] = True
+ if '_preload_content' in kwargs:
+ raise ValueError("Error! Please call the create_or_update_active_learning_v2_score_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, **kwargs) # noqa: E501
+
+ @validate_arguments
+ def create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """create_or_update_active_learning_v2_score_by_dataset_id # noqa: E501
+
+ Create or update active learning score object for a dataset, taskName, predictionUUIDTimestamp # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param task_name: The prediction task name for which one wants to list the predictions (required)
+ :type task_name: str
+ :param active_learning_score_create_request: (required)
+ :type active_learning_score_create_request: ActiveLearningScoreCreateRequest
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _preload_content: if False, the ApiResponse.data will
+ be set to none and raw_data will store the
+ HTTP response body without reading/decoding.
+ Default is True.
+ :type _preload_content: bool, optional
+ :param _return_http_data_only: response data instead of ApiResponse
+ object with status code, headers, etc
+ :type _return_http_data_only: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :param _request_auth: set to override the auth_settings for an a single
+ request; this effectively ignores the authentication
+ in the spec for a single request.
+ :type _request_auth: dict, optional
+ :type _content_type: string, optional: force content-type for the request
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: tuple(CreateEntityResponse, status_code(int), headers(HTTPHeaderDict))
+ """
+
+ _params = locals()
+
+ _all_params = [
+ 'dataset_id',
+ 'task_name',
+ 'active_learning_score_create_request',
+ 'prediction_uuid_timestamp'
+ ]
+ _all_params.extend(
+ [
+ 'async_req',
+ '_return_http_data_only',
+ '_preload_content',
+ '_request_timeout',
+ '_request_auth',
+ '_content_type',
+ '_headers'
+ ]
+ )
+
+ # validate the arguments
+ for _key, _val in _params['kwargs'].items():
+ if _key not in _all_params:
+ raise ApiTypeError(
+ "Got an unexpected keyword argument '%s'"
+ " to method create_or_update_active_learning_v2_score_by_dataset_id" % _key
+ )
+ _params[_key] = _val
+ del _params['kwargs']
+
+ _collection_formats = {}
+
+ # process the path parameters
+ _path_params = {}
+ if _params['dataset_id']:
+ _path_params['datasetId'] = _params['dataset_id']
+
+
+ # process the query parameters
+ _query_params = []
+ if _params.get('task_name') is not None: # noqa: E501
+ _query_params.append((
+ 'taskName',
+ _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name']
+ ))
+
+ if _params.get('prediction_uuid_timestamp') is not None: # noqa: E501
+ _query_params.append((
+ 'predictionUUIDTimestamp',
+ _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp']
+ ))
+
+ # process the header parameters
+ _header_params = dict(_params.get('_headers', {}))
+ # process the form parameters
+ _form_params = []
+ _files = {}
+ # process the body parameter
+ _body_params = None
+ if _params['active_learning_score_create_request'] is not None:
+ _body_params = _params['active_learning_score_create_request']
+
+ # set the HTTP header `Accept`
+ _header_params['Accept'] = self.api_client.select_header_accept(
+ ['application/json']) # noqa: E501
+
+ # set the HTTP header `Content-Type`
+ _content_types_list = _params.get('_content_type',
+ self.api_client.select_header_content_type(
+ ['application/json']))
+ if _content_types_list:
+ _header_params['Content-Type'] = _content_types_list
+
+ # authentication setting
+ _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
+
+ _response_types_map = {
+ '201': "CreateEntityResponse",
+ '400': "ApiErrorResponse",
+ '401': "ApiErrorResponse",
+ '403': "ApiErrorResponse",
+ '404': "ApiErrorResponse",
+ }
+
+ return self.api_client.call_api(
+ '/v1/datasets/{datasetId}/predictions/scores', 'POST',
+ _path_params,
+ _query_params,
+ _header_params,
+ body=_body_params,
+ post_params=_form_params,
+ files=_files,
+ response_types_map=_response_types_map,
+ auth_settings=_auth_settings,
+ async_req=_params.get('async_req'),
+ _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501
+ _preload_content=_params.get('_preload_content', True),
+ _request_timeout=_params.get('_request_timeout'),
+ collection_formats=_collection_formats,
+ _request_auth=_params.get('_request_auth'))
+
@validate_arguments
def get_active_learning_score_by_score_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ActiveLearningScoreData: # noqa: E501
- """get_active_learning_score_by_score_id # noqa: E501
+ """(Deprecated) get_active_learning_score_by_score_id # noqa: E501
Get active learning score object by id # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -250,7 +434,7 @@ def get_active_learning_score_by_score_id(self, dataset_id : Annotated[constr(st
@validate_arguments
def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ApiResponse: # noqa: E501
- """get_active_learning_score_by_score_id # noqa: E501
+ """(Deprecated) get_active_learning_score_by_score_id # noqa: E501
Get active learning score object by id # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -290,6 +474,8 @@ def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Anno
:rtype: tuple(ActiveLearningScoreData, status_code(int), headers(HTTPHeaderDict))
"""
+ warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/scores/{scoreId} is deprecated.", DeprecationWarning)
+
_params = locals()
_all_params = [
@@ -376,7 +562,7 @@ def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Anno
@validate_arguments
def get_active_learning_scores_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> List[TagActiveLearningScoresData]: # noqa: E501
- """get_active_learning_scores_by_tag_id # noqa: E501
+ """(Deprecated) get_active_learning_scores_by_tag_id # noqa: E501
Get all scoreIds for the given tag # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -407,7 +593,7 @@ def get_active_learning_scores_by_tag_id(self, dataset_id : Annotated[constr(str
@validate_arguments
def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501
- """get_active_learning_scores_by_tag_id # noqa: E501
+ """(Deprecated) get_active_learning_scores_by_tag_id # noqa: E501
Get all scoreIds for the given tag # noqa: E501
This method makes a synchronous HTTP request by default. To make an
@@ -445,6 +631,8 @@ def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annot
:rtype: tuple(List[TagActiveLearningScoresData], status_code(int), headers(HTTPHeaderDict))
"""
+ warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/scores is deprecated.", DeprecationWarning)
+
_params = locals()
_all_params = [
@@ -524,3 +712,308 @@ def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annot
_request_timeout=_params.get('_request_timeout'),
collection_formats=_collection_formats,
_request_auth=_params.get('_request_auth'))
+
+ @validate_arguments
+ def get_active_learning_v2_score_by_dataset_and_score_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ActiveLearningScoreV2Data: # noqa: E501
+ """get_active_learning_v2_score_by_dataset_and_score_id # noqa: E501
+
+ Get active learning scores by scoreId # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_active_learning_v2_score_by_dataset_and_score_id(dataset_id, score_id, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param score_id: ObjectId of the scores (required)
+ :type score_id: str
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: ActiveLearningScoreV2Data
+ """
+ kwargs['_return_http_data_only'] = True
+ if '_preload_content' in kwargs:
+ raise ValueError("Error! Please call the get_active_learning_v2_score_by_dataset_and_score_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(dataset_id, score_id, **kwargs) # noqa: E501
+
+ @validate_arguments
+ def get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ApiResponse: # noqa: E501
+ """get_active_learning_v2_score_by_dataset_and_score_id # noqa: E501
+
+ Get active learning scores by scoreId # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(dataset_id, score_id, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param score_id: ObjectId of the scores (required)
+ :type score_id: str
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _preload_content: if False, the ApiResponse.data will
+ be set to none and raw_data will store the
+ HTTP response body without reading/decoding.
+ Default is True.
+ :type _preload_content: bool, optional
+ :param _return_http_data_only: response data instead of ApiResponse
+ object with status code, headers, etc
+ :type _return_http_data_only: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :param _request_auth: set to override the auth_settings for an a single
+ request; this effectively ignores the authentication
+ in the spec for a single request.
+ :type _request_auth: dict, optional
+ :type _content_type: string, optional: force content-type for the request
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: tuple(ActiveLearningScoreV2Data, status_code(int), headers(HTTPHeaderDict))
+ """
+
+ _params = locals()
+
+ _all_params = [
+ 'dataset_id',
+ 'score_id'
+ ]
+ _all_params.extend(
+ [
+ 'async_req',
+ '_return_http_data_only',
+ '_preload_content',
+ '_request_timeout',
+ '_request_auth',
+ '_content_type',
+ '_headers'
+ ]
+ )
+
+ # validate the arguments
+ for _key, _val in _params['kwargs'].items():
+ if _key not in _all_params:
+ raise ApiTypeError(
+ "Got an unexpected keyword argument '%s'"
+ " to method get_active_learning_v2_score_by_dataset_and_score_id" % _key
+ )
+ _params[_key] = _val
+ del _params['kwargs']
+
+ _collection_formats = {}
+
+ # process the path parameters
+ _path_params = {}
+ if _params['dataset_id']:
+ _path_params['datasetId'] = _params['dataset_id']
+
+ if _params['score_id']:
+ _path_params['scoreId'] = _params['score_id']
+
+
+ # process the query parameters
+ _query_params = []
+ # process the header parameters
+ _header_params = dict(_params.get('_headers', {}))
+ # process the form parameters
+ _form_params = []
+ _files = {}
+ # process the body parameter
+ _body_params = None
+ # set the HTTP header `Accept`
+ _header_params['Accept'] = self.api_client.select_header_accept(
+ ['application/json']) # noqa: E501
+
+ # authentication setting
+ _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
+
+ _response_types_map = {
+ '200': "ActiveLearningScoreV2Data",
+ '400': "ApiErrorResponse",
+ '401': "ApiErrorResponse",
+ '403': "ApiErrorResponse",
+ '404': "ApiErrorResponse",
+ }
+
+ return self.api_client.call_api(
+ '/v1/datasets/{datasetId}/predictions/scores/{scoreId}', 'GET',
+ _path_params,
+ _query_params,
+ _header_params,
+ body=_body_params,
+ post_params=_form_params,
+ files=_files,
+ response_types_map=_response_types_map,
+ auth_settings=_auth_settings,
+ async_req=_params.get('async_req'),
+ _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501
+ _preload_content=_params.get('_preload_content', True),
+ _request_timeout=_params.get('_request_timeout'),
+ collection_formats=_collection_formats,
+ _request_auth=_params.get('_request_auth'))
+
+ @validate_arguments
+ def get_active_learning_v2_scores_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> List[ActiveLearningScoreTypesV2Data]: # noqa: E501
+ """get_active_learning_v2_scores_by_dataset_id # noqa: E501
+
+ Get all AL score types by datasetId and predictionUUIDTimestamp # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_active_learning_v2_scores_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: List[ActiveLearningScoreTypesV2Data]
+ """
+ kwargs['_return_http_data_only'] = True
+ if '_preload_content' in kwargs:
+ raise ValueError("Error! Please call the get_active_learning_v2_scores_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data")
+ return self.get_active_learning_v2_scores_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501
+
+ @validate_arguments
+ def get_active_learning_v2_scores_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501
+ """get_active_learning_v2_scores_by_dataset_id # noqa: E501
+
+ Get all AL score types by datasetId and predictionUUIDTimestamp # noqa: E501
+ This method makes a synchronous HTTP request by default. To make an
+ asynchronous HTTP request, please pass async_req=True
+
+ >>> thread = api.get_active_learning_v2_scores_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True)
+ >>> result = thread.get()
+
+ :param dataset_id: ObjectId of the dataset (required)
+ :type dataset_id: str
+ :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset.
+ :type prediction_uuid_timestamp: int
+ :param async_req: Whether to execute the request asynchronously.
+ :type async_req: bool, optional
+ :param _preload_content: if False, the ApiResponse.data will
+ be set to none and raw_data will store the
+ HTTP response body without reading/decoding.
+ Default is True.
+ :type _preload_content: bool, optional
+ :param _return_http_data_only: response data instead of ApiResponse
+ object with status code, headers, etc
+ :type _return_http_data_only: bool, optional
+ :param _request_timeout: timeout setting for this request. If one
+ number provided, it will be total request
+ timeout. It can also be a pair (tuple) of
+ (connection, read) timeouts.
+ :param _request_auth: set to override the auth_settings for an a single
+ request; this effectively ignores the authentication
+ in the spec for a single request.
+ :type _request_auth: dict, optional
+ :type _content_type: string, optional: force content-type for the request
+ :return: Returns the result object.
+ If the method is called asynchronously,
+ returns the request thread.
+ :rtype: tuple(List[ActiveLearningScoreTypesV2Data], status_code(int), headers(HTTPHeaderDict))
+ """
+
+ _params = locals()
+
+ _all_params = [
+ 'dataset_id',
+ 'prediction_uuid_timestamp'
+ ]
+ _all_params.extend(
+ [
+ 'async_req',
+ '_return_http_data_only',
+ '_preload_content',
+ '_request_timeout',
+ '_request_auth',
+ '_content_type',
+ '_headers'
+ ]
+ )
+
+ # validate the arguments
+ for _key, _val in _params['kwargs'].items():
+ if _key not in _all_params:
+ raise ApiTypeError(
+ "Got an unexpected keyword argument '%s'"
+ " to method get_active_learning_v2_scores_by_dataset_id" % _key
+ )
+ _params[_key] = _val
+ del _params['kwargs']
+
+ _collection_formats = {}
+
+ # process the path parameters
+ _path_params = {}
+ if _params['dataset_id']:
+ _path_params['datasetId'] = _params['dataset_id']
+
+
+ # process the query parameters
+ _query_params = []
+ if _params.get('prediction_uuid_timestamp') is not None: # noqa: E501
+ _query_params.append((
+ 'predictionUUIDTimestamp',
+ _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp']
+ ))
+
+ # process the header parameters
+ _header_params = dict(_params.get('_headers', {}))
+ # process the form parameters
+ _form_params = []
+ _files = {}
+ # process the body parameter
+ _body_params = None
+ # set the HTTP header `Accept`
+ _header_params['Accept'] = self.api_client.select_header_accept(
+ ['application/json']) # noqa: E501
+
+ # authentication setting
+ _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
+
+ _response_types_map = {
+ '200': "List[ActiveLearningScoreTypesV2Data]",
+ '400': "ApiErrorResponse",
+ '401': "ApiErrorResponse",
+ '403': "ApiErrorResponse",
+ '404': "ApiErrorResponse",
+ }
+
+ return self.api_client.call_api(
+ '/v1/datasets/{datasetId}/predictions/scores', 'GET',
+ _path_params,
+ _query_params,
+ _header_params,
+ body=_body_params,
+ post_params=_form_params,
+ files=_files,
+ response_types_map=_response_types_map,
+ auth_settings=_auth_settings,
+ async_req=_params.get('async_req'),
+ _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501
+ _preload_content=_params.get('_preload_content', True),
+ _request_timeout=_params.get('_request_timeout'),
+ collection_formats=_collection_formats,
+ _request_auth=_params.get('_request_auth'))
diff --git a/lightly/openapi_generated/swagger_client/api/tags_api.py b/lightly/openapi_generated/swagger_client/api/tags_api.py
index 4affcbccd..428c23c98 100644
--- a/lightly/openapi_generated/swagger_client/api/tags_api.py
+++ b/lightly/openapi_generated/swagger_client/api/tags_api.py
@@ -524,9 +524,9 @@ def delete_tag_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(stri
@validate_arguments
def download_zip_of_samples_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> bytearray: # noqa: E501
- """download_zip_of_samples_by_tag_id # noqa: E501
+ """(Deprecated) download_zip_of_samples_by_tag_id # noqa: E501
- Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501
+ DEPRECATED - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -555,9 +555,9 @@ def download_zip_of_samples_by_tag_id(self, dataset_id : Annotated[constr(strict
@validate_arguments
def download_zip_of_samples_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501
- """download_zip_of_samples_by_tag_id # noqa: E501
+ """(Deprecated) download_zip_of_samples_by_tag_id # noqa: E501
- Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501
+ DEPRECATED - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -593,6 +593,8 @@ def download_zip_of_samples_by_tag_id_with_http_info(self, dataset_id : Annotate
:rtype: tuple(bytearray, status_code(int), headers(HTTPHeaderDict))
"""
+ warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/export/zip is deprecated.", DeprecationWarning)
+
_params = locals()
_all_params = [
@@ -1112,7 +1114,7 @@ def export_tag_to_basic_filenames_and_read_urls_with_http_info(self, dataset_id
def export_tag_to_label_box_data_rows(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], expires_in : Annotated[Optional[StrictInt], Field(description="If defined, the URLs provided will only be valid for amount of seconds from time of issuence. If not defined, the URls will be valid indefinitely. ")] = None, access_control : Annotated[Optional[StrictStr], Field(description="which access control name to be used")] = None, file_name_format : Optional[FileNameFormat] = None, include_meta_data : Annotated[Optional[StrictBool], Field(description="if true, will also include metadata")] = None, format : Optional[FileOutputFormat] = None, preview_example : Annotated[Optional[StrictBool], Field(description="if true, will generate a preview example of how the structure will look")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[LabelBoxDataRow]: # noqa: E501
"""(Deprecated) export_tag_to_label_box_data_rows # noqa: E501
- Deprecated. Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501
+ DEPRECATED - Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -1159,7 +1161,7 @@ def export_tag_to_label_box_data_rows(self, dataset_id : Annotated[constr(strict
def export_tag_to_label_box_data_rows_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], expires_in : Annotated[Optional[StrictInt], Field(description="If defined, the URLs provided will only be valid for amount of seconds from time of issuence. If not defined, the URls will be valid indefinitely. ")] = None, access_control : Annotated[Optional[StrictStr], Field(description="which access control name to be used")] = None, file_name_format : Optional[FileNameFormat] = None, include_meta_data : Annotated[Optional[StrictBool], Field(description="if true, will also include metadata")] = None, format : Optional[FileOutputFormat] = None, preview_example : Annotated[Optional[StrictBool], Field(description="if true, will generate a preview example of how the structure will look")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501
"""(Deprecated) export_tag_to_label_box_data_rows # noqa: E501
- Deprecated. Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501
+ DEPRECATED - Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2070,7 +2072,7 @@ def export_tag_to_sama_tasks_with_http_info(self, dataset_id : Annotated[constr(
def get_filenames_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> List[str]: # noqa: E501
"""(Deprecated) get_filenames_by_tag_id # noqa: E501
- Get list of filenames by tag. Deprecated, please use # noqa: E501
+ DEPRECATED, please use exportTagToBasicFilenames - Get list of filenames by tag. # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2101,7 +2103,7 @@ def get_filenames_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Fi
def get_filenames_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501
"""(Deprecated) get_filenames_by_tag_id # noqa: E501
- Get list of filenames by tag. Deprecated, please use # noqa: E501
+ DEPRECATED, please use exportTagToBasicFilenames - Get list of filenames by tag. # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2675,7 +2677,7 @@ def perform_tag_arithmetics_with_http_info(self, dataset_id : Annotated[constr(s
def perform_tag_arithmetics_bitmask(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_arithmetics_request : TagArithmeticsRequest, **kwargs) -> TagBitMaskResponse: # noqa: E501
"""(Deprecated) perform_tag_arithmetics_bitmask # noqa: E501
- Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501
+ DEPRECATED, use performTagArithmetics - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
@@ -2706,7 +2708,7 @@ def perform_tag_arithmetics_bitmask(self, dataset_id : Annotated[constr(strict=T
def perform_tag_arithmetics_bitmask_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_arithmetics_request : TagArithmeticsRequest, **kwargs) -> ApiResponse: # noqa: E501
"""(Deprecated) perform_tag_arithmetics_bitmask # noqa: E501
- Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501
+ DEPRECATED, use performTagArithmetics - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
diff --git a/lightly/openapi_generated/swagger_client/api/teams_api.py b/lightly/openapi_generated/swagger_client/api/teams_api.py
index 811099f83..7eed9b158 100644
--- a/lightly/openapi_generated/swagger_client/api/teams_api.py
+++ b/lightly/openapi_generated/swagger_client/api/teams_api.py
@@ -51,7 +51,7 @@ def __init__(self, api_client=None):
self.api_client = api_client
@validate_arguments
- def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., description="id of the team")], create_team_membership_request : CreateTeamMembershipRequest, **kwargs) -> None: # noqa: E501
+ def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., description="id of the team")], create_team_membership_request : CreateTeamMembershipRequest, **kwargs) -> str: # noqa: E501
"""add_team_member # noqa: E501
Add a team member. One needs to be part of the team to do so. # noqa: E501
@@ -74,7 +74,7 @@ def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., de
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: None
+ :rtype: str
"""
kwargs['_return_http_data_only'] = True
if '_preload_content' in kwargs:
@@ -118,7 +118,7 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True)
:return: Returns the result object.
If the method is called asynchronously,
returns the request thread.
- :rtype: None
+ :rtype: tuple(str, status_code(int), headers(HTTPHeaderDict))
"""
_params = locals()
@@ -171,7 +171,7 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True)
# set the HTTP header `Accept`
_header_params['Accept'] = self.api_client.select_header_accept(
- ['application/json']) # noqa: E501
+ ['text/plain', 'application/json']) # noqa: E501
# set the HTTP header `Content-Type`
_content_types_list = _params.get('_content_type',
@@ -183,7 +183,13 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True)
# authentication setting
_auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501
- _response_types_map = {}
+ _response_types_map = {
+ '200': "str",
+ '400': "ApiErrorResponse",
+ '401': "ApiErrorResponse",
+ '403': "ApiErrorResponse",
+ '404': "ApiErrorResponse",
+ }
return self.api_client.call_api(
'/v1/teams/{teamId}/members', 'POST',
diff --git a/lightly/openapi_generated/swagger_client/models/__init__.py b/lightly/openapi_generated/swagger_client/models/__init__.py
index 1c47427c7..fd20c0edc 100644
--- a/lightly/openapi_generated/swagger_client/models/__init__.py
+++ b/lightly/openapi_generated/swagger_client/models/__init__.py
@@ -17,6 +17,8 @@
# import models into model package
from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest
from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData
+from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data
+from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data
from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode
from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse
from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData
@@ -42,10 +44,12 @@
from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure
from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
+from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath
from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS
from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY
from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL
+from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS
from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf
from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3
@@ -180,13 +184,23 @@
from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod
from lightly.openapi_generated.swagger_client.models.sector import Sector
from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig
+from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase
from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry
from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput
from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange
from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName
from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType
from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation
from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType
+from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3
from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData
from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest
from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest
diff --git a/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py b/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py
new file mode 100644
index 000000000..5455a44d3
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py
@@ -0,0 +1,116 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+
+from pydantic import Extra, BaseModel, Field, conint, constr, validator
+
+class ActiveLearningScoreTypesV2Data(BaseModel):
+ """
+ ActiveLearningScoreTypesV2Data
+ """
+ id: constr(strict=True) = Field(..., description="MongoDB ObjectId")
+ dataset_id: constr(strict=True) = Field(..., alias="datasetId", description="MongoDB ObjectId")
+ prediction_uuid_timestamp: conint(strict=True, ge=0) = Field(..., alias="predictionUUIDTimestamp", description="unix timestamp in milliseconds")
+ task_name: constr(strict=True, min_length=1) = Field(..., alias="taskName", description="A name which is safe to have as a file/folder name in a file system")
+ score_type: constr(strict=True, min_length=1) = Field(..., alias="scoreType", description="Type of active learning score")
+ created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds")
+ __properties = ["id", "datasetId", "predictionUUIDTimestamp", "taskName", "scoreType", "createdAt"]
+
+ @validator('id')
+ def id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
+
+ @validator('dataset_id')
+ def dataset_id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
+
+ @validator('task_name')
+ def task_name_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/")
+ return value
+
+ @validator('score_type')
+ def score_type_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/")
+ return value
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> ActiveLearningScoreTypesV2Data:
+ """Create an instance of ActiveLearningScoreTypesV2Data from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> ActiveLearningScoreTypesV2Data:
+ """Create an instance of ActiveLearningScoreTypesV2Data from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return ActiveLearningScoreTypesV2Data.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in ActiveLearningScoreTypesV2Data) in the input: " + str(obj))
+
+ _obj = ActiveLearningScoreTypesV2Data.parse_obj({
+ "id": obj.get("id"),
+ "dataset_id": obj.get("datasetId"),
+ "prediction_uuid_timestamp": obj.get("predictionUUIDTimestamp"),
+ "task_name": obj.get("taskName"),
+ "score_type": obj.get("scoreType"),
+ "created_at": obj.get("createdAt")
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py b/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py
new file mode 100644
index 000000000..1970b7c58
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py
@@ -0,0 +1,118 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import List, Union
+from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, conint, conlist, constr, validator
+
+class ActiveLearningScoreV2Data(BaseModel):
+ """
+ ActiveLearningScoreV2Data
+ """
+ id: constr(strict=True) = Field(..., description="MongoDB ObjectId")
+ dataset_id: constr(strict=True) = Field(..., alias="datasetId", description="MongoDB ObjectId")
+ prediction_uuid_timestamp: conint(strict=True, ge=0) = Field(..., alias="predictionUUIDTimestamp", description="unix timestamp in milliseconds")
+ task_name: constr(strict=True, min_length=1) = Field(..., alias="taskName", description="A name which is safe to have as a file/folder name in a file system")
+ score_type: constr(strict=True, min_length=1) = Field(..., alias="scoreType", description="Type of active learning score")
+ scores: conlist(Union[StrictFloat, StrictInt], min_items=1) = Field(..., description="Array of active learning scores")
+ created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds")
+ __properties = ["id", "datasetId", "predictionUUIDTimestamp", "taskName", "scoreType", "scores", "createdAt"]
+
+ @validator('id')
+ def id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
+
+ @validator('dataset_id')
+ def dataset_id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
+
+ @validator('task_name')
+ def task_name_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/")
+ return value
+
+ @validator('score_type')
+ def score_type_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/")
+ return value
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> ActiveLearningScoreV2Data:
+ """Create an instance of ActiveLearningScoreV2Data from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> ActiveLearningScoreV2Data:
+ """Create an instance of ActiveLearningScoreV2Data from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return ActiveLearningScoreV2Data.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in ActiveLearningScoreV2Data) in the input: " + str(obj))
+
+ _obj = ActiveLearningScoreV2Data.parse_obj({
+ "id": obj.get("id"),
+ "dataset_id": obj.get("datasetId"),
+ "prediction_uuid_timestamp": obj.get("predictionUUIDTimestamp"),
+ "task_name": obj.get("taskName"),
+ "score_type": obj.get("scoreType"),
+ "scores": obj.get("scores"),
+ "created_at": obj.get("createdAt")
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/annotation_data.py b/lightly/openapi_generated/swagger_client/models/annotation_data.py
deleted file mode 100644
index 673e75aec..000000000
--- a/lightly/openapi_generated/swagger_client/models/annotation_data.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# coding: utf-8
-
-"""
- Lightly API
-
- Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
-
- The version of the OpenAPI document: 1.0.0
- Contact: support@lightly.ai
- Generated by OpenAPI Generator (https://openapi-generator.tech)
-
- Do not edit the class manually.
-"""
-
-
-from __future__ import annotations
-import pprint
-import re # noqa: F401
-import json
-
-
-from typing import Optional
-from pydantic import Extra, BaseModel, Field, StrictStr, conint
-from lightly.openapi_generated.swagger_client.models.annotation_meta_data import AnnotationMetaData
-from lightly.openapi_generated.swagger_client.models.annotation_offer_data import AnnotationOfferData
-from lightly.openapi_generated.swagger_client.models.annotation_state import AnnotationState
-
-class AnnotationData(BaseModel):
- """
- AnnotationData
- """
- id: StrictStr = Field(..., alias="_id")
- state: AnnotationState = Field(...)
- dataset_id: StrictStr = Field(..., alias="datasetId")
- tag_id: StrictStr = Field(..., alias="tagId")
- partner_id: Optional[StrictStr] = Field(None, alias="partnerId")
- created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds")
- last_modified_at: conint(strict=True, ge=0) = Field(..., alias="lastModifiedAt", description="unix timestamp in milliseconds")
- meta: AnnotationMetaData = Field(...)
- offer: Optional[AnnotationOfferData] = None
- __properties = ["_id", "state", "datasetId", "tagId", "partnerId", "createdAt", "lastModifiedAt", "meta", "offer"]
-
- class Config:
- """Pydantic configuration"""
- allow_population_by_field_name = True
- validate_assignment = True
- use_enum_values = True
- extra = Extra.forbid
-
- def to_str(self, by_alias: bool = False) -> str:
- """Returns the string representation of the model"""
- return pprint.pformat(self.dict(by_alias=by_alias))
-
- def to_json(self, by_alias: bool = False) -> str:
- """Returns the JSON representation of the model"""
- return json.dumps(self.to_dict(by_alias=by_alias))
-
- @classmethod
- def from_json(cls, json_str: str) -> AnnotationData:
- """Create an instance of AnnotationData from a JSON string"""
- return cls.from_dict(json.loads(json_str))
-
- def to_dict(self, by_alias: bool = False):
- """Returns the dictionary representation of the model"""
- _dict = self.dict(by_alias=by_alias,
- exclude={
- },
- exclude_none=True)
- # override the default output from pydantic by calling `to_dict()` of meta
- if self.meta:
- _dict['meta' if by_alias else 'meta'] = self.meta.to_dict(by_alias=by_alias)
- # override the default output from pydantic by calling `to_dict()` of offer
- if self.offer:
- _dict['offer' if by_alias else 'offer'] = self.offer.to_dict(by_alias=by_alias)
- return _dict
-
- @classmethod
- def from_dict(cls, obj: dict) -> AnnotationData:
- """Create an instance of AnnotationData from a dict"""
- if obj is None:
- return None
-
- if not isinstance(obj, dict):
- return AnnotationData.parse_obj(obj)
-
- # raise errors for additional fields in the input
- for _key in obj.keys():
- if _key not in cls.__properties:
- raise ValueError("Error due to additional fields (not defined in AnnotationData) in the input: " + str(obj))
-
- _obj = AnnotationData.parse_obj({
- "id": obj.get("_id"),
- "state": obj.get("state"),
- "dataset_id": obj.get("datasetId"),
- "tag_id": obj.get("tagId"),
- "partner_id": obj.get("partnerId"),
- "created_at": obj.get("createdAt"),
- "last_modified_at": obj.get("lastModifiedAt"),
- "meta": AnnotationMetaData.from_dict(obj.get("meta")) if obj.get("meta") is not None else None,
- "offer": AnnotationOfferData.from_dict(obj.get("offer")) if obj.get("offer") is not None else None
- })
- return _obj
-
diff --git a/lightly/openapi_generated/swagger_client/models/api_error_code.py b/lightly/openapi_generated/swagger_client/models/api_error_code.py
index 4876a203d..6e2b5dd56 100644
--- a/lightly/openapi_generated/swagger_client/models/api_error_code.py
+++ b/lightly/openapi_generated/swagger_client/models/api_error_code.py
@@ -37,6 +37,7 @@ class ApiErrorCode(str, Enum):
UNAUTHORIZED = 'UNAUTHORIZED'
NOT_FOUND = 'NOT_FOUND'
NOT_MODIFIED = 'NOT_MODIFIED'
+ CONFLICT = 'CONFLICT'
MALFORMED_REQUEST = 'MALFORMED_REQUEST'
MALFORMED_RESPONSE = 'MALFORMED_RESPONSE'
PAYLOAD_TOO_LARGE = 'PAYLOAD_TOO_LARGE'
@@ -104,6 +105,7 @@ class ApiErrorCode(str, Enum):
DOCKER_WORKER_SCHEDULE_UPDATE_FAILED = 'DOCKER_WORKER_SCHEDULE_UPDATE_FAILED'
METADATA_CONFIGURATION_UNKNOWN = 'METADATA_CONFIGURATION_UNKNOWN'
CUSTOM_METADATA_AT_MAX_SIZE = 'CUSTOM_METADATA_AT_MAX_SIZE'
+ ONPREM_SUBSCRIPTION_INSUFFICIENT = 'ONPREM_SUBSCRIPTION_INSUFFICIENT'
ACCOUNT_SUBSCRIPTION_INSUFFICIENT = 'ACCOUNT_SUBSCRIPTION_INSUFFICIENT'
TEAM_UNKNOWN = 'TEAM_UNKNOWN'
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py b/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py
index 9f9bf4dec..6214b0725 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py
@@ -20,16 +20,17 @@
-from pydantic import Extra, BaseModel, Field, constr
+from pydantic import Extra, BaseModel, Field, StrictStr, constr
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
class DatasourceConfigAzure(DatasourceConfigBase):
"""
DatasourceConfigAzure
"""
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
account_name: constr(strict=True, min_length=1) = Field(..., alias="accountName", description="name of the Azure Storage Account")
account_key: constr(strict=True, min_length=1) = Field(..., alias="accountKey", description="key of the Azure Storage Account")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "accountName", "accountKey"]
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "accountName", "accountKey"]
class Config:
"""Pydantic configuration"""
@@ -77,8 +78,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigAzure:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
"thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath"),
"account_name": obj.get("accountName"),
"account_key": obj.get("accountKey")
})
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_base.py b/lightly/openapi_generated/swagger_client/models/datasource_config_base.py
index d3a5f3983..41c71cce3 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_base.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_base.py
@@ -31,9 +31,8 @@ class DatasourceConfigBase(BaseModel):
id: Optional[constr(strict=True)] = Field(None, description="MongoDB ObjectId")
purpose: DatasourcePurpose = Field(...)
type: StrictStr = Field(...)
- full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
thumb_suffix: Optional[StrictStr] = Field(None, alias="thumbSuffix", description="the suffix of where to find the thumbnail image. If none is provided, the full image will be loaded where thumbnails would be loaded otherwise. - [filename]: represents the filename without the extension - [extension]: represents the files extension (e.g jpg, png, webp) ")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"]
+ __properties = ["id", "purpose", "type", "thumbSuffix"]
@validator('id')
def id_validate_regular_expression(cls, value):
diff --git a/lightly/openapi_generated/swagger_client/models/annotation_meta_data.py b/lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py
similarity index 67%
rename from lightly/openapi_generated/swagger_client/models/annotation_meta_data.py
rename to lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py
index eb0070b1e..eace63453 100644
--- a/lightly/openapi_generated/swagger_client/models/annotation_meta_data.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py
@@ -19,15 +19,15 @@
import json
-from typing import Optional
-from pydantic import Extra, BaseModel, StrictStr
-class AnnotationMetaData(BaseModel):
+from pydantic import Extra, BaseModel, Field, StrictStr
+
+class DatasourceConfigBaseFullPath(BaseModel):
"""
- AnnotationMetaData
+ DatasourceConfigBaseFullPath
"""
- description: Optional[StrictStr] = None
- __properties = ["description"]
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
+ __properties = ["fullPath"]
class Config:
"""Pydantic configuration"""
@@ -45,8 +45,8 @@ def to_json(self, by_alias: bool = False) -> str:
return json.dumps(self.to_dict(by_alias=by_alias))
@classmethod
- def from_json(cls, json_str: str) -> AnnotationMetaData:
- """Create an instance of AnnotationMetaData from a JSON string"""
+ def from_json(cls, json_str: str) -> DatasourceConfigBaseFullPath:
+ """Create an instance of DatasourceConfigBaseFullPath from a JSON string"""
return cls.from_dict(json.loads(json_str))
def to_dict(self, by_alias: bool = False):
@@ -58,21 +58,21 @@ def to_dict(self, by_alias: bool = False):
return _dict
@classmethod
- def from_dict(cls, obj: dict) -> AnnotationMetaData:
- """Create an instance of AnnotationMetaData from a dict"""
+ def from_dict(cls, obj: dict) -> DatasourceConfigBaseFullPath:
+ """Create an instance of DatasourceConfigBaseFullPath from a dict"""
if obj is None:
return None
if not isinstance(obj, dict):
- return AnnotationMetaData.parse_obj(obj)
+ return DatasourceConfigBaseFullPath.parse_obj(obj)
# raise errors for additional fields in the input
for _key in obj.keys():
if _key not in cls.__properties:
- raise ValueError("Error due to additional fields (not defined in AnnotationMetaData) in the input: " + str(obj))
+ raise ValueError("Error due to additional fields (not defined in DatasourceConfigBaseFullPath) in the input: " + str(obj))
- _obj = AnnotationMetaData.parse_obj({
- "description": obj.get("description")
+ _obj = DatasourceConfigBaseFullPath.parse_obj({
+ "full_path": obj.get("fullPath")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py b/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py
index 027ea1fc5..d9b150eb6 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py
@@ -27,9 +27,10 @@ class DatasourceConfigGCS(DatasourceConfigBase):
"""
DatasourceConfigGCS
"""
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
gcs_project_id: constr(strict=True, min_length=1) = Field(..., alias="gcsProjectId", description="The projectId where you have your bucket configured")
gcs_credentials: StrictStr = Field(..., alias="gcsCredentials", description="this is the content of the credentials JSON file stringified which you downloaded from Google Cloud Platform")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "gcsProjectId", "gcsCredentials"]
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "gcsProjectId", "gcsCredentials"]
class Config:
"""Pydantic configuration"""
@@ -77,8 +78,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigGCS:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
"thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath"),
"gcs_project_id": obj.get("gcsProjectId"),
"gcs_credentials": obj.get("gcsCredentials")
})
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py b/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py
index 7355ed5d0..2617eae53 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py
@@ -20,14 +20,15 @@
-from pydantic import Extra, BaseModel
+from pydantic import Extra, BaseModel, Field, StrictStr
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
class DatasourceConfigLIGHTLY(DatasourceConfigBase):
"""
DatasourceConfigLIGHTLY
"""
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"]
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath"]
class Config:
"""Pydantic configuration"""
@@ -75,8 +76,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigLIGHTLY:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
- "thumb_suffix": obj.get("thumbSuffix")
+ "thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_local.py b/lightly/openapi_generated/swagger_client/models/datasource_config_local.py
index 6a9170d63..e2f819723 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_local.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_local.py
@@ -19,15 +19,27 @@
import json
-
-from pydantic import Extra, BaseModel
+from typing import Optional
+from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
class DatasourceConfigLOCAL(DatasourceConfigBase):
"""
DatasourceConfigLOCAL
"""
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"]
+ full_path: StrictStr = Field(..., alias="fullPath", description="Relative path from the mount point. Not allowed to start with \"/\", contain \"://\" or contain \".\" or \"..\" directory parts.")
+ web_server_location: Optional[constr(strict=True)] = Field(None, alias="webServerLocation", description="The webserver location where your local webserver is running to use for viewing images in the webapp when using the local datasource workflow. Defaults to http://localhost:3456 ")
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "webServerLocation"]
+
+ @validator('web_server_location')
+ def web_server_location_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^https?:\/\/.+$", value):
+ raise ValueError(r"must validate the regular expression /^https?:\/\/.+$/")
+ return value
class Config:
"""Pydantic configuration"""
@@ -75,8 +87,9 @@ def from_dict(cls, obj: dict) -> DatasourceConfigLOCAL:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
+ "thumb_suffix": obj.get("thumbSuffix"),
"full_path": obj.get("fullPath"),
- "thumb_suffix": obj.get("thumbSuffix")
+ "web_server_location": obj.get("webServerLocation")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py b/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py
new file mode 100644
index 000000000..247074c37
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py
@@ -0,0 +1,90 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import Optional
+from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator
+
+class DatasourceConfigLOCALAllOf(BaseModel):
+ """
+ DatasourceConfigLOCALAllOf
+ """
+ full_path: StrictStr = Field(..., alias="fullPath", description="Relative path from the mount point. Not allowed to start with \"/\", contain \"://\" or contain \".\" or \"..\" directory parts.")
+ web_server_location: Optional[constr(strict=True)] = Field(None, alias="webServerLocation", description="The webserver location where your local webserver is running to use for viewing images in the webapp when using the local datasource workflow. Defaults to http://localhost:3456 ")
+ __properties = ["fullPath", "webServerLocation"]
+
+ @validator('web_server_location')
+ def web_server_location_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^https?:\/\/.+$", value):
+ raise ValueError(r"must validate the regular expression /^https?:\/\/.+$/")
+ return value
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> DatasourceConfigLOCALAllOf:
+ """Create an instance of DatasourceConfigLOCALAllOf from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> DatasourceConfigLOCALAllOf:
+ """Create an instance of DatasourceConfigLOCALAllOf from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return DatasourceConfigLOCALAllOf.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in DatasourceConfigLOCALAllOf) in the input: " + str(obj))
+
+ _obj = DatasourceConfigLOCALAllOf.parse_obj({
+ "full_path": obj.get("fullPath"),
+ "web_server_location": obj.get("webServerLocation")
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py b/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py
index 3dbafebbf..f50bab55b 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py
@@ -20,17 +20,18 @@
-from pydantic import Extra, BaseModel, Field, constr, validator
+from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
class DatasourceConfigOBS(DatasourceConfigBase):
"""
DatasourceConfigOBS
"""
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
obs_endpoint: constr(strict=True, min_length=1) = Field(..., alias="obsEndpoint", description="The Object Storage Service (OBS) endpoint to use of your S3 compatible cloud storage provider")
obs_access_key_id: constr(strict=True, min_length=1) = Field(..., alias="obsAccessKeyId", description="The Access Key Id of the credential you are providing Lightly to use")
obs_secret_access_key: constr(strict=True, min_length=1) = Field(..., alias="obsSecretAccessKey", description="The Secret Access Key of the credential you are providing Lightly to use")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "obsEndpoint", "obsAccessKeyId", "obsSecretAccessKey"]
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "obsEndpoint", "obsAccessKeyId", "obsSecretAccessKey"]
@validator('obs_endpoint')
def obs_endpoint_validate_regular_expression(cls, value):
@@ -85,8 +86,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigOBS:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
"thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath"),
"obs_endpoint": obj.get("obsEndpoint"),
"obs_access_key_id": obj.get("obsAccessKeyId"),
"obs_secret_access_key": obj.get("obsSecretAccessKey")
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py b/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py
index 2d25ebde2..638018727 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py
@@ -20,7 +20,7 @@
from typing import Optional
-from pydantic import Extra, BaseModel, Field, constr, validator
+from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
from lightly.openapi_generated.swagger_client.models.s3_region import S3Region
@@ -28,11 +28,12 @@ class DatasourceConfigS3(DatasourceConfigBase):
"""
DatasourceConfigS3
"""
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
s3_region: S3Region = Field(..., alias="s3Region")
s3_access_key_id: constr(strict=True, min_length=1) = Field(..., alias="s3AccessKeyId", description="The accessKeyId of the credential you are providing Lightly to use")
s3_secret_access_key: constr(strict=True, min_length=1) = Field(..., alias="s3SecretAccessKey", description="The secretAccessKey of the credential you are providing Lightly to use")
s3_server_side_encryption_kms_key: Optional[constr(strict=True, min_length=1)] = Field(None, alias="s3ServerSideEncryptionKMSKey", description="If set, Lightly Worker will automatically set the headers to use server side encryption https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html with this value as the appropriate KMS key arn. This will encrypt the files created by Lightly (crops, frames, thumbnails) in the S3 bucket. ")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "s3Region", "s3AccessKeyId", "s3SecretAccessKey", "s3ServerSideEncryptionKMSKey"]
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "s3Region", "s3AccessKeyId", "s3SecretAccessKey", "s3ServerSideEncryptionKMSKey"]
@validator('s3_server_side_encryption_kms_key')
def s3_server_side_encryption_kms_key_validate_regular_expression(cls, value):
@@ -90,8 +91,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigS3:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
"thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath"),
"s3_region": obj.get("s3Region"),
"s3_access_key_id": obj.get("s3AccessKeyId"),
"s3_secret_access_key": obj.get("s3SecretAccessKey"),
diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py b/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py
index a5ef73962..3ea071e56 100644
--- a/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py
+++ b/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py
@@ -20,7 +20,7 @@
from typing import Optional
-from pydantic import Extra, BaseModel, Field, constr, validator
+from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator
from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase
from lightly.openapi_generated.swagger_client.models.s3_region import S3Region
@@ -28,11 +28,12 @@ class DatasourceConfigS3DelegatedAccess(DatasourceConfigBase):
"""
DatasourceConfigS3DelegatedAccess
"""
+ full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information")
s3_region: S3Region = Field(..., alias="s3Region")
s3_external_id: constr(strict=True, min_length=10) = Field(..., alias="s3ExternalId", description="The external ID specified when creating the role.")
s3_arn: constr(strict=True, min_length=12) = Field(..., alias="s3ARN", description="The ARN of the role you created")
s3_server_side_encryption_kms_key: Optional[constr(strict=True, min_length=1)] = Field(None, alias="s3ServerSideEncryptionKMSKey", description="If set, Lightly Worker will automatically set the headers to use server side encryption https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html with this value as the appropriate KMS key arn. This will encrypt the files created by Lightly (crops, frames, thumbnails) in the S3 bucket. ")
- __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "s3Region", "s3ExternalId", "s3ARN", "s3ServerSideEncryptionKMSKey"]
+ __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "s3Region", "s3ExternalId", "s3ARN", "s3ServerSideEncryptionKMSKey"]
@validator('s3_external_id')
def s3_external_id_validate_regular_expression(cls, value):
@@ -104,8 +105,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigS3DelegatedAccess:
"id": obj.get("id"),
"purpose": obj.get("purpose"),
"type": obj.get("type"),
- "full_path": obj.get("fullPath"),
"thumb_suffix": obj.get("thumbSuffix"),
+ "full_path": obj.get("fullPath"),
"s3_region": obj.get("s3Region"),
"s3_external_id": obj.get("s3ExternalId"),
"s3_arn": obj.get("s3ARN"),
diff --git a/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py b/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py
index e4fa545e2..02d827c70 100644
--- a/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py
+++ b/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py
@@ -36,6 +36,7 @@ class DockerRunArtifactType(str, Enum):
CHECKPOINT = 'CHECKPOINT'
REPORT_PDF = 'REPORT_PDF'
REPORT_JSON = 'REPORT_JSON'
+ REPORT_V2_JSON = 'REPORT_V2_JSON'
CORRUPTNESS_CHECK_INFORMATION = 'CORRUPTNESS_CHECK_INFORMATION'
SEQUENCE_INFORMATION = 'SEQUENCE_INFORMATION'
RELEVANT_FILENAMES = 'RELEVANT_FILENAMES'
diff --git a/lightly/openapi_generated/swagger_client/models/docker_run_data.py b/lightly/openapi_generated/swagger_client/models/docker_run_data.py
index 4b8004f93..5502a6a2c 100644
--- a/lightly/openapi_generated/swagger_client/models/docker_run_data.py
+++ b/lightly/openapi_generated/swagger_client/models/docker_run_data.py
@@ -20,7 +20,7 @@
from typing import List, Optional
-from pydantic import Extra, BaseModel, Field, StrictStr, conint, conlist, constr, validator
+from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint, conlist, constr, validator
from lightly.openapi_generated.swagger_client.models.docker_run_artifact_data import DockerRunArtifactData
from lightly.openapi_generated.swagger_client.models.docker_run_state import DockerRunState
@@ -32,6 +32,7 @@ class DockerRunData(BaseModel):
user_id: StrictStr = Field(..., alias="userId")
docker_version: StrictStr = Field(..., alias="dockerVersion")
state: DockerRunState = Field(...)
+ archived: Optional[StrictBool] = Field(None, description="if the run is archived")
dataset_id: Optional[constr(strict=True)] = Field(None, alias="datasetId", description="MongoDB ObjectId")
config_id: Optional[constr(strict=True)] = Field(None, alias="configId", description="MongoDB ObjectId")
scheduled_id: Optional[constr(strict=True)] = Field(None, alias="scheduledId", description="MongoDB ObjectId")
@@ -39,7 +40,7 @@ class DockerRunData(BaseModel):
last_modified_at: conint(strict=True, ge=0) = Field(..., alias="lastModifiedAt", description="unix timestamp in milliseconds")
message: Optional[StrictStr] = Field(None, description="last message sent to the docker run")
artifacts: Optional[conlist(DockerRunArtifactData)] = Field(None, description="list of artifacts that were created for a run")
- __properties = ["id", "userId", "dockerVersion", "state", "datasetId", "configId", "scheduledId", "createdAt", "lastModifiedAt", "message", "artifacts"]
+ __properties = ["id", "userId", "dockerVersion", "state", "archived", "datasetId", "configId", "scheduledId", "createdAt", "lastModifiedAt", "message", "artifacts"]
@validator('id')
def id_validate_regular_expression(cls, value):
@@ -132,6 +133,7 @@ def from_dict(cls, obj: dict) -> DockerRunData:
"user_id": obj.get("userId"),
"docker_version": obj.get("dockerVersion"),
"state": obj.get("state"),
+ "archived": obj.get("archived"),
"dataset_id": obj.get("datasetId"),
"config_id": obj.get("configId"),
"scheduled_id": obj.get("scheduledId"),
diff --git a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py
index a59c3f34d..4e5d1d8db 100644
--- a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py
+++ b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py
@@ -24,7 +24,7 @@
from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker import DockerWorkerConfigV3Docker
from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly import DockerWorkerConfigV3Lightly
from lightly.openapi_generated.swagger_client.models.docker_worker_type import DockerWorkerType
-from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig
+from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3
class DockerWorkerConfigV3(BaseModel):
"""
@@ -33,7 +33,7 @@ class DockerWorkerConfigV3(BaseModel):
worker_type: DockerWorkerType = Field(..., alias="workerType")
docker: Optional[DockerWorkerConfigV3Docker] = None
lightly: Optional[DockerWorkerConfigV3Lightly] = None
- selection: Optional[SelectionConfig] = None
+ selection: Optional[SelectionConfigV3] = None
__properties = ["workerType", "docker", "lightly", "selection"]
class Config:
@@ -91,7 +91,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3:
"worker_type": obj.get("workerType"),
"docker": DockerWorkerConfigV3Docker.from_dict(obj.get("docker")) if obj.get("docker") is not None else None,
"lightly": DockerWorkerConfigV3Lightly.from_dict(obj.get("lightly")) if obj.get("lightly") is not None else None,
- "selection": SelectionConfig.from_dict(obj.get("selection")) if obj.get("selection") is not None else None
+ "selection": SelectionConfigV3.from_dict(obj.get("selection")) if obj.get("selection") is not None else None
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py
index 962263917..7a66eab2e 100644
--- a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py
+++ b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py
@@ -20,7 +20,7 @@
from typing import Optional
-from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint
+from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint, constr, validator
from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_corruptness_check import DockerWorkerConfigV3DockerCorruptnessCheck
from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_datasource import DockerWorkerConfigV3DockerDatasource
from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_training import DockerWorkerConfigV3DockerTraining
@@ -30,6 +30,7 @@ class DockerWorkerConfigV3Docker(BaseModel):
docker run configurations, keys should match the structure of https://github.com/lightly-ai/lightly-core/blob/develop/onprem-docker/lightly_worker/src/lightly_worker/resources/docker/docker.yaml
"""
checkpoint: Optional[StrictStr] = None
+ checkpoint_run_id: Optional[constr(strict=True)] = Field(None, alias="checkpointRunId", description="MongoDB ObjectId")
corruptness_check: Optional[DockerWorkerConfigV3DockerCorruptnessCheck] = Field(None, alias="corruptnessCheck")
datasource: Optional[DockerWorkerConfigV3DockerDatasource] = None
embeddings: Optional[StrictStr] = None
@@ -44,8 +45,18 @@ class DockerWorkerConfigV3Docker(BaseModel):
relevant_filenames_file: Optional[StrictStr] = Field(None, alias="relevantFilenamesFile")
selected_sequence_length: Optional[conint(strict=True, ge=1)] = Field(None, alias="selectedSequenceLength")
upload_report: Optional[StrictBool] = Field(None, alias="uploadReport")
- use_datapool: Optional[StrictBool] = Field(None, alias="useDatapool")
- __properties = ["checkpoint", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport", "useDatapool"]
+ shutdown_when_job_finished: Optional[StrictBool] = Field(None, alias="shutdownWhenJobFinished")
+ __properties = ["checkpoint", "checkpointRunId", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport", "shutdownWhenJobFinished"]
+
+ @validator('checkpoint_run_id')
+ def checkpoint_run_id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
class Config:
"""Pydantic configuration"""
@@ -100,6 +111,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3Docker:
_obj = DockerWorkerConfigV3Docker.parse_obj({
"checkpoint": obj.get("checkpoint"),
+ "checkpoint_run_id": obj.get("checkpointRunId"),
"corruptness_check": DockerWorkerConfigV3DockerCorruptnessCheck.from_dict(obj.get("corruptnessCheck")) if obj.get("corruptnessCheck") is not None else None,
"datasource": DockerWorkerConfigV3DockerDatasource.from_dict(obj.get("datasource")) if obj.get("datasource") is not None else None,
"embeddings": obj.get("embeddings"),
@@ -114,7 +126,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3Docker:
"relevant_filenames_file": obj.get("relevantFilenamesFile"),
"selected_sequence_length": obj.get("selectedSequenceLength"),
"upload_report": obj.get("uploadReport"),
- "use_datapool": obj.get("useDatapool")
+ "shutdown_when_job_finished": obj.get("shutdownWhenJobFinished")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py
index c0def6930..be66eae4a 100644
--- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py
+++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py
@@ -28,8 +28,9 @@ class PredictionSingletonKeypointDetection(PredictionSingletonBase):
PredictionSingletonKeypointDetection
"""
keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ")
+ bbox: Optional[conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], max_items=4, min_items=4)] = Field(None, description="The bbox of where a prediction task yielded a finding. [x, y, width, height]")
probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.")
- __properties = ["type", "taskName", "cropDatasetId", "cropSampleId", "categoryId", "score", "keypoints", "probabilities"]
+ __properties = ["type", "taskName", "cropDatasetId", "cropSampleId", "categoryId", "score", "keypoints", "bbox", "probabilities"]
class Config:
"""Pydantic configuration"""
@@ -81,6 +82,7 @@ def from_dict(cls, obj: dict) -> PredictionSingletonKeypointDetection:
"category_id": obj.get("categoryId"),
"score": obj.get("score"),
"keypoints": obj.get("keypoints"),
+ "bbox": obj.get("bbox"),
"probabilities": obj.get("probabilities")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py
index abeadb9dc..d5cfa8795 100644
--- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py
+++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py
@@ -27,8 +27,9 @@ class PredictionSingletonKeypointDetectionAllOf(BaseModel):
PredictionSingletonKeypointDetectionAllOf
"""
keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ")
+ bbox: Optional[conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], max_items=4, min_items=4)] = Field(None, description="The bbox of where a prediction task yielded a finding. [x, y, width, height]")
probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.")
- __properties = ["keypoints", "probabilities"]
+ __properties = ["keypoints", "bbox", "probabilities"]
class Config:
"""Pydantic configuration"""
@@ -74,6 +75,7 @@ def from_dict(cls, obj: dict) -> PredictionSingletonKeypointDetectionAllOf:
_obj = PredictionSingletonKeypointDetectionAllOf.parse_obj({
"keypoints": obj.get("keypoints"),
+ "bbox": obj.get("bbox"),
"probabilities": obj.get("probabilities")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/sample_create_request.py b/lightly/openapi_generated/swagger_client/models/sample_create_request.py
index ab3c6c34d..3c292a76a 100644
--- a/lightly/openapi_generated/swagger_client/models/sample_create_request.py
+++ b/lightly/openapi_generated/swagger_client/models/sample_create_request.py
@@ -78,6 +78,16 @@ def to_dict(self, by_alias: bool = False):
if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__:
_dict['customMetaData' if by_alias else 'custom_meta_data'] = None
+ # set to None if video_frame_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.video_frame_data is None and "video_frame_data" in self.__fields_set__:
+ _dict['videoFrameData' if by_alias else 'video_frame_data'] = None
+
+ # set to None if crop_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.crop_data is None and "crop_data" in self.__fields_set__:
+ _dict['cropData' if by_alias else 'crop_data'] = None
+
return _dict
@classmethod
diff --git a/lightly/openapi_generated/swagger_client/models/sample_data.py b/lightly/openapi_generated/swagger_client/models/sample_data.py
index 42ff509a7..84d3c39fb 100644
--- a/lightly/openapi_generated/swagger_client/models/sample_data.py
+++ b/lightly/openapi_generated/swagger_client/models/sample_data.py
@@ -102,11 +102,26 @@ def to_dict(self, by_alias: bool = False):
if self.thumb_name is None and "thumb_name" in self.__fields_set__:
_dict['thumbName' if by_alias else 'thumb_name'] = None
+ # set to None if exif (nullable) is None
+ # and __fields_set__ contains the field
+ if self.exif is None and "exif" in self.__fields_set__:
+ _dict['exif' if by_alias else 'exif'] = None
+
# set to None if custom_meta_data (nullable) is None
# and __fields_set__ contains the field
if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__:
_dict['customMetaData' if by_alias else 'custom_meta_data'] = None
+ # set to None if video_frame_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.video_frame_data is None and "video_frame_data" in self.__fields_set__:
+ _dict['videoFrameData' if by_alias else 'video_frame_data'] = None
+
+ # set to None if crop_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.crop_data is None and "crop_data" in self.__fields_set__:
+ _dict['cropData' if by_alias else 'crop_data'] = None
+
return _dict
@classmethod
diff --git a/lightly/openapi_generated/swagger_client/models/sample_data_modes.py b/lightly/openapi_generated/swagger_client/models/sample_data_modes.py
index 426589a33..736754e2e 100644
--- a/lightly/openapi_generated/swagger_client/models/sample_data_modes.py
+++ b/lightly/openapi_generated/swagger_client/models/sample_data_modes.py
@@ -102,11 +102,26 @@ def to_dict(self, by_alias: bool = False):
if self.thumb_name is None and "thumb_name" in self.__fields_set__:
_dict['thumbName' if by_alias else 'thumb_name'] = None
+ # set to None if exif (nullable) is None
+ # and __fields_set__ contains the field
+ if self.exif is None and "exif" in self.__fields_set__:
+ _dict['exif' if by_alias else 'exif'] = None
+
# set to None if custom_meta_data (nullable) is None
# and __fields_set__ contains the field
if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__:
_dict['customMetaData' if by_alias else 'custom_meta_data'] = None
+ # set to None if video_frame_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.video_frame_data is None and "video_frame_data" in self.__fields_set__:
+ _dict['videoFrameData' if by_alias else 'video_frame_data'] = None
+
+ # set to None if crop_data (nullable) is None
+ # and __fields_set__ contains the field
+ if self.crop_data is None and "crop_data" in self.__fields_set__:
+ _dict['cropData' if by_alias else 'crop_data'] = None
+
return _dict
@classmethod
diff --git a/lightly/openapi_generated/swagger_client/models/sample_meta_data.py b/lightly/openapi_generated/swagger_client/models/sample_meta_data.py
index 705c8b391..2153a9487 100644
--- a/lightly/openapi_generated/swagger_client/models/sample_meta_data.py
+++ b/lightly/openapi_generated/swagger_client/models/sample_meta_data.py
@@ -66,6 +66,16 @@ def to_dict(self, by_alias: bool = False):
exclude={
},
exclude_none=True)
+ # set to None if custom (nullable) is None
+ # and __fields_set__ contains the field
+ if self.custom is None and "custom" in self.__fields_set__:
+ _dict['custom' if by_alias else 'custom'] = None
+
+ # set to None if dynamic (nullable) is None
+ # and __fields_set__ contains the field
+ if self.dynamic is None and "dynamic" in self.__fields_set__:
+ _dict['dynamic' if by_alias else 'dynamic'] = None
+
return _dict
@classmethod
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py
new file mode 100644
index 000000000..3d04824fa
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py
@@ -0,0 +1,86 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import List
+from pydantic import Extra, BaseModel, Field, conlist
+from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry
+
+class SelectionConfigAllOf(BaseModel):
+ """
+ SelectionConfigAllOf
+ """
+ strategies: conlist(SelectionConfigEntry, min_items=1) = Field(...)
+ __properties = ["strategies"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigAllOf:
+ """Create an instance of SelectionConfigAllOf from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of each item in strategies (list)
+ _items = []
+ if self.strategies:
+ for _item in self.strategies:
+ if _item:
+ _items.append(_item.to_dict(by_alias=by_alias))
+ _dict['strategies' if by_alias else 'strategies'] = _items
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigAllOf:
+ """Create an instance of SelectionConfigAllOf from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigAllOf.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigAllOf) in the input: " + str(obj))
+
+ _obj = SelectionConfigAllOf.parse_obj({
+ "strategies": [SelectionConfigEntry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/annotation_offer_data.py b/lightly/openapi_generated/swagger_client/models/selection_config_base.py
similarity index 66%
rename from lightly/openapi_generated/swagger_client/models/annotation_offer_data.py
rename to lightly/openapi_generated/swagger_client/models/selection_config_base.py
index e12f63aa5..51b358cf1 100644
--- a/lightly/openapi_generated/swagger_client/models/annotation_offer_data.py
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_base.py
@@ -20,15 +20,15 @@
from typing import Optional, Union
-from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, conint
+from pydantic import Extra, BaseModel, Field, confloat, conint
-class AnnotationOfferData(BaseModel):
+class SelectionConfigBase(BaseModel):
"""
- AnnotationOfferData
+ SelectionConfigBase
"""
- cost: Optional[Union[StrictFloat, StrictInt]] = None
- completed_by: Optional[conint(strict=True, ge=0)] = Field(None, alias="completedBy", description="unix timestamp in milliseconds")
- __properties = ["cost", "completedBy"]
+ n_samples: Optional[conint(strict=True, ge=-1)] = Field(None, alias="nSamples")
+ proportion_samples: Optional[Union[confloat(le=1.0, ge=0.0, strict=True), conint(le=1, ge=0, strict=True)]] = Field(None, alias="proportionSamples")
+ __properties = ["nSamples", "proportionSamples"]
class Config:
"""Pydantic configuration"""
@@ -46,8 +46,8 @@ def to_json(self, by_alias: bool = False) -> str:
return json.dumps(self.to_dict(by_alias=by_alias))
@classmethod
- def from_json(cls, json_str: str) -> AnnotationOfferData:
- """Create an instance of AnnotationOfferData from a JSON string"""
+ def from_json(cls, json_str: str) -> SelectionConfigBase:
+ """Create an instance of SelectionConfigBase from a JSON string"""
return cls.from_dict(json.loads(json_str))
def to_dict(self, by_alias: bool = False):
@@ -59,22 +59,22 @@ def to_dict(self, by_alias: bool = False):
return _dict
@classmethod
- def from_dict(cls, obj: dict) -> AnnotationOfferData:
- """Create an instance of AnnotationOfferData from a dict"""
+ def from_dict(cls, obj: dict) -> SelectionConfigBase:
+ """Create an instance of SelectionConfigBase from a dict"""
if obj is None:
return None
if not isinstance(obj, dict):
- return AnnotationOfferData.parse_obj(obj)
+ return SelectionConfigBase.parse_obj(obj)
# raise errors for additional fields in the input
for _key in obj.keys():
if _key not in cls.__properties:
- raise ValueError("Error due to additional fields (not defined in AnnotationOfferData) in the input: " + str(obj))
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigBase) in the input: " + str(obj))
- _obj = AnnotationOfferData.parse_obj({
- "cost": obj.get("cost"),
- "completed_by": obj.get("completedBy")
+ _obj = SelectionConfigBase.parse_obj({
+ "n_samples": obj.get("nSamples"),
+ "proportion_samples": obj.get("proportionSamples")
})
return _obj
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3.py
new file mode 100644
index 000000000..80aed78fb
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3.py
@@ -0,0 +1,90 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import List, Optional, Union
+from pydantic import Extra, BaseModel, Field, confloat, conint, conlist
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry
+
+class SelectionConfigV3(BaseModel):
+ """
+ SelectionConfigV3
+ """
+ n_samples: Optional[conint(strict=True, ge=-1)] = Field(None, alias="nSamples")
+ proportion_samples: Optional[Union[confloat(le=1.0, ge=0.0, strict=True), conint(le=1, ge=0, strict=True)]] = Field(None, alias="proportionSamples")
+ strategies: conlist(SelectionConfigV3Entry, min_items=1) = Field(...)
+ __properties = ["nSamples", "proportionSamples", "strategies"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3:
+ """Create an instance of SelectionConfigV3 from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of each item in strategies (list)
+ _items = []
+ if self.strategies:
+ for _item in self.strategies:
+ if _item:
+ _items.append(_item.to_dict(by_alias=by_alias))
+ _dict['strategies' if by_alias else 'strategies'] = _items
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3:
+ """Create an instance of SelectionConfigV3 from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3.parse_obj({
+ "n_samples": obj.get("nSamples"),
+ "proportion_samples": obj.get("proportionSamples"),
+ "strategies": [SelectionConfigV3Entry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py
new file mode 100644
index 000000000..84770bdd9
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py
@@ -0,0 +1,86 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import List
+from pydantic import Extra, BaseModel, Field, conlist
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry
+
+class SelectionConfigV3AllOf(BaseModel):
+ """
+ SelectionConfigV3AllOf
+ """
+ strategies: conlist(SelectionConfigV3Entry, min_items=1) = Field(...)
+ __properties = ["strategies"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3AllOf:
+ """Create an instance of SelectionConfigV3AllOf from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of each item in strategies (list)
+ _items = []
+ if self.strategies:
+ for _item in self.strategies:
+ if _item:
+ _items.append(_item.to_dict(by_alias=by_alias))
+ _dict['strategies' if by_alias else 'strategies'] = _items
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3AllOf:
+ """Create an instance of SelectionConfigV3AllOf from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3AllOf.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3AllOf) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3AllOf.parse_obj({
+ "strategies": [SelectionConfigV3Entry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py
new file mode 100644
index 000000000..834cbd317
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py
@@ -0,0 +1,88 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+
+from pydantic import Extra, BaseModel, Field
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy
+
+class SelectionConfigV3Entry(BaseModel):
+ """
+ SelectionConfigV3Entry
+ """
+ input: SelectionConfigV3EntryInput = Field(...)
+ strategy: SelectionConfigV3EntryStrategy = Field(...)
+ __properties = ["input", "strategy"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3Entry:
+ """Create an instance of SelectionConfigV3Entry from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of input
+ if self.input:
+ _dict['input' if by_alias else 'input'] = self.input.to_dict(by_alias=by_alias)
+ # override the default output from pydantic by calling `to_dict()` of strategy
+ if self.strategy:
+ _dict['strategy' if by_alias else 'strategy'] = self.strategy.to_dict(by_alias=by_alias)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3Entry:
+ """Create an instance of SelectionConfigV3Entry from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3Entry.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3Entry) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3Entry.parse_obj({
+ "input": SelectionConfigV3EntryInput.from_dict(obj.get("input")) if obj.get("input") is not None else None,
+ "strategy": SelectionConfigV3EntryStrategy.from_dict(obj.get("strategy")) if obj.get("strategy") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py
new file mode 100644
index 000000000..aac4c59da
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py
@@ -0,0 +1,136 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import List, Optional
+from pydantic import Extra, BaseModel, Field, StrictInt, conlist, constr, validator
+from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName
+from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType
+
+class SelectionConfigV3EntryInput(BaseModel):
+ """
+ SelectionConfigV3EntryInput
+ """
+ type: SelectionInputType = Field(...)
+ task: Optional[constr(strict=True)] = Field(None, description="Since we sometimes stitch together SelectionInputTask+ActiveLearningScoreType, they need to follow the same specs of ActiveLearningScoreType. However, this can be an empty string due to internal logic. ")
+ score: Optional[constr(strict=True, min_length=1)] = Field(None, description="Type of active learning score")
+ key: Optional[constr(strict=True, min_length=1)] = None
+ name: Optional[SelectionInputPredictionsName] = None
+ dataset_id: Optional[constr(strict=True)] = Field(None, alias="datasetId", description="MongoDB ObjectId")
+ tag_name: Optional[constr(strict=True, min_length=3)] = Field(None, alias="tagName", description="The name of the tag")
+ random_seed: Optional[StrictInt] = Field(None, alias="randomSeed")
+ categories: Optional[conlist(constr(strict=True, min_length=1), min_items=1, unique_items=True)] = None
+ __properties = ["type", "task", "score", "key", "name", "datasetId", "tagName", "randomSeed", "categories"]
+
+ @validator('task')
+ def task_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/")
+ return value
+
+ @validator('score')
+ def score_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/")
+ return value
+
+ @validator('dataset_id')
+ def dataset_id_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^[a-f0-9]{24}$", value):
+ raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/")
+ return value
+
+ @validator('tag_name')
+ def tag_name_validate_regular_expression(cls, value):
+ """Validates the regular expression"""
+ if value is None:
+ return value
+
+ if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 .:;=@_-]+$", value):
+ raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 .:;=@_-]+$/")
+ return value
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3EntryInput:
+ """Create an instance of SelectionConfigV3EntryInput from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3EntryInput:
+ """Create an instance of SelectionConfigV3EntryInput from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3EntryInput.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryInput) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3EntryInput.parse_obj({
+ "type": obj.get("type"),
+ "task": obj.get("task"),
+ "score": obj.get("score"),
+ "key": obj.get("key"),
+ "name": obj.get("name"),
+ "dataset_id": obj.get("datasetId"),
+ "tag_name": obj.get("tagName"),
+ "random_seed": obj.get("randomSeed"),
+ "categories": obj.get("categories")
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py
new file mode 100644
index 000000000..a9e00e1ea
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py
@@ -0,0 +1,102 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import Any, Dict, Optional, Union
+from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, confloat, conint
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange
+from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation
+from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3
+
+class SelectionConfigV3EntryStrategy(BaseModel):
+ """
+ SelectionConfigV3EntryStrategy
+ """
+ type: SelectionStrategyTypeV3 = Field(...)
+ stopping_condition_minimum_distance: Optional[Union[StrictFloat, StrictInt]] = None
+ threshold: Optional[Union[StrictFloat, StrictInt]] = None
+ operation: Optional[SelectionStrategyThresholdOperation] = None
+ target: Optional[Dict[str, Any]] = None
+ num_nearest_neighbors: Optional[Union[confloat(ge=2, strict=True), conint(ge=2, strict=True)]] = Field(None, alias="numNearestNeighbors", description="It is the number of nearest datapoints used to compute the typicality of each sample. ")
+ stopping_condition_minimum_typicality: Optional[Union[confloat(gt=0, strict=True), conint(gt=0, strict=True)]] = Field(None, alias="stoppingConditionMinimumTypicality", description="It is the minimal allowed typicality of the selected samples. When the typicality of the selected samples reaches this, the selection stops. It should be a number between 0 and 1. ")
+ strength: Optional[Union[confloat(le=1000000000, ge=-1000000000, strict=True), conint(le=1000000000, ge=-1000000000, strict=True)]] = Field(None, description="The relative strength of this strategy compared to other strategies. The default value is 1.0, which is set in the worker for backwards compatibility. The minimum and maximum values of +-10^9 are used to prevent numerical issues. ")
+ stopping_condition_max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="stoppingConditionMaxSum", description="When the sum of inputs reaches this, the selection stops. Only compatible with the WEIGHTS strategy. Similar to the stopping_condition_minimum_distance for the DIVERSITY strategy. ")
+ target_range: Optional[SelectionConfigV3EntryStrategyAllOfTargetRange] = Field(None, alias="targetRange")
+ __properties = ["type", "stopping_condition_minimum_distance", "threshold", "operation", "target", "numNearestNeighbors", "stoppingConditionMinimumTypicality", "strength", "stoppingConditionMaxSum", "targetRange"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategy:
+ """Create an instance of SelectionConfigV3EntryStrategy from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of target_range
+ if self.target_range:
+ _dict['targetRange' if by_alias else 'target_range'] = self.target_range.to_dict(by_alias=by_alias)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategy:
+ """Create an instance of SelectionConfigV3EntryStrategy from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3EntryStrategy.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategy) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3EntryStrategy.parse_obj({
+ "type": obj.get("type"),
+ "stopping_condition_minimum_distance": obj.get("stopping_condition_minimum_distance"),
+ "threshold": obj.get("threshold"),
+ "operation": obj.get("operation"),
+ "target": obj.get("target"),
+ "num_nearest_neighbors": obj.get("numNearestNeighbors"),
+ "stopping_condition_minimum_typicality": obj.get("stoppingConditionMinimumTypicality"),
+ "strength": obj.get("strength"),
+ "stopping_condition_max_sum": obj.get("stoppingConditionMaxSum"),
+ "target_range": SelectionConfigV3EntryStrategyAllOfTargetRange.from_dict(obj.get("targetRange")) if obj.get("targetRange") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py
new file mode 100644
index 000000000..30adac0db
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py
@@ -0,0 +1,93 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import Optional, Union
+from pydantic import Extra, BaseModel, Field, confloat, conint
+from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange
+from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3
+
+class SelectionConfigV3EntryStrategyAllOf(BaseModel):
+ """
+ SelectionConfigV3EntryStrategyAllOf
+ """
+ type: SelectionStrategyTypeV3 = Field(...)
+ num_nearest_neighbors: Optional[Union[confloat(ge=2, strict=True), conint(ge=2, strict=True)]] = Field(None, alias="numNearestNeighbors", description="It is the number of nearest datapoints used to compute the typicality of each sample. ")
+ stopping_condition_minimum_typicality: Optional[Union[confloat(gt=0, strict=True), conint(gt=0, strict=True)]] = Field(None, alias="stoppingConditionMinimumTypicality", description="It is the minimal allowed typicality of the selected samples. When the typicality of the selected samples reaches this, the selection stops. It should be a number between 0 and 1. ")
+ strength: Optional[Union[confloat(le=1000000000, ge=-1000000000, strict=True), conint(le=1000000000, ge=-1000000000, strict=True)]] = Field(None, description="The relative strength of this strategy compared to other strategies. The default value is 1.0, which is set in the worker for backwards compatibility. The minimum and maximum values of +-10^9 are used to prevent numerical issues. ")
+ stopping_condition_max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="stoppingConditionMaxSum", description="When the sum of inputs reaches this, the selection stops. Only compatible with the WEIGHTS strategy. Similar to the stopping_condition_minimum_distance for the DIVERSITY strategy. ")
+ target_range: Optional[SelectionConfigV3EntryStrategyAllOfTargetRange] = Field(None, alias="targetRange")
+ __properties = ["type", "numNearestNeighbors", "stoppingConditionMinimumTypicality", "strength", "stoppingConditionMaxSum", "targetRange"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategyAllOf:
+ """Create an instance of SelectionConfigV3EntryStrategyAllOf from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ # override the default output from pydantic by calling `to_dict()` of target_range
+ if self.target_range:
+ _dict['targetRange' if by_alias else 'target_range'] = self.target_range.to_dict(by_alias=by_alias)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategyAllOf:
+ """Create an instance of SelectionConfigV3EntryStrategyAllOf from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3EntryStrategyAllOf.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategyAllOf) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3EntryStrategyAllOf.parse_obj({
+ "type": obj.get("type"),
+ "num_nearest_neighbors": obj.get("numNearestNeighbors"),
+ "stopping_condition_minimum_typicality": obj.get("stoppingConditionMinimumTypicality"),
+ "strength": obj.get("strength"),
+ "stopping_condition_max_sum": obj.get("stoppingConditionMaxSum"),
+ "target_range": SelectionConfigV3EntryStrategyAllOfTargetRange.from_dict(obj.get("targetRange")) if obj.get("targetRange") is not None else None
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py
new file mode 100644
index 000000000..73495ae9c
--- /dev/null
+++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py
@@ -0,0 +1,80 @@
+# coding: utf-8
+
+"""
+ Lightly API
+
+ Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501
+
+ The version of the OpenAPI document: 1.0.0
+ Contact: support@lightly.ai
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
+
+ Do not edit the class manually.
+"""
+
+
+from __future__ import annotations
+import pprint
+import re # noqa: F401
+import json
+
+
+from typing import Optional, Union
+from pydantic import Extra, BaseModel, Field, confloat, conint
+
+class SelectionConfigV3EntryStrategyAllOfTargetRange(BaseModel):
+ """
+ If specified, it tries to select samples such that their sum of inputs is >= min_sum and <= max_sum. Only compatible with the WEIGHTS strategy.
+ """
+ min_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="minSum", description="Target minimum sum of inputs. ")
+ max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="maxSum", description="Target maximum sum of inputs. Must be >= min_sum. ")
+ __properties = ["minSum", "maxSum"]
+
+ class Config:
+ """Pydantic configuration"""
+ allow_population_by_field_name = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = Extra.forbid
+
+ def to_str(self, by_alias: bool = False) -> str:
+ """Returns the string representation of the model"""
+ return pprint.pformat(self.dict(by_alias=by_alias))
+
+ def to_json(self, by_alias: bool = False) -> str:
+ """Returns the JSON representation of the model"""
+ return json.dumps(self.to_dict(by_alias=by_alias))
+
+ @classmethod
+ def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategyAllOfTargetRange:
+ """Create an instance of SelectionConfigV3EntryStrategyAllOfTargetRange from a JSON string"""
+ return cls.from_dict(json.loads(json_str))
+
+ def to_dict(self, by_alias: bool = False):
+ """Returns the dictionary representation of the model"""
+ _dict = self.dict(by_alias=by_alias,
+ exclude={
+ },
+ exclude_none=True)
+ return _dict
+
+ @classmethod
+ def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategyAllOfTargetRange:
+ """Create an instance of SelectionConfigV3EntryStrategyAllOfTargetRange from a dict"""
+ if obj is None:
+ return None
+
+ if not isinstance(obj, dict):
+ return SelectionConfigV3EntryStrategyAllOfTargetRange.parse_obj(obj)
+
+ # raise errors for additional fields in the input
+ for _key in obj.keys():
+ if _key not in cls.__properties:
+ raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategyAllOfTargetRange) in the input: " + str(obj))
+
+ _obj = SelectionConfigV3EntryStrategyAllOfTargetRange.parse_obj({
+ "min_sum": obj.get("minSum"),
+ "max_sum": obj.get("maxSum")
+ })
+ return _obj
+
diff --git a/lightly/openapi_generated/swagger_client/models/annotation_state.py b/lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py
similarity index 59%
rename from lightly/openapi_generated/swagger_client/models/annotation_state.py
rename to lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py
index 48f5388ea..4d1984560 100644
--- a/lightly/openapi_generated/swagger_client/models/annotation_state.py
+++ b/lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py
@@ -23,24 +23,24 @@
-class AnnotationState(str, Enum):
+class SelectionStrategyTypeV3(str, Enum):
"""
- AnnotationState
+ SelectionStrategyTypeV3
"""
"""
allowed enum values
"""
- DRAFT = 'DRAFT'
- OFFER_REQUESTED = 'OFFER_REQUESTED'
- OFFER_RETURNED = 'OFFER_RETURNED'
- ACCEPTED = 'ACCEPTED'
- ACTIVE = 'ACTIVE'
- COMPLETED = 'COMPLETED'
+ DIVERSITY = 'DIVERSITY'
+ WEIGHTS = 'WEIGHTS'
+ THRESHOLD = 'THRESHOLD'
+ BALANCE = 'BALANCE'
+ SIMILARITY = 'SIMILARITY'
+ TYPICALITY = 'TYPICALITY'
@classmethod
- def from_json(cls, json_str: str) -> 'AnnotationState':
- """Create an instance of AnnotationState from a JSON string"""
- return AnnotationState(json.loads(json_str))
+ def from_json(cls, json_str: str) -> 'SelectionStrategyTypeV3':
+ """Create an instance of SelectionStrategyTypeV3 from a JSON string"""
+ return SelectionStrategyTypeV3(json.loads(json_str))
diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py
index 59fafe34d..33c70bdad 100644
--- a/lightly/transforms/__init__.py
+++ b/lightly/transforms/__init__.py
@@ -8,6 +8,11 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform
from lightly.transforms.fast_siam_transform import FastSiamTransform
from lightly.transforms.gaussian_blur import GaussianBlur
diff --git a/lightly/transforms/byol_transform.py b/lightly/transforms/byol_transform.py
new file mode 100644
index 000000000..26d9f1360
--- /dev/null
+++ b/lightly/transforms/byol_transform.py
@@ -0,0 +1,177 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import torchvision.transforms as T
+from PIL.Image import Image
+from torch import Tensor
+
+from lightly.transforms.gaussian_blur import GaussianBlur
+from lightly.transforms.multi_view_transform import MultiViewTransform
+from lightly.transforms.rotation import random_rotation_transform
+from lightly.transforms.solarize import RandomSolarization
+from lightly.transforms.utils import IMAGENET_NORMALIZE
+
+
+class BYOLView1Transform:
+ def __init__(
+ self,
+ input_size: int = 224,
+ cj_prob: float = 0.8,
+ cj_strength: float = 1.0,
+ cj_bright: float = 0.4,
+ cj_contrast: float = 0.4,
+ cj_sat: float = 0.2,
+ cj_hue: float = 0.1,
+ min_scale: float = 0.08,
+ random_gray_scale: float = 0.2,
+ gaussian_blur: float = 1.0,
+ solarization_prob: float = 0.0,
+ kernel_size: Optional[float] = None,
+ sigmas: Tuple[float, float] = (0.1, 2),
+ vf_prob: float = 0.0,
+ hf_prob: float = 0.5,
+ rr_prob: float = 0.0,
+ rr_degrees: Union[None, float, Tuple[float, float]] = None,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
+ ):
+ color_jitter = T.ColorJitter(
+ brightness=cj_strength * cj_bright,
+ contrast=cj_strength * cj_contrast,
+ saturation=cj_strength * cj_sat,
+ hue=cj_strength * cj_hue,
+ )
+
+ transform = [
+ T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)),
+ random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees),
+ T.RandomHorizontalFlip(p=hf_prob),
+ T.RandomVerticalFlip(p=vf_prob),
+ T.RandomApply([color_jitter], p=cj_prob),
+ T.RandomGrayscale(p=random_gray_scale),
+ GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur),
+ RandomSolarization(prob=solarization_prob),
+ T.ToTensor(),
+ ]
+ if normalize:
+ transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])]
+ self.transform = T.Compose(transform)
+
+ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
+ """
+ Applies the transforms to the input image.
+
+ Args:
+ image:
+ The input image to apply the transforms to.
+
+ Returns:
+ The transformed image.
+
+ """
+ transformed: Tensor = self.transform(image)
+ return transformed
+
+
+class BYOLView2Transform:
+ def __init__(
+ self,
+ input_size: int = 224,
+ cj_prob: float = 0.8,
+ cj_strength: float = 1.0,
+ cj_bright: float = 0.4,
+ cj_contrast: float = 0.4,
+ cj_sat: float = 0.2,
+ cj_hue: float = 0.1,
+ min_scale: float = 0.08,
+ random_gray_scale: float = 0.2,
+ gaussian_blur: float = 0.1,
+ solarization_prob: float = 0.2,
+ kernel_size: Optional[float] = None,
+ sigmas: Tuple[float, float] = (0.1, 2),
+ vf_prob: float = 0.0,
+ hf_prob: float = 0.5,
+ rr_prob: float = 0.0,
+ rr_degrees: Union[None, float, Tuple[float, float]] = None,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
+ ):
+ color_jitter = T.ColorJitter(
+ brightness=cj_strength * cj_bright,
+ contrast=cj_strength * cj_contrast,
+ saturation=cj_strength * cj_sat,
+ hue=cj_strength * cj_hue,
+ )
+
+ transform = [
+ T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)),
+ random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees),
+ T.RandomHorizontalFlip(p=hf_prob),
+ T.RandomVerticalFlip(p=vf_prob),
+ T.RandomApply([color_jitter], p=cj_prob),
+ T.RandomGrayscale(p=random_gray_scale),
+ GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur),
+ RandomSolarization(prob=solarization_prob),
+ T.ToTensor(),
+ ]
+ if normalize:
+ transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])]
+ self.transform = T.Compose(transform)
+
+ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
+ """
+ Applies the transforms to the input image.
+
+ Args:
+ image:
+ The input image to apply the transforms to.
+
+ Returns:
+ The transformed image.
+
+ """
+ transformed: Tensor = self.transform(image)
+ return transformed
+
+
+class BYOLTransform(MultiViewTransform):
+ """Implements the transformations for BYOL[0].
+
+ Input to this transform:
+ PIL Image or Tensor.
+
+ Output of this transform:
+ List of Tensor of length 2.
+
+ Applies the following augmentations by default:
+ - Random resized crop
+ - Random horizontal flip
+ - Color jitter
+ - Random gray scale
+ - Gaussian blur
+ - Solarization
+ - ImageNet normalization
+
+ Note that SimCLR v1 and v2 use similar augmentations. In detail, BYOL has
+ asymmetric gaussian blur and solarization. Furthermore, BYOL has weaker
+ color jitter compared to SimCLR.
+
+ - [0]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf
+
+ Input to this transform:
+ PIL Image or Tensor.
+
+ Output of this transform:
+ List of [tensor, tensor].
+
+ Attributes:
+ view_1_transform: The transform for the first view.
+ view_2_transform: The transform for the second view.
+ """
+
+ def __init__(
+ self,
+ view_1_transform: Union[BYOLView1Transform, None] = None,
+ view_2_transform: Union[BYOLView2Transform, None] = None,
+ ):
+ # We need to initialize the transforms here
+ view_1_transform = view_1_transform or BYOLView1Transform()
+ view_2_transform = view_2_transform or BYOLView2Transform()
+ super().__init__(transforms=[view_1_transform, view_2_transform])
diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py
index de6a18657..38e719fbd 100644
--- a/lightly/transforms/dino_transform.py
+++ b/lightly/transforms/dino_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import PIL
import torchvision.transforms as T
@@ -117,7 +117,7 @@ def __init__(
kernel_scale: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
solarization_prob: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
# first global crop
global_transform_0 = DINOViewTransform(
@@ -213,7 +213,7 @@ def __init__(
kernel_scale: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
solarization_prob: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transform = [
T.RandomResizedCrop(
@@ -262,4 +262,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/fast_siam_transform.py b/lightly/transforms/fast_siam_transform.py
index 7b560591f..c70b77de1 100644
--- a/lightly/transforms/fast_siam_transform.py
+++ b/lightly/transforms/fast_siam_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.simsiam_transform import SimSiamViewTransform
@@ -89,7 +89,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transforms = [
SimSiamViewTransform(
diff --git a/lightly/transforms/gaussian_blur.py b/lightly/transforms/gaussian_blur.py
index 180e6ff08..3f8f9e2d7 100644
--- a/lightly/transforms/gaussian_blur.py
+++ b/lightly/transforms/gaussian_blur.py
@@ -6,6 +6,7 @@
import numpy as np
from PIL import ImageFilter
+from PIL.Image import Image
class GaussianBlur:
@@ -47,7 +48,7 @@ def __init__(
self.prob = prob
self.sigmas = sigmas
- def __call__(self, sample):
+ def __call__(self, sample: Image) -> Image:
"""Blurs the image with a given probability.
Args:
diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py
index 321dba66a..6bd002446 100644
--- a/lightly/transforms/ijepa_transform.py
+++ b/lightly/transforms/ijepa_transform.py
@@ -1,4 +1,4 @@
-from typing import Tuple, Union
+from typing import Dict, List, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -30,7 +30,7 @@ def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 224,
min_scale: float = 0.2,
- normalize: dict = IMAGENET_NORMALIZE,
+ normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
transforms = [
T.RandomResizedCrop(
@@ -55,4 +55,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/image_grid_transform.py b/lightly/transforms/image_grid_transform.py
index 30854c910..1822761a6 100644
--- a/lightly/transforms/image_grid_transform.py
+++ b/lightly/transforms/image_grid_transform.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple, Union
+from typing import List, Sequence, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -19,7 +19,7 @@ class ImageGridTransform:
grids.
"""
- def __init__(self, transforms):
+ def __init__(self, transforms: Sequence[T.Compose]):
self.transforms = transforms
def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]:
diff --git a/lightly/transforms/jigsaw.py b/lightly/transforms/jigsaw.py
index db9cbc37b..adebb808b 100644
--- a/lightly/transforms/jigsaw.py
+++ b/lightly/transforms/jigsaw.py
@@ -1,10 +1,14 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved
+from typing import List
+
import numpy as np
import torch
-from PIL import Image
-from torchvision import transforms
+from PIL import Image as Image
+from PIL.Image import Image as PILImage
+from torch import Tensor
+from torchvision import transforms as T
class Jigsaw(object):
@@ -34,7 +38,11 @@ class Jigsaw(object):
"""
def __init__(
- self, n_grid=3, img_size=255, crop_size=64, transform=transforms.ToTensor()
+ self,
+ n_grid: int = 3,
+ img_size: int = 255,
+ crop_size: int = 64,
+ transform: T.Compose = T.ToTensor(),
):
self.n_grid = n_grid
self.img_size = img_size
@@ -47,7 +55,7 @@ def __init__(
self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,))
- def __call__(self, img):
+ def __call__(self, img: PILImage) -> Tensor:
"""Performs the Jigsaw augmentation
Args:
img:
@@ -59,7 +67,7 @@ def __call__(self, img):
r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
img = np.asarray(img, np.uint8)
- crops = []
+ crops: List[PILImage] = []
for i in range(self.n_grid * self.n_grid):
crops.append(
img[
@@ -68,7 +76,9 @@ def __call__(self, img):
:,
]
)
- crops = [Image.fromarray(crop) for crop in crops]
- crops = torch.stack([self.transform(crop) for crop in crops])
- crops = crops[np.random.permutation(self.n_grid**2)]
- return crops
+ crop_images = [Image.fromarray(crop) for crop in crops]
+ crop_tensors: Tensor = torch.stack(
+ [self.transform(crop) for crop in crop_images]
+ )
+ permutation: List[int] = np.random.permutation(self.n_grid**2).tolist()
+ return crop_tensors[permutation]
diff --git a/lightly/transforms/mae_transform.py b/lightly/transforms/mae_transform.py
index 3176f084e..50f9dd9f7 100644
--- a/lightly/transforms/mae_transform.py
+++ b/lightly/transforms/mae_transform.py
@@ -1,10 +1,9 @@
-from typing import List, Tuple, Union
+from typing import Dict, List, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor
-from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE
@@ -37,7 +36,7 @@ def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 224,
min_scale: float = 0.2,
- normalize: dict = IMAGENET_NORMALIZE,
+ normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
transforms = [
T.RandomResizedCrop(
diff --git a/lightly/transforms/moco_transform.py b/lightly/transforms/moco_transform.py
index 8ec8ade55..3f5f728fe 100644
--- a/lightly/transforms/moco_transform.py
+++ b/lightly/transforms/moco_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE
@@ -83,7 +83,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: dict = IMAGENET_NORMALIZE,
+ normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
super().__init__(
input_size=input_size,
diff --git a/lightly/transforms/msn_transform.py b/lightly/transforms/msn_transform.py
index d07130732..b8a78ac1c 100644
--- a/lightly/transforms/msn_transform.py
+++ b/lightly/transforms/msn_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -96,7 +96,7 @@ def __init__(
random_gray_scale: float = 0.2,
hf_prob: float = 0.5,
vf_prob: float = 0.0,
- normalize: dict = IMAGENET_NORMALIZE,
+ normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
random_view_transform = MSNViewTransform(
crop_size=random_size,
@@ -150,7 +150,7 @@ def __init__(
random_gray_scale: float = 0.2,
hf_prob: float = 0.5,
vf_prob: float = 0.0,
- normalize: dict = IMAGENET_NORMALIZE,
+ normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -183,4 +183,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/multi_crop_transform.py b/lightly/transforms/multi_crop_transform.py
index 307a507c6..2e546b22f 100644
--- a/lightly/transforms/multi_crop_transform.py
+++ b/lightly/transforms/multi_crop_transform.py
@@ -34,11 +34,11 @@ class MultiCropTranform(MultiViewTransform):
def __init__(
self,
- crop_sizes: Tuple[int],
- crop_counts: Tuple[int],
- crop_min_scales: Tuple[float],
- crop_max_scales: Tuple[float],
- transforms,
+ crop_sizes: Tuple[int, ...],
+ crop_counts: Tuple[int, ...],
+ crop_min_scales: Tuple[float, ...],
+ crop_max_scales: Tuple[float, ...],
+ transforms: T.Compose,
):
if len(crop_sizes) != len(crop_counts):
raise ValueError(
diff --git a/lightly/transforms/multi_view_transform.py b/lightly/transforms/multi_view_transform.py
index a00f4fc00..62c9f2cb5 100644
--- a/lightly/transforms/multi_view_transform.py
+++ b/lightly/transforms/multi_view_transform.py
@@ -1,7 +1,8 @@
-from typing import List, Union
+from typing import List, Sequence, Union
from PIL.Image import Image
from torch import Tensor
+from torchvision import transforms as T
class MultiViewTransform:
@@ -13,7 +14,7 @@ class MultiViewTransform:
"""
- def __init__(self, transforms):
+ def __init__(self, transforms: Sequence[T.Compose]):
self.transforms = transforms
def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]:
diff --git a/lightly/transforms/pirl_transform.py b/lightly/transforms/pirl_transform.py
index b67e451c7..1d5ec7d57 100644
--- a/lightly/transforms/pirl_transform.py
+++ b/lightly/transforms/pirl_transform.py
@@ -1,8 +1,6 @@
-from typing import Tuple, Union
+from typing import Dict, List, Tuple, Union
import torchvision.transforms as T
-from PIL.Image import Image
-from torch import Tensor
from lightly.transforms.jigsaw import Jigsaw
from lightly.transforms.multi_view_transform import MultiViewTransform
@@ -71,7 +69,7 @@ def __init__(
random_gray_scale: float = 0.2,
hf_prob: float = 0.5,
n_grid: int = 3,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
if isinstance(input_size, tuple):
input_size_ = max(input_size)
@@ -79,13 +77,17 @@ def __init__(
input_size_ = input_size
# Cropping and normalisation for non-transformed image
- no_augment = T.Compose(
- [
- T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)),
- T.ToTensor(),
- T.Normalize(mean=normalize["mean"], std=normalize["std"]),
- ]
- )
+ transforms_no_augment = [
+ T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)),
+ T.ToTensor(),
+ ]
+
+ if normalize is not None:
+ transforms_no_augment.append(
+ T.Normalize(mean=normalize["mean"], std=normalize["std"])
+ )
+
+ no_augment = T.Compose(transforms_no_augment)
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -95,21 +97,21 @@ def __init__(
)
# Transform for transformed jigsaw image
- transform = [
+ transforms = [
T.RandomHorizontalFlip(p=hf_prob),
T.RandomApply([color_jitter], p=cj_prob),
T.RandomGrayscale(p=random_gray_scale),
T.ToTensor(),
]
- if normalize:
- transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])]
+ if normalize is not None:
+ transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"]))
jigsaw = Jigsaw(
n_grid=n_grid,
img_size=input_size_,
crop_size=int(input_size_ // n_grid),
- transform=T.Compose(transform),
+ transform=T.Compose(transforms),
)
super().__init__([no_augment, jigsaw])
diff --git a/lightly/transforms/py.typed b/lightly/transforms/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightly/transforms/random_crop_and_flip_with_grid.py b/lightly/transforms/random_crop_and_flip_with_grid.py
index 6fbb92a33..60a66226f 100644
--- a/lightly/transforms/random_crop_and_flip_with_grid.py
+++ b/lightly/transforms/random_crop_and_flip_with_grid.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Dict, List, Tuple
+from typing import Tuple
import torch
import torchvision.transforms as T
@@ -28,7 +28,7 @@ class Location:
vertical_flip: bool = False
-class RandomResizedCropWithLocation(T.RandomResizedCrop):
+class RandomResizedCropWithLocation(T.RandomResizedCrop): # type: ignore[misc] # Class cannot subclass "RandomResizedCrop" (has type "Any")
"""
Do a random resized crop and return both the resulting image and the location. See base class.
@@ -59,7 +59,7 @@ def forward(self, img: Image.Image) -> Tuple[Image.Image, Location]:
return img, location
-class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip):
+class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip): # type: ignore[misc] # Class cannot subclass "RandomHorizontalFlip" (has type "Any")
"""See base class."""
def forward(
@@ -84,7 +84,7 @@ def forward(
return img, location
-class RandomVerticalFlipWithLocation(T.RandomVerticalFlip):
+class RandomVerticalFlipWithLocation(T.RandomVerticalFlip): # type: ignore[misc] # Class cannot subclass "RandomVerticalFlip" (has type "Any")
"""See base class."""
def forward(
diff --git a/lightly/transforms/simclr_transform.py b/lightly/transforms/simclr_transform.py
index 8c39591c7..d975b0771 100644
--- a/lightly/transforms/simclr_transform.py
+++ b/lightly/transforms/simclr_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -102,7 +102,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
view_transform = SimCLRViewTransform(
input_size=input_size,
@@ -145,7 +145,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -180,4 +180,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/simsiam_transform.py b/lightly/transforms/simsiam_transform.py
index 379d55242..4f2607480 100644
--- a/lightly/transforms/simsiam_transform.py
+++ b/lightly/transforms/simsiam_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -91,7 +91,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
view_transform = SimSiamViewTransform(
input_size=input_size,
@@ -134,7 +134,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -169,4 +169,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/smog_transform.py b/lightly/transforms/smog_transform.py
index 6f8958377..69d8de64c 100644
--- a/lightly/transforms/smog_transform.py
+++ b/lightly/transforms/smog_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -88,7 +88,7 @@ def __init__(
cj_sat: float = 0.4,
cj_hue: float = 0.2,
random_gray_scale: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transforms = []
for i in range(len(crop_sizes)):
@@ -137,7 +137,7 @@ def __init__(
cj_sat: float = 0.4,
cj_hue: float = 0.2,
random_gray_scale: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -175,4 +175,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/solarize.py b/lightly/transforms/solarize.py
index bbd460899..3e640a404 100644
--- a/lightly/transforms/solarize.py
+++ b/lightly/transforms/solarize.py
@@ -3,6 +3,7 @@
import numpy as np
from PIL import ImageOps
+from PIL.Image import Image as PILImage
class RandomSolarization(object):
@@ -22,7 +23,7 @@ def __init__(self, prob: float = 0.5, threshold: int = 128):
self.prob = prob
self.threshold = threshold
- def __call__(self, sample):
+ def __call__(self, sample: PILImage) -> PILImage:
"""Solarizes the given input image
Args:
diff --git a/lightly/transforms/swav_transform.py b/lightly/transforms/swav_transform.py
index 2cbabf94f..f7000f945 100644
--- a/lightly/transforms/swav_transform.py
+++ b/lightly/transforms/swav_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -96,7 +96,7 @@ def __init__(
gaussian_blur: float = 0.5,
kernel_size: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transforms = SwaVViewTransform(
hf_prob=hf_prob,
@@ -142,7 +142,7 @@ def __init__(
gaussian_blur: float = 0.5,
kernel_size: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/vicreg_transform.py b/lightly/transforms/vicreg_transform.py
index e2234f6c4..c7cc0b270 100644
--- a/lightly/transforms/vicreg_transform.py
+++ b/lightly/transforms/vicreg_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -97,7 +97,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
view_transform = VICRegViewTransform(
input_size=input_size,
@@ -142,7 +142,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/transforms/vicregl_transform.py b/lightly/transforms/vicregl_transform.py
index 70baf139d..335c03c34 100644
--- a/lightly/transforms/vicregl_transform.py
+++ b/lightly/transforms/vicregl_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torchvision.transforms as T
from PIL.Image import Image
@@ -123,7 +123,7 @@ def __init__(
cj_sat: float = 0.4,
cj_hue: float = 0.2,
random_gray_scale: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
global_transform = (
RandomResizedCropAndFlip(
@@ -189,7 +189,7 @@ def __init__(
cj_sat: float = 0.4,
cj_hue: float = 0.2,
random_gray_scale: float = 0.2,
- normalize: Union[None, dict] = IMAGENET_NORMALIZE,
+ normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
@@ -223,4 +223,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
Returns:
The transformed image.
"""
- return self.transform(image)
+ transformed: Tensor = self.transform(image)
+ return transformed
diff --git a/lightly/utils/benchmarking/linear_classifier.py b/lightly/utils/benchmarking/linear_classifier.py
index b47c94560..3fb76c6e6 100644
--- a/lightly/utils/benchmarking/linear_classifier.py
+++ b/lightly/utils/benchmarking/linear_classifier.py
@@ -1,9 +1,9 @@
-from typing import Dict, Tuple
+from typing import Any, Dict, List, Tuple, Union
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import CrossEntropyLoss, Linear, Module
-from torch.optim import SGD
+from torch.optim import SGD, Optimizer
from lightly.models.utils import activate_requires_grad, deactivate_requires_grad
from lightly.utils.benchmarking.topk import mean_topk_accuracy
@@ -94,9 +94,12 @@ def __init__(
def forward(self, images: Tensor) -> Tensor:
features = self.model.forward(images).flatten(start_dim=1)
- return self.classification_head(features)
+ output: Tensor = self.classification_head(features)
+ return output
- def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]:
+ def shared_step(
+ self, batch: Tuple[Tensor, ...], batch_idx: int
+ ) -> Tuple[Tensor, Dict[int, Tensor]]:
images, targets = batch[0], batch[1]
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
@@ -104,7 +107,7 @@ def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]:
topk = mean_topk_accuracy(predicted_labels, targets, k=self.topk)
return loss, topk
- def training_step(self, batch, batch_idx) -> Tensor:
+ def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx)
batch_size = len(batch[1])
log_dict = {f"train_top{k}": acc for k, acc in topk.items()}
@@ -114,7 +117,7 @@ def training_step(self, batch, batch_idx) -> Tensor:
self.log_dict(log_dict, sync_dist=True, batch_size=batch_size)
return loss
- def validation_step(self, batch, batch_idx) -> Tensor:
+ def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx)
batch_size = len(batch[1])
log_dict = {f"val_top{k}": acc for k, acc in topk.items()}
@@ -122,7 +125,9 @@ def validation_step(self, batch, batch_idx) -> Tensor:
self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size)
return loss
- def configure_optimizers(self):
+ def configure_optimizers(
+ self,
+ ) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]:
parameters = list(self.classification_head.parameters())
if not self.freeze_model:
parameters += self.model.parameters()
@@ -136,7 +141,7 @@ def configure_optimizers(self):
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
- max_epochs=self.trainer.estimated_stepping_batches,
+ max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py
index f632f3d2a..b635b2fbb 100644
--- a/lightly/utils/benchmarking/metric_callback.py
+++ b/lightly/utils/benchmarking/metric_callback.py
@@ -1,9 +1,11 @@
-from typing import Dict, List
+from typing import Dict, List, Union
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from torch import Tensor
+MetricValue = Union[Tensor, float]
+
class MetricCallback(Callback):
"""Callback that collects log metrics from the LightningModule and stores them after
@@ -39,7 +41,7 @@ class MetricCallback(Callback):
>>> max_val_acc = max(metric_callback.val_metrics["val_acc"])
"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self.train_metrics: Dict[str, List[float]] = {}
self.val_metrics: Dict[str, List[float]] = {}
@@ -58,7 +60,6 @@ def _append_metrics(
self, metrics_dict: Dict[str, List[float]], trainer: Trainer
) -> None:
for name, value in trainer.callback_metrics.items():
- # Only store scalar values.
if isinstance(value, float) or (
isinstance(value, Tensor) and value.numel() == 1
):
diff --git a/lightly/utils/bounding_box.py b/lightly/utils/bounding_box.py
index 7f22657a0..322b13b16 100644
--- a/lightly/utils/bounding_box.py
+++ b/lightly/utils/bounding_box.py
@@ -1,5 +1,6 @@
""" Bounding Box Utils """
+from __future__ import annotations
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
@@ -43,7 +44,7 @@ def __init__(
if clip_values:
- def clip_to_0_1(value):
+ def clip_to_0_1(value: float) -> float:
return min(1, max(0, value))
x0 = clip_to_0_1(x0)
@@ -75,7 +76,7 @@ def clip_to_0_1(value):
self.y1 = y1
@classmethod
- def from_x_y_w_h(cls, x: float, y: float, w: float, h: float):
+ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
"""Helper to convert from bounding box format with width and height.
Examples:
@@ -85,7 +86,9 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float):
return cls(x, y, x + w, y + h)
@classmethod
- def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float):
+ def from_yolo_label(
+ cls, x_center: float, y_center: float, w: float, h: float
+ ) -> BoundingBox:
"""Helper to convert from yolo label format
x_center, y_center, w, h --> x0, y0, x1, y1
@@ -102,16 +105,16 @@ def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float):
)
@property
- def width(self):
+ def width(self) -> float:
"""Returns the width of the bounding box relative to the image size."""
return self.x1 - self.x0
@property
- def height(self):
+ def height(self) -> float:
"""Returns the height of the bounding box relative to the image size."""
return self.y1 - self.y0
@property
- def area(self):
+ def area(self) -> float:
"""Returns the area of the bounding box relative to the area of the image."""
return self.width * self.height
diff --git a/lightly/utils/cropping/__init__.py b/lightly/utils/cropping/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightly/utils/dist.py b/lightly/utils/dist.py
index 64f9ed351..2c20750c6 100644
--- a/lightly/utils/dist.py
+++ b/lightly/utils/dist.py
@@ -68,3 +68,28 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor:
diag_mask = torch.zeros((n, n * world_size()), dtype=torch.bool)
diag_mask[(rows, cols)] = True
return diag_mask
+
+
+def rank_zero_only(fn):
+ """Decorator that only runs the function on the process with rank 0.
+
+ Example:
+ >>> @rank_zero_only
+ >>> def print_rank_zero(message: str):
+ >>> print(message)
+ >>>
+ >>> print_rank_zero("Hello from rank 0!")
+
+ """
+
+ def wrapped(*args, **kwargs):
+ if rank() == 0:
+ return fn(*args, **kwargs)
+
+ return wrapped
+
+
+@rank_zero_only
+def print_rank_zero(*args, **kwargs) -> None:
+ """Equivalent to print, but only runs on the process with rank 0."""
+ print(*args, **kwargs)
diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py
index 0b40a4bac..046f2d874 100644
--- a/lightly/utils/embeddings_2d.py
+++ b/lightly/utils/embeddings_2d.py
@@ -3,7 +3,12 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
+from __future__ import annotations
+
+from typing import Optional, Tuple
+
import numpy as np
+from numpy.typing import NDArray
class PCA(object):
@@ -18,11 +23,11 @@ class PCA(object):
def __init__(self, n_components: int = 2, eps: float = 1e-10):
self.n_components = n_components
- self.mean = None
- self.w = None
+ self.mean: Optional[NDArray[np.float32]] = None
+ self.w: Optional[NDArray[np.float32]] = None
self.eps = eps
- def fit(self, X: np.ndarray):
+ def fit(self, X: NDArray[np.float32]) -> PCA:
"""Fits PCA to data in X.
Args:
@@ -35,6 +40,7 @@ def fit(self, X: np.ndarray):
"""
X = X.astype(np.float32)
self.mean = X.mean(axis=0)
+ assert self.mean is not None
X = X - self.mean + self.eps
cov = np.cov(X.T) / X.shape[0]
v, w = np.linalg.eig(cov)
@@ -43,7 +49,7 @@ def fit(self, X: np.ndarray):
self.w = w
return self
- def transform(self, X: np.ndarray):
+ def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]:
"""Uses PCA to transform data in X.
Args:
@@ -53,13 +59,22 @@ def transform(self, X: np.ndarray):
Returns:
Numpy array of n x p datapoints where p <= d.
+ Raises:
+ ValueError: If PCA was not fitted before.
"""
+ if self.mean is None or self.w is None:
+ raise ValueError("PCA not fitted yet. Call fit() before transform().")
X = X.astype(np.float32)
X = X - self.mean + self.eps
- return X.dot(self.w)[:, : self.n_components]
+ transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components]
+ return np.asarray(transformed)
-def fit_pca(embeddings: np.ndarray, n_components: int = 2, fraction: float = None):
+def fit_pca(
+ embeddings: NDArray[np.float32],
+ n_components: int = 2,
+ fraction: Optional[float] = None,
+) -> PCA:
"""Fits PCA to randomly selected subset of embeddings.
For large datasets, it can be unfeasible to perform PCA on the whole data.
diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py
index fd2d7097b..44dc32aed 100644
--- a/lightly/utils/hipify.py
+++ b/lightly/utils/hipify.py
@@ -1,6 +1,6 @@
import copy
import warnings
-from typing import Type
+from typing import Type, Union
class bcolors:
@@ -14,15 +14,19 @@ class bcolors:
UNDERLINE = "\033[4m"
-def _custom_formatwarning(msg, *args, **kwargs):
+def _custom_formatwarning(
+ message: Union[str, Warning],
+ category: Type[Warning],
+ filename: str,
+ lineno: int,
+ line: Union[str, None] = None,
+) -> str:
# ignore everything except the message
- return f"{bcolors.WARNING}{msg}{bcolors.WARNING}\n"
+ return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n"
-def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning):
+def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None:
old_format = copy.copy(warnings.formatwarning)
-
warnings.formatwarning = _custom_formatwarning
warnings.warn(message, warning_class)
-
warnings.formatwarning = old_format
diff --git a/lightly/utils/io.py b/lightly/utils/io.py
index 00ca63a1c..74240bf6e 100644
--- a/lightly/utils/io.py
+++ b/lightly/utils/io.py
@@ -7,35 +7,13 @@
import json
import re
from itertools import compress
-from typing import Dict, List, Tuple
+from typing import Any, Dict, List, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
-INVALID_FILENAME_CHARACTERS = [","]
-
-def _is_valid_filename(filename: str) -> bool:
- """Returns False if the filename is misformatted."""
- for character in INVALID_FILENAME_CHARACTERS:
- if character in filename:
- return False
- return True
-
-
-def check_filenames(filenames: List[str]):
- """Raises an error if one of the filenames is misformatted
-
- Args:
- filenames:
- A list of string being filenames
-
- """
- invalid_filenames = [f for f in filenames if not _is_valid_filename(f)]
- if len(invalid_filenames) > 0:
- raise ValueError(f"Invalid filename(s): {invalid_filenames}")
-
-
-def check_embeddings(path: str, remove_additional_columns: bool = False):
+def check_embeddings(path: str, remove_additional_columns: bool = False) -> None:
"""Raises an error if the embeddings csv file has not the correct format
Use this check whenever you want to upload an embedding to the Lightly
@@ -118,8 +96,8 @@ def check_embeddings(path: str, remove_additional_columns: bool = False):
def save_embeddings(
- path: str, embeddings: np.ndarray, labels: List[int], filenames: List[str]
-):
+ path: str, embeddings: NDArray[np.float64], labels: List[int], filenames: List[str]
+) -> None:
"""Saves embeddings in a csv file in a Lightly compatible format.
Creates a csv file at the location specified by path and saves embeddings,
@@ -146,7 +124,6 @@ def save_embeddings(
>>> labels,
>>> filenames)
"""
- check_filenames(filenames)
n_embeddings = len(embeddings)
n_filenames = len(filenames)
@@ -167,7 +144,7 @@ def save_embeddings(
writer.writerow([filename] + list(embedding) + [str(label)])
-def load_embeddings(path: str):
+def load_embeddings(path: str) -> Tuple[NDArray[np.float64], List[int], List[str]]:
"""Loads embeddings from a csv file in a Lightly compatible format.
Args:
@@ -202,15 +179,13 @@ def load_embeddings(path: str):
# read embeddings
embeddings.append(row[1:-1])
- check_filenames(filenames)
-
- embeddings = np.array(embeddings).astype(np.float32)
- return embeddings, labels, filenames
+ embedding_array = np.array(embeddings).astype(np.float64)
+ return embedding_array, labels, filenames
def load_embeddings_as_dict(
path: str, embedding_name: str = "default", return_all: bool = False
-):
+) -> Union[Any, Tuple[Any, NDArray[np.float64], List[int], List[str]]]:
"""Loads embeddings from csv and store it in a dictionary for transfer.
Loads embeddings to a dictionary which can be serialized and sent to the
@@ -245,10 +220,13 @@ def load_embeddings_as_dict(
embeddings, labels, filenames = load_embeddings(path)
# build dictionary
- data = {"embeddingName": embedding_name, "embeddings": []}
- for embedding, filename, label in zip(embeddings, filenames, labels):
- item = {"fileName": filename, "value": embedding.tolist(), "label": label}
- data["embeddings"].append(item)
+ data = {
+ "embeddingName": embedding_name,
+ "embeddings": [
+ {"fileName": filename, "value": embedding.tolist(), "label": label}
+ for embedding, filename, label in zip(embeddings, filenames, labels)
+ ],
+ }
# return embeddings along with dictionary
if return_all:
@@ -258,7 +236,10 @@ def load_embeddings_as_dict(
class COCO_ANNOTATION_KEYS:
- """Enum of coco annotation keys complemented with a key for custom metadata."""
+ """Enum of coco annotation keys complemented with a key for custom metadata.
+
+ :meta private: # Skip docstring generation
+ """
# image keys
images: str = "images"
@@ -270,7 +251,9 @@ class COCO_ANNOTATION_KEYS:
custom_metadata_image_id: str = "image_id"
-def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]):
+def format_custom_metadata(
+ custom_metadata: List[Tuple[str, Any]]
+) -> Dict[str, List[Any]]:
"""Transforms custom metadata into a format which can be handled by Lightly.
Args:
@@ -292,8 +275,9 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]):
>>> > 'metadata': [{'image_id': 0, 'number_of_people': 1}, {'image_id': 1, 'number_of_people': 3}]
>>> > }
+ :meta private: # Skip docstring generation
"""
- formatted = {
+ formatted: Dict[str, List[Any]] = {
COCO_ANNOTATION_KEYS.images: [],
COCO_ANNOTATION_KEYS.custom_metadata: [],
}
@@ -315,7 +299,7 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]):
return formatted
-def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]):
+def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Any]]) -> None:
"""Saves custom metadata in a .json.
Args:
@@ -324,6 +308,7 @@ def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]):
custom_metadata:
List of tuples (filename, metadata) where metadata is a dictionary.
+ :meta private: # Skip docstring generation
"""
formatted = format_custom_metadata(custom_metadata)
with open(path, "w") as f:
@@ -333,7 +318,7 @@ def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]):
def save_tasks(
path: str,
tasks: List[str],
-):
+) -> None:
"""Saves a list of prediction task names in the right format.
Args:
@@ -347,7 +332,7 @@ def save_tasks(
json.dump(tasks, f)
-def save_schema(path: str, task_type: str, ids: List[int], names: List[str]):
+def save_schema(path: str, task_type: str, ids: List[int], names: List[str]) -> None:
"""Saves a prediction schema in the right format.
Args:
diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py
index af14ff028..063149d36 100644
--- a/lightly/utils/lars.py
+++ b/lightly/utils/lars.py
@@ -1,5 +1,8 @@
+from typing import Any, Callable, Dict, Optional, Union
+
import torch
-from torch.optim.optimizer import Optimizer, required
+from torch import Tensor
+from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined]
class LARS(Optimizer):
@@ -65,7 +68,7 @@ class LARS(Optimizer):
def __init__(
self,
- params,
+ params: Any,
lr: float = required,
momentum: float = 0,
dampening: float = 0,
@@ -95,14 +98,14 @@ def __init__(
super().__init__(params, defaults)
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
- def step(self, closure=None):
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Args:
diff --git a/lightly/utils/reordering.py b/lightly/utils/reordering.py
index 22311d8b6..baaf286d3 100644
--- a/lightly/utils/reordering.py
+++ b/lightly/utils/reordering.py
@@ -1,7 +1,12 @@
-from typing import List, Sized
+from typing import List, Sequence, TypeVar
+_K = TypeVar("_K")
+_V = TypeVar("_V")
-def sort_items_by_keys(keys: List[any], items: List[any], sorted_keys: List[any]):
+
+def sort_items_by_keys(
+ keys: Sequence[_K], items: Sequence[_V], sorted_keys: Sequence[_K]
+) -> List[_V]:
"""Sorts the items in the same order as the sorted keys.
Args:
diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py
index 0f200146b..a698e09ac 100644
--- a/lightly/utils/scheduler.py
+++ b/lightly/utils/scheduler.py
@@ -1,14 +1,19 @@
import warnings
+from typing import Optional
import numpy as np
import torch
def cosine_schedule(
- step: float, max_steps: float, start_value: float, end_value: float
+ step: int,
+ max_steps: int,
+ start_value: float,
+ end_value: float,
+ period: Optional[int] = None,
) -> float:
- """
- Use cosine decay to gradually modify start_value to reach target end_value during iterations.
+ """Use cosine decay to gradually modify start_value to reach target end_value during
+ iterations.
Args:
step:
@@ -19,6 +24,9 @@ def cosine_schedule(
Starting value.
end_value:
Target value.
+ period (optional):
+ The number of steps over which the cosine function completes a full cycle.
+ If not provided, it defaults to max_steps.
Returns:
Cosine decay value.
@@ -28,13 +36,21 @@ def cosine_schedule(
raise ValueError("Current step number can't be negative")
if max_steps < 1:
raise ValueError("Total step number must be >= 1")
- if step > max_steps:
+ if period is None and step > max_steps:
warnings.warn(
f"Current step number {step} exceeds max_steps {max_steps}.",
category=RuntimeWarning,
)
+ if period is not None and period <= 0:
+ raise ValueError("Period must be >= 1")
- if max_steps == 1:
+ decay: float
+ if period is not None: # "cycle" based on period, if provided
+ decay = (
+ end_value
+ - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
+ )
+ elif max_steps == 1:
# Avoid division by zero
decay = end_value
elif step == max_steps:
@@ -52,8 +68,7 @@ def cosine_schedule(
class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
- """
- Cosine warmup scheduler for learning rate.
+ """Cosine warmup scheduler for learning rate.
Args:
optimizer:
@@ -70,22 +85,28 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
Target learning rate scale. Default: 0.001
verbose:
If True, prints a message to stdout for each update. Default: False.
+
+ Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
+ can be used. The naming follows the Pytorch convention to use `epoch` for the steps
+ in the scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
- warmup_epochs: float,
- max_epochs: float,
- last_epoch: float = -1,
+ warmup_epochs: int,
+ max_epochs: int,
+ last_epoch: int = -1,
start_value: float = 1.0,
end_value: float = 0.001,
+ period: Optional[int] = None,
verbose: bool = False,
) -> None:
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.start_value = start_value
self.end_value = end_value
+ self.period = period
super().__init__(
optimizer=optimizer,
lr_lambda=self.scale_lr,
@@ -106,7 +127,15 @@ def scale_lr(self, epoch: int) -> float:
"""
if epoch < self.warmup_epochs:
- return (epoch + 1) / self.warmup_epochs
+ return self.start_value * (epoch + 1) / self.warmup_epochs
+ elif self.period is not None:
+ return cosine_schedule(
+ step=epoch - self.warmup_epochs,
+ max_steps=1,
+ start_value=self.start_value,
+ end_value=self.end_value,
+ period=self.period,
+ )
else:
return cosine_schedule(
step=epoch - self.warmup_epochs,
diff --git a/lightly/utils/version_compare.py b/lightly/utils/version_compare.py
index 898407cfb..05c7093f5 100644
--- a/lightly/utils/version_compare.py
+++ b/lightly/utils/version_compare.py
@@ -4,7 +4,7 @@
# All Rights Reserved
-def version_compare(v0: str, v1: str):
+def version_compare(v0: str, v1: str) -> int:
"""Returns 1 if version of v0 is larger than v1 and -1 otherwise
Use this method to compare Python package versions and see which one is
@@ -16,14 +16,14 @@ def version_compare(v0: str, v1: str):
>>> version_compare('1.2.0', '1.1.2')
>>> 1
"""
- v0 = [int(n) for n in v0.split(".")][::-1]
- v1 = [int(n) for n in v1.split(".")][::-1]
- if len(v0) != 3 or len(v1) != 3:
+ v0_parsed = [int(n) for n in v0.split(".")][::-1]
+ v1_parsed = [int(n) for n in v1.split(".")][::-1]
+ if len(v0_parsed) != 3 or len(v1_parsed) != 3:
raise ValueError(
f"Length of version strings is not 3 (expected pattern `x.y.z`) but is "
- f"{v0} and {v1}."
+ f"{v0_parsed} and {v1_parsed}."
)
- pairs = list(zip(v0, v1))[::-1]
+ pairs = list(zip(v0_parsed, v1_parsed))[::-1]
for x, y in pairs:
if x < y:
return -1
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 000000000..a4c18e97f
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,219 @@
+# Global options:
+
+[mypy]
+ignore_missing_imports = True
+python_version = 3.10
+warn_unused_configs = True
+strict_equality = True
+
+# Disallow dynamic typing
+disallow_any_decorated = True
+# TODO(Philipp, 09/23): Remove me!
+# disallow_any_explicit = True
+disallow_any_generics = True
+disallow_subclassing_any = True
+
+# Disallow untyped definitions
+disallow_untyped_defs = True
+disallow_incomplete_defs = True
+check_untyped_defs = True
+disallow_untyped_decorators = True
+
+# None and optional handling
+no_implicit_optional = True
+strict_optional = True
+
+# Configuring warnings
+warn_unused_ignores = True
+warn_no_return = True
+warn_return_any = True
+warn_redundant_casts = True
+warn_unreachable = True
+
+# Print format
+show_error_codes = True
+show_error_context = True
+
+# Plugins
+plugins = numpy.typing.mypy_plugin
+
+# Exludes
+# TODO(Philipp, 09/23): Remove these one by one (start with 300 files).
+exclude = (?x)(
+ lightly/cli/version_cli.py |
+ lightly/cli/crop_cli.py |
+ lightly/cli/serve_cli.py |
+ lightly/cli/embed_cli.py |
+ lightly/cli/lightly_cli.py |
+ lightly/cli/download_cli.py |
+ lightly/cli/config/get_config.py |
+ lightly/cli/train_cli.py |
+ lightly/cli/_cli_simclr.py |
+ lightly/cli/_helpers.py |
+ lightly/loss/ntx_ent_loss.py |
+ lightly/loss/vicreg_loss.py |
+ lightly/loss/tico_loss.py |
+ lightly/loss/pmsn_loss.py |
+ lightly/loss/swav_loss.py |
+ lightly/loss/negative_cosine_similarity.py |
+ lightly/loss/hypersphere_loss.py |
+ lightly/loss/msn_loss.py |
+ lightly/loss/dino_loss.py |
+ lightly/loss/sym_neg_cos_sim_loss.py |
+ lightly/loss/vicregl_loss.py |
+ lightly/loss/dcl_loss.py |
+ lightly/loss/regularizer/co2.py |
+ lightly/loss/barlow_twins_loss.py |
+ lightly/data/lightly_subset.py |
+ lightly/data/dataset.py |
+ lightly/data/collate.py |
+ lightly/data/_image.py |
+ lightly/data/_helpers.py |
+ lightly/data/_image_loaders.py |
+ lightly/data/_video.py |
+ lightly/data/_utils.py |
+ lightly/data/multi_view_collate.py |
+ lightly/core.py |
+ lightly/api/api_workflow_compute_worker.py |
+ lightly/api/api_workflow_predictions.py |
+ lightly/api/download.py |
+ lightly/api/api_workflow_export.py |
+ lightly/api/api_workflow_download_dataset.py |
+ lightly/api/bitmask.py |
+ lightly/api/_version_checking.py |
+ lightly/api/serve.py |
+ lightly/api/patch.py |
+ lightly/api/swagger_api_client.py |
+ lightly/api/api_workflow_collaboration.py |
+ lightly/api/utils.py |
+ lightly/api/api_workflow_datasets.py |
+ lightly/api/api_workflow_selection.py |
+ lightly/api/swagger_rest_client.py |
+ lightly/api/api_workflow_datasources.py |
+ lightly/api/api_workflow_upload_embeddings.py |
+ lightly/api/api_workflow_client.py |
+ lightly/api/api_workflow_upload_metadata.py |
+ lightly/api/api_workflow_tags.py |
+ lightly/api/api_workflow_artifacts.py |
+ lightly/utils/cropping/crop_image_by_bounding_boxes.py |
+ lightly/utils/cropping/read_yolo_label_file.py |
+ lightly/utils/debug.py |
+ lightly/utils/dist.py |
+ lightly/utils/benchmarking/benchmark_module.py |
+ lightly/utils/benchmarking/knn_classifier.py |
+ lightly/utils/benchmarking/online_linear_classifier.py |
+ lightly/models/modules/masked_autoencoder.py |
+ lightly/models/modules/ijepa.py |
+ lightly/models/utils.py |
+ tests/cli/test_cli_version.py |
+ tests/cli/test_cli_magic.py |
+ tests/cli/test_cli_crop.py |
+ tests/cli/test_cli_download.py |
+ tests/cli/test_cli_train.py |
+ tests/cli/test_cli_get_lighty_config.py |
+ tests/cli/test_cli_embed.py |
+ tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py |
+ tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py |
+ tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py |
+ tests/loss/test_NegativeCosineSimilarity.py |
+ tests/loss/test_MSNLoss.py |
+ tests/loss/test_DINOLoss.py |
+ tests/loss/test_VICRegLLoss.py |
+ tests/loss/test_CO2Regularizer.py |
+ tests/loss/test_DCLLoss.py |
+ tests/loss/test_barlow_twins_loss.py |
+ tests/loss/test_SymNegCosineSimilarityLoss.py |
+ tests/loss/test_NTXentLoss.py |
+ tests/loss/test_MemoryBank.py |
+ tests/loss/test_TicoLoss.py |
+ tests/loss/test_VICRegLoss.py |
+ tests/loss/test_PMSNLoss.py |
+ tests/loss/test_HyperSphere.py |
+ tests/loss/test_SwaVLoss.py |
+ tests/core/test_Core.py |
+ tests/data/test_multi_view_collate.py |
+ tests/data/test_data_collate.py |
+ tests/data/test_VideoDataset.py |
+ tests/data/test_LightlySubset.py |
+ tests/data/test_LightlyDataset.py |
+ tests/embedding/test_callbacks.py |
+ tests/embedding/test_embedding.py |
+ tests/api/test_serve.py |
+ tests/api/test_swagger_rest_client.py |
+ tests/api/test_rest_parser.py |
+ tests/api/test_utils.py |
+ tests/api/benchmark_video_download.py |
+ tests/api/test_BitMask.py |
+ tests/api/test_patch.py |
+ tests/api/test_download.py |
+ tests/api/test_version_checking.py |
+ tests/api/test_swagger_api_client.py |
+ tests/utils/test_debug.py |
+ tests/utils/benchmarking/test_benchmark_module.py |
+ tests/utils/benchmarking/test_topk.py |
+ tests/utils/benchmarking/test_online_linear_classifier.py |
+ tests/utils/benchmarking/test_knn_classifier.py |
+ tests/utils/benchmarking/test_knn.py |
+ tests/utils/benchmarking/test_linear_classifier.py |
+ tests/utils/benchmarking/test_metric_callback.py |
+ tests/utils/test_dist.py |
+ tests/models/test_ModelsSimSiam.py |
+ tests/models/modules/test_masked_autoencoder.py |
+ tests/models/test_ModelsSimCLR.py |
+ tests/models/test_ModelUtils.py |
+ tests/models/test_ModelsNNCLR.py |
+ tests/models/test_ModelsMoCo.py |
+ tests/models/test_ProjectionHeads.py |
+ tests/models/test_ModelsBYOL.py |
+ tests/conftest.py |
+ tests/api_workflow/test_api_workflow_selection.py |
+ tests/api_workflow/test_api_workflow_datasets.py |
+ tests/api_workflow/mocked_api_workflow_client.py |
+ tests/api_workflow/test_api_workflow_compute_worker.py |
+ tests/api_workflow/test_api_workflow_artifacts.py |
+ tests/api_workflow/test_api_workflow_download_dataset.py |
+ tests/api_workflow/utils.py |
+ tests/api_workflow/test_api_workflow_client.py |
+ tests/api_workflow/test_api_workflow_export.py |
+ tests/api_workflow/test_api_workflow_datasources.py |
+ tests/api_workflow/test_api_workflow_tags.py |
+ tests/api_workflow/test_api_workflow_upload_custom_metadata.py |
+ tests/api_workflow/test_api_workflow_upload_embeddings.py |
+ tests/api_workflow/test_api_workflow_collaboration.py |
+ tests/api_workflow/test_api_workflow_predictions.py |
+ tests/api_workflow/test_api_workflow.py |
+ # Let's not type check deprecated active learning:
+ lightly/active_learning |
+ # Let's not type deprecated models:
+ lightly/models/simclr.py |
+ lightly/models/moco.py |
+ lightly/models/barlowtwins.py |
+ lightly/models/nnclr.py |
+ lightly/models/simsiam.py |
+ lightly/models/byol.py )
+
+# Ignore imports from untyped modules.
+[mypy-lightly.api.*]
+follow_imports = skip
+
+[mypy-lightly.cli.*]
+follow_imports = skip
+
+[mypy-lightly.data.*]
+follow_imports = skip
+
+[mypy-lightly.loss.*]
+follow_imports = skip
+
+[mypy-lightly.models.*]
+follow_imports = skip
+
+[mypy-lightly.utils.benchmarking.*]
+follow_imports = skip
+
+[mypy-tests.api_workflow.*]
+follow_imports = skip
+
+# Ignore errors in auto generated code.
+[mypy-lightly.openapi_generated.*]
+ignore_errors = True
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 18cdd7197..904100d3d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,6 @@
+# pyproject.toml is currently only used to configure developement tools.
+# Configurations for the lightly package are in setup.py.
+
[tool.black]
extend-exclude = "lightly/openapi_generated/.*"
diff --git a/requirements/dev.txt b/requirements/dev.txt
index dc74efe15..d5bc31f7b 100644
--- a/requirements/dev.txt
+++ b/requirements/dev.txt
@@ -22,3 +22,4 @@ torchmetrics
lightning-bolts # for LARS optimizer
black==23.1.0 # frozen version to avoid differences between CI and local dev machines
isort==5.11.5 # frozen version to avoid differences between CI and local dev machines
+mypy==1.4.1 # frozen version to avoid differences between CI and local dev machines
diff --git a/requirements/openapi.txt b/requirements/openapi.txt
index d52723c7a..74ede174a 100644
--- a/requirements/openapi.txt
+++ b/requirements/openapi.txt
@@ -1,4 +1,5 @@
python_dateutil >= 2.5.3
+setuptools >= 21.0.0
urllib3 >= 1.25.3
pydantic >= 1.10.5, < 2
aenum >= 3.1.11
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index a684c9b94..000000000
--- a/setup.cfg
+++ /dev/null
@@ -1,3 +0,0 @@
-[metadata]
-# This includes the license file(s) in the wheel.
-license_files = LICENSE.txt
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 94d96ef1f..7caad8ce9 100644
--- a/setup.py
+++ b/setup.py
@@ -1,36 +1,19 @@
import os
-import sys
+from pathlib import Path
+from typing import List
import setuptools
-try:
- import builtins
-except ImportError:
- import __builtin__ as builtins
+_PATH_ROOT = Path(os.path.dirname(__file__))
-PATH_ROOT = PATH_ROOT = os.path.dirname(__file__)
-builtins.__LIGHTLY_SETUP__ = True
-import lightly
-
-
-def load_description(path_dir=PATH_ROOT, filename="DOCS.md"):
- """Load long description from readme in the path_dir/ directory"""
- with open(os.path.join(path_dir, filename)) as f:
- long_description = f.read()
- return long_description
-
-
-def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#"):
- """From pytorch-lightning repo: https://github.com/PyTorchLightning/pytorch-lightning.
- Load requirements from text file in the path_dir/requirements/ directory.
-
- """
- with open(os.path.join(path_dir, "requirements", filename), "r") as file:
+def load_requirements(filename: str, comment_char: str = "#") -> List[str]:
+ """Load requirements from text file in the requirements directory."""
+ with (_PATH_ROOT / "requirements" / filename).open() as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
- # filer all comments
+ # filter all comments
if comment_char in ln:
ln = ln[: ln.index(comment_char)].strip()
# skip directly installed dependencies
@@ -41,28 +24,44 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#")
return reqs
+def load_version() -> str:
+ """Load version from the lightly/__init__.py file.
+
+ Note: We do not want to get the version by accessing `lightly.__version__` because
+ it would require importing `lightly`. Importing `lightly` in setup.py breaks the
+ installation process as the import has side effects and requires dependencies to be
+ installed. As dependencies are not yet available during installation, the `lightly`
+ import fails.
+ """
+ version_filepath = _PATH_ROOT / "lightly" / "__init__.py"
+ with version_filepath.open() as file:
+ for line in file.readlines():
+ if line.startswith("__version__"):
+ version = line.split("=")[-1].strip().strip('"')
+ return version
+ raise RuntimeError("Unable to find version string in '{version_filepath}'.")
+
+
if __name__ == "__main__":
name = "lightly"
- version = lightly.__version__
- description = lightly.__doc__
-
- author = "Philipp Wirth & Igor Susmelj"
- author_email = "philipp@lightly.ai"
+ version = load_version()
+ author = "Lightly Team"
+ author_email = "team@lightly.ai"
description = "A deep learning package for self-supervised learning"
+ long_description = (_PATH_ROOT / "README.md").read_text()
entry_points = {
"console_scripts": [
"lightly-crop = lightly.cli.crop_cli:entry",
- "lightly-train = lightly.cli.train_cli:entry",
+ "lightly-download = lightly.cli.download_cli:entry",
"lightly-embed = lightly.cli.embed_cli:entry",
"lightly-magic = lightly.cli.lightly_cli:entry",
- "lightly-download = lightly.cli.download_cli:entry",
+ "lightly-serve = lightly.cli.serve_cli:entry",
+ "lightly-train = lightly.cli.train_cli:entry",
"lightly-version = lightly.cli.version_cli:entry",
]
}
- long_description = load_description()
-
python_requires = ">=3.6"
base_requires = load_requirements(filename="base.txt")
openapi_requires = load_requirements(filename="openapi.txt")
@@ -78,28 +77,7 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#")
"all": dev_requires + video_requires,
}
- packages = [
- "lightly",
- "lightly.api",
- "lightly.cli",
- "lightly.cli.config",
- "lightly.data",
- "lightly.embedding",
- "lightly.loss",
- "lightly.loss.regularizer",
- "lightly.models",
- "lightly.models.modules",
- "lightly.transforms",
- "lightly.utils",
- "lightly.utils.benchmarking",
- "lightly.utils.cropping",
- "lightly.active_learning",
- "lightly.active_learning.config",
- "lightly.openapi_generated",
- "lightly.openapi_generated.swagger_client",
- "lightly.openapi_generated.swagger_client.api",
- "lightly.openapi_generated.swagger_client.models",
- ]
+ packages = setuptools.find_packages(include=["lightly*"])
project_urls = {
"Homepage": "https://www.lightly.ai",
@@ -115,8 +93,9 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#")
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Image Processing",
+ "Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
@@ -125,6 +104,8 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#")
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
"License :: OSI Approved :: MIT License",
]
@@ -136,6 +117,7 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#")
description=description,
entry_points=entry_points,
license="MIT",
+ license_files=["LICENSE.txt"],
long_description=long_description,
long_description_content_type="text/markdown",
setup_requires=setup_requires,
diff --git a/tests/api/test_serve.py b/tests/api/test_serve.py
new file mode 100644
index 000000000..1d056199d
--- /dev/null
+++ b/tests/api/test_serve.py
@@ -0,0 +1,18 @@
+from pathlib import Path
+
+from lightly.api import serve
+
+
+def test__translate_path(tmp_path: Path) -> None:
+ tmp_file = tmp_path / "hello/world.txt"
+ assert serve._translate_path(path="/hello/world.txt", directories=[]) == ""
+ assert serve._translate_path(path="/hello/world.txt", directories=[tmp_path]) == ""
+ tmp_file.mkdir(parents=True, exist_ok=True)
+ tmp_file.touch()
+ assert serve._translate_path(
+ path="/hello/world.txt", directories=[tmp_path]
+ ) == str(tmp_file)
+ assert serve._translate_path(
+ path="/world.txt",
+ directories=[tmp_path / "hi", tmp_path / "hello"],
+ ) == str(tmp_file)
diff --git a/tests/api/test_version_checking.py b/tests/api/test_version_checking.py
index df885c785..5a72d8344 100644
--- a/tests/api/test_version_checking.py
+++ b/tests/api/test_version_checking.py
@@ -1,75 +1,87 @@
-import sys
+import os
import time
-import unittest
-
-import lightly
-from lightly.api.version_checking import (
- LightlyAPITimeoutException,
- get_latest_version,
- get_minimum_compatible_version,
- is_compatible_version,
- is_latest_version,
- pretty_print_latest_version,
-)
-from tests.api_workflow.mocked_api_workflow_client import MockedVersioningApi
-
-
-class TestVersionChecking(unittest.TestCase):
- def setUp(self) -> None:
- lightly.api.version_checking.VersioningApi = MockedVersioningApi
-
- def test_get_latest_version(self):
- get_latest_version("1.2.3")
-
- def test_get_minimum_compatible_version(self):
- get_minimum_compatible_version()
-
- def test_is_latest_version(self) -> None:
- assert is_latest_version("1.2.8")
- assert not is_latest_version("1.2.7")
- assert not is_latest_version("1.1.8")
- assert not is_latest_version("0.2.8")
-
- def test_is_compatible_version(self) -> None:
- assert is_compatible_version("1.2.1")
- assert not is_compatible_version("1.2.0")
- assert not is_compatible_version("1.1.9")
- assert not is_compatible_version("0.2.1")
-
- def test_pretty_print(self):
- pretty_print_latest_version(current_version="curr", latest_version="1.1.1")
-
- def test_version_check_timout_mocked(self):
- """
- We cannot check for other errors as we don't know whether the
- current LIGHTLY_SERVER_URL is
- - unreachable (error in < 1 second)
- - causing a timeout and thus raising a LightlyAPITimeoutException
- - reachable (success in < 1 second
-
- Thus this only checks that the actual lightly.do_version_check()
- with needing >1s internally causes a LightlyAPITimeoutException
- """
- try:
- old_get_versioning_api = lightly.api.version_checking.get_versioning_api
-
- def mocked_get_versioning_api_timeout():
- time.sleep(10)
- print("This line should never be reached, calling sys.exit()")
- sys.exit()
-
- lightly.api.version_checking.get_versioning_api = (
- mocked_get_versioning_api_timeout
- )
-
- start_time = time.time()
-
- with self.assertRaises(LightlyAPITimeoutException):
- is_latest_version(lightly.__version__)
-
- duration = time.time() - start_time
-
- self.assertLess(duration, 1.5)
-
- finally:
- lightly.api.version_checking.get_versioning_api = old_get_versioning_api
+
+import pytest
+from pytest_mock import MockerFixture
+from urllib3.exceptions import MaxRetryError
+
+from lightly.api import _version_checking
+from lightly.openapi_generated.swagger_client.api import VersioningApi
+
+
+# Overwrite the mock_versioning_api fixture from conftest.py that is applied by default
+# for all tests as we want to test the functionality of the versioning api.
+@pytest.fixture(autouse=True)
+def mock_versioning_api():
+ return
+
+
+@pytest.mark.disable_mock_versioning_api
+def test_is_latest_version(mocker: MockerFixture) -> None:
+ mocker.patch.object(
+ _version_checking.VersioningApi, "get_latest_pip_version", return_value="1.2.8"
+ )
+ assert _version_checking.is_latest_version("1.2.8")
+ assert not _version_checking.is_latest_version("1.2.7")
+ assert not _version_checking.is_latest_version("1.1.8")
+ assert not _version_checking.is_latest_version("0.2.8")
+
+
+def test_is_compatible_version(mocker: MockerFixture) -> None:
+ mocker.patch.object(
+ _version_checking.VersioningApi,
+ "get_minimum_compatible_pip_version",
+ return_value="1.2.8",
+ )
+ assert _version_checking.is_compatible_version("1.2.8")
+ assert not _version_checking.is_compatible_version("1.2.7")
+ assert not _version_checking.is_compatible_version("1.1.8")
+ assert not _version_checking.is_compatible_version("0.2.8")
+
+
+def test_get_latest_version(mocker: MockerFixture) -> None:
+ mocker.patch.object(
+ _version_checking.VersioningApi, "get_latest_pip_version", return_value="1.2.8"
+ )
+ assert _version_checking.get_latest_version("1.2.8") == "1.2.8"
+
+
+def test_get_latest_version__timeout(mocker: MockerFixture) -> None:
+ mocker.patch.dict(os.environ, {"LIGHTLY_SERVER_LOCATION": "invalid-url"})
+ start = time.perf_counter()
+ with pytest.raises(MaxRetryError):
+ # Urllib3 raises a timeout error (connection refused) for invalid URLs.
+ _version_checking.get_latest_version("1.2.8", timeout_sec=0.1)
+ end = time.perf_counter()
+ assert end - start < 0.2 # Give some slack for timeout.
+
+
+def test_get_minimum_compatible_version(mocker: MockerFixture) -> None:
+ mocker.patch.object(
+ _version_checking.VersioningApi,
+ "get_minimum_compatible_pip_version",
+ return_value="1.2.8",
+ )
+
+ assert _version_checking.get_minimum_compatible_version() == "1.2.8"
+
+
+def test_get_minimum_compatible_version__timeout(mocker: MockerFixture) -> None:
+ mocker.patch.dict(os.environ, {"LIGHTLY_SERVER_LOCATION": "invalid-url"})
+ start = time.perf_counter()
+ with pytest.raises(MaxRetryError):
+ # Urllib3 raises a timeout error (connection refused) for invalid URLs.
+ _version_checking.get_minimum_compatible_version(timeout_sec=0.1)
+ end = time.perf_counter()
+ assert end - start < 0.2 # Give some slack for timeout.
+
+
+def test_check_is_latest_version_in_background(mocker: MockerFixture) -> None:
+ spy_is_latest_version = mocker.spy(_version_checking, "is_latest_version")
+ _version_checking.check_is_latest_version_in_background("1.2.8")
+ time.sleep(0.1) # Wait for thread to run.
+ spy_is_latest_version.assert_called_once_with(current_version="1.2.8")
+
+
+def test__get_versioning_api() -> None:
+ assert isinstance(_version_checking._get_versioning_api(), VersioningApi)
diff --git a/tests/api_workflow/mocked_api_workflow_client.py b/tests/api_workflow/mocked_api_workflow_client.py
index f61d9d75e..700f1522a 100644
--- a/tests/api_workflow/mocked_api_workflow_client.py
+++ b/tests/api_workflow/mocked_api_workflow_client.py
@@ -36,7 +36,9 @@
DatasetData,
DatasetEmbeddingData,
DatasourceConfig,
+ DatasourceConfigAzure,
DatasourceConfigBase,
+ DatasourceConfigLOCAL,
DatasourceProcessedUntilTimestampRequest,
DatasourceProcessedUntilTimestampResponse,
DatasourceRawSamplesData,
@@ -657,11 +659,14 @@ def __init__(self, api_client=None):
self.reset()
def reset(self):
- local_datasource = DatasourceConfigBase(
- type="LOCAL", full_path="", purpose="INPUT_OUTPUT"
+ local_datasource = DatasourceConfigLOCAL(
+ type="LOCAL",
+ full_path="",
+ web_server_location="https://localhost:1234",
+ purpose="INPUT_OUTPUT",
).to_dict()
azure_datasource = DatasourceConfigBase(
- type="AZURE", full_path="", purpose="INPUT_OUTPUT"
+ type="AZURE", purpose="INPUT_OUTPUT"
).to_dict()
self._datasources = {
@@ -982,14 +987,6 @@ def update_scheduled_docker_run_state_by_id(
raise NotImplementedError()
-class MockedVersioningApi(VersioningApi):
- def get_latest_pip_version(self, **kwargs):
- return "1.2.8"
-
- def get_minimum_compatible_pip_version(self, **kwargs):
- return "1.2.1"
-
-
class MockedQuotaApi(QuotaApi):
def get_quota_maximum_dataset_size(self, **kwargs):
return "60000"
@@ -1032,7 +1029,6 @@ class MockedApiWorkflowClient(ApiWorkflowClient):
n_embedding_rows_on_server = N_FILES_ON_SERVER
def __init__(self, *args, **kwargs):
- lightly.api.version_checking.VersioningApi = MockedVersioningApi
ApiWorkflowClient.__init__(self, *args, **kwargs)
self._selection_api = MockedSamplingsApi(api_client=self.api_client)
diff --git a/tests/api_workflow/test_api_workflow_client.py b/tests/api_workflow/test_api_workflow_client.py
index 48920077e..7ed256d19 100644
--- a/tests/api_workflow/test_api_workflow_client.py
+++ b/tests/api_workflow/test_api_workflow_client.py
@@ -85,11 +85,6 @@ def raise_connection_error(*args, **kwargs):
def test_user_agent_header(mocker: MockerFixture) -> None:
mocker.patch.object(lightly.api.api_workflow_client, "__version__", new="VERSION")
- mocker.patch.object(
- lightly.api.api_workflow_client.version_checking,
- "is_compatible_version",
- new=lambda _: True,
- )
mocked_platform = mocker.patch.object(
lightly.api.api_workflow_client, "platform", spec_set=platform
)
diff --git a/tests/api_workflow/test_api_workflow_compute_worker.py b/tests/api_workflow/test_api_workflow_compute_worker.py
index 09efdd8a2..b0eb80625 100644
--- a/tests/api_workflow/test_api_workflow_compute_worker.py
+++ b/tests/api_workflow/test_api_workflow_compute_worker.py
@@ -33,14 +33,14 @@
DockerWorkerConfigV3LightlyLoader,
DockerWorkerState,
DockerWorkerType,
- SelectionConfig,
- SelectionConfigEntry,
- SelectionConfigEntryInput,
- SelectionConfigEntryStrategy,
+ SelectionConfigV3,
+ SelectionConfigV3Entry,
+ SelectionConfigV3EntryInput,
+ SelectionConfigV3EntryStrategy,
SelectionInputPredictionsName,
SelectionInputType,
SelectionStrategyThresholdOperation,
- SelectionStrategyType,
+ SelectionStrategyTypeV3,
TagData,
)
from lightly.openapi_generated.swagger_client.rest import ApiException
@@ -101,17 +101,17 @@ def test_create_compute_worker_config__selection_config_is_class(self) -> None:
"batch_size": 64,
},
},
- selection_config=SelectionConfig(
+ selection_config=SelectionConfigV3(
n_samples=20,
strategies=[
- SelectionConfigEntry(
- input=SelectionConfigEntryInput(
+ SelectionConfigV3Entry(
+ input=SelectionConfigV3EntryInput(
type=SelectionInputType.EMBEDDINGS,
dataset_id=utils.generate_id(),
tag_name="some-tag-name",
),
- strategy=SelectionConfigEntryStrategy(
- type=SelectionStrategyType.SIMILARITY,
+ strategy=SelectionConfigV3EntryStrategy(
+ type=SelectionStrategyTypeV3.SIMILARITY,
),
)
],
@@ -203,44 +203,46 @@ def _check_if_openapi_generated_obj_is_valid(self, obj) -> Any:
return obj_api
def test_selection_config(self):
- selection_config = SelectionConfig(
+ selection_config = SelectionConfigV3(
n_samples=1,
strategies=[
- SelectionConfigEntry(
- input=SelectionConfigEntryInput(type=SelectionInputType.EMBEDDINGS),
- strategy=SelectionConfigEntryStrategy(
- type=SelectionStrategyType.DIVERSITY,
+ SelectionConfigV3Entry(
+ input=SelectionConfigV3EntryInput(
+ type=SelectionInputType.EMBEDDINGS
+ ),
+ strategy=SelectionConfigV3EntryStrategy(
+ type=SelectionStrategyTypeV3.DIVERSITY,
stopping_condition_minimum_distance=-1,
),
),
- SelectionConfigEntry(
- input=SelectionConfigEntryInput(
+ SelectionConfigV3Entry(
+ input=SelectionConfigV3EntryInput(
type=SelectionInputType.SCORES,
task="my-classification-task",
score="uncertainty_margin",
),
- strategy=SelectionConfigEntryStrategy(
- type=SelectionStrategyType.WEIGHTS
+ strategy=SelectionConfigV3EntryStrategy(
+ type=SelectionStrategyTypeV3.WEIGHTS
),
),
- SelectionConfigEntry(
- input=SelectionConfigEntryInput(
+ SelectionConfigV3Entry(
+ input=SelectionConfigV3EntryInput(
type=SelectionInputType.METADATA, key="lightly.sharpness"
),
- strategy=SelectionConfigEntryStrategy(
- type=SelectionStrategyType.THRESHOLD,
+ strategy=SelectionConfigV3EntryStrategy(
+ type=SelectionStrategyTypeV3.THRESHOLD,
threshold=20,
operation=SelectionStrategyThresholdOperation.BIGGER_EQUAL,
),
),
- SelectionConfigEntry(
- input=SelectionConfigEntryInput(
+ SelectionConfigV3Entry(
+ input=SelectionConfigV3EntryInput(
type=SelectionInputType.PREDICTIONS,
task="my_object_detection_task",
name=SelectionInputPredictionsName.CLASS_DISTRIBUTION,
),
- strategy=SelectionConfigEntryStrategy(
- type=SelectionStrategyType.BALANCE,
+ strategy=SelectionConfigV3EntryStrategy(
+ type=SelectionStrategyTypeV3.BALANCE,
target={"Ambulance": 0.2, "Bus": 0.4},
),
),
diff --git a/tests/api_workflow/test_api_workflow_datasources.py b/tests/api_workflow/test_api_workflow_datasources.py
index b83910975..cbaa46bb7 100644
--- a/tests/api_workflow/test_api_workflow_datasources.py
+++ b/tests/api_workflow/test_api_workflow_datasources.py
@@ -1,13 +1,15 @@
import pytest
+import tqdm
from pytest_mock import MockerFixture
-from lightly.api import ApiWorkflowClient
+from lightly.api import ApiWorkflowClient, api_workflow_datasources
from lightly.openapi_generated.swagger_client.models import (
DatasourceConfigAzure,
DatasourceConfigGCS,
DatasourceConfigLOCAL,
DatasourceConfigS3,
DatasourceConfigS3DelegatedAccess,
+ DatasourcePurpose,
DatasourceRawSamplesDataRow,
)
from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data import (
@@ -16,327 +18,627 @@
from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data_errors import (
DatasourceConfigVerifyDataErrors,
)
+from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import (
+ DatasourceProcessedUntilTimestampResponse,
+)
+from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import (
+ DatasourceRawSamplesData,
+)
-def test__download_raw_files(mocker: MockerFixture) -> None:
- mock_response_1 = mocker.MagicMock()
- mock_response_1.has_more = True
- mock_response_1.data = [
- DatasourceRawSamplesDataRow(file_name="/file1", read_url="url1"),
- DatasourceRawSamplesDataRow(file_name="file2", read_url="url2"),
- ]
-
- mock_response_2 = mocker.MagicMock()
- mock_response_2.has_more = False
- mock_response_2.data = [
- DatasourceRawSamplesDataRow(file_name="./file3", read_url="url3"),
- DatasourceRawSamplesDataRow(file_name="file2", read_url="url2"),
- ]
-
- mocked_method = mocker.MagicMock(side_effect=[mock_response_1, mock_response_2])
- mocked_pbar = mocker.MagicMock()
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_warning = mocker.patch("warnings.warn")
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- result = client._download_raw_files(
- download_function=mocked_method,
- progress_bar=mocked_pbar,
- )
- kwargs = mocked_method.call_args[1]
- assert "relevant_filenames_file_name" not in kwargs
- assert mocked_pbar.update.call_count == 2
- assert mocked_warning.call_count == 3
- warning_text = [str(call_args[0][0]) for call_args in mocked_warning.call_args_list]
- assert warning_text == [
- (
- "Absolute file paths like /file1 are not supported"
- " in relevant filenames file None due to blob storage"
- ),
- (
- "Using dot notation ('./', '../') like in ./file3 is not supported"
- " in relevant filenames file None due to blob storage"
- ),
- ("Duplicate filename file2 in relevant filenames file None"),
- ]
- assert len(result) == 1
- assert result[0][0] == "file2"
-
-
-def test_get_prediction_read_url(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- client._datasources_api = mocked_api
- client.get_prediction_read_url("test.json")
- mocked_method = (
- mocked_api.get_prediction_file_read_url_from_datasource_by_dataset_id
- )
- mocked_method.assert_called_once_with(
- dataset_id="dataset-id", file_name="test.json"
- )
+class TestDatasourcesMixin:
+ def test_download_raw_samples(self, mocker: MockerFixture) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_from_datasource_by_dataset_id",
+ side_effect=[response],
+ )
+ assert client.download_raw_samples() == [("file1", "url1"), ("file2", "url2")]
+
+ def test_download_raw_predictions(self, mocker: MockerFixture) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
+ side_effect=[response],
+ )
+ assert client.download_raw_predictions(task_name="task") == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ]
+
+ def test_download_raw_predictions_iter(self, mocker: MockerFixture) -> None:
+ response_1 = DatasourceRawSamplesData(
+ hasMore=True,
+ cursor="cursor1",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ response_2 = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="cursor2",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"),
+ DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
+ side_effect=[response_1, response_2],
+ )
+ assert list(client.download_raw_predictions_iter(task_name="task")) == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ("file3", "url3"),
+ ("file4", "url4"),
+ ]
+ client._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id.assert_has_calls(
+ [
+ mocker.call(
+ dataset_id="dataset-id",
+ task_name="task",
+ var_from=0,
+ to=mocker.ANY,
+ use_redirected_read_url=False,
+ ),
+ mocker.call(
+ dataset_id="dataset-id",
+ task_name="task",
+ cursor="cursor1",
+ use_redirected_read_url=False,
+ ),
+ ]
+ )
+ def test_download_raw_predictions_iter__relevant_filenames_artifact_id(
+ self,
+ mocker: MockerFixture,
+ ) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id",
+ side_effect=[response],
+ )
+ assert list(
+ client.download_raw_predictions_iter(
+ task_name="task",
+ run_id="run-id",
+ relevant_filenames_artifact_id="relevant-filenames",
+ )
+ ) == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ]
+ client._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id.assert_called_once_with(
+ dataset_id="dataset-id",
+ task_name="task",
+ var_from=0,
+ to=mocker.ANY,
+ relevant_filenames_run_id="run-id",
+ relevant_filenames_artifact_id="relevant-filenames",
+ use_redirected_read_url=False,
+ )
-def test_download_new_raw_samples(mocker: MockerFixture) -> None:
- from_timestamp = 2
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocker.patch.object(
- ApiWorkflowClient, "get_processed_until_timestamp", return_value=from_timestamp
- )
- current_time = 5
- mocker.patch("time.time", return_value=current_time)
- mocked_download = mocker.patch.object(ApiWorkflowClient, "download_raw_samples")
- mocked_update_timestamp = mocker.patch.object(
- ApiWorkflowClient, "update_processed_until_timestamp"
- )
- client = ApiWorkflowClient()
- client.download_new_raw_samples()
- mocked_download.assert_called_once_with(
- from_=from_timestamp + 1,
- to=current_time,
- relevant_filenames_file_name=None,
- use_redirected_read_url=False,
- )
- mocked_update_timestamp.assert_called_once_with(timestamp=current_time)
+ # should raise ValueError when only run_id is given
+ with pytest.raises(ValueError):
+ next(
+ client.download_raw_predictions_iter(task_name="task", run_id="run-id")
+ )
+
+ # should raise ValueError when only relevant_filenames_artifact_id is given
+ with pytest.raises(ValueError):
+ next(
+ client.download_raw_predictions_iter(
+ task_name="task",
+ relevant_filenames_artifact_id="relevant-filenames",
+ )
+ )
+
+ def test_download_raw_metadata(self, mocker: MockerFixture) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
+ side_effect=[response],
+ )
+ assert client.download_raw_metadata() == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ]
+
+ def test_download_raw_metadata_iter(self, mocker: MockerFixture) -> None:
+ response_1 = DatasourceRawSamplesData(
+ hasMore=True,
+ cursor="cursor1",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ response_2 = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="cursor2",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"),
+ DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
+ side_effect=[response_1, response_2],
+ )
+ assert list(client.download_raw_metadata_iter()) == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ("file3", "url3"),
+ ("file4", "url4"),
+ ]
+ client._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id.assert_has_calls(
+ [
+ mocker.call(
+ dataset_id="dataset-id",
+ var_from=0,
+ to=mocker.ANY,
+ use_redirected_read_url=False,
+ ),
+ mocker.call(
+ dataset_id="dataset-id",
+ cursor="cursor1",
+ use_redirected_read_url=False,
+ ),
+ ]
+ )
+ def test_download_raw_metadata_iter__relevant_filenames_artifact_id(
+ self, mocker: MockerFixture
+ ) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id",
+ side_effect=[response],
+ )
+ assert list(
+ client.download_raw_metadata_iter(
+ run_id="run-id",
+ relevant_filenames_artifact_id="relevant-filenames",
+ )
+ ) == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ]
+ client._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id.assert_called_once_with(
+ dataset_id="dataset-id",
+ var_from=0,
+ to=mocker.ANY,
+ relevant_filenames_run_id="run-id",
+ relevant_filenames_artifact_id="relevant-filenames",
+ use_redirected_read_url=False,
+ )
-def test_download_new_raw_samples__from_beginning(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocker.patch.object(
- ApiWorkflowClient, "get_processed_until_timestamp", return_value=0
- )
- current_time = 5
- mocker.patch("time.time", return_value=current_time)
- mocked_download = mocker.patch.object(ApiWorkflowClient, "download_raw_samples")
- mocked_update_timestamp = mocker.patch.object(
- ApiWorkflowClient, "update_processed_until_timestamp"
- )
- client = ApiWorkflowClient()
- client.download_new_raw_samples()
- mocked_download.assert_called_once_with(
- from_=0,
- to=current_time,
- relevant_filenames_file_name=None,
- use_redirected_read_url=False,
- )
- mocked_update_timestamp.assert_called_once_with(timestamp=current_time)
-
-
-def test_download_raw_samples_predictions__relevant_filenames_artifact_id(
- mocker: MockerFixture,
-) -> None:
- mock_response = mocker.MagicMock()
- mock_response.has_more = False
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_api = mocker.MagicMock()
- mocked_method = mocker.MagicMock(return_value=mock_response)
- mocked_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id = (
- mocked_method
- )
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- client._datasources_api = mocked_api
- client.download_raw_predictions(
- task_name="task", run_id="foo", relevant_filenames_artifact_id="bar"
- )
- kwargs = mocked_method.call_args[1]
- assert kwargs.get("relevant_filenames_run_id") == "foo"
- assert kwargs.get("relevant_filenames_artifact_id") == "bar"
-
- # should raise ValueError when only run_id is given
- with pytest.raises(ValueError):
- client.download_raw_predictions(task_name="foobar", run_id="foo")
- # should raise ValueError when only relevant_filenames_artifact_id is given
- with pytest.raises(ValueError):
- client.download_raw_predictions(
- task_name="foobar", relevant_filenames_artifact_id="bar"
- )
-
-
-def test_download_raw_samples_metadata__relevant_filenames_artifact_id(
- mocker: MockerFixture,
-) -> None:
- mock_response = mocker.MagicMock()
- mock_response.has_more = False
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_api = mocker.MagicMock()
- mocked_method = mocker.MagicMock(return_value=mock_response)
- mocked_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id = (
- mocked_method
- )
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- client._datasources_api = mocked_api
- client.download_raw_metadata(run_id="foo", relevant_filenames_artifact_id="bar")
- kwargs = mocked_method.call_args[1]
- assert kwargs.get("relevant_filenames_run_id") == "foo"
- assert kwargs.get("relevant_filenames_artifact_id") == "bar"
-
- # should raise ValueError when only run_id is given
- with pytest.raises(ValueError):
- client.download_raw_metadata(run_id="foo")
- # should raise ValueError when only relevant_filenames_artifact_id is given
- with pytest.raises(ValueError):
- client.download_raw_metadata(relevant_filenames_artifact_id="bar")
-
-
-def test_get_processed_until_timestamp(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- client._datasources_api = mocked_datasources_api
- client.get_processed_until_timestamp()
- mocked_method = (
- mocked_datasources_api.get_datasource_processed_until_timestamp_by_dataset_id
- )
- mocked_method.assert_called_once_with(dataset_id="dataset-id")
-
-
-def test_set_azure_config(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._datasources_api = mocked_datasources_api
- client._dataset_id = "dataset-id"
- client.set_azure_config(
- container_name="my-container/name",
- account_name="my-account-name",
- sas_token="my-sas-token",
- thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
- )
- kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1]
- assert isinstance(
- kwargs["datasource_config"].actual_instance, DatasourceConfigAzure
- )
+ # should raise ValueError when only run_id is given
+ with pytest.raises(ValueError):
+ next(client.download_raw_metadata_iter(run_id="run-id"))
+
+ # should raise ValueError when only relevant_filenames_artifact_id is given
+ with pytest.raises(ValueError):
+ next(
+ client.download_raw_metadata_iter(
+ relevant_filenames_artifact_id="relevant-filenames",
+ )
+ )
+
+ def test_download_new_raw_samples(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ client.get_processed_until_timestamp = mocker.MagicMock(return_value=2)
+ mocker.patch("time.time", return_value=5)
+ mocker.patch.object(client, "download_raw_samples")
+ mocker.patch.object(client, "update_processed_until_timestamp")
+ client.download_new_raw_samples()
+ client.download_raw_samples.assert_called_once_with(
+ from_=2 + 1,
+ to=5,
+ relevant_filenames_file_name=None,
+ use_redirected_read_url=False,
+ )
+ client.update_processed_until_timestamp.assert_called_once_with(timestamp=5)
+
+ def test_download_new_raw_samples__from_beginning(
+ self, mocker: MockerFixture
+ ) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ client.get_processed_until_timestamp = mocker.MagicMock(return_value=2)
+ mocker.patch("time.time", return_value=5)
+ mocker.patch.object(client, "download_raw_samples")
+ mocker.patch.object(client, "update_processed_until_timestamp")
+ client.download_new_raw_samples()
+ client.download_raw_samples.assert_called_once_with(
+ from_=3,
+ to=5,
+ relevant_filenames_file_name=None,
+ use_redirected_read_url=False,
+ )
+ client.update_processed_until_timestamp.assert_called_once_with(timestamp=5)
+
+ def test_get_processed_until_timestamp(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_datasource_processed_until_timestamp_by_dataset_id",
+ return_value=DatasourceProcessedUntilTimestampResponse(
+ processedUntilTimestamp=5
+ ),
+ )
+ assert client.get_processed_until_timestamp() == 5
+ client._datasources_api.get_datasource_processed_until_timestamp_by_dataset_id.assert_called_once_with(
+ dataset_id="dataset-id"
+ )
+ def test_update_processed_until_timestamp(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_processed_until_timestamp_by_dataset_id",
+ )
+ client.update_processed_until_timestamp(timestamp=10)
+ kwargs = client._datasources_api.update_datasource_processed_until_timestamp_by_dataset_id.call_args[
+ 1
+ ]
+ assert kwargs["dataset_id"] == "dataset-id"
+ assert (
+ kwargs[
+ "datasource_processed_until_timestamp_request"
+ ].processed_until_timestamp
+ == 10
+ )
-def test_set_gcs_config(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._datasources_api = mocked_datasources_api
- client._dataset_id = "dataset-id"
- client.set_gcs_config(
- resource_path="gs://my-bucket/my-dataset",
- project_id="my-project-id",
- credentials="my-credentials",
- thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
- )
- kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1]
- assert isinstance(kwargs["datasource_config"].actual_instance, DatasourceConfigGCS)
-
-
-def test_set_local_config(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._datasources_api = mocked_datasources_api
- client._dataset_id = "dataset-id"
- client.set_local_config(
- resource_path="http://localhost:1234/path/to/my/data",
- thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
- )
- kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1]
- assert isinstance(
- kwargs["datasource_config"].actual_instance, DatasourceConfigLOCAL
- )
+ def test_set_azure_config(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_by_dataset_id",
+ )
+ client.set_azure_config(
+ container_name="my-container/name",
+ account_name="my-account-name",
+ sas_token="my-sas-token",
+ thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
+ )
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ assert isinstance(
+ kwargs["datasource_config"].actual_instance, DatasourceConfigAzure
+ )
+ def test_set_gcs_config(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_by_dataset_id",
+ )
+ client.set_gcs_config(
+ resource_path="gs://my-bucket/my-dataset",
+ project_id="my-project-id",
+ credentials="my-credentials",
+ thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
+ )
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ assert isinstance(
+ kwargs["datasource_config"].actual_instance, DatasourceConfigGCS
+ )
-def test_set_s3_config(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._datasources_api = mocked_datasources_api
- client._dataset_id = "dataset-id"
- client.set_s3_config(
- resource_path="s3://my-bucket/my-dataset",
- thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
- region="eu-central-1",
- access_key="my-access-key",
- secret_access_key="my-secret-access-key",
- )
- kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1]
- assert isinstance(kwargs["datasource_config"].actual_instance, DatasourceConfigS3)
-
-
-def test_set_s3_delegated_access_config(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._datasources_api = mocked_datasources_api
- client._dataset_id = "dataset-id"
- client.set_s3_delegated_access_config(
- resource_path="s3://my-bucket/my-dataset",
- thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
- region="eu-central-1",
- role_arn="arn:aws:iam::000000000000:role.test",
- external_id="my-external-id",
- )
- kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1]
- assert isinstance(
- kwargs["datasource_config"].actual_instance, DatasourceConfigS3DelegatedAccess
- )
+ def test_set_local_config(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_by_dataset_id",
+ )
+ client.set_local_config(
+ web_server_location="http://localhost:1234",
+ relative_path="path/to/my/data",
+ thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
+ purpose=DatasourcePurpose.INPUT,
+ )
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ datasource_config = kwargs["datasource_config"].actual_instance
+ assert isinstance(datasource_config, DatasourceConfigLOCAL)
+ assert datasource_config.type == "LOCAL"
+ assert datasource_config.web_server_location == "http://localhost:1234"
+ assert datasource_config.full_path == "path/to/my/data"
+ assert (
+ datasource_config.thumb_suffix
+ == ".lightly/thumbnails/[filename]-thumb-[extension]"
+ )
+ assert datasource_config.purpose == DatasourcePurpose.INPUT
+
+ # Test defaults
+ client.set_local_config()
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ datasource_config = kwargs["datasource_config"].actual_instance
+ assert isinstance(datasource_config, DatasourceConfigLOCAL)
+ assert datasource_config.type == "LOCAL"
+ assert datasource_config.web_server_location == "http://localhost:3456"
+ assert datasource_config.full_path == ""
+ assert (
+ datasource_config.thumb_suffix
+ == ".lightly/thumbnails/[filename]_thumb.[extension]"
+ )
+ assert datasource_config.purpose == DatasourcePurpose.INPUT_OUTPUT
+ def test_set_s3_config(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_by_dataset_id",
+ )
+ client.set_s3_config(
+ resource_path="s3://my-bucket/my-dataset",
+ thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
+ region="eu-central-1",
+ access_key="my-access-key",
+ secret_access_key="my-secret-access-key",
+ )
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ assert isinstance(
+ kwargs["datasource_config"].actual_instance, DatasourceConfigS3
+ )
-def test_update_processed_until_timestamp(mocker: MockerFixture) -> None:
- mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None)
- mocked_datasources_api = mocker.MagicMock()
- client = ApiWorkflowClient()
- client._dataset_id = "dataset-id"
- client._datasources_api = mocked_datasources_api
- client.update_processed_until_timestamp(10)
- kwargs = mocked_datasources_api.update_datasource_processed_until_timestamp_by_dataset_id.call_args[
- 1
- ]
- assert kwargs["dataset_id"] == "dataset-id"
- assert (
- kwargs["datasource_processed_until_timestamp_request"].processed_until_timestamp
- == 10
- )
+ def test_set_s3_delegated_access_config(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "update_datasource_by_dataset_id",
+ )
+ client.set_s3_delegated_access_config(
+ resource_path="s3://my-bucket/my-dataset",
+ thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]",
+ region="eu-central-1",
+ role_arn="arn:aws:iam::000000000000:role.test",
+ external_id="my-external-id",
+ )
+ kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1]
+ assert isinstance(
+ kwargs["datasource_config"].actual_instance,
+ DatasourceConfigS3DelegatedAccess,
+ )
+ def test_get_prediction_read_url(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_prediction_file_read_url_from_datasource_by_dataset_id",
+ return_value="read-url",
+ )
+ assert client.get_prediction_read_url(filename="test.json") == "read-url"
+ client._datasources_api.get_prediction_file_read_url_from_datasource_by_dataset_id.assert_called_once_with(
+ dataset_id="dataset-id", file_name="test.json"
+ )
-def test_list_datasource_permissions(mocker: MockerFixture) -> None:
- client = ApiWorkflowClient(token="abc")
- client._dataset_id = "dataset-id"
- client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock(
- return_value=DatasourceConfigVerifyData(
- canRead=True,
- canWrite=True,
- canList=False,
- canOverwrite=True,
- errors=None,
- ),
- )
- assert client.list_datasource_permissions() == {
- "can_read": True,
- "can_write": True,
- "can_list": False,
- "can_overwrite": True,
- }
-
-
-def test_list_datasource_permissions__error(mocker: MockerFixture) -> None:
- client = ApiWorkflowClient(token="abc")
- client._dataset_id = "dataset-id"
- client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock(
- return_value=DatasourceConfigVerifyData(
- canRead=True,
- canWrite=True,
- canList=False,
- canOverwrite=True,
- errors=DatasourceConfigVerifyDataErrors(
- canRead=None, canWrite=None, canList="error message", canOverwrite=None
+ def test_get_custom_embedding_read_url(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ mocker.patch.object(
+ client._datasources_api,
+ "get_custom_embedding_file_read_url_from_datasource_by_dataset_id",
+ return_value="read-url",
+ )
+ assert (
+ client.get_custom_embedding_read_url(filename="embeddings.csv")
+ == "read-url"
+ )
+ client._datasources_api.get_custom_embedding_file_read_url_from_datasource_by_dataset_id.assert_called_once_with(
+ dataset_id="dataset-id", file_name="embeddings.csv"
+ )
+
+ def test_list_datasource_permissions(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock(
+ return_value=DatasourceConfigVerifyData(
+ canRead=True,
+ canWrite=True,
+ canList=False,
+ canOverwrite=True,
+ errors=None,
+ ),
+ )
+ assert client.list_datasource_permissions() == {
+ "can_read": True,
+ "can_write": True,
+ "can_list": False,
+ "can_overwrite": True,
+ }
+
+ def test_list_datasource_permissions__error(self, mocker: MockerFixture) -> None:
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock(
+ return_value=DatasourceConfigVerifyData(
+ canRead=True,
+ canWrite=True,
+ canList=False,
+ canOverwrite=True,
+ errors=DatasourceConfigVerifyDataErrors(
+ canRead=None,
+ canWrite=None,
+ canList="error message",
+ canOverwrite=None,
+ ),
),
- ),
+ )
+ assert client.list_datasource_permissions() == {
+ "can_read": True,
+ "can_write": True,
+ "can_list": False,
+ "can_overwrite": True,
+ "errors": {
+ "can_list": "error message",
+ },
+ }
+
+ def test__download_raw_files(self, mocker: MockerFixture) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ download_function = mocker.MagicMock(side_effect=[response])
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ assert client._download_raw_files(
+ download_function=download_function,
+ ) == [("file1", "url1"), ("file2", "url2")]
+
+ def test__download_raw_files_iter(self, mocker: MockerFixture) -> None:
+ response_1 = DatasourceRawSamplesData(
+ hasMore=True,
+ cursor="cursor1",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"),
+ ],
+ )
+ response_2 = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="cursor2",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"),
+ DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"),
+ ],
+ )
+ download_function = mocker.MagicMock(side_effect=[response_1, response_2])
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ progress_bar = mocker.spy(tqdm, "tqdm")
+ assert list(
+ client._download_raw_files_iter(
+ download_function=download_function,
+ from_=0,
+ to=5,
+ relevant_filenames_file_name="relevant-filenames",
+ use_redirected_read_url=True,
+ progress_bar=progress_bar,
+ foo="bar",
+ )
+ ) == [
+ ("file1", "url1"),
+ ("file2", "url2"),
+ ("file3", "url3"),
+ ("file4", "url4"),
+ ]
+ download_function.assert_has_calls(
+ [
+ mocker.call(
+ dataset_id="dataset-id",
+ var_from=0,
+ to=5,
+ relevant_filenames_file_name="relevant-filenames",
+ use_redirected_read_url=True,
+ foo="bar",
+ ),
+ mocker.call(
+ dataset_id="dataset-id",
+ cursor="cursor1",
+ relevant_filenames_file_name="relevant-filenames",
+ use_redirected_read_url=True,
+ foo="bar",
+ ),
+ ]
+ )
+ assert progress_bar.update.call_count == 4
+
+ def test__download_raw_files_iter__no_relevant_filenames(
+ self, mocker: MockerFixture
+ ) -> None:
+ response = DatasourceRawSamplesData(hasMore=False, cursor="", data=[])
+ download_function = mocker.MagicMock(side_effect=[response])
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ list(client._download_raw_files_iter(download_function=download_function))
+ assert "relevant_filenames_file_name" not in download_function.call_args[1]
+
+ def test__download_raw_files_iter__warning(self, mocker: MockerFixture) -> None:
+ response = DatasourceRawSamplesData(
+ hasMore=False,
+ cursor="",
+ data=[
+ DatasourceRawSamplesDataRow(fileName="/file1", readUrl="url1"),
+ ],
+ )
+ download_function = mocker.MagicMock(side_effect=[response])
+ client = ApiWorkflowClient(token="abc", dataset_id="dataset-id")
+ with pytest.warns(UserWarning, match="Absolute file paths like /file1"):
+ list(client._download_raw_files_iter(download_function=download_function))
+
+
+def test__sample_unseen_and_valid() -> None:
+ with pytest.warns(UserWarning, match="Absolute file paths like /file1"):
+ assert not api_workflow_datasources._sample_unseen_and_valid(
+ sample=DatasourceRawSamplesDataRow(fileName="/file1", readUrl="url1"),
+ relevant_filenames_file_name=None,
+ listed_filenames=set(),
+ )
+
+ with pytest.warns(UserWarning, match="Using dot notation"):
+ assert not api_workflow_datasources._sample_unseen_and_valid(
+ sample=DatasourceRawSamplesDataRow(fileName="./file1", readUrl="url1"),
+ relevant_filenames_file_name=None,
+ listed_filenames=set(),
+ )
+
+ with pytest.warns(UserWarning, match="Duplicate filename file1"):
+ assert not api_workflow_datasources._sample_unseen_and_valid(
+ sample=DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ relevant_filenames_file_name=None,
+ listed_filenames={"file1"},
+ )
+
+ assert api_workflow_datasources._sample_unseen_and_valid(
+ sample=DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"),
+ relevant_filenames_file_name=None,
+ listed_filenames=set(),
)
- assert client.list_datasource_permissions() == {
- "can_read": True,
- "can_write": True,
- "can_list": False,
- "can_overwrite": True,
- "errors": {
- "can_list": "error message",
- },
- }
diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py
index 4444d11d7..a3bb54c2b 100644
--- a/tests/api_workflow/test_api_workflow_upload_embeddings.py
+++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py
@@ -4,11 +4,7 @@
import numpy as np
from lightly.utils import io as io_utils
-from lightly.utils.io import INVALID_FILENAME_CHARACTERS
-from tests.api_workflow.mocked_api_workflow_client import (
- N_FILES_ON_SERVER,
- MockedApiWorkflowSetup,
-)
+from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup
class TestApiWorkflowUploadEmbeddings(MockedApiWorkflowSetup):
@@ -80,15 +76,6 @@ def test_upload_wrong_filenames(self):
with self.assertRaises(ValueError):
self.t_ester_upload_embedding(n_data=n_data, special_name_first_sample=True)
- def test_upload_comma_filenames(self):
- n_data = len(self.api_workflow_client._mappings_api.sample_names)
- for invalid_char in INVALID_FILENAME_CHARACTERS:
- with self.subTest(msg=f"invalid_char: {invalid_char}"):
- with self.assertRaises(ValueError):
- self.t_ester_upload_embedding(
- n_data=n_data, special_char_in_first_filename=invalid_char
- )
-
def test_set_embedding_id_default(self):
self.api_workflow_client.set_embedding_id_to_latest()
embeddings = (
diff --git a/tests/conftest.py b/tests/conftest.py
index 22e256d29..7133ba908 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -41,7 +41,7 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_slow)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def mock_versioning_api():
"""Fixture that is applied to all tests and mocks the versioning API.
@@ -56,17 +56,17 @@ def mock_versioning_api():
should be compatible with all future versions.
"""
- def mock_get_latest_pip_version(current_version: str) -> str:
+ def mock_get_latest_pip_version(current_version: str, **kwargs) -> str:
return current_version
# NOTE(guarin, 2/6/23): Cannot use pytest mocker fixture here because it has not
- # a "session" scope and it is not possible to use a fixture that has a tigher scope
+ # a "module" scope and it is not possible to use a fixture that has a tighter scope
# inside a fixture with a wider scope.
with mock.patch(
- "lightly.api.version_checking.VersioningApi.get_latest_pip_version",
+ "lightly.api._version_checking.VersioningApi.get_latest_pip_version",
new=mock_get_latest_pip_version,
), mock.patch(
- "lightly.api.version_checking.VersioningApi.get_minimum_compatible_pip_version",
+ "lightly.api._version_checking.VersioningApi.get_minimum_compatible_pip_version",
return_value="1.0.0",
):
yield
diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py
index c5a1a7c6c..0b4c7b196 100644
--- a/tests/data/test_LightlyDataset.py
+++ b/tests/data/test_LightlyDataset.py
@@ -1,23 +1,18 @@
import os
-import random
-import re
import shutil
import tempfile
import unittest
-import warnings
from typing import List, Tuple
import numpy as np
-import torch
import torchvision
from PIL.Image import Image
from lightly.data import LightlyDataset
from lightly.data._utils import check_images
-from lightly.utils.io import INVALID_FILENAME_CHARACTERS
try:
- import av
+ import av as _
import cv2
from lightly.data._video import VideoDataset
@@ -137,24 +132,6 @@ def test_create_lightly_dataset_from_folder_nosubdir(self):
for i in range(n_tot):
sample, target, fname = dataset[i]
- def test_create_lightly_dataset_with_invalid_char_in_filename(self):
- # create a dataset
- n_tot = 100
- dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32))
-
- for invalid_char in INVALID_FILENAME_CHARACTERS:
- with self.subTest(msg=f"invalid_char: {invalid_char}"):
- tmp_dir = tempfile.mkdtemp()
- sample_names = [f"img_,_{i}.jpg" for i in range(n_tot)]
- for sample_idx in range(n_tot):
- data = dataset[sample_idx]
- path = os.path.join(tmp_dir, sample_names[sample_idx])
- data[0].save(path)
-
- # create lightly dataset
- with self.assertRaises(ValueError):
- dataset = LightlyDataset(input_dir=tmp_dir)
-
def test_check_images(self):
# create a dataset
tmp_dir = tempfile.mkdtemp()
diff --git a/tests/embedding/test_embedding.py b/tests/embedding/test_embedding.py
index c6c6d1baa..d6f4bcb8e 100644
--- a/tests/embedding/test_embedding.py
+++ b/tests/embedding/test_embedding.py
@@ -66,8 +66,8 @@ def test_embed_correct_order(self):
device=device,
)
- np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, rtol=5e-5)
- np.testing.assert_allclose(labels_1_worker, labels_4_worker, rtol=1e-5)
+ np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, atol=5e-4)
+ np.testing.assert_allclose(labels_1_worker, labels_4_worker, atol=1e-5)
self.assertListEqual(filenames_1_worker, filenames_4_worker)
self.assertListEqual(filenames_1_worker, dataset.get_filenames())
diff --git a/tests/loss/test_PMSNLoss.py b/tests/loss/test_PMSNLoss.py
index 4919e965b..1ddc8b84e 100644
--- a/tests/loss/test_PMSNLoss.py
+++ b/tests/loss/test_PMSNLoss.py
@@ -12,7 +12,7 @@
class TestPMSNLoss:
def test_regularization_loss(self) -> None:
criterion = PMSNLoss()
- mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]).log()
+ mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6])
loss = criterion.regularization_loss(mean_anchor_probs=mean_anchor_probs)
norm = 1 / (1**0.25) + 1 / (2**0.25) + 1 / (3**0.25)
t0 = 1 / (1**0.25) / norm
@@ -45,7 +45,7 @@ def test_forward_cuda(self) -> None:
class TestPMSNCustomLoss:
def test_regularization_loss(self) -> None:
criterion = PMSNCustomLoss(target_distribution=_uniform_distribution)
- mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]).log()
+ mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6])
loss = criterion.regularization_loss(mean_anchor_probs=mean_anchor_probs)
expected_loss = (
1
diff --git a/tests/loss/test_barlow_twins_loss.py b/tests/loss/test_barlow_twins_loss.py
index f947ee784..c2853042b 100644
--- a/tests/loss/test_barlow_twins_loss.py
+++ b/tests/loss/test_barlow_twins_loss.py
@@ -1,10 +1,51 @@
import pytest
+import torch
+import torch.nn as nn
from pytest_mock import MockerFixture
from torch import distributed as dist
from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
+class BarlowTwinsLossReference(torch.nn.Module):
+ def __init__(
+ self,
+ projector_dim: int = 8196,
+ lambda_param: float = 5e-3,
+ gather_distributed: bool = False,
+ ):
+ super(BarlowTwinsLossReference, self).__init__()
+ # normalization layer for the representations z1 and z2
+ self.bn = nn.BatchNorm1d(projector_dim, affine=False)
+ self.lambda_param = lambda_param
+ self.gather_distributed = gather_distributed
+
+ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
+ # code from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
+
+ N = z_a.size(0)
+
+ # empirical cross-correlation matrix
+ c = self.bn(z_a).T @ self.bn(z_b)
+
+ # sum the cross-correlation matrix between all gpus
+ c.div_(N)
+ if self.gather_distributed:
+ torch.distributed.all_reduce(c)
+
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+ off_diag = off_diagonal(c).pow_(2).sum()
+ loss = on_diag + self.lambda_param * off_diag
+ return loss
+
+
+def off_diagonal(x):
+ # return a flattened view of the off-diagonal elements of a square matrix
+ n, m = x.shape
+ assert n == m
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
+
+
class TestBarlowTwinsLoss:
def test__gather_distributed(self, mocker: MockerFixture) -> None:
mock_is_available = mocker.patch.object(dist, "is_available", return_value=True)
@@ -20,3 +61,25 @@ def test__gather_distributed_dist_not_available(
with pytest.raises(ValueError):
BarlowTwinsLoss(gather_distributed=True)
mock_is_available.assert_called_once()
+
+ def test__loss_matches_reference_loss(self) -> None:
+ batch_size = 32
+ projector_dim = 8196
+ lambda_param = 5e-3
+ gather_distributed = False
+ loss = BarlowTwinsLoss(
+ lambda_param=lambda_param, gather_distributed=gather_distributed
+ )
+ loss_ref = BarlowTwinsLossReference(
+ projector_dim=projector_dim,
+ lambda_param=lambda_param,
+ gather_distributed=gather_distributed,
+ )
+
+ z_a = torch.randn(batch_size, projector_dim)
+ z_b = torch.randn(batch_size, projector_dim)
+
+ loss_out = loss(z_a, z_b)
+ loss_ref_out = loss_ref(z_a, z_b)
+
+ assert torch.allclose(loss_out, loss_ref_out, rtol=1e-3, atol=1e-3)
diff --git a/tests/transforms/test_Solarize.py b/tests/transforms/test_Solarize.py
index 4a1cda349..1b595caf3 100644
--- a/tests/transforms/test_Solarize.py
+++ b/tests/transforms/test_Solarize.py
@@ -6,7 +6,7 @@
class TestRandomSolarization(unittest.TestCase):
- def test_on_pil_image(self):
+ def test_on_pil_image(self) -> None:
for w in [32, 64, 128]:
for h in [32, 64, 128]:
solarization = RandomSolarization(0.5)
diff --git a/tests/transforms/test_byol_transform.py b/tests/transforms/test_byol_transform.py
new file mode 100644
index 000000000..ecf72df48
--- /dev/null
+++ b/tests/transforms/test_byol_transform.py
@@ -0,0 +1,26 @@
+from PIL import Image
+
+from lightly.transforms.byol_transform import (
+ BYOLTransform,
+ BYOLView1Transform,
+ BYOLView2Transform,
+)
+
+
+def test_view_on_pil_image() -> None:
+ single_view_transform = BYOLView1Transform(input_size=32)
+ sample = Image.new("RGB", (100, 100))
+ output = single_view_transform(sample)
+ assert output.shape == (3, 32, 32)
+
+
+def test_multi_view_on_pil_image() -> None:
+ multi_view_transform = BYOLTransform(
+ view_1_transform=BYOLView1Transform(input_size=32),
+ view_2_transform=BYOLView2Transform(input_size=32),
+ )
+ sample = Image.new("RGB", (100, 100))
+ output = multi_view_transform(sample)
+ assert len(output) == 2
+ assert output[0].shape == (3, 32, 32)
+ assert output[1].shape == (3, 32, 32)
diff --git a/tests/transforms/test_dino_transform.py b/tests/transforms/test_dino_transform.py
index 4ca06c721..74bfea478 100644
--- a/tests/transforms/test_dino_transform.py
+++ b/tests/transforms/test_dino_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = DINOViewTransform(crop_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_fastsiam_transform.py b/tests/transforms/test_fastsiam_transform.py
index a5f60dfdf..672cf0a41 100644
--- a/tests/transforms/test_fastsiam_transform.py
+++ b/tests/transforms/test_fastsiam_transform.py
@@ -3,7 +3,7 @@
from lightly.transforms.fast_siam_transform import FastSiamTransform
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = FastSiamTransform(num_views=3, input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_GaussianBlur.py b/tests/transforms/test_gaussian_blur.py
similarity index 69%
rename from tests/transforms/test_GaussianBlur.py
rename to tests/transforms/test_gaussian_blur.py
index a778f15e4..fae09c674 100644
--- a/tests/transforms/test_GaussianBlur.py
+++ b/tests/transforms/test_gaussian_blur.py
@@ -2,21 +2,21 @@
from PIL import Image
-from lightly.transforms import GaussianBlur
+from lightly.transforms.gaussian_blur import GaussianBlur
class TestGaussianBlur(unittest.TestCase):
- def test_on_pil_image(self):
+ def test_on_pil_image(self) -> None:
for w in range(1, 100):
for h in range(1, 100):
gaussian_blur = GaussianBlur()
sample = Image.new("RGB", (w, h))
gaussian_blur(sample)
- def test_raise_kernel_size_deprecation(self):
+ def test_raise_kernel_size_deprecation(self) -> None:
gaussian_blur = GaussianBlur(kernel_size=2)
self.assertWarns(DeprecationWarning)
- def test_raise_scale_deprecation(self):
+ def test_raise_scale_deprecation(self) -> None:
gaussian_blur = GaussianBlur(scale=0.1)
self.assertWarns(DeprecationWarning)
diff --git a/tests/transforms/test_Jigsaw.py b/tests/transforms/test_jigsaw.py
similarity index 66%
rename from tests/transforms/test_Jigsaw.py
rename to tests/transforms/test_jigsaw.py
index 964e738c5..6482ee109 100644
--- a/tests/transforms/test_Jigsaw.py
+++ b/tests/transforms/test_jigsaw.py
@@ -2,11 +2,11 @@
from PIL import Image
-from lightly.transforms import Jigsaw
+from lightly.transforms.jigsaw import Jigsaw
class TestJigsaw(unittest.TestCase):
- def test_on_pil_image(self):
+ def test_on_pil_image(self) -> None:
crop = Jigsaw()
sample = Image.new("RGB", (255, 255))
crop(sample)
diff --git a/tests/transforms/test_location_to_NxN_grid.py b/tests/transforms/test_location_to_NxN_grid.py
index 2ec1beb5b..013d4ab8f 100644
--- a/tests/transforms/test_location_to_NxN_grid.py
+++ b/tests/transforms/test_location_to_NxN_grid.py
@@ -3,7 +3,7 @@
import lightly.transforms.random_crop_and_flip_with_grid as test_module
-def test_location_to_NxN_grid():
+def test_location_to_NxN_grid() -> None:
# create a test instance of the Location class
test_location = test_module.Location(
left=10,
diff --git a/tests/transforms/test_mae_transform.py b/tests/transforms/test_mae_transform.py
index aafa11cdf..6f9b928c1 100644
--- a/tests/transforms/test_mae_transform.py
+++ b/tests/transforms/test_mae_transform.py
@@ -3,7 +3,7 @@
from lightly.transforms.mae_transform import MAETransform
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = MAETransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_moco_transform.py b/tests/transforms/test_moco_transform.py
index eef1c651f..aa43a216f 100644
--- a/tests/transforms/test_moco_transform.py
+++ b/tests/transforms/test_moco_transform.py
@@ -3,7 +3,7 @@
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
-def test_moco_v1_multi_view_on_pil_image():
+def test_moco_v1_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV1Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
@@ -12,7 +12,7 @@ def test_moco_v1_multi_view_on_pil_image():
assert output[1].shape == (3, 32, 32)
-def test_moco_v2_multi_view_on_pil_image():
+def test_moco_v2_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV2Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_msn_transform.py b/tests/transforms/test_msn_transform.py
index fd2030bab..4f5be7b53 100644
--- a/tests/transforms/test_msn_transform.py
+++ b/tests/transforms/test_msn_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = MSNViewTransform(crop_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = MSNTransform(random_size=32, focal_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_multi_view_transform.py b/tests/transforms/test_multi_view_transform.py
index dddd4db3d..27806b92a 100644
--- a/tests/transforms/test_multi_view_transform.py
+++ b/tests/transforms/test_multi_view_transform.py
@@ -6,7 +6,7 @@
from lightly.transforms.multi_view_transform import MultiViewTransform
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = MultiViewTransform(
[
T.RandomHorizontalFlip(p=0.1),
diff --git a/tests/transforms/test_pirl_transform.py b/tests/transforms/test_pirl_transform.py
index 5042a1e8f..20c7c8705 100644
--- a/tests/transforms/test_pirl_transform.py
+++ b/tests/transforms/test_pirl_transform.py
@@ -3,7 +3,7 @@
from lightly.transforms.pirl_transform import PIRLTransform
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = PIRLTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_rotation.py b/tests/transforms/test_rotation.py
index 448858fdd..065b8dd1d 100644
--- a/tests/transforms/test_rotation.py
+++ b/tests/transforms/test_rotation.py
@@ -1,3 +1,5 @@
+from typing import List, Tuple, Union
+
from PIL import Image
from lightly.transforms.rotation import (
@@ -7,20 +9,21 @@
)
-def test_RandomRotate_on_pil_image():
+def test_RandomRotate_on_pil_image() -> None:
random_rotate = RandomRotate()
sample = Image.new("RGB", (100, 100))
random_rotate(sample)
-def test_RandomRotateDegrees_on_pil_image():
- for degrees in [0, 1, 45, (0, 0), (-15, 30)]:
+def test_RandomRotateDegrees_on_pil_image() -> None:
+ all_degrees: List[Union[float, Tuple[float, float]]] = [0, 1, 45, (0, 0), (-15, 30)]
+ for degrees in all_degrees:
random_rotate = RandomRotateDegrees(prob=0.5, degrees=degrees)
sample = Image.new("RGB", (100, 100))
random_rotate(sample)
-def test_random_rotation_transform():
+def test_random_rotation_transform() -> None:
transform = random_rotation_transform(rr_prob=1.0, rr_degrees=None)
assert isinstance(transform, RandomRotate)
transform = random_rotation_transform(rr_prob=1.0, rr_degrees=45)
diff --git a/tests/transforms/test_simclr_transform.py b/tests/transforms/test_simclr_transform.py
index 70fff7ab4..78a9a5cca 100644
--- a/tests/transforms/test_simclr_transform.py
+++ b/tests/transforms/test_simclr_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = SimCLRViewTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimCLRTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_simsiam_transform.py b/tests/transforms/test_simsiam_transform.py
index 2444924ec..39a88721a 100644
--- a/tests/transforms/test_simsiam_transform.py
+++ b/tests/transforms/test_simsiam_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = SimSiamViewTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimSiamTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_smog_transform.py b/tests/transforms/test_smog_transform.py
index 49fed878b..042d46f9f 100644
--- a/tests/transforms/test_smog_transform.py
+++ b/tests/transforms/test_smog_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = SmoGViewTransform(crop_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = SMoGTransform(crop_sizes=(32, 8))
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_swav_transform.py b/tests/transforms/test_swav_transform.py
index 05475ce6a..7c2cdd2c0 100644
--- a/tests/transforms/test_swav_transform.py
+++ b/tests/transforms/test_swav_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = SwaVViewTransform()
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 100, 100)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = SwaVTransform(crop_sizes=(32, 8))
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_vicreg_transform.py b/tests/transforms/test_vicreg_transform.py
index 5a2b0633d..06e710f25 100644
--- a/tests/transforms/test_vicreg_transform.py
+++ b/tests/transforms/test_vicreg_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = VICRegViewTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 32, 32)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = VICRegTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
diff --git a/tests/transforms/test_vicregl_transform.py b/tests/transforms/test_vicregl_transform.py
index f4696d067..e697807c4 100644
--- a/tests/transforms/test_vicregl_transform.py
+++ b/tests/transforms/test_vicregl_transform.py
@@ -3,14 +3,14 @@
from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform
-def test_view_on_pil_image():
+def test_view_on_pil_image() -> None:
single_view_transform = VICRegLViewTransform()
sample = Image.new("RGB", (100, 100))
output = single_view_transform(sample)
assert output.shape == (3, 100, 100)
-def test_multi_view_on_pil_image():
+def test_multi_view_on_pil_image() -> None:
multi_view_transform = VICRegLTransform(
global_crop_size=32,
local_crop_size=8,
diff --git a/tests/utils/test_dist.py b/tests/utils/test_dist.py
index 79eb16013..8d3a564a9 100644
--- a/tests/utils/test_dist.py
+++ b/tests/utils/test_dist.py
@@ -2,6 +2,7 @@
from unittest import mock
import torch
+from pytest import CaptureFixture
from lightly.utils import dist
@@ -31,3 +32,25 @@ def test_eye_rank_dist(self):
expected.append(zeros)
expected = torch.cat(expected, dim=1)
self.assertTrue(torch.all(dist.eye_rank(n) == expected))
+
+
+def test_rank_zero_only__rank_0() -> None:
+ @dist.rank_zero_only
+ def fn():
+ return 0
+
+ assert fn() == 0
+
+
+def test_rank_zero_only__rank_1() -> None:
+ @dist.rank_zero_only
+ def fn():
+ return 0
+
+ with mock.patch.object(dist, "rank", lambda: 1):
+ assert fn() is None
+
+
+def test_print_rank_zero(capsys: CaptureFixture[str]) -> None:
+ dist.print_rank_zero("message")
+ assert capsys.readouterr().out == "message\n"
diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py
index 988b622f1..04086be5f 100644
--- a/tests/utils/test_io.py
+++ b/tests/utils/test_io.py
@@ -1,62 +1,35 @@
import csv
import json
-import sys
import tempfile
import unittest
+from pathlib import Path
import numpy as np
-from lightly.utils.io import (
- check_embeddings,
- check_filenames,
- save_custom_metadata,
- save_embeddings,
- save_schema,
- save_tasks,
-)
-from tests.api_workflow.mocked_api_workflow_client import (
- MockedApiWorkflowClient,
- MockedApiWorkflowSetup,
-)
-
-
-class TestCLICrop(MockedApiWorkflowSetup):
- def test_save_metadata(self):
+from lightly.utils import io
+from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup
+
+
+class TestCLICrop(MockedApiWorkflowSetup): # type: ignore[misc]
+ def test_save_metadata(self) -> None:
metadata = [("filename.jpg", {"random_metadata": 42})]
metadata_filepath = tempfile.mktemp(".json", "metadata")
- save_custom_metadata(metadata_filepath, metadata)
-
- def test_valid_filenames(self):
- valid = "img.png"
- non_valid = "img,1.png"
- filenames_list = [
- ([valid], True),
- ([valid, valid], True),
- ([non_valid], False),
- ([valid, non_valid], False),
- ]
- for filenames, valid in filenames_list:
- with self.subTest(msg=f"filenames:{filenames}"):
- if valid:
- check_filenames(filenames)
- else:
- with self.assertRaises(ValueError):
- check_filenames(filenames)
+ io.save_custom_metadata(metadata_filepath, metadata)
class TestEmbeddingsIO(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# correct embedding file as created through lightly
self.embeddings_path = tempfile.mktemp(".csv", "embeddings")
embeddings = np.random.rand(32, 2)
labels = [0 for i in range(len(embeddings))]
filenames = [f"img_{i}.jpg" for i in range(len(embeddings))]
- save_embeddings(self.embeddings_path, embeddings, labels, filenames)
+ io.save_embeddings(self.embeddings_path, embeddings, labels, filenames)
- def test_valid_embeddings(self):
- check_embeddings(self.embeddings_path)
+ def test_valid_embeddings(self) -> None:
+ io.check_embeddings(self.embeddings_path)
- def test_whitespace_in_embeddings(self):
+ def test_whitespace_in_embeddings(self) -> None:
# should fail because there whitespaces in the header columns
lines = [
"filenames, embedding_0,embedding_1,labels\n",
@@ -65,19 +38,19 @@ def test_whitespace_in_embeddings(self):
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
- check_embeddings(self.embeddings_path)
+ io.check_embeddings(self.embeddings_path)
self.assertTrue("must not contain whitespaces" in str(context.exception))
- def test_no_labels_in_embeddings(self):
+ def test_no_labels_in_embeddings(self) -> None:
# should fail because there is no `labels` column in the header
lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"]
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
- check_embeddings(self.embeddings_path)
+ io.check_embeddings(self.embeddings_path)
self.assertTrue("has no `labels` column" in str(context.exception))
- def test_no_empty_rows_in_embeddings(self):
+ def test_no_empty_rows_in_embeddings(self) -> None:
# should fail because there are empty rows in the embeddings file
lines = [
"filenames,embedding_0,embedding_1,labels\n",
@@ -86,10 +59,10 @@ def test_no_empty_rows_in_embeddings(self):
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
- check_embeddings(self.embeddings_path)
+ io.check_embeddings(self.embeddings_path)
self.assertTrue("must not have empty rows" in str(context.exception))
- def test_embeddings_extra_rows(self):
+ def test_embeddings_extra_rows(self) -> None:
rows = [
["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
@@ -99,14 +72,14 @@ def test_embeddings_extra_rows(self):
csv_writer = csv.writer(f)
csv_writer.writerows(rows)
- check_embeddings(self.embeddings_path, remove_additional_columns=True)
+ io.check_embeddings(self.embeddings_path, remove_additional_columns=True)
with open(self.embeddings_path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row_read, row_original in zip(csv_reader, rows):
self.assertListEqual(row_read, row_original[:-2])
- def test_embeddings_extra_rows_special_order(self):
+ def test_embeddings_extra_rows_special_order(self) -> None:
input_rows = [
["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
@@ -121,26 +94,26 @@ def test_embeddings_extra_rows_special_order(self):
csv_writer = csv.writer(f)
csv_writer.writerows(input_rows)
- check_embeddings(self.embeddings_path, remove_additional_columns=True)
+ io.check_embeddings(self.embeddings_path, remove_additional_columns=True)
with open(self.embeddings_path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row_read, row_original in zip(csv_reader, correct_output_rows):
self.assertListEqual(row_read, row_original)
- def test_save_tasks(self):
+ def test_save_tasks(self) -> None:
tasks = [
"task1",
"task2",
"task3",
]
with tempfile.NamedTemporaryFile(suffix=".json") as file:
- save_tasks(file.name, tasks)
+ io.save_tasks(file.name, tasks)
with open(file.name, "r") as f:
loaded = json.load(f)
self.assertListEqual(tasks, loaded)
- def test_save_schema(self):
+ def test_save_schema(self) -> None:
description = "classification"
ids = [1, 2, 3, 4]
names = ["name1", "name2", "name3", "name4"]
@@ -154,16 +127,56 @@ def test_save_schema(self):
],
}
with tempfile.NamedTemporaryFile(suffix=".json") as file:
- save_schema(file.name, description, ids, names)
+ io.save_schema(file.name, description, ids, names)
with open(file.name, "r") as f:
loaded = json.load(f)
self.assertListEqual(sorted(expected_format), sorted(loaded))
- def test_save_schema_different(self):
+ def test_save_schema_different(self) -> None:
with self.assertRaises(ValueError):
- save_schema(
+ io.save_schema(
"name_doesnt_matter",
"description_doesnt_matter",
[1, 2],
["name1"],
)
+
+
+def test_save_and_load_embeddings(tmp_path: Path) -> None:
+ embeddings = np.random.rand(2, 32)
+ labels = [0, 1]
+ filenames = ["img_1.jpg", "img_2.jpg"]
+
+ io.save_embeddings(
+ path=str(tmp_path / "embeddings.csv"),
+ embeddings=embeddings,
+ labels=labels,
+ filenames=filenames,
+ )
+
+ loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings(
+ path=str(tmp_path / "embeddings.csv")
+ )
+ assert np.allclose(embeddings, loaded_embeddings)
+ assert labels == loaded_labels
+ assert filenames == loaded_filenames
+
+
+def test_save_and_load_embeddings__filename_with_comma(tmp_path: Path) -> None:
+ embeddings = np.random.rand(4, 32)
+ labels = [0, 1, 2, 3]
+ filenames = ["img,1.jpg", '",img,.jpg', ',"img".jpg', ',"img\n".jpg']
+
+ io.save_embeddings(
+ path=str(tmp_path / "embeddings.csv"),
+ embeddings=embeddings,
+ labels=labels,
+ filenames=filenames,
+ )
+
+ loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings(
+ path=str(tmp_path / "embeddings.csv")
+ )
+ assert np.allclose(embeddings, loaded_embeddings)
+ assert labels == loaded_labels
+ assert filenames == loaded_filenames
diff --git a/tests/utils/test_scheduler.py b/tests/utils/test_scheduler.py
index 6a0bd569c..7beb6754f 100644
--- a/tests/utils/test_scheduler.py
+++ b/tests/utils/test_scheduler.py
@@ -7,7 +7,7 @@
class TestScheduler(unittest.TestCase):
- def test_cosine_schedule(self):
+ def test_cosine_schedule(self) -> None:
self.assertAlmostEqual(cosine_schedule(1, 10, 0.99, 1.0), 0.99030154, 6)
self.assertAlmostEqual(cosine_schedule(95, 100, 0.7, 2.0), 1.99477063, 6)
self.assertAlmostEqual(cosine_schedule(0, 1, 0.996, 1.0), 1.0, 6)
@@ -23,7 +23,15 @@ def test_cosine_schedule(self):
):
cosine_schedule(11, 10, 0.0, 1.0)
- def test_CosineWarmupScheduler(self):
+ def test_cosine_schedule__period(self) -> None:
+ self.assertAlmostEqual(cosine_schedule(0, 1, 0, 1.0, period=10), 0.0, 6)
+ self.assertAlmostEqual(cosine_schedule(3, 1, 0, 2.0, period=10), 1.30901706, 6)
+ self.assertAlmostEqual(cosine_schedule(10, 1, 0, 1.0, period=10), 0.0, 6)
+ self.assertAlmostEqual(cosine_schedule(15, 1, 0, 1.0, period=10), 1.0, 6)
+ with self.assertRaises(ValueError):
+ cosine_schedule(1, 10, 0.0, 1.0, period=-1)
+
+ def test_CosineWarmupScheduler(self) -> None:
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(
model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0
@@ -62,3 +70,23 @@ def test_CosineWarmupScheduler(self):
RuntimeWarning, msg="Current step number 7 exceeds max_steps 6."
):
scheduler.step()
+
+ def test_CosineWarmupScheduler__warmup(self) -> None:
+ model = nn.Linear(10, 1)
+ optimizer = torch.optim.SGD(
+ model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0
+ )
+ scheduler = CosineWarmupScheduler(
+ optimizer,
+ warmup_epochs=3,
+ max_epochs=6,
+ start_value=2.0,
+ end_value=0.0,
+ )
+ # Linear warmup
+ self.assertAlmostEqual(scheduler.scale_lr(epoch=0), 2.0 * 1 / 3)
+ self.assertAlmostEqual(scheduler.scale_lr(epoch=1), 2.0 * 2 / 3)
+ self.assertAlmostEqual(scheduler.scale_lr(epoch=2), 2.0 * 3 / 3)
+ # Cosine decay
+ self.assertAlmostEqual(scheduler.scale_lr(epoch=3), 2.0 * 3 / 3)
+ self.assertLess(scheduler.scale_lr(epoch=4), 2.0)
diff --git a/tests/utils/test_version_compare.py b/tests/utils/test_version_compare.py
index ce39dbb6a..40d516ea0 100644
--- a/tests/utils/test_version_compare.py
+++ b/tests/utils/test_version_compare.py
@@ -4,7 +4,7 @@
class TestVersionCompare(unittest.TestCase):
- def test_valid_versions(self):
+ def test_valid_versions(self) -> None:
# general test of smaller than version numbers
self.assertEqual(version_compare.version_compare("0.1.4", "1.2.0"), -1)
self.assertEqual(version_compare.version_compare("1.1.0", "1.2.0"), -1)
@@ -16,7 +16,7 @@ def test_valid_versions(self):
# test equal
self.assertEqual(version_compare.version_compare("1.2.0", "1.2.0"), 0)
- def test_invalid_versions(self):
+ def test_invalid_versions(self) -> None:
with self.assertRaises(ValueError):
version_compare.version_compare("1.2", "1.1.0")