diff --git a/.github/workflows/discord_release_notification.yml b/.github/workflows/discord_release_notification.yml new file mode 100644 index 000000000..47367714b --- /dev/null +++ b/.github/workflows/discord_release_notification.yml @@ -0,0 +1,37 @@ +name: Discord Release Notification + +on: + release: + types: [published] + +jobs: + notify-discord: + runs-on: ubuntu-latest + steps: + - name: Send Notification to Discord + env: + DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} + # We truncate the description at the models section (starting with ### Models) + # to keep the message short. + # We have also have to format the release description for it to be valid JSON. + # This is done by piping the description to jq. + run: | + DESCRIPTION=$(echo '${{ github.event.release.body }}' | awk '/### Models/{exit}1' | jq -aRs .) + curl -H "Content-Type: application/json" \ + -X POST \ + -d @- \ + "${DISCORD_WEBHOOK}" << EOF + { + "username": "Lightly", + "avatar_url": "https://avatars.githubusercontent.com/u/50146475", + "content": "Lightly ${{ github.event.release.tag_name }} has been released!", + "embeds": [ + { + "title": "${{ github.event.release.name }}", + "url": "${{ github.event.release.html_url }}", + "color": 5814783, + "description": $DESCRIPTION + } + ] + } + EOF diff --git a/.github/workflows/test_code_format.yml b/.github/workflows/test_code_format.yml index 504756df3..61eb862fb 100644 --- a/.github/workflows/test_code_format.yml +++ b/.github/workflows/test_code_format.yml @@ -9,7 +9,6 @@ jobs: test: name: Check runs-on: ubuntu-latest - steps: - name: Checkout Code uses: actions/checkout@v3 @@ -21,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.7" - uses: actions/cache@v2 with: path: ${{ env.pythonLocation }} @@ -30,5 +29,7 @@ jobs: run: pip install -e '.[all]' - name: Run Format Check run: | - export LIGHTLY_SERVER_LOCATION="localhost:-1" make format-check + - name: Run Type Check + run: | + make type-check diff --git a/.gitignore b/.gitignore index 031ae8795..83d4d1fa2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ lightning_logs/ **lightning_logs/ **/__MACOSX datasets/ +dist/ docs/source/tutorials/package/* docs/source/tutorials/platform/* docs/source/tutorials_source/platform/data diff --git a/Makefile b/Makefile index ead2c748d..fcb2ab17b 100644 --- a/Makefile +++ b/Makefile @@ -63,8 +63,15 @@ test: test-fast: pytest tests -# run format checks and tests -all-checks: format-check test +## check typing +type-check: + mypy lightly tests + +## run format checks +static-checks: format-check type-check + +## run format checks and tests +all-checks: static-checks test ## build source and wheel package dist: clean diff --git a/README.md b/README.md index 7938b21e8..321b09125 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,31 @@ -![Lightly Logo](docs/logos/lightly_logo_crop.png) +![Lightly SSL self-supervised learning Logo](docs/logos/lightly_SSL_logo_crop.png) ![GitHub](https://img.shields.io/github/license/lightly-ai/lightly) ![Unit Tests](https://github.com/lightly-ai/lightly/workflows/Unit%20Tests/badge.svg) [![PyPI](https://img.shields.io/pypi/v/lightly)](https://pypi.org/project/lightly/) -[![Downloads](https://pepy.tech/badge/lightly)](https://pepy.tech/project/lightly) +[![Downloads](https://static.pepy.tech/badge/lightly)](https://pepy.tech/project/lightly) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Discord](https://img.shields.io/discord/752876370337726585?logo=discord&logoColor=white&label=discord&color=7289da)](https://discord.gg/xvNJW94) -Lightly is a computer vision framework for self-supervised learning. -> We, at [Lightly](https://www.lightly.ai), are passionate engineers who want to make deep learning more efficient. That's why - together with our community - we want to popularize the use of self-supervised methods to understand and curate raw image data. Our solution can be applied before any data annotation step and the learned representations can be used to visualize and analyze datasets. This allows to select the best set of samples for model training through advanced filtering. +Lightly SSL is a computer vision framework for self-supervised learning. -- [Homepage](https://www.lightly.ai) -- [Web-App](https://app.lightly.ai) - [Documentation](https://docs.lightly.ai/self-supervised-learning/) -- [Lightly Solution Documentation (Lightly Worker & API)](https://docs.lightly.ai/) - [Github](https://github.com/lightly-ai/lightly) - [Discord](https://discord.gg/xvNJW94) (We have weekly paper sessions!) +We've also built a whole platform on top, with additional features for active learning +and [data curation](https://docs.lightly.ai/docs/what-is-lightly). If you're interested in the +Lightly Worker Solution to easily process millions of samples and run [powerful algorithms](https://docs.lightly.ai/docs/selection) +on your data, check out [lightly.ai](https://www.lightly.ai). It's free to get started! + ## Features -Lightly offers features like +This self-supervised learning framework offers the following features: -- Modular framework which exposes low-level building blocks such as loss functions and +- Modular framework, which exposes low-level building blocks such as loss functions and model heads. - Easy to use and written in a PyTorch like style. - Supports custom backbone models for self-supervised pre-training. @@ -66,17 +68,6 @@ Want to jump to the tutorials and see Lightly in action? - [Use Lightly with Custom Augmentations](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_custom_augmentations.html) - [Pre-train a Detectron2 Backbone with Lightly](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_pretrain_detectron2.html) -Tutorials for the Lightly Solution (Lightly Worker & API): - -- [General Docs of Lightly Solution](https://docs.lightly.ai) -- [Active Learning Using YOLOv7 and Comma10k](https://docs.lightly.ai/docs/active-learning-yolov7) -- [Active Learning for Driveable Area Segmentation Using Cityscapes](https://docs.lightly.ai/docs/active-learning-for-driveable-area-segmentation-using-cityscapes) -- [Active Learning for Transactions of Images](https://docs.lightly.ai/docs/active-learning-for-transactions-of-images) -- [Improving YOLOv8 using Active Learning on Videos](https://docs.lightly.ai/docs/active-learning-yolov8-video) -- [Assertion-based Active Learning with YOLOv8](https://docs.lightly.ai/docs/assertion-based-active-learning-tutorial) -- and more ... - - Community and partner projects: - [On-Device Deep Learning with Lightly on an ARM microcontroller](https://github.com/ARM-software/EndpointAI/tree/master/ProofOfConcepts/Vision/OpenMvMaskDefaults) @@ -105,9 +96,6 @@ pip3 install lightly We strongly recommend that you install Lightly in a dedicated virtualenv, to avoid conflicting with your system packages. -If you only want to install the API client without torch and torchvision dependencies -follow the docs on [how to install the Lightly Python Client](https://docs.lightly.ai/docs/install-lightly#install-the-lightly-python-client). - ### Lightly in Action @@ -274,75 +262,48 @@ We provide multi-GPU training examples with distributed gather and synchronized ## Benchmarks Implemented models and their performance on various datasets. Hyperparameters are not -tuned for maximum accuracy. For detailed results and more info about the benchmarks click +tuned for maximum accuracy. For detailed results and more information about the benchmarks click [here](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html). -### Imagenet +### ImageNet1k + +[ImageNet1k benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet1k) -> **Note**: Evaluation settings are based on these papers: -> * Linear: [SimCLR](https://arxiv.org/abs/2002.05709) -> * Finetune: [SimCLR](https://arxiv.org/abs/2002.05709) -> * KNN: [InstDisc](https://arxiv.org/abs/1805.01978) -> -> See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details. +**Note**: Evaluation settings are based on these papers: + * Linear: [SimCLR](https://arxiv.org/abs/2002.05709) + * Finetune: [SimCLR](https://arxiv.org/abs/2002.05709) + * KNN: [InstDisc](https://arxiv.org/abs/1805.01978) + +See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details. -| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | KNN Top1 | Tensorboard | Checkpoint | + +| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | kNN Top1 | Tensorboard | Checkpoint | |----------------|----------|------------|--------|-------------|---------------|----------|-------------|------------| +| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) | | SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | | SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | +| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) | *\*We use square root learning rate scaling instead of linear scaling as it yields -better results for smaller batch sizes. See Appendix B.1 in [SimCLR paper](https://arxiv.org/abs/2002.05709).* - - - -### ImageNette - -| Model | Backbone | Batch Size | Epochs | KNN Top1 | -|-------------|----------|------------|--------|----------| -| BarlowTwins | Res18 | 256 | 800 | 0.852 | -| BYOL | Res18 | 256 | 800 | 0.887 | -| DCL | Res18 | 256 | 800 | 0.861 | -| DCLW | Res18 | 256 | 800 | 0.865 | -| DINO | Res18 | 256 | 800 | 0.888 | -| FastSiam | Res18 | 256 | 800 | 0.873 | -| MAE | ViT-S | 256 | 800 | 0.610 | -| MSN | ViT-S | 256 | 800 | 0.828 | -| Moco | Res18 | 256 | 800 | 0.874 | -| NNCLR | Res18 | 256 | 800 | 0.884 | -| PMSN | ViT-S | 256 | 800 | 0.822 | -| SimCLR | Res18 | 256 | 800 | 0.889 | -| SimMIM | ViT-B32 | 256 | 800 | 0.343 | -| SimSiam | Res18 | 256 | 800 | 0.872 | -| SwaV | Res18 | 256 | 800 | 0.902 | -| SwaVQueue | Res18 | 256 | 800 | 0.890 | -| SMoG | Res18 | 256 | 800 | 0.788 | -| TiCo | Res18 | 256 | 800 | 0.856 | -| VICReg | Res18 | 256 | 800 | 0.845 | -| VICRegL | Res18 | 256 | 800 | 0.778 | - - -### Cifar10 - -| Model | Backbone | Batch Size | Epochs | KNN Top1 | -|-------------|----------|------------|--------|----------| -| BarlowTwins | Res18 | 512 | 800 | 0.859 | -| BYOL | Res18 | 512 | 800 | 0.910 | -| DCL | Res18 | 512 | 800 | 0.874 | -| DCLW | Res18 | 512 | 800 | 0.871 | -| DINO | Res18 | 512 | 800 | 0.848 | -| FastSiam | Res18 | 512 | 800 | 0.902 | -| Moco | Res18 | 512 | 800 | 0.899 | -| NNCLR | Res18 | 512 | 800 | 0.892 | -| SimCLR | Res18 | 512 | 800 | 0.879 | -| SimSiam | Res18 | 512 | 800 | 0.904 | -| SwaV | Res18 | 512 | 800 | 0.884 | -| SMoG | Res18 | 512 | 800 | 0.800 | +better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709).* + +### ImageNet100 +[ImageNet100 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet100) + + +### Imagenette + +[Imagenette benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenette) + + +### CIFAR-10 + +[CIFAR-10 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#cifar-10) ## Terminology @@ -355,7 +316,7 @@ The terms in bold are explained in more detail in our [documentation](https://do ### Next Steps -Head to the [documentation](https://docs.lightly.ai) and see the things you can achieve with Lightly! +Head to the [documentation](https://docs.lightly.ai/self-supervised-learning/) and see the things you can achieve with Lightly! ## Development @@ -438,6 +399,15 @@ make format - [Decoupled Contrastive Learning, 2021](https://arxiv.org/abs/2110.06848) - [solo-learn: A Library of Self-supervised Methods for Visual Representation Learning, 2021](https://www.jmlr.org/papers/volume23/21-1155/21-1155.pdf) +## Company behind this Open Source Framework +[Lightly](https://www.lightly.ai) is a spin-off from ETH Zurich that helps companies +build efficient active learning pipelines to select the most relevant data for their models. + +You can find out more about the company and it's services by following the links below: + +- [Homepage](https://www.lightly.ai) +- [Web-App](https://app.lightly.ai) +- [Lightly Solution Documentation (Lightly Worker & API)](https://docs.lightly.ai/) ## BibTeX If you want to cite the framework feel free to use this: diff --git a/benchmarks/imagenet/resnet50/barlowtwins.py b/benchmarks/imagenet/resnet50/barlowtwins.py new file mode 100644 index 000000000..ec2e5f72e --- /dev/null +++ b/benchmarks/imagenet/resnet50/barlowtwins.py @@ -0,0 +1,112 @@ +import copy +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss import BarlowTwinsLoss +from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.utils import get_weight_decay_parameters +from lightly.transforms import BYOLTransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler + + +class BarlowTwins(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = BarlowTwinsProjectionHead() + self.criterion = BarlowTwinsLoss(lambda_param=5e-3, gather_distributed=True) + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + # Forward pass and loss calculation. + views, targets = batch[0], batch[1] + features = self.forward(torch.cat(views)).flatten(start_dim=1) + z = self.projection_head(features) + z0, z1 = z.chunk(len(views)) + loss = self.criterion(z0, z1) + + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + # Online linear evaluation. + cls_loss, cls_log = self.online_classifier.training_step( + (features.detach(), targets.repeat(len(views))), batch_idx + ) + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + lr_factor = self.batch_size_per_device * self.trainer.world_size / 256 + + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.backbone, self.projection_head] + ) + optimizer = LARS( + [ + {"name": "barlowtwins", "params": params}, + { + "name": "barlowtwins_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + "lr": 0.0048 * lr_factor, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + lr=0.2 * lr_factor, + momentum=0.9, + weight_decay=1.5e-6, + ) + + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=int( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=int(self.trainer.estimated_stepping_batches), + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +# BarlowTwins uses same transform as BYOL. +transform = BYOLTransform() diff --git a/benchmarks/imagenet/resnet50/byol.py b/benchmarks/imagenet/resnet50/byol.py index bfc2f1080..27d550dc6 100644 --- a/benchmarks/imagenet/resnet50/byol.py +++ b/benchmarks/imagenet/resnet50/byol.py @@ -10,7 +10,7 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import get_weight_decay_parameters, update_momentum -from lightly.transforms import SimCLRTransform +from lightly.transforms import BYOLTransform from lightly.utils.benchmarking import OnlineLinearClassifier from lightly.utils.lars import LARS from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule @@ -132,17 +132,19 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } return [optimizer], [scheduler] -# BYOL uses same transform as SimCLR. -transform = SimCLRTransform() +# BYOL uses a slight modification of the SimCLR transforms. +# It uses asymmetric augmentation and solarize. +# Check table 6 in the BYOL paper for more info. +transform = BYOLTransform() diff --git a/benchmarks/imagenet/resnet50/dcl.py b/benchmarks/imagenet/resnet50/dcl.py index 42c66c7c5..f08fdb157 100644 --- a/benchmarks/imagenet/resnet50/dcl.py +++ b/benchmarks/imagenet/resnet50/dcl.py @@ -97,12 +97,12 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } diff --git a/benchmarks/imagenet/resnet50/dclw.py b/benchmarks/imagenet/resnet50/dclw.py index bcae95d6e..6f0bb7e54 100644 --- a/benchmarks/imagenet/resnet50/dclw.py +++ b/benchmarks/imagenet/resnet50/dclw.py @@ -97,12 +97,12 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } diff --git a/benchmarks/imagenet/resnet50/dino.py b/benchmarks/imagenet/resnet50/dino.py index d7924c9b5..a440ce147 100644 --- a/benchmarks/imagenet/resnet50/dino.py +++ b/benchmarks/imagenet/resnet50/dino.py @@ -136,12 +136,12 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } diff --git a/benchmarks/imagenet/resnet50/finetune_eval.py b/benchmarks/imagenet/resnet50/finetune_eval.py index 466f73f1c..9db6f0f88 100644 --- a/benchmarks/imagenet/resnet50/finetune_eval.py +++ b/benchmarks/imagenet/resnet50/finetune_eval.py @@ -11,6 +11,7 @@ from lightly.data import LightlyDataset from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.utils.benchmarking import LinearClassifier, MetricCallback +from lightly.utils.dist import print_rank_zero from lightly.utils.scheduler import CosineWarmupScheduler @@ -63,7 +64,7 @@ def finetune_eval( References: - [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709 """ - print("Running fine-tune evaluation...") + print_rank_zero("Running fine-tune evaluation...") # Setup training data. train_transform = T.Compose( @@ -81,7 +82,7 @@ def finetune_eval( shuffle=True, num_workers=num_workers, drop_last=True, - persistent_workers=True, + persistent_workers=False, ) # Setup validation data. @@ -99,7 +100,7 @@ def finetune_eval( batch_size=batch_size_per_device, shuffle=False, num_workers=num_workers, - persistent_workers=True, + persistent_workers=False, ) # Train linear classifier. @@ -116,6 +117,7 @@ def finetune_eval( logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"), precision=precision, strategy="ddp_find_unused_parameters_true", + num_sanity_val_steps=0, ) classifier = FinetuneLinearClassifier( model=model, @@ -130,4 +132,6 @@ def finetune_eval( val_dataloaders=val_dataloader, ) for metric in ["val_top1", "val_top5"]: - print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}") + print_rank_zero( + f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}" + ) diff --git a/benchmarks/imagenet/resnet50/knn_eval.py b/benchmarks/imagenet/resnet50/knn_eval.py index 51017c0c1..62d7501a8 100644 --- a/benchmarks/imagenet/resnet50/knn_eval.py +++ b/benchmarks/imagenet/resnet50/knn_eval.py @@ -10,6 +10,7 @@ from lightly.data import LightlyDataset from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.utils.benchmarking import KNNClassifier, MetricCallback +from lightly.utils.dist import print_rank_zero def knn_eval( @@ -34,7 +35,7 @@ def knn_eval( References: - [0]: InstDict, 2018, https://arxiv.org/abs/1805.01978 """ - print("Running KNN evaluation...") + print_rank_zero("Running KNN evaluation...") # Setup training data. transform = T.Compose( @@ -81,6 +82,7 @@ def knn_eval( metric_callback, ], strategy="ddp_find_unused_parameters_true", + num_sanity_val_steps=0, ) trainer.fit( model=classifier, @@ -88,4 +90,4 @@ def knn_eval( val_dataloaders=val_dataloader, ) for metric in ["val_top1", "val_top5"]: - print(f"knn {metric}: {max(metric_callback.val_metrics[metric])}") + print_rank_zero(f"knn {metric}: {max(metric_callback.val_metrics[metric])}") diff --git a/benchmarks/imagenet/resnet50/linear_eval.py b/benchmarks/imagenet/resnet50/linear_eval.py index 21a4aef48..08af6b8c0 100644 --- a/benchmarks/imagenet/resnet50/linear_eval.py +++ b/benchmarks/imagenet/resnet50/linear_eval.py @@ -10,6 +10,7 @@ from lightly.data import LightlyDataset from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.utils.benchmarking import LinearClassifier, MetricCallback +from lightly.utils.dist import print_rank_zero def linear_eval( @@ -40,7 +41,7 @@ def linear_eval( References: - [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709 """ - print("Running linear evaluation...") + print_rank_zero("Running linear evaluation...") # Setup training data. train_transform = T.Compose( @@ -58,7 +59,7 @@ def linear_eval( shuffle=True, num_workers=num_workers, drop_last=True, - persistent_workers=True, + persistent_workers=False, ) # Setup validation data. @@ -76,7 +77,7 @@ def linear_eval( batch_size=batch_size_per_device, shuffle=False, num_workers=num_workers, - persistent_workers=True, + persistent_workers=False, ) # Train linear classifier. @@ -93,6 +94,7 @@ def linear_eval( logger=TensorBoardLogger(save_dir=str(log_dir), name="linear_eval"), precision=precision, strategy="ddp_find_unused_parameters_true", + num_sanity_val_steps=0, ) classifier = LinearClassifier( model=model, @@ -107,4 +109,6 @@ def linear_eval( val_dataloaders=val_dataloader, ) for metric in ["val_top1", "val_top5"]: - print(f"max linear {metric}: {max(metric_callback.val_metrics[metric])}") + print_rank_zero( + f"max linear {metric}: {max(metric_callback.val_metrics[metric])}" + ) diff --git a/benchmarks/imagenet/resnet50/main.py b/benchmarks/imagenet/resnet50/main.py index ce1e9e40b..505d4412c 100644 --- a/benchmarks/imagenet/resnet50/main.py +++ b/benchmarks/imagenet/resnet50/main.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Sequence, Union +import barlowtwins import byol import dcl import dclw @@ -14,6 +15,7 @@ import simclr import swav import torch +import vicreg from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ( DeviceStatsMonitor, @@ -27,6 +29,7 @@ from lightly.data import LightlyDataset from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.utils.benchmarking import MetricCallback +from lightly.utils.dist import print_rank_zero parser = ArgumentParser("ImageNet ResNet50 Benchmarks") parser.add_argument("--train-dir", type=Path, default="/datasets/imagenet/train") @@ -38,6 +41,7 @@ parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--devices", type=int, default=1) parser.add_argument("--precision", type=str, default="16-mixed") +parser.add_argument("--ckpt-path", type=Path, default=None) parser.add_argument("--compile-model", action="store_true") parser.add_argument("--methods", type=str, nargs="+") parser.add_argument("--num-classes", type=int, default=1000) @@ -46,6 +50,10 @@ parser.add_argument("--skip-finetune-eval", action="store_true") METHODS = { + "barlowtwins": { + "model": barlowtwins.BarlowTwins, + "transform": barlowtwins.transform, + }, "byol": {"model": byol.BYOL, "transform": byol.transform}, "dcl": {"model": dcl.DCL, "transform": dcl.transform}, "dclw": {"model": dclw.DCLW, "transform": dclw.transform}, @@ -53,6 +61,7 @@ "mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform}, "simclr": {"model": simclr.SimCLR, "transform": simclr.transform}, "swav": {"model": swav.SwAV, "transform": swav.transform}, + "vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform}, } @@ -72,6 +81,7 @@ def main( skip_knn_eval: bool, skip_linear_eval: bool, skip_finetune_eval: bool, + ckpt_path: Union[Path, None], ) -> None: torch.set_float32_matmul_precision("high") @@ -87,11 +97,13 @@ def main( if compile_model and hasattr(torch, "compile"): # Compile model if PyTorch supports it. - print("Compiling model...") + print_rank_zero("Compiling model...") model = torch.compile(model) if epochs <= 0: - print("Epochs <= 0, skipping pretraining.") + print_rank_zero("Epochs <= 0, skipping pretraining.") + if ckpt_path is not None: + model.load_state_dict(torch.load(ckpt_path)["state_dict"]) else: pretrain( model=model, @@ -105,10 +117,11 @@ def main( accelerator=accelerator, devices=devices, precision=precision, + ckpt_path=ckpt_path, ) if skip_knn_eval: - print("Skipping KNN eval.") + print_rank_zero("Skipping KNN eval.") else: knn_eval.knn_eval( model=model, @@ -123,7 +136,7 @@ def main( ) if skip_linear_eval: - print("Skipping linear eval.") + print_rank_zero("Skipping linear eval.") else: linear_eval.linear_eval( model=model, @@ -139,7 +152,7 @@ def main( ) if skip_finetune_eval: - print("Skipping fine-tune eval.") + print_rank_zero("Skipping fine-tune eval.") else: finetune_eval.finetune_eval( model=model, @@ -167,8 +180,9 @@ def pretrain( accelerator: str, devices: int, precision: str, + ckpt_path: Union[Path, None], ) -> None: - print(f"Running pretraining for {method}...") + print_rank_zero(f"Running pretraining for {method}...") # Setup training data. train_transform = METHODS[method]["transform"] @@ -179,7 +193,7 @@ def pretrain( shuffle=True, num_workers=num_workers, drop_last=True, - persistent_workers=True, + persistent_workers=False, ) # Setup validation data. @@ -197,7 +211,7 @@ def pretrain( batch_size=batch_size_per_device, shuffle=False, num_workers=num_workers, - persistent_workers=True, + persistent_workers=False, ) # Train model. @@ -217,14 +231,17 @@ def pretrain( precision=precision, strategy="ddp_find_unused_parameters_true", sync_batchnorm=True, + num_sanity_val_steps=0, ) + trainer.fit( model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, + ckpt_path=ckpt_path, ) for metric in ["val_online_cls_top1", "val_online_cls_top5"]: - print(f"max {metric}: {max(metric_callback.val_metrics[metric])}") + print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}") if __name__ == "__main__": diff --git a/benchmarks/imagenet/resnet50/simclr.py b/benchmarks/imagenet/resnet50/simclr.py index de6958638..6562025ae 100644 --- a/benchmarks/imagenet/resnet50/simclr.py +++ b/benchmarks/imagenet/resnet50/simclr.py @@ -96,12 +96,12 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } diff --git a/benchmarks/imagenet/resnet50/swav.py b/benchmarks/imagenet/resnet50/swav.py index e3117e4f8..96eb7a1fb 100644 --- a/benchmarks/imagenet/resnet50/swav.py +++ b/benchmarks/imagenet/resnet50/swav.py @@ -156,12 +156,12 @@ def configure_optimizers(self): scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=( + warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), end_value=0.0006 * (self.batch_size_per_device * self.trainer.world_size) / 256, diff --git a/benchmarks/imagenet/resnet50/vicreg.py b/benchmarks/imagenet/resnet50/vicreg.py new file mode 100644 index 000000000..0da61677e --- /dev/null +++ b/benchmarks/imagenet/resnet50/vicreg.py @@ -0,0 +1,127 @@ +from typing import List, Tuple + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import Identity +from torchvision.models import resnet50 + +from lightly.loss.vicreg_loss import VICRegLoss +from lightly.models.modules.heads import VICRegProjectionHead +from lightly.models.utils import get_weight_decay_parameters +from lightly.transforms.vicreg_transform import VICRegTransform +from lightly.utils.benchmarking import OnlineLinearClassifier +from lightly.utils.lars import LARS +from lightly.utils.scheduler import CosineWarmupScheduler + + +class VICReg(LightningModule): + def __init__(self, batch_size_per_device: int, num_classes: int) -> None: + super().__init__() + self.save_hyperparameters() + self.batch_size_per_device = batch_size_per_device + + resnet = resnet50() + resnet.fc = Identity() # Ignore classification head + self.backbone = resnet + self.projection_head = VICRegProjectionHead(num_layers=2) + self.criterion = VICRegLoss() + + self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) + + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + features = self.forward(torch.cat(views)).flatten(start_dim=1) + z = self.projection_head(features) + z_a, z_b = z.chunk(len(views)) + loss = self.criterion(z_a=z_a, z_b=z_b) + self.log( + "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) + ) + + # Online linear evaluation. + cls_loss, cls_log = self.online_classifier.training_step( + (features.detach(), targets.repeat(len(views))), batch_idx + ) + + self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) + return loss + cls_loss + + def validation_step( + self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int + ) -> Tensor: + images, targets = batch[0], batch[1] + features = self.forward(images).flatten(start_dim=1) + cls_loss, cls_log = self.online_classifier.validation_step( + (features.detach(), targets), batch_idx + ) + self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) + return cls_loss + + def configure_optimizers(self): + # Don't use weight decay for batch norm, bias parameters, and classification + # head to improve performance. + params, params_no_weight_decay = get_weight_decay_parameters( + [self.backbone, self.projection_head] + ) + global_batch_size = self.batch_size_per_device * self.trainer.world_size + base_lr = _get_base_learning_rate(global_batch_size=global_batch_size) + optimizer = LARS( + [ + {"name": "vicreg", "params": params}, + { + "name": "vicreg_no_weight_decay", + "params": params_no_weight_decay, + "weight_decay": 0.0, + }, + { + "name": "online_classifier", + "params": self.online_classifier.parameters(), + "weight_decay": 0.0, + }, + ], + # Linear learning rate scaling with a base learning rate of 0.2. + # See https://arxiv.org/pdf/2105.04906.pdf for details. + lr=base_lr * global_batch_size / 256, + momentum=0.9, + weight_decay=1e-6, + ) + scheduler = { + "scheduler": CosineWarmupScheduler( + optimizer=optimizer, + warmup_epochs=( + self.trainer.estimated_stepping_batches + / self.trainer.max_epochs + * 10 + ), + max_epochs=self.trainer.estimated_stepping_batches, + end_value=0.01, # Scale base learning rate from 0.2 to 0.002. + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +# VICReg transform +transform = VICRegTransform() + + +def _get_base_learning_rate(global_batch_size: int) -> float: + """Returns the base learning rate for training 100 epochs with a given batch size. + + This follows section C.4 in https://arxiv.org/pdf/2105.04906.pdf. + + """ + if global_batch_size == 128: + return 0.8 + elif global_batch_size == 256: + return 0.5 + elif global_batch_size == 512: + return 0.4 + else: + return 0.3 diff --git a/docs/Makefile b/docs/Makefile index 95f708852..5900e0fe6 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -60,20 +60,6 @@ download-noplot: unzip $(ZIPOPTS) $(DATADIR)/resources.zip -d $(DOCKERSOURCE);\ unzip $(ZIPOPTS) $(DATADIR)/resources.zip -d $(DOCKER_ARCHIVE_SOURCE); \ - # pizza dataset - @if [ ! -d $(PLATFORMSOURCE)/pizzas/salami ]; then \ - wget -N https://storage.googleapis.com/datasets_boris/pizzas.zip -P $(DATADIR);\ - unzip $(ZIPOPTS) $(DATADIR)/pizzas.zip -d $(PLATFORMSOURCE);\ - mkdir -p $(PLATFORMSOURCE)/pizzas/margherita;\ - mkdir -p $(PLATFORMSOURCE)/pizzas/salami;\ - mv $(PLATFORMSOURCE)/pizzas/margherita*.jpg $(PLATFORMSOURCE)/pizzas/margherita;\ - mv $(PLATFORMSOURCE)/pizzas/salami*.jpg $(PLATFORMSOURCE)/pizzas/salami;\ - fi - - # sunflowers dataset - @if [ ! -f $(DATADIR)/Sunflowers.zip ]; then \ - wget -N https://storage.googleapis.com/datasets_boris/Sunflowers.zip -P $(DATADIR);\ - fi # Download also the datasets needed for the tutorials download: download-noplot diff --git a/docs/logos/lightly_SSL_logo_crop.png b/docs/logos/lightly_SSL_logo_crop.png new file mode 100644 index 000000000..62028eaf2 Binary files /dev/null and b/docs/logos/lightly_SSL_logo_crop.png differ diff --git a/docs/logos/lightly_SSL_logo_crop_white_text.png b/docs/logos/lightly_SSL_logo_crop_white_text.png new file mode 100644 index 000000000..90733d7d2 Binary files /dev/null and b/docs/logos/lightly_SSL_logo_crop_white_text.png differ diff --git a/docs/source/_templates/footer.html b/docs/source/_templates/footer.html index 86060f5ef..217e11623 100644 --- a/docs/source/_templates/footer.html +++ b/docs/source/_templates/footer.html @@ -26,7 +26,13 @@ {%- else %} {% set copyright = copyright|e %} - © {% trans %}Copyright{% endtrans %} {{ copyright_year }}, {{ copyright }} + © {% trans %}Copyright{% endtrans %} {{ copyright_year }} +  | {{ copyright }} +  |  + Lightly SSL source code + Source Code + +  | Lightly Worker Solution documentation {%- endif %} {%- endif %} diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index bcf32e9f5..e39f8042f 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -8,6 +8,36 @@ We need this to override the footer --> {%- block content %} + +
+ Looking to easily do active learning on millions of samples? See our Lighly Worker docs. +
{% if theme_style_external_links|tobool %} -Check the installation of lightly ------------------------------------ -To see if the lightly command-line tool was installed correctly, you can run the -following command which will print the installed lightly version: +Check the installation of Lightly SSL +------------------------------------- +To see if the Lightly SSL command-line tool was installed correctly, you can run the +following command which will print the version of the installed Lightly SSL package: .. code-block:: bash lightly-version -If lightly was installed correctly, you should see something like this: +If Lightly SSL was installed correctly, you should see something like this: .. code-block:: bash diff --git a/docs/source/getting_started/distributed_training.rst b/docs/source/getting_started/distributed_training.rst index e1b140c28..30daefd7b 100644 --- a/docs/source/getting_started/distributed_training.rst +++ b/docs/source/getting_started/distributed_training.rst @@ -3,7 +3,7 @@ Distributed Training ==================== -Lightly supports training your model on multiple GPUs using Pytorch Lightning +Lightly SSL supports training your model on multiple GPUs using Pytorch Lightning and Distributed Data Parallel (DDP) training. You can find reference implementations for all our models in the :ref:`models` section. @@ -12,7 +12,7 @@ Training with multiple gpus is also available from the command line: :ref:`cli-t For details on distributed training we recommend the following pages: - `Pytorch Distributed Overview `_ -- `Pytorch Lightning Multi-GPU Training `_ +- `Pytorch Lightning Multi-GPU Training `_ There are different levels of synchronization for distributed training. One can diff --git a/docs/source/getting_started/install.rst b/docs/source/getting_started/install.rst index e05f61fc8..1cb8ff670 100644 --- a/docs/source/getting_started/install.rst +++ b/docs/source/getting_started/install.rst @@ -4,24 +4,24 @@ Installation Supported Python versions ------------------------- -Lightly requires Python 3.6+. We recommend installing Lighlty in a Linux or OSX environment. +Lightly SSL requires Python 3.6+. We recommend installing Lightly SSL in a Linux or OSX environment. .. _rst-installing: -Installing Lightly ------------------- +Installing Lightly SSL +---------------------- -You can install Lightly and its dependencies from PyPi with: +You can install Lightly SSL and its dependencies from PyPi with: .. code-block:: bash pip install lightly -We strongly recommend that you install Lightly in a dedicated virtualenv, to avoid conflicting with your system packages. +We strongly recommend that you install Lightly SSL in a dedicated virtualenv, to avoid conflicting with your system packages. Dependencies ------------ -Lightly currently uses `PyTorch `_ as the underlying deep learning framework. +Lightly SSL currently uses `PyTorch `_ as the underlying deep learning framework. On top of PyTorch we use `Hydra `_ for managing configurations and `PyTorch Lightning `_ for training models. @@ -35,4 +35,4 @@ If you want to work with video files you need to additionally install Next Steps ------------ -Check out our tutorial: :ref:`lightly-tutorials` +Start with one of our tutorials: :ref:`input-structure-label` diff --git a/docs/source/getting_started/lightly_at_a_glance.rst b/docs/source/getting_started/lightly_at_a_glance.rst index e4b4af2f6..78361da69 100644 --- a/docs/source/getting_started/lightly_at_a_glance.rst +++ b/docs/source/getting_started/lightly_at_a_glance.rst @@ -3,14 +3,14 @@ Self-supervised learning ======================== -Lightly is a computer vision framework for training deep learning models using self-supervised learning. +Lightly SSL is a computer vision framework for training deep learning models using self-supervised learning. The framework can be used for a wide range of useful applications such as finding the nearest neighbors, similarity search, transfer learning, or data analytics. -How Lightly Works ------------------ -The flexible design of Lightly makes it easy to integrate in your Python code. Lightly is built +How Lightly SSL Works +--------------------- +The flexible design of Lightly SSL makes it easy to integrate in your Python code. Lightly SSL is built completely around PyTorch and the different pieces can be put together to fit *your* requirements. Data and Transformations @@ -219,12 +219,12 @@ Furthermore, the ResNet backbone can be used for transfer and few-shot learning. Self-supervised learning does not require labels for a model to be trained on. Lightly, however, supports the use of additional labels. For example, if you train a model on a folder 'cats' with subfolders 'Maine Coon', 'Bengal' and 'British Shorthair' - Lightly automatically returns the enumerated labels as a list. + Lightly SSL automatically returns the enumerated labels as a list. -Lightly in Three Lines ----------------------------------------- +Lightly SSL in Three Lines +-------------------------- -Lightly also offers an easy-to-use interface. The following lines show how the package can +Lightly SSL also offers an easy-to-use interface. The following lines show how the package can be used to train a model with self-supervision and create embeddings with only three lines of code. diff --git a/docs/source/getting_started/main_concepts.rst b/docs/source/getting_started/main_concepts.rst index a891539b3..8c2c561f0 100644 --- a/docs/source/getting_started/main_concepts.rst +++ b/docs/source/getting_started/main_concepts.rst @@ -6,18 +6,18 @@ Main Concepts Self-Supervised Learning ------------------------ -The figure below shows an overview of the different concepts used by the Lightly package +The figure below shows an overview of the different concepts used by the Lightly SSL package and a schema of how they interact. The expressions in **bold** are explained further below. .. figure:: images/lightly_overview.png - :align: center - :alt: Lightly Overview + :align: center + :alt: Lightly SSL Overview - Overview of the different concepts used by the Lightly package and how they interact. + Overview of the different concepts used by the Lightly SSL package and how they interact. * **Dataset** - In Lightly, datasets are accessed through :py:class:`~lightly.data.dataset.LightlyDataset`. + In Lightly SSL, datasets are accessed through :py:class:`~lightly.data.dataset.LightlyDataset`. You can create a :py:class:`~lightly.data.dataset.LightlyDataset` from a directory of images or videos, or directly from a `torchvision dataset `_. You can learn more about this in our tutorial: @@ -36,7 +36,7 @@ below. * **Collate Function** The collate function aggregates the views of multiple images into a single batch. - You can use the default collate function. Lightly also provides a + You can use the default collate function. Lightly SSL also provides a :py:class:`~lightly.data.multi_view_collate.MultiViewCollate` * **Dataloader** @@ -56,7 +56,7 @@ below. They project the outputs of the backbone, commonly called *embeddings*, *representations*, or *features*, into a new space in which the loss is calculated. This has been found to be hugely beneficial instead of directly calculating the loss - on the embeddings. Lightly provides common :py:mod:`~lightly.models.modules.heads` + on the embeddings. Lightly SSL provides common :py:mod:`~lightly.models.modules.heads` that can be added to any backbone. * **Model** @@ -71,26 +71,25 @@ below. * :ref:`sphx_glr_tutorials_package_tutorial_simsiam_esa.py` * **Loss** - The loss function plays a crucial role in self-supervised learning. Lightly provides + The loss function plays a crucial role in self-supervised learning. Lightly SSL provides common loss functions in the :py:mod:`~lightly.loss` module. * **Optimizer** - With Lightly, you can use any `PyTorch optimizer `_ + With Lightly SSL, you can use any `PyTorch optimizer `_ to train your model. * **Training** The model can either be trained using a plain `PyTorch training loop `_ or with a dedicated framework such as `PyTorch Lightning `_. - Lightly lets you choose what is best for you. Check out our :ref:`models ` and - `tutorials `_ - sections on how to train models with PyTorch or PyTorch Lightning. + Lightly SSL lets you choose what is best for you. Check out our :ref:`models ` and + :ref:`tutorials ` sections on how to train models with PyTorch + or PyTorch Lightning. * **Image Embeddings** During the training process, the model learns to create compact embeddings from images. The embeddings, also often called representations or features, can then be used for tasks such as identifying similar images or creating a diverse subset from your data: - * :ref:`lightly-tutorial-sunflowers` * :ref:`lightly-simsiam-tutorial-4` * **Pre-Trained Backbone** diff --git a/docs/source/index.rst b/docs/source/index.rst index c4e02bf77..81949504a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -4,36 +4,37 @@ contain the root `toctree` directive. -.. image:: ../logos/lightly_logo_crop.png - :width: 600 - :alt: Lightly +.. image:: ../logos/lightly_SSL_logo_crop.png + :width: 600 + :align: center + :alt: Lightly SSL Self-Supervised Learning Documentation =================================== .. note:: These pages document the Lightly self-supervised learning library. - If you are looking for Lightly Worker Solution to easily process millions - of samples and run powerful active learning algorithms on your data - please follow - `Lightly Worker documentation `_. + If you are looking for the Lightly Worker Solution with + advanced `active learning algorithms `_ and + `selection strategies `_ to select the best samples + within millions of unlabeled images or video frames stored in your cloud storage or locally, + please follow our `Lightly Worker documentation `_. -Lightly is a computer vision framework for self-supervised learning. +Lightly SSL is a computer vision framework for self-supervised learning. -With Lightly you can train deep learning models using self-supervision. +With Lightly SSL you can train deep learning models using self-supervision. This means, that you don’t require any labels to train a model. -Lightly has been built to help you understand and work with large unlabeled +Lightly SSL has been built to help you understand and work with large unlabeled datasets. It is built on top of PyTorch and therefore fully compatible with other frameworks such as Fast.ai. -Lightly -------- +Lightly AI +---------- - `Homepage `_ -- `Web-App `_ -- `Documentation `_ -- `Lightly Solution Documentation (Lightly Worker & API) `_ +- `Lightly Worker Solution Documentation `_ +- `Lightly Platform `_ - `Github `_ - `Discord `_ (We have weekly paper sessions!) @@ -58,8 +59,12 @@ Lightly :maxdepth: 1 :caption: Tutorials - tutorials/package.rst - tutorials/platform.rst + tutorials/structure_your_input.rst + tutorials/package/tutorial_moco_memory_bank.rst + tutorials/package/tutorial_simclr_clothing.rst + tutorials/package/tutorial_simsiam_esa.rst + tutorials/package/tutorial_custom_augmentations.rst + tutorials/package/tutorial_pretrain_detectron2.rst .. toctree:: :maxdepth: 1 diff --git a/docs/source/lightly.cli.rst b/docs/source/lightly.cli.rst index 62996119e..44861d86d 100644 --- a/docs/source/lightly.cli.rst +++ b/docs/source/lightly.cli.rst @@ -35,7 +35,7 @@ lightly.cli .config.config.yaml ------------------- -The default settings for all command line tools in the lightly Python package are stored in a YAML config file. +The default settings for all command line tools in the Lightly SSL Python package are stored in a YAML config file. The config file is distributed along with the Python package and can be adapted to fit custom requirements. The arguments are grouped into namespaces. For example, everything related to the embedding model is grouped under diff --git a/docs/source/tutorials/package.rst b/docs/source/tutorials/package.rst deleted file mode 100644 index 1b0e2904f..000000000 --- a/docs/source/tutorials/package.rst +++ /dev/null @@ -1,22 +0,0 @@ -.. _lightly-tutorials: - -Python Package -=================================== - -With the lightly framework you can use the power of self-supervised learning -for computervision with ease. Here we show you tutorials to help you work with -the Python library. - -Since lightly is built on top of `PyTorch `_ -and `PyTorch Lightning `_ -you might want to have a look at the two frameworks to understand basic concepts. - -.. toctree:: - :maxdepth: 1 - - structure_your_input.rst - package/tutorial_moco_memory_bank.rst - package/tutorial_simclr_clothing.rst - package/tutorial_simsiam_esa.rst - package/tutorial_custom_augmentations.rst - package/tutorial_pretrain_detectron2.rst diff --git a/docs/source/tutorials/platform.rst b/docs/source/tutorials/platform.rst deleted file mode 100644 index b7b499457..000000000 --- a/docs/source/tutorials/platform.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _platform-tutorials-label: - -Platform -=================================== - -.. warning:: - **Tutorials are outdated** - - These tutorials use a deprecated workflow of the Lightly Solution and will be removed in the future. - Please refer to the `new documentation and tutorials `_ instead. - -Lightly is more than just a framework for self-supervised learning. We built a complete data curation platform on top. -Use the embeddings generated using the lightly framework and use them to curate your dataset. Collaborate with your friends -and share the curated data with your favorite data labeling partner. - -In this tutorial series, you will learn how to get the most out of the platform. - -.. toctree:: - :maxdepth: 1 - - platform/tutorial_pizza_filter.rst - platform/tutorial_sunflowers.rst - platform/tutorial_active_learning.rst - platform/tutorial_active_learning_detectron2.rst - platform/tutorial_aquarium_custom_metadata.rst - platform/tutorial_cropped_objects_metadata.rst - Tutorial 7: Active Learning with Nvidia TLT - Tutorial 8: Integration with LabelStudio for Active Learning - Tutorial 9: Embedded COVID mask detection - platform/tutorial_label_studio_export.rst \ No newline at end of file diff --git a/docs/source/tutorials/structure_your_input.rst b/docs/source/tutorials/structure_your_input.rst index 7994f7b11..976bba9bf 100644 --- a/docs/source/tutorials/structure_your_input.rst +++ b/docs/source/tutorials/structure_your_input.rst @@ -3,23 +3,32 @@ Tutorial 1: Structure Your Input ================================ -The `lightly Python package `_ can process image datasets to generate embeddings -or to upload data to the `Lightly platform `_. In this tutorial you will learn how to structure -your image dataset such that it is understood by our framework. - -You can also skip this tutorial and jump right into training a model: +If you are familiar with torch-like image dataset, you can skip this tutorial and +jump right into training a model: - :ref:`lightly-moco-tutorial-2` - :ref:`lightly-simclr-tutorial-3` +- :ref:`lightly-simsiam-tutorial-4` +- :ref:`lightly-custom-augmentation-5` +- :ref:`lightly-detectron-tutorial-6` + +If you are looking for a use case that's not covered by the above tutorials please +let us know by `creating an issue `_ +for it. + Supported File Types -------------------- +By default, the `Lightly SSL Python package `_ +can process images or videos for self-supervised learning or for generating embeddings. +You can always write your own torch-like dataset to use other file types. + Images ^^^^^^^^^^^^^^^^^^^^^ -Since lightly uses `Pillow `_ -for image loading it also supports all the image formats supported by +Since Lightly SSL uses `Pillow `_ +for image loading, it also supports all the image formats supported by `Pillow `_. - .jpg, .png, .tiff and @@ -28,14 +37,13 @@ for image loading it also supports all the image formats supported by Videos ^^^^^^^^^^^^^^^^^^^^^ -To load videos directly lightly uses +To load videos directly, Lightly SSL uses `torchvision `_ and `PyAV `_. The following formats are supported. - .mov, .mp4 and .avi - Image Folder Datasets --------------------- @@ -46,7 +54,7 @@ Flat Directory Containing Images ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ You can store all images of interest in a single folder without additional hierarchy. For example below, -lightly will load all filenames and images in the directory `data/`. Additionally, it will assign all images +Lightly SSL will load all filenames and images in the directory `data/`. Additionally, it will assign all images a placeholder label. .. code-block:: bash @@ -58,7 +66,7 @@ a placeholder label. ... +--- img-N.jpg -For the structure above, lightly will understand the input as follows: +For the structure above, Lightly SSL will understand the input as follows: .. code-block:: python @@ -80,7 +88,7 @@ Directory with Subdirectories Containing Images ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ You can give structure to your input directory by collecting the input images in subdirectories. In this case, -the filenames loaded by lightly are with respect to the "root directory" `data/`. Furthermore, lightly assigns +the filenames loaded by Lightly SSL are with respect to the "root directory" `data/`. Furthermore, Lightly SSL assigns each image a so-called "weak-label" indicating to which subdirectory it belongs. .. code-block:: bash @@ -106,7 +114,7 @@ each image a so-called "weak-label" indicating to which subdirectory it belongs. ... +-- img-N10.jpg -For the structure above, lightly will understand the input as follows: +For the structure above, Lightly SSL will understand the input as follows: .. code-block:: python @@ -136,9 +144,9 @@ For the structure above, lightly will understand the input as follows: Video Folder Datasets --------------------- -The lightly Python package allows you to work `directly` on video data, without having +The Lightly SSL Python package allows you to work `directly` on video data, without having to exctract the frames first. This can save a lot of disk space as video files are -typically strongly compressed. Using lightly on video data is as simple as pointing +typically strongly compressed. Using Lightly SSL on video data is as simple as pointing the software at an input directory where one or more videos are stored. The package will automatically detect all video files and index them so that each frame can be accessed. @@ -154,175 +162,10 @@ An example for an input directory with videos could look like this: +-- my_video_4.avi We assign a weak label to each video. -To upload the three videos from above to the platform, you can use -.. code-block:: bash - - lightly-upload token='123' new_dataset_name='my_video_dataset' input_dir='data/' - -All other operations (like training a self-supervised model and embedding the frames individually) -also work on video data. Give it a try! .. note:: Randomly accessing video frames is slower compared to accessing the extracted frames on disk. However, by working directly on video files, one can save a lot of disk space because the frames do not have to be extracted beforehand. - - -Embedding Files ---------------- - -Embeddings generated by the lightly Python package are typically stored in a `.csv` file and can then be uploaded to the -Lightly platform from the command line. If the embeddings were generated with the lightly command-line tool, they have -the correct format already. - -You can also save your own embeddings in a `.csv` file to upload them. In that case, make sure the file meets the format -requirements: Use the `save_embeddings` function from `lightly.utils.io` to convert your embeddings, weak-labels, and -filenames to the right shape. - -.. code-block:: python - - import lightly.utils.io as io - - # embeddings: - # embeddings are stored as an n_samples x dim numpy array - embeddings = np.array([[0.1, 0.5], - [0.2, 0.2], - [0.1, 0.9], - [0.3, 0.2]]) - - # weak-labels - # a list of integers carrying meta-information about the images - labels = [0, 1, 1, 0] - - # filenames - # list of strings containing the filenames of the images w.r.t the input directory - filenames = [ - 'weak-label-0/img-1.jpg', - 'weak-label-1/img-1.jpg', - 'weak-label-1/img-2.jpg', - 'weak-label-0/img-2.jpg', - ] - - io.save_embeddings('my_embeddings_file.csv', embeddings, labels, filenames) - -The code shown above will produce the following `.csv` file: - -.. list-table:: my_embeddings_file.csv - :widths: 50 50 50 50 - :header-rows: 1 - - * - filenames - - embedding_0 - - embedding_1 - - labels - * - weak-label-0/img-1.jpg - - 0.1 - - 0.5 - - 0 - * - weak-label-1/img-1.jpg - - 0.2 - - 0.2 - - 1 - * - weak-label-1/img-2.jpg - - 0.1 - - 0.9 - - 1 - * - weak-label-0/img-2.jpg - - 0.3 - - 0.2 - - 0 - -.. note:: Note that lightly automatically creates "weak" labels for datasets - with subfolders. Each subfolder corresponds to one weak label. - The labels are called "weak" since they might not be used for a task - you want to solve with ML directly but still can be relevant to group - the data into buckets. - - -Advanced usage of Embeddings -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In some cases you want to enrich the embeddings with additional information. -The lightly csv scheme is very simple and can be easily extended. -For example, you can add your own embeddings to the existing embeddings. This -could be useful if you have additional meta information about each sample. - -.. _lightly-custom-labels: - -Add Custom Embeddings -"""""""""""""""""""""""""""""" - -To add custom embeddings you need to add more embedding columns to the .csv file. -Make sure you keep the enumeration of the embeddings in correct order. - - -Here you see an embedding from lightly with a 2-dimensional embedding vector. - -.. list-table:: lightly_embeddings.csv - :widths: 50 50 50 50 - :header-rows: 1 - - * - filenames - - embedding_0 - - embedding_1 - - labels - * - img-1.jpg - - 0.1 - - 0.5 - - 0 - * - img-2.jpg - - 0.2 - - 0.2 - - 0 - * - img-3.jpg - - 0.1 - - 0.9 - - 1 - -We can now append our embedding vector to the .csv file. - -.. list-table:: lightly_with_custom_embeddings.csv - :widths: 50 50 50 50 50 50 - :header-rows: 1 - - * - filenames - - embedding_0 - - embedding_1 - - embedding_2 - - embedding_3 - - labels - * - img-1.jpg - - 0.1 - - 0.5 - - 0.2 - - 0.7 - - 0 - * - img-2.jpg - - 0.2 - - -0.2 - - 1.1 - - -0.4 - - 0 - * - img-3.jpg - - 0.1 - - 0.9 - - -0.2 - - 0.5 - - 1 - -.. note:: The embedding columns must be grouped together. You can not have - another column between two embedding columns. - - -Next Steps ------------------ - -Now that you understand the various data formats lightly supports you can -start training a model: - -- :ref:`lightly-moco-tutorial-2` -- :ref:`lightly-simclr-tutorial-3` -- :ref:`lightly-simsiam-tutorial-4` -- :ref:`lightly-custom-augmentation-5` \ No newline at end of file diff --git a/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg b/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg deleted file mode 100644 index 7e976c820..000000000 Binary files a/docs/source/tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg and /dev/null differ diff --git a/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg b/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg deleted file mode 100644 index 589b47933..000000000 Binary files a/docs/source/tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg and /dev/null differ diff --git a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py b/docs/source/tutorials_source/platform/tutorial_pizza_filter.py deleted file mode 100644 index 3500afeed..000000000 --- a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -This documentation accompanies the video tutorial: `youtube link `_ - -############################################################################## - -.. _lightly-tutorial-pizza-filter: - -Tutorial 1: Curate Pizza Images -=============================== - -.. warning:: - **Tutorial is outdated** - - This tutorial uses a deprecated workflow of the Lightly Solution and will be removed in the future. - Please refer to the `new documentation and tutorials `_ instead. - -In this tutorial, you will learn how to upload a dataset to the Lightly platform, -curate the data, and finally use the curated data to train a model. - -What you will learn -------------------- - -* Create and upload a new dataset -* Curate a dataset using simple image metrics such as Width, Height, Sharpness, Signal-to-Noise ratio, File Size -* Download images based on a tag from a dataset -* Train an image classifier with the filtered dataset - - -Requirements ------------- - -You can use your dataset or use the one we provide with this tutorial: -:download:`pizzas.zip <../../../_data/pizzas.zip>`. -If you use your dataset, please make sure the images are smaller than -2048 pixels with width and height, and you use less than 1000 images. - -.. note:: For this tutorial, we provide you with a small dataset of pizza images. - We chose a small dataset because it's easy to ship and train. - -Upload the data ---------------- - -We start by uploading the dataset to the `Lightly Platform `_. - -Create a new account if you do not have one yet. -Go to your user Preferences and copy your API token. - -Now install lightly if you haven't already, and upload your dataset. - -.. code-block:: console - - # install Lightly - pip3 install lightly - - # upload your DATA directory - lightly-upload token=MY_TOKEN new_dataset_name='NEW_DATASET_NAME' input_dir='DATA/' - - -Filter the dataset using metadata ---------------------------------- - -Once the dataset is created and the -images uploaded, you can head to 'Metadata' under the 'Analyze & Filter' menu. - -Move the sliders below the histograms to define filter rules for the dataset. -Once you are satisfied with the filtered dataset, create a new tag using the tag menu -on the left side. - -Download the curated dataset ----------------------------- - -We have filtered the dataset and want to download it now to train a model. -Therefore, click on the download menu on the left. - -We can now download the filtered images by clicking on the 'DOWNLOAD IMAGES' button. -In our case, the images are stored in the 'pizzas' folder. We now have to -annotate the images. We can do this by moving the individual images to -subfolders corresponding to the class. E.g. we move salami pizza images to the -'salami' folder and Margherita pizza images to the 'margherita' folder. - -############################################################################## - -Training a model using the curated data ---------------------------------------- - -""" - - -# %% -# Now we can start training our model using PyTorch Lightning -# We start by importing the necessary dependencies -import os - -import pytorch_lightning as pl -import torch -import torchmetrics -from torch.utils.data import DataLoader, random_split -from torchvision import transforms -from torchvision.datasets import ImageFolder -from torchvision.models import resnet18 - -# %% -# We use a small batch size to make sure we can run the training on all kinds -# of machines. Feel free to adjust the value to one that works on your machine. -batch_size = 8 -seed = 42 - -# %% -# Set the seed to make the experiment reproducible -pl.seed_everything(seed) - -# %% -# Let's set up the augmentations for the train and the test data. -train_transform = transforms.Compose( - [ - transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] -) - -# we don't do any resizing or mirroring for the test data -test_transform = transforms.Compose( - [ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] -) - - -# %% -# We load our data and split it into train/test with a 70/30 ratio. - -# Please make sure the data folder contains subfolders for each class -# -# pizzas -# L salami -# L margherita -dset = ImageFolder("pizzas", transform=train_transform) - -# to use the random_split method we need to obtain the length -# of the train and test set -full_len = len(dset) -train_len = int(full_len * 0.7) -test_len = int(full_len - train_len) -dataset_train, dataset_test = random_split(dset, [train_len, test_len]) -dataset_test.transforms = test_transform - -print("Training set consists of {} images".format(len(dataset_train))) -print("Test set consists of {} images".format(len(dataset_test))) - -# %% -# We can create our data loaders to fetch the data from the training and test -# set and pack them into batches. -dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True) -dataloader_test = DataLoader(dataset_test, batch_size=batch_size) - - -# %% -# PyTorch Lightning allows us to pack the loss as well as the -# optimizer into a single module. -class MyModel(pl.LightningModule): - def __init__(self, num_classes=2): - super().__init__() - self.save_hyperparameters() - - # load a pretrained resnet from torchvision - self.model = resnet18(pretrained=True) - - # add new linear output layer (transfer learning) - num_ftrs = self.model.fc.in_features - self.model.fc = torch.nn.Linear(num_ftrs, 2) - - self.accuracy = torchmetrics.Accuracy() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = torch.nn.functional.cross_entropy(y_hat, y) - self.log("train_loss", loss, prog_bar=True) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = torch.nn.functional.cross_entropy(y_hat, y) - y_hat = torch.nn.functional.softmax(y_hat, dim=1) - self.accuracy(y_hat, y) - self.log("val_loss", loss, on_epoch=True, prog_bar=True) - self.log("val_acc", self.accuracy.compute(), on_epoch=True, prog_bar=True) - - def configure_optimizers(self): - return torch.optim.SGD(self.model.fc.parameters(), lr=0.001, momentum=0.9) - - -# %% -# Finally, we can create the model and use the Trainer -# to train our model. -model = MyModel() -trainer = pl.Trainer(max_epochs=4, devices=1) -trainer.fit(model, dataloader_train, dataloader_test) diff --git a/docs/source/tutorials_source/platform/tutorial_sunflowers.py b/docs/source/tutorials_source/platform/tutorial_sunflowers.py deleted file mode 100644 index 41c2c0cd8..000000000 --- a/docs/source/tutorials_source/platform/tutorial_sunflowers.py +++ /dev/null @@ -1,136 +0,0 @@ -""" - -.. _lightly-tutorial-sunflowers: - -Tutorial 2: Diversify the Sunflowers Dataset -============================================= - -.. warning:: - **Tutorial is outdated** - - This tutorial uses a deprecated workflow of the Lightly Solution and will be removed in the future. - Please refer to the `new documentation and tutorials `_ instead. - -This tutorial highlights the basic functionality of selecting a subset in the web-app. -You can use the CORESET selection strategy to choose a diverse subset of your dataset. -This can be useful many purposes, e.g. for having a good subset of data to label -or for creating a validation or test dataset that covers the complete sample space. -Removing duplicate images can also help you in reducing bias and imbalances in your dataset. - -What you will learn --------------------- - -* Upload images and embeddings to the web-app via the Python package -* Sample a diverse subset of your original dataset in the web-app -* Download the filenames of the subset and use it to create a new local dataset folder. - -Requirements -------------- -You can use your own dataset or the one we provide for this tutorial. The dataset -we provide consists of 734 images of sunflowers. You can -download it here :download:`Sunflowers.zip <../../../_data/Sunflowers.zip>`. - -To use the Lightly platform, we need to upload the dataset with embeddings to it. -The first step for this is to train a self-supervised embedding model. -Then, embed your dataset and lastly, upload the dataset and embeddings to the Lightly platform. -These three steps can be done using a single terminal command from the lightly pip package: lightly-magic -But first, we need to install lightly from the Python package index. - -.. code-block:: bash - - # Install lightly as a pip package - pip install lightly - -.. code-block:: bash - - # The lightly-magic command first needs the input directory of your dataset. - # Then it needs the information for how many epochs to train an embedding model on it. - # If you want to use our pretrained model instead, set trainer.max_epochs=0. - # Next, the embedding model is used to embed all images in the input directory - # and saves the embeddings in a csv file. Last, a new dataset with the specified name - # is created on the Lightly platform. - - lightly-magic input_dir="./Sunflowers" trainer.max_epochs=0 token=YOUR_TOKEN - new_dataset_name="sunflowers_dataset" - -.. note:: - - The lightly-magic command with prefilled parameters is displayed in the web-app when you - create a new dataset. `Head over there and try it! `_ - For more information on the CLI commands refer to :ref:`lightly-command-line-tool` and :ref:`lightly-at-a-glance`. - -Create a Selection ------------------- - -Now, you have everything you need to create a selection of your dataset. For this, -head to the *Embedding* page of your dataset. You should see a two-dimensional -scatter plot of your embeddings. If you hover over the images, their thumbnails -will appear. Can you find clusters of similar images? - -.. figure:: ../../tutorials_source/platform/images/sunflowers_scatter_before_selection.jpg - :align: center - :alt: Alt text - :figclass: align-center - - You should see a two-dimensional scatter plot of your dataset as shown above. - Hover over an image to view a thumbnail of it. - There are also features like selecting and browsing some images and creating - a tag from it. - -.. note:: - - We reduce the dimensionality of the embeddings to 2 dimensions before plotting them. - You can switch between the PCA, tSNE and UMAP dimensionality reduction methods. - -Right above the scatter plot you should see a button "Create Sampling". Click on it to -create a selection. You will need to configure the following settings: - -* **Embedding:** Choose the embedding to use for the selection. -* **Sampling Strategy:** Choose the selection strategy to use. This will be one of: - - * CORESET: Selects samples which are diverse. - * CORAL: Combines CORESET with uncertainty scores to do active learning. - * RANDOM: Selects samples uniformly at random. -* **Stopping Condition:** Indicate how many samples you want to keep. -* **Name:** Give your selection a name. A new tag will be created under this name. - -.. figure:: ../../tutorials_source/platform/images/selection_create_request.png - :align: center - :alt: Alt text - :figclass: align-center - :figwidth: 400px - - Example of a filled out selection request in the web-app. - -After confirming your settings, a worker will start processing your request. Once -it's done, the page switches to the new tag. You can see how the scatter plot now shows -selected images and discarded images in a different color. Play around with the different selection strategies -to see differences between the results. - -.. figure:: ../../tutorials_source/platform/images/sunflowers_scatter_after_selection.jpg - :align: center - :alt: Alt text - :figclass: align-center - - After the selection you can see which samples were selected and which ones were discarded. - Here, the green dots are part of the new tag while the gray ones are left away. Notice - how the CORESET selection strategy selects an evenly spaced subset of images. - -.. note:: - - The CORESET selection strategy chooses the samples evenly spaced out in the 32-dimensional space. - This does not necessarily translate into being evenly spaced out after the dimensionality - reduction to 2 dimensions. - - -Download selected samples -------------------------- - -Now you can use this diverse subset for your machine learning project. -Just head over to the *Download* tag to see the different download options. -Apart from downloading the filenames or the images directly, you can also -use the lightly-download command to copy the files in the subset from your existing -to a new directory. The CLI command with prefilled arguments is already provided. - - -""" diff --git a/examples/pytorch/barlowtwins.py b/examples/pytorch/barlowtwins.py index 8bf4524dc..c6631518d 100644 --- a/examples/pytorch/barlowtwins.py +++ b/examples/pytorch/barlowtwins.py @@ -8,7 +8,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(nn.Module): @@ -30,7 +34,12 @@ def forward(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch/byol.py b/examples/pytorch/byol.py index cf80fdf11..4abcdf0a0 100644 --- a/examples/pytorch/byol.py +++ b/examples/pytorch/byol.py @@ -11,7 +11,11 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -49,7 +53,11 @@ def forward_momentum(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py index bbda91bd1..82274f3f9 100644 --- a/examples/pytorch/tico.py +++ b/examples/pytorch/tico.py @@ -11,7 +11,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -47,7 +51,12 @@ def forward_momentum(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch/vicreg.py b/examples/pytorch/vicreg.py index a3221846c..39968cb08 100644 --- a/examples/pytorch/vicreg.py +++ b/examples/pytorch/vicreg.py @@ -7,7 +7,7 @@ ## The projection head is the same as the Barlow Twins one from lightly.loss.vicreg_loss import VICRegLoss -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -15,7 +15,12 @@ class VICReg(nn.Module): def __init__(self, backbone): super().__init__() self.backbone = backbone - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) diff --git a/examples/pytorch_lightning/barlowtwins.py b/examples/pytorch_lightning/barlowtwins.py index 4053fe65d..86d069e5a 100644 --- a/examples/pytorch_lightning/barlowtwins.py +++ b/examples/pytorch_lightning/barlowtwins.py @@ -9,7 +9,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(pl.LightningModule): @@ -39,7 +43,12 @@ def configure_optimizers(self): model = BarlowTwins() -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/byol.py b/examples/pytorch_lightning/byol.py index 2b227d34d..a87feb4df 100644 --- a/examples/pytorch_lightning/byol.py +++ b/examples/pytorch_lightning/byol.py @@ -12,7 +12,11 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -62,7 +66,11 @@ def configure_optimizers(self): model = BYOL() -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py index b55246549..48f9d0db1 100644 --- a/examples/pytorch_lightning/tico.py +++ b/examples/pytorch_lightning/tico.py @@ -8,7 +8,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -56,7 +60,12 @@ def configure_optimizers(self): model = TiCo() -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/vicreg.py b/examples/pytorch_lightning/vicreg.py index 5fa5c89fd..2770309e7 100644 --- a/examples/pytorch_lightning/vicreg.py +++ b/examples/pytorch_lightning/vicreg.py @@ -10,7 +10,7 @@ from lightly.loss.vicreg_loss import VICRegLoss ## The projection head is the same as the Barlow Twins one -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -19,7 +19,12 @@ def __init__(self): super().__init__() resnet = torchvision.models.resnet18() self.backbone = nn.Sequential(*list(resnet.children())[:-1]) - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) self.criterion = VICRegLoss() def forward(self, x): diff --git a/examples/pytorch_lightning_distributed/barlowtwins.py b/examples/pytorch_lightning_distributed/barlowtwins.py index c837dde04..876d7e5f2 100644 --- a/examples/pytorch_lightning_distributed/barlowtwins.py +++ b/examples/pytorch_lightning_distributed/barlowtwins.py @@ -9,7 +9,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(pl.LightningModule): @@ -42,7 +46,12 @@ def configure_optimizers(self): model = BarlowTwins() -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/byol.py b/examples/pytorch_lightning_distributed/byol.py index 5dd49a16f..6fb9d88fe 100644 --- a/examples/pytorch_lightning_distributed/byol.py +++ b/examples/pytorch_lightning_distributed/byol.py @@ -13,7 +13,11 @@ from lightly.models.modules import BYOLProjectionHead from lightly.models.modules.heads import BYOLPredictionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -63,7 +67,11 @@ def configure_optimizers(self): model = BYOL() -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py index 0ba938bd5..5ed02bad2 100644 --- a/examples/pytorch_lightning_distributed/tico.py +++ b/examples/pytorch_lightning_distributed/tico.py @@ -8,7 +8,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -56,7 +60,12 @@ def configure_optimizers(self): model = TiCo() -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/vicreg.py b/examples/pytorch_lightning_distributed/vicreg.py index a30c726e5..bc8aa38ad 100644 --- a/examples/pytorch_lightning_distributed/vicreg.py +++ b/examples/pytorch_lightning_distributed/vicreg.py @@ -10,7 +10,7 @@ from lightly.loss import VICRegLoss ## The projection head is the same as the Barlow Twins one -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.models.modules.heads import VICRegProjectionHead from lightly.transforms.vicreg_transform import VICRegTransform @@ -19,7 +19,12 @@ def __init__(self): super().__init__() resnet = torchvision.models.resnet18() self.backbone = nn.Sequential(*list(resnet.children())[:-1]) - self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) + self.projection_head = VICRegProjectionHead( + input_dim=512, + hidden_dim=2048, + output_dim=2048, + num_layers=2, + ) # enable gather_distributed to gather features from all gpus # before calculating the loss diff --git a/lightly/__init__.py b/lightly/__init__.py index 50929e040..cd41a557d 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -75,46 +75,31 @@ # All Rights Reserved __name__ = "lightly" -__version__ = "1.4.12" +__version__ = "1.4.22" + import os +# see if torchvision vision transformer is available try: - # See (https://github.com/PyTorchLightning/pytorch-lightning) - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - __LIGHTLY_SETUP__ -except NameError: - __LIGHTLY_SETUP__ = False - - -if __LIGHTLY_SETUP__: - # setting up lightly - msg = f"Partial import of {__name__}=={__version__} during build process." - print(msg) -else: - # see if torchvision vision transformer is available - try: - import torchvision.models.vision_transformer - - _torchvision_vit_available = True - except ( - RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) - ImportError, # No installation or old version of torchvision - ): - _torchvision_vit_available = False - - if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False": - os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True" - from multiprocessing import current_process - - if current_process().name == "MainProcess": - from lightly.api.version_checking import is_latest_version - - try: - is_latest_version(current_version=__version__) - except Exception: - # Version check should never break the package. - pass + import torchvision.models.vision_transformer + + _torchvision_vit_available = True +except ( + RuntimeError, # Different CUDA versions for torch and torchvision + OSError, # Different CUDA versions for torch and torchvision (old) + ImportError, # No installation or old version of torchvision +): + _torchvision_vit_available = False + + +if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False": + os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True" + import multiprocessing + + if multiprocessing.current_process().name == "MainProcess": + from lightly.api import _version_checking + + _version_checking.check_is_latest_version_in_background( + current_version=__version__ + ) diff --git a/lightly/api/_version_checking.py b/lightly/api/_version_checking.py new file mode 100644 index 000000000..3f370732a --- /dev/null +++ b/lightly/api/_version_checking.py @@ -0,0 +1,73 @@ +from threading import Thread + +from lightly.api import utils +from lightly.api.swagger_api_client import LightlySwaggerApiClient +from lightly.openapi_generated.swagger_client.api import VersioningApi +from lightly.utils import version_compare + +# Default timeout for API version verification requests in seconds. +DEFAULT_TIMEOUT_SEC = 2 + + +def is_latest_version(current_version: str) -> bool: + """Returns True if package version is latest released version.""" + latest_version = get_latest_version(current_version) + return version_compare.version_compare(current_version, latest_version) >= 0 + + +def is_compatible_version(current_version: str) -> bool: + """Returns True if package version is compatible with API.""" + minimum_version = get_minimum_compatible_version() + return version_compare.version_compare(current_version, minimum_version) >= 0 + + +def get_latest_version( + current_version: str, timeout_sec: float = DEFAULT_TIMEOUT_SEC +) -> str: + """Returns the latest package version.""" + versioning_api = _get_versioning_api() + version_number: str = versioning_api.get_latest_pip_version( + current_version=current_version, + _request_timeout=timeout_sec, + ) + return version_number + + +def get_minimum_compatible_version( + timeout_sec: float = DEFAULT_TIMEOUT_SEC, +) -> str: + """Returns minimum package version that is compatible with the API.""" + versioning_api = _get_versioning_api() + version_number: str = versioning_api.get_minimum_compatible_pip_version( + _request_timeout=timeout_sec + ) + return version_number + + +def check_is_latest_version_in_background(current_version: str) -> None: + """Checks if the current version is the latest version in a background thread.""" + + def _check_version_in_background(current_version: str) -> None: + try: + is_latest_version(current_version=current_version) + except Exception: + # Ignore failed check. + pass + + thread = Thread( + target=_check_version_in_background, + kwargs=dict(current_version=current_version), + daemon=True, + ) + thread.start() + + +def _get_versioning_api() -> VersioningApi: + configuration = utils.get_api_client_configuration( + raise_if_no_token_specified=False, + ) + # Set retries to 0 to avoid waiting for retries in case of a timeout. + configuration.retries = 0 + api_client = LightlySwaggerApiClient(configuration=configuration) + versioning_api = VersioningApi(api_client=api_client) + return versioning_api diff --git a/lightly/api/api_workflow_artifacts.py b/lightly/api/api_workflow_artifacts.py index 5ea55244b..6bea0447e 100644 --- a/lightly/api/api_workflow_artifacts.py +++ b/lightly/api/api_workflow_artifacts.py @@ -1,4 +1,5 @@ import os +import warnings from lightly.api import download from lightly.openapi_generated.swagger_client.models import ( @@ -145,6 +146,10 @@ def download_compute_worker_run_report_json( ) -> None: """Download the report in json format from a run. + DEPRECATED: This method is deprecated and will be removed in the future. Use + download_compute_worker_run_report_v2_json to download the new report_v2.json + instead. + Args: run: Run from which to download the report. @@ -171,6 +176,13 @@ def download_compute_worker_run_report_json( >>> client.download_compute_worker_run_report_json(run=run, output_path="report.json") """ + warnings.warn( + DeprecationWarning( + "This method downloads the deprecated report.json file and will be " + "removed in the future. Use download_compute_worker_run_report_v2_json " + "to download the new report_v2.json file instead." + ) + ) self._download_compute_worker_run_artifact_by_type( run=run, artifact_type=DockerRunArtifactType.REPORT_JSON, @@ -178,6 +190,47 @@ def download_compute_worker_run_report_json( timeout=timeout, ) + def download_compute_worker_run_report_v2_json( + self, + run: DockerRunData, + output_path: str, + timeout: int = 60, + ) -> None: + """Download the report in json format from a run. + + Args: + run: + Run from which to download the report. + output_path: + Path where report will be saved. + timeout: + Timeout in seconds after which download is interrupted. + + Raises: + ArtifactNotExist: + If the run has no report artifact or the report has not yet been + uploaded. + + Examples: + >>> # schedule run + >>> scheduled_run_id = client.schedule_compute_worker_run(...) + >>> + >>> # wait until run completed + >>> for run_info in client.compute_worker_run_info_generator(scheduled_run_id=scheduled_run_id): + >>> pass + >>> + >>> # download checkpoint + >>> run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id) + >>> client.download_compute_worker_run_report_v2_json(run=run, output_path="report_v2.json") + + """ + self._download_compute_worker_run_artifact_by_type( + run=run, + artifact_type=DockerRunArtifactType.REPORT_V2_JSON, + output_path=output_path, + timeout=timeout, + ) + def download_compute_worker_run_log( self, run: DockerRunData, diff --git a/lightly/api/api_workflow_client.py b/lightly/api/api_workflow_client.py index 18fa190e7..87f1193bc 100644 --- a/lightly/api/api_workflow_client.py +++ b/lightly/api/api_workflow_client.py @@ -6,9 +6,10 @@ import requests from requests import Response +from urllib3.exceptions import HTTPError from lightly.__init__ import __version__ -from lightly.api import utils, version_checking +from lightly.api import _version_checking, utils from lightly.api.api_workflow_artifacts import _ArtifactsMixin from lightly.api.api_workflow_collaboration import _CollaborationMixin from lightly.api.api_workflow_compute_worker import _ComputeWorkerMixin @@ -23,7 +24,6 @@ from lightly.api.api_workflow_upload_metadata import _UploadCustomMetadataMixin from lightly.api.swagger_api_client import LightlySwaggerApiClient from lightly.api.utils import DatasourceType -from lightly.api.version_checking import LightlyAPITimeoutException from lightly.openapi_generated.swagger_client.api import ( CollaborationApi, DatasetsApi, @@ -93,7 +93,7 @@ def __init__( creator: str = Creator.USER_PIP, ): try: - if not version_checking.is_compatible_version(__version__): + if not _version_checking.is_compatible_version(__version__): warnings.warn( UserWarning( ( @@ -104,10 +104,14 @@ def __init__( ) ) except ( + # Error if version compare fails. ValueError, + # Any error by API client if status not in [200, 299]. ApiException, - LightlyAPITimeoutException, - AttributeError, + # Any error by urllib3 from within API client. Happens for failed requests + # that are not handled by API client. For example if there is no internet + # connection or a timeout. + HTTPError, ): pass @@ -214,6 +218,8 @@ def get_filenames(self) -> List[str]: Returns: Names of files in the current dataset. + + :meta private: # Skip docstring generation """ filenames_on_server = self._mappings_api.get_sample_mappings_by_dataset_id( dataset_id=self.dataset_id, field="fileName" @@ -243,6 +249,7 @@ def upload_file_with_signed_url( Returns: The response of the put request. + :meta private: # Skip docstring generation """ # check to see if server side encryption for S3 is desired diff --git a/lightly/api/api_workflow_compute_worker.py b/lightly/api/api_workflow_compute_worker.py index 06c37c392..f98244256 100644 --- a/lightly/api/api_workflow_compute_worker.py +++ b/lightly/api/api_workflow_compute_worker.py @@ -22,10 +22,10 @@ DockerWorkerConfigV3Lightly, DockerWorkerRegistryEntryData, DockerWorkerType, - SelectionConfig, - SelectionConfigEntry, - SelectionConfigEntryInput, - SelectionConfigEntryStrategy, + SelectionConfigV3, + SelectionConfigV3Entry, + SelectionConfigV3EntryInput, + SelectionConfigV3EntryStrategy, TagData, ) from lightly.openapi_generated.swagger_client.rest import ApiException @@ -175,7 +175,7 @@ def create_compute_worker_config( self, worker_config: Optional[Dict[str, Any]] = None, lightly_config: Optional[Dict[str, Any]] = None, - selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None, + selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None, ) -> str: """Creates a new configuration for a Lightly Worker run. @@ -207,6 +207,8 @@ def create_compute_worker_config( >>> config_id = client.create_compute_worker_config( ... selection_config=selection_config, ... ) + + :meta private: # Skip docstring generation """ if isinstance(selection_config, dict): selection = selection_config_from_dict(cfg=selection_config) @@ -267,7 +269,7 @@ def schedule_compute_worker_run( self, worker_config: Optional[Dict[str, Any]] = None, lightly_config: Optional[Dict[str, Any]] = None, - selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None, + selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None, priority: str = DockerRunScheduledPriority.MID, runs_on: Optional[List[str]] = None, ) -> str: @@ -632,17 +634,17 @@ def get_compute_worker_run_tags(self, run_id: str) -> List[TagData]: return tags_in_dataset -def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig: - """Recursively converts selection config from dict to a SelectionConfig instance.""" +def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfigV3: + """Recursively converts selection config from dict to a SelectionConfigV3 instance.""" strategies = [] for entry in cfg.get("strategies", []): new_entry = copy.deepcopy(entry) - new_entry["input"] = SelectionConfigEntryInput(**entry["input"]) - new_entry["strategy"] = SelectionConfigEntryStrategy(**entry["strategy"]) - strategies.append(SelectionConfigEntry(**new_entry)) + new_entry["input"] = SelectionConfigV3EntryInput(**entry["input"]) + new_entry["strategy"] = SelectionConfigV3EntryStrategy(**entry["strategy"]) + strategies.append(SelectionConfigV3Entry(**new_entry)) new_cfg = copy.deepcopy(cfg) new_cfg["strategies"] = strategies - return SelectionConfig(**new_cfg) + return SelectionConfigV3(**new_cfg) _T = TypeVar("_T") diff --git a/lightly/api/api_workflow_datasets.py b/lightly/api/api_workflow_datasets.py index 29d33d63d..2d670e76c 100644 --- a/lightly/api/api_workflow_datasets.py +++ b/lightly/api/api_workflow_datasets.py @@ -376,7 +376,7 @@ def create_dataset( >>> from lightly.api import ApiWorkflowClient >>> from lightly.openapi_generated.swagger_client.models import DatasetType >>> - >>> client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") + >>> client = ApiWorkflowClient(token="YOUR_TOKEN") >>> client.create_dataset('your-dataset-name', dataset_type=DatasetType.IMAGES) >>> >>> # or to work with videos diff --git a/lightly/api/api_workflow_datasources.py b/lightly/api/api_workflow_datasources.py index 90566002c..5c2033766 100644 --- a/lightly/api/api_workflow_datasources.py +++ b/lightly/api/api_workflow_datasources.py @@ -1,93 +1,22 @@ import time import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union import tqdm from lightly.openapi_generated.swagger_client.models import ( DatasourceConfig, - DatasourceConfigVerifyDataErrors, DatasourceProcessedUntilTimestampRequest, DatasourceProcessedUntilTimestampResponse, DatasourcePurpose, DatasourceRawSamplesData, ) +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import ( + DatasourceRawSamplesDataRow, +) class _DatasourcesMixin: - def _download_raw_files( - self, - download_function: Union[ - "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id", - "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", - "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", - ], - from_: int = 0, - to: Optional[int] = None, - relevant_filenames_file_name: Optional[str] = None, - use_redirected_read_url: bool = False, - progress_bar: Optional[tqdm.tqdm] = None, - **kwargs, - ) -> List[Tuple[str, str]]: - if to is None: - to = int(time.time()) - relevant_filenames_kwargs = ( - {"relevant_filenames_file_name": relevant_filenames_file_name} - if relevant_filenames_file_name - else dict() - ) - - response: DatasourceRawSamplesData = download_function( - dataset_id=self.dataset_id, - var_from=from_, - to=to, - use_redirected_read_url=use_redirected_read_url, - **relevant_filenames_kwargs, - **kwargs, - ) - cursor = response.cursor - samples = response.data - if progress_bar is not None: - progress_bar.update(len(response.data)) - while response.has_more: - response: DatasourceRawSamplesData = download_function( - dataset_id=self.dataset_id, - cursor=cursor, - use_redirected_read_url=use_redirected_read_url, - **relevant_filenames_kwargs, - **kwargs, - ) - cursor = response.cursor - samples.extend(response.data) - if progress_bar is not None: - progress_bar.update(len(response.data)) - sample_map = {} - for idx, s in enumerate(samples): - if s.file_name.startswith("/"): - warnings.warn( - UserWarning( - f"Absolute file paths like {s.file_name} are not supported" - f" in relevant filenames file {relevant_filenames_file_name} due to blob storage" - ) - ) - elif s.file_name.startswith(("./", "../")): - warnings.warn( - UserWarning( - f"Using dot notation ('./', '../') like in {s.file_name} is not supported" - f" in relevant filenames file {relevant_filenames_file_name} due to blob storage" - ) - ) - elif s.file_name in sample_map: - warnings.warn( - UserWarning( - f"Duplicate filename {s.file_name} in relevant" - f" filenames file {relevant_filenames_file_name}" - ) - ) - else: - sample_map[s.file_name] = s.read_url - return [(file_name, read_url) for file_name, read_url in sample_map.items()] - def download_raw_samples( self, from_: int = 0, @@ -131,8 +60,10 @@ def download_raw_samples( >>> client.set_dataset_id_by_name("my-dataset") >>> client.download_raw_samples() [('image-1.png', 'https://......'), ('image-2.png', 'https://......')] + + :meta private: # Skip docstring generation """ - samples = self._download_raw_files( + return self._download_raw_files( download_function=self._datasources_api.get_list_of_raw_samples_from_datasource_by_dataset_id, from_=from_, to=to, @@ -140,7 +71,6 @@ def download_raw_samples( use_redirected_read_url=use_redirected_read_url, progress_bar=progress_bar, ) - return samples def download_raw_predictions( self, @@ -153,6 +83,36 @@ def download_raw_predictions( use_redirected_read_url: bool = False, progress_bar: Optional[tqdm.tqdm] = None, ) -> List[Tuple[str, str]]: + """Downloads all prediction filenames and read urls from the datasource. + + See `download_raw_predictions_iter` for details. + + :meta private: # Skip docstring generation + """ + return list( + self.download_raw_predictions_iter( + task_name=task_name, + from_=from_, + to=to, + relevant_filenames_file_name=relevant_filenames_file_name, + run_id=run_id, + relevant_filenames_artifact_id=relevant_filenames_artifact_id, + use_redirected_read_url=use_redirected_read_url, + progress_bar=progress_bar, + ) + ) + + def download_raw_predictions_iter( + self, + task_name: str, + from_: int = 0, + to: Optional[int] = None, + relevant_filenames_file_name: Optional[str] = None, + run_id: Optional[str] = None, + relevant_filenames_artifact_id: Optional[str] = None, + use_redirected_read_url: bool = False, + progress_bar: Optional[tqdm.tqdm] = None, + ) -> Iterator[Tuple[str, str]]: """Downloads prediction filenames and read urls from the datasource. Only samples with timestamp between `from_` (inclusive) and `to` (inclusive) @@ -188,7 +148,7 @@ def download_raw_predictions( retrieved. Returns: - A list of (filename, url) tuples where each tuple represents a sample. + An iterator of (filename, url) tuples where each tuple represents a sample. Examples: >>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN") @@ -196,9 +156,11 @@ def download_raw_predictions( >>> # Already created some Lightly Worker runs with this dataset >>> task_name = "object-detection" >>> client.set_dataset_id_by_name("my-dataset") - >>> client.download_raw_predictions(task_name=task_name) + >>> list(client.download_raw_predictions(task_name=task_name)) [('.lightly/predictions/object-detection/image-1.json', 'https://......'), ('.lightly/predictions/object-detection/image-2.json', 'https://......')] + + :meta private: # Skip docstring generation """ if run_id is not None and relevant_filenames_artifact_id is None: raise ValueError( @@ -217,7 +179,7 @@ def download_raw_predictions( "relevant_filenames_artifact_id" ] = relevant_filenames_artifact_id - samples = self._download_raw_files( + yield from self._download_raw_files_iter( download_function=self._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id, from_=from_, to=to, @@ -227,7 +189,6 @@ def download_raw_predictions( progress_bar=progress_bar, **relevant_filenames_kwargs, ) - return samples def download_raw_metadata( self, @@ -241,6 +202,34 @@ def download_raw_metadata( ) -> List[Tuple[str, str]]: """Downloads all metadata filenames and read urls from the datasource. + See `download_raw_metadata_iter` for details. + + :meta private: # Skip docstring generation + """ + return list( + self.download_raw_metadata_iter( + from_=from_, + to=to, + run_id=run_id, + relevant_filenames_artifact_id=relevant_filenames_artifact_id, + relevant_filenames_file_name=relevant_filenames_file_name, + use_redirected_read_url=use_redirected_read_url, + progress_bar=progress_bar, + ) + ) + + def download_raw_metadata_iter( + self, + from_: int = 0, + to: Optional[int] = None, + run_id: Optional[str] = None, + relevant_filenames_artifact_id: Optional[str] = None, + relevant_filenames_file_name: Optional[str] = None, + use_redirected_read_url: bool = False, + progress_bar: Optional[tqdm.tqdm] = None, + ) -> Iterator[Tuple[str, str]]: + """Downloads all metadata filenames and read urls from the datasource. + Only samples with timestamp between `from_` (inclusive) and `to` (inclusive) will be downloaded. @@ -272,16 +261,18 @@ def download_raw_metadata( retrieved. Returns: - A list of (filename, url) tuples where each tuple represents a sample. + An iterator of (filename, url) tuples where each tuple represents a sample. Examples: >>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN") >>> >>> # Already created some Lightly Worker runs with this dataset >>> client.set_dataset_id_by_name("my-dataset") - >>> client.download_raw_metadata() + >>> list(client.download_raw_metadata_iter()) [('.lightly/metadata/object-detection/image-1.json', 'https://......'), ('.lightly/metadata/object-detection/image-2.json', 'https://......')] + + :meta private: # Skip docstring generation """ if run_id is not None and relevant_filenames_artifact_id is None: raise ValueError( @@ -300,8 +291,8 @@ def download_raw_metadata( "relevant_filenames_artifact_id" ] = relevant_filenames_artifact_id - samples = self._download_raw_files( - self._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id, + yield from self._download_raw_files_iter( + download_function=self._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id, from_=from_, to=to, relevant_filenames_file_name=relevant_filenames_file_name, @@ -309,7 +300,6 @@ def download_raw_metadata( progress_bar=progress_bar, **relevant_filenames_kwargs, ) - return samples def download_new_raw_samples( self, @@ -371,6 +361,8 @@ def get_processed_until_timestamp(self) -> int: >>> client.set_dataset_id_by_name("my-dataset") >>> client.get_processed_until_timestamp() 1684750513 + + :meta private: # Skip docstring generation """ response: DatasourceProcessedUntilTimestampResponse = self._datasources_api.get_datasource_processed_until_timestamp_by_dataset_id( dataset_id=self.dataset_id @@ -449,7 +441,6 @@ def set_azure_config( purpose: Datasource purpose, determines if datasource is read only (INPUT) or can be written to as well (LIGHTLY, INPUT_OUTPUT). - The latter is required when Lightly extracts frames from input videos. """ # TODO: Use DatasourceConfigAzure once we switch/update the api generator. @@ -499,7 +490,6 @@ def set_gcs_config( purpose: Datasource purpose, determines if datasource is read only (INPUT) or can be written to as well (LIGHTLY, INPUT_OUTPUT). - The latter is required when Lightly extracts frames from input videos. """ # TODO: Use DatasourceConfigGCS once we switch/update the api generator. @@ -519,10 +509,12 @@ def set_gcs_config( def set_local_config( self, - resource_path: str, + relative_path: str = "", + web_server_location: Optional[str] = "http://localhost:3456", thumbnail_suffix: Optional[ str ] = ".lightly/thumbnails/[filename]_thumb.[extension]", + purpose: str = DatasourcePurpose.INPUT_OUTPUT, ) -> None: """Sets the local configuration for the datasource of the current dataset. @@ -530,22 +522,29 @@ def set_local_config( server in our docs: https://docs.lightly.ai/getting_started/dataset_creation/dataset_creation_local_server.html Args: - resource_path: - Url to your local file server, for example: "http://localhost:1234/path/to/my/data". + relative_path: + Relative path from the mount root, for example: "path/to/my/data". + web_server_location: + Location of your local file server. Defaults to "http://localhost:3456". thumbnail_suffix: Where to save thumbnails of the images in the dataset, for example ".lightly/thumbnails/[filename]_thumb.[extension]". Set to None to disable thumbnails and use the full images from the datasource instead. + purpose: + Datasource purpose, determines if datasource is read only (INPUT) + or can be written to as well (LIGHTLY, INPUT_OUTPUT). + """ # TODO: Use DatasourceConfigLocal once we switch/update the api generator. self._datasources_api.update_datasource_by_dataset_id( datasource_config=DatasourceConfig.from_dict( { "type": "LOCAL", - "fullPath": resource_path, + "webServerLocation": web_server_location, + "fullPath": relative_path, "thumbSuffix": thumbnail_suffix, - "purpose": DatasourcePurpose.INPUT_OUTPUT, + "purpose": purpose, } ), dataset_id=self.dataset_id, @@ -584,7 +583,6 @@ def set_s3_config( purpose: Datasource purpose, determines if datasource is read only (INPUT) or can be written to as well (LIGHTLY, INPUT_OUTPUT). - The latter is required when Lightly extracts frames from input videos. """ # TODO: Use DatasourceConfigS3 once we switch/update the api generator. @@ -636,7 +634,6 @@ def set_s3_delegated_access_config( purpose: Datasource purpose, determines if datasource is read only (INPUT) or can be written to as well (LIGHTLY, INPUT_OUTPUT). - The latter is required when Lightly extracts frames from input videos. """ # TODO: Use DatasourceConfigS3 once we switch/update the api generator. @@ -685,7 +682,7 @@ def set_obs_config( purpose: Datasource purpose, determines if datasource is read only (INPUT) or can be written to as well (LIGHTLY, INPUT_OUTPUT). - The latter is required when Lightly extracts frames from input videos. + """ # TODO: Use DatasourceConfigOBS once we switch/update the api generator. self._datasources_api.update_datasource_by_dataset_id( @@ -706,7 +703,7 @@ def set_obs_config( def get_prediction_read_url( self, filename: str, - ): + ) -> str: """Returns a read-url for .lightly/predictions/{filename}. Args: @@ -717,6 +714,7 @@ def get_prediction_read_url( A read-url to the file. Note that a URL will be returned even if the file does not exist. + :meta private: # Skip docstring generation """ return self._datasources_api.get_prediction_file_read_url_from_datasource_by_dataset_id( dataset_id=self.dataset_id, @@ -726,7 +724,7 @@ def get_prediction_read_url( def get_metadata_read_url( self, filename: str, - ): + ) -> str: """Returns a read-url for .lightly/metadata/{filename}. Args: @@ -737,6 +735,7 @@ def get_metadata_read_url( A read-url to the file. Note that a URL will be returned even if the file does not exist. + :meta private: # Skip docstring generation """ return self._datasources_api.get_metadata_file_read_url_from_datasource_by_dataset_id( dataset_id=self.dataset_id, @@ -757,6 +756,7 @@ def get_custom_embedding_read_url( A read-url to the file. Note that a URL will be returned even if the file does not exist. + :meta private: # Skip docstring generation """ return self._datasources_api.get_custom_embedding_file_read_url_from_datasource_by_dataset_id( dataset_id=self.dataset_id, @@ -790,3 +790,123 @@ def list_datasource_permissions( return self._datasources_api.verify_datasource_by_dataset_id( dataset_id=self.dataset_id, ).to_dict() + + def _download_raw_files( + self, + download_function: Union[ + "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id", + "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", + "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", + ], + from_: int = 0, + to: Optional[int] = None, + relevant_filenames_file_name: Optional[str] = None, + use_redirected_read_url: bool = False, + progress_bar: Optional[tqdm.tqdm] = None, + **kwargs, + ) -> List[Tuple[str, str]]: + return list( + self._download_raw_files_iter( + download_function=download_function, + from_=from_, + to=to, + relevant_filenames_file_name=relevant_filenames_file_name, + use_redirected_read_url=use_redirected_read_url, + progress_bar=progress_bar, + **kwargs, + ) + ) + + def _download_raw_files_iter( + self, + download_function: Union[ + "DatasourcesApi.get_list_of_raw_samples_from_datasource_by_dataset_id", + "DatasourcesApi.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", + "DatasourcesApi.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", + ], + from_: int = 0, + to: Optional[int] = None, + relevant_filenames_file_name: Optional[str] = None, + use_redirected_read_url: bool = False, + progress_bar: Optional[tqdm.tqdm] = None, + **kwargs, + ) -> Iterator[Tuple[str, str]]: + if to is None: + to = int(time.time()) + relevant_filenames_kwargs = ( + {"relevant_filenames_file_name": relevant_filenames_file_name} + if relevant_filenames_file_name + else dict() + ) + + listed_filenames = set() + + def get_samples( + response: DatasourceRawSamplesData, + ) -> Iterator[Tuple[str, str]]: + for sample in response.data: + if _sample_unseen_and_valid( + sample=sample, + relevant_filenames_file_name=relevant_filenames_file_name, + listed_filenames=listed_filenames, + ): + listed_filenames.add(sample.file_name) + yield sample.file_name, sample.read_url + if progress_bar is not None: + progress_bar.update(1) + + response: DatasourceRawSamplesData = download_function( + dataset_id=self.dataset_id, + var_from=from_, + to=to, + use_redirected_read_url=use_redirected_read_url, + **relevant_filenames_kwargs, + **kwargs, + ) + yield from get_samples(response=response) + while response.has_more: + response: DatasourceRawSamplesData = download_function( + dataset_id=self.dataset_id, + cursor=response.cursor, + use_redirected_read_url=use_redirected_read_url, + **relevant_filenames_kwargs, + **kwargs, + ) + yield from get_samples(response=response) + + +def _sample_unseen_and_valid( + sample: DatasourceRawSamplesDataRow, + relevant_filenames_file_name: Optional[str], + listed_filenames: Set[str], +) -> bool: + # Note: We want to remove these checks eventually. Absolute paths and relative paths + # with dot notation should be handled either in the API or the Worker. Duplicate + # filenames should be handled in the Worker as handling it in the API would require + # too much memory. + if sample.file_name.startswith("/"): + warnings.warn( + UserWarning( + f"Absolute file paths like {sample.file_name} are not supported" + f" in relevant filenames file {relevant_filenames_file_name} due to blob storage" + ) + ) + return False + elif sample.file_name.startswith(("./", "../")): + warnings.warn( + UserWarning( + f"Using dot notation ('./', '../') like in {sample.file_name} is not supported" + f" in relevant filenames file {relevant_filenames_file_name} due to blob storage" + ) + ) + return False + elif sample.file_name in listed_filenames: + warnings.warn( + UserWarning( + f"Duplicate filename {sample.file_name} in relevant" + f" filenames file {relevant_filenames_file_name}" + ) + ) + return False + else: + return True diff --git a/lightly/api/api_workflow_predictions.py b/lightly/api/api_workflow_predictions.py index 3ea7b63af..258620091 100644 --- a/lightly/api/api_workflow_predictions.py +++ b/lightly/api/api_workflow_predictions.py @@ -44,7 +44,7 @@ def create_or_update_prediction_task_schema( >>> ) >>> client.create_or_update_prediction_task_schema(schema=schema) - + :meta private: # Skip docstring generation """ self._predictions_api.create_or_update_prediction_task_schema_by_dataset_id( prediction_task_schema=schema, @@ -71,6 +71,8 @@ def create_or_update_prediction( prediction_singletons: Predictions to be uploaded for the designated sample. + + :meta private: # Skip docstring generation """ self._predictions_api.create_or_update_prediction_by_sample_id( prediction_singleton=prediction_singletons, diff --git a/lightly/api/api_workflow_selection.py b/lightly/api/api_workflow_selection.py index f013c5881..760ef2d4d 100644 --- a/lightly/api/api_workflow_selection.py +++ b/lightly/api/api_workflow_selection.py @@ -39,6 +39,8 @@ def upload_scores( and score arrays. The length of each score array must match samples in the designated tag. query_tag_id: ID of the desired tag. + + :meta private: # Skip docstring generation """ # iterate over all available score types and upload them for score_type, score_values in al_scores.items(): @@ -79,6 +81,7 @@ def selection( When `initial-tag` does not exist in the dataset. When the selection task fails. + :meta private: # Skip docstring generation """ warnings.warn( diff --git a/lightly/api/api_workflow_tags.py b/lightly/api/api_workflow_tags.py index c6b4780bc..be70ebe22 100644 --- a/lightly/api/api_workflow_tags.py +++ b/lightly/api/api_workflow_tags.py @@ -124,6 +124,8 @@ def get_filenames_in_tag( >>> tag = client.get_tag_by_name("cool-tag") >>> client.get_filenames_in_tag(tag_data=tag) ['image-1.png', 'image-2.png'] + + :meta private: # Skip docstring generation """ if exclude_parent_tag: diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py index a8b21a169..bf28c1c2f 100644 --- a/lightly/api/api_workflow_upload_embeddings.py +++ b/lightly/api/api_workflow_upload_embeddings.py @@ -31,7 +31,10 @@ def _get_csv_reader_from_read_url(self, read_url: str) -> None: return reader def set_embedding_id_to_latest(self) -> None: - """Sets the embedding ID in the API client to the latest embedding ID in the current dataset.""" + """Sets the embedding ID in the API client to the latest embedding ID in the current dataset. + + :meta private: # Skip docstring generation + """ embeddings_on_server: List[ DatasetEmbeddingData ] = self._embeddings_api.get_embeddings_by_dataset_id( @@ -104,6 +107,7 @@ def upload_embeddings(self, path_to_embeddings_csv: str, name: str) -> None: The name of the embedding. If an embedding with such a name already exists on the server, the upload is aborted. + :meta private: # Skip docstring generation """ io_utils.check_embeddings( path_to_embeddings_csv, remove_additional_columns=True @@ -185,6 +189,7 @@ def append_embeddings(self, path_to_embeddings_csv: str, embedding_id: str) -> N If the number of columns in the local embeddings file and that of the remote embeddings file mismatch. + :meta private: # Skip docstring generation """ # read embedding from API @@ -254,7 +259,6 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]: f"The filenames in the embedding file and " f"the filenames on the server do not align" ) - io_utils.check_filenames(filenames) rows_without_header_ordered = self._order_list_by_filenames( filenames, rows_without_header diff --git a/lightly/api/api_workflow_upload_metadata.py b/lightly/api/api_workflow_upload_metadata.py index e5859b089..87aa55f4e 100644 --- a/lightly/api/api_workflow_upload_metadata.py +++ b/lightly/api/api_workflow_upload_metadata.py @@ -67,6 +67,7 @@ def index_custom_metadata_by_filename( If there are no annotations for a filename, the custom metadata is None instead. + :meta private: # Skip docstring generation """ # The mapping is filename -> image_id -> custom_metadata @@ -141,6 +142,7 @@ def upload_custom_metadata( max_workers: Maximum number of concurrent threads during upload. + :meta private: # Skip docstring generation """ self.verify_custom_metadata_format(custom_metadata) @@ -241,7 +243,7 @@ def create_custom_metadata_config( >>> [entry], >>> ) - + :meta private: # Skip docstring generation """ config_set_request = ConfigurationSetRequest(name=name, configs=configs) resp = self._metadata_configurations_api.create_meta_data_configuration( diff --git a/lightly/api/serve.py b/lightly/api/serve.py new file mode 100644 index 000000000..716088c05 --- /dev/null +++ b/lightly/api/serve.py @@ -0,0 +1,68 @@ +from http.server import HTTPServer, SimpleHTTPRequestHandler +from pathlib import Path +from typing import Sequence + + +def get_server( + paths: Sequence[Path], + host: str, + port: int, +): + """Returns an HTTP server that serves a local datasource. + + Args: + paths: + List of paths to serve. + host: + Host to serve the datasource on. + port: + Port to serve the datasource on. + + Examples: + >>> from lightly.api import serve + >>> from pathlib import Path + >>> serve( + >>> paths=[Path("/input_mount), Path("/lightly_mount)], + >>> host="localhost", + >>> port=3456, + >>> ) + + """ + + class _LocalDatasourceRequestHandler(SimpleHTTPRequestHandler): + def translate_path(self, path: str) -> str: + return _translate_path(path=path, directories=paths) + + def do_OPTIONS(self) -> None: + self.send_response(204) + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.end_headers() + + return HTTPServer((host, port), _LocalDatasourceRequestHandler) + + +def _translate_path(path: str, directories: Sequence[Path]) -> str: + """Translates a relative path to a file in the local datasource. + + Tries to resolve the relative path to a file in the first directory + and serves it if it exists. Otherwise, it tries to resolve the relative + path to a file in the second directory and serves it if it exists, etc. + + Args: + path: + Relative path to a file in the local datasource. + directories: + List of directories to search for the file. + + + Returns: + Absolute path to the file in the local datasource or an empty string + if the file doesn't exist. + + """ + stripped_path = path.lstrip("/") + for directory in directories: + if (directory / stripped_path).exists(): + return str(directory / stripped_path) + return "" # Not found. diff --git a/lightly/api/utils.py b/lightly/api/utils.py index 5960fa7f6..d80bb2d67 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -24,7 +24,7 @@ RETRY_MAX_RETRIES = 5 -def retry(func, *args, **kwargs): +def retry(func, *args, **kwargs): # type: ignore """Repeats a function until it completes successfully or fails too often. Args: diff --git a/lightly/api/version_checking.py b/lightly/api/version_checking.py deleted file mode 100644 index 175179a5f..000000000 --- a/lightly/api/version_checking.py +++ /dev/null @@ -1,79 +0,0 @@ -import signal -import warnings -from typing import Tuple - -from lightly.api import utils -from lightly.api.swagger_api_client import LightlySwaggerApiClient -from lightly.openapi_generated.swagger_client.api import VersioningApi -from lightly.utils import version_compare - - -class LightlyAPITimeoutException(Exception): - pass - - -class TimeoutDecorator: - def __init__(self, seconds): - self.seconds = seconds - - def handle_timeout_method(self, *args, **kwargs): - raise LightlyAPITimeoutException - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout_method) - signal.alarm(self.seconds) - - def __exit__(self, exc_type, exc_val, exc_tb): - signal.alarm(0) - - -def is_latest_version(current_version: str) -> bool: - with TimeoutDecorator(1): - versioning_api = get_versioning_api() - latest_version: str = versioning_api.get_latest_pip_version( - current_version=current_version - ) - return version_compare.version_compare(current_version, latest_version) >= 0 - - -def is_compatible_version(current_version: str) -> bool: - with TimeoutDecorator(1): - versioning_api = get_versioning_api() - minimum_version: str = versioning_api.get_minimum_compatible_pip_version() - return version_compare.version_compare(current_version, minimum_version) >= 0 - - -def get_versioning_api() -> VersioningApi: - configuration = utils.get_api_client_configuration( - raise_if_no_token_specified=False, - ) - api_client = LightlySwaggerApiClient(configuration=configuration) - versioning_api = VersioningApi(api_client=api_client) - return versioning_api - - -def get_latest_version(current_version: str) -> Tuple[None, str]: - try: - versioning_api = get_versioning_api() - version_number: str = versioning_api.get_latest_pip_version( - current_version=current_version - ) - return version_number - except Exception as e: - return None - - -def get_minimum_compatible_version(): - versioning_api = get_versioning_api() - version_number: str = versioning_api.get_minimum_compatible_pip_version() - return version_number - - -def pretty_print_latest_version(current_version, latest_version, width=70): - warning = ( - f"You are using lightly version {current_version}. " - f"There is a newer version of the package available. " - f"For compatability reasons, please upgrade your current version: " - f"pip install lightly=={latest_version}" - ) - warnings.warn(Warning(warning)) diff --git a/lightly/cli/config/lightly-serve.yaml b/lightly/cli/config/lightly-serve.yaml new file mode 100644 index 000000000..28ee095e1 --- /dev/null +++ b/lightly/cli/config/lightly-serve.yaml @@ -0,0 +1,16 @@ +input_mount: '' # Path to the input directory. +lightly_mount: '' # Path to the lightly directory. +host: 'localhost' # Hostname for serving the data. +port: 3456 # Port for serving the data. + + +# Disable Hydra log directories. +defaults: + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +hydra: + output_subdir: null + run: + dir: . diff --git a/lightly/cli/serve_cli.py b/lightly/cli/serve_cli.py new file mode 100644 index 000000000..285e61519 --- /dev/null +++ b/lightly/cli/serve_cli.py @@ -0,0 +1,48 @@ +import sys +from pathlib import Path + +import hydra + +from lightly.api import serve +from lightly.cli._helpers import fix_hydra_arguments + + +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="lightly-serve")) +def lightly_serve(cfg): + """Use lightly-serve to serve your data for interactive exploration. + + Command-Line Args: + input_mount: + Path to the input directory. + lightly_mount: + Path to the Lightly directory. + host: + Hostname for serving the data (defaults to localhost). + port: + Port for serving the data (defaults to 3456). + + Examples: + >>> lightly-serve input_mount=data/ lightly_mount=lightly/ port=3456 + + + """ + if not cfg.input_mount: + print("Please provide a valid input mount. Use --help for more information.") + sys.exit(1) + + if not cfg.lightly_mount: + print("Please provide a valid Lightly mount. Use --help for more information.") + sys.exit(1) + + httpd = serve.get_server( + paths=[Path(cfg.input_mount), Path(cfg.lightly_mount)], + host=cfg.host, + port=cfg.port, + ) + print(f"Starting server, listening at '{httpd.server_name}:{httpd.server_port}'") + print(f"Serving files in '{cfg.input_mount}' and '{cfg.lightly_mount}'") + httpd.serve_forever() + + +def entry() -> None: + lightly_serve() diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index 4cd1ce87d..da7fd24b1 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -9,13 +9,11 @@ import torchvision.datasets as datasets from PIL import Image -from torch._C import Value from torchvision import transforms from torchvision.datasets.vision import StandardTransform, VisionDataset from lightly.data._helpers import DatasetFolder, _load_dataset_from_folder from lightly.data._video import VideoDataset -from lightly.utils.io import check_filenames def _get_filename_by_index(dataset, index): @@ -177,11 +175,6 @@ def is_valid_file(filepath: str): if index_to_filename is not None: self.index_to_filename = index_to_filename - # if created from an input directory with filenames, check if they - # are valid - if input_dir: - check_filenames(self.get_filenames()) - @classmethod def from_torch_dataset(cls, dataset, transform=None, index_to_filename=None): """Builds a LightlyDataset from a PyTorch (or torchvision) dataset. diff --git a/lightly/embedding/__init__.py b/lightly/embedding/__init__.py index f2bb98c2d..3ee7a5480 100644 --- a/lightly/embedding/__init__.py +++ b/lightly/embedding/__init__.py @@ -8,5 +8,6 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved + from lightly.embedding._base import BaseEmbedding from lightly.embedding.embedding import SelfSupervisedEmbedding diff --git a/lightly/embedding/_base.py b/lightly/embedding/_base.py index 0a0e0e7cb..8d5086ba6 100644 --- a/lightly/embedding/_base.py +++ b/lightly/embedding/_base.py @@ -4,18 +4,34 @@ # All Rights Reserved import copy import os +from typing import Any, List, Optional, Sequence, Tuple, Union import omegaconf from omegaconf import DictConfig from pytorch_lightning import LightningModule, Trainer - +from pytorch_lightning.callbacks.callback import Callback +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + +from lightly.data.dataset import LightlyDataset from lightly.embedding import callbacks +from lightly.utils.benchmarking import BenchmarkModule class BaseEmbedding(LightningModule): """All trainable embeddings must inherit from BaseEmbedding.""" - def __init__(self, model, criterion, optimizer, dataloader, scheduler=None): + def __init__( + self, + model: BenchmarkModule, + criterion: Module, + optimizer: Optimizer, + dataloader: DataLoader[LightlyDataset], + scheduler: Optional[_LRScheduler] = None, + ) -> None: """Constructor Args: @@ -32,30 +48,35 @@ def __init__(self, model, criterion, optimizer, dataloader, scheduler=None): self.optimizer = optimizer self.dataloader = dataloader self.scheduler = scheduler - self.checkpoint = None + self.checkpoint: Optional[str] = None self.cwd = os.getcwd() - def forward(self, x0, x1): - return self.model(x0, x1) + def forward(self, x0: Tensor, x1: Tensor) -> Tensor: + embedding: Tensor = self.model(x0, x1) + return embedding - def training_step(self, batch, batch_idx): + def training_step( + self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int + ) -> Tensor: # get the two image transformations (x0, x1), _, _ = batch # forward pass of the transformations y0, y1 = self(x0, x1) # calculate loss - loss = self.criterion(y0, y1) + loss: Tensor = self.criterion(y0, y1) # log loss and return self.log("loss", loss) return loss - def configure_optimizers(self): + def configure_optimizers( + self, + ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[_LRScheduler]]]: if self.scheduler is None: return self.optimizer else: return [self.optimizer], [self.scheduler] - def train_dataloader(self): + def train_dataloader(self) -> DataLoader[LightlyDataset]: return self.dataloader def train_embedding( @@ -63,7 +84,7 @@ def train_embedding( trainer_config: DictConfig, checkpoint_callback_config: DictConfig, summary_callback_config: DictConfig, - ): + ) -> None: """Train the model on the provided dataset. Args: @@ -80,9 +101,9 @@ def train_embedding( A trained encoder, ready for embedding datasets. """ - trainer_callbacks = [] + trainer_callbacks: List[Callback] = [] - checkpoint_cb = callbacks.create_checkpoint_callback( + checkpoint_cb = callbacks.create_checkpoint_callback( # type: ignore[misc] **checkpoint_callback_config ) trainer_callbacks.append(checkpoint_cb) @@ -100,14 +121,13 @@ def train_embedding( if "weights_summary" in trainer_config_copy: with omegaconf.open_dict(trainer_config_copy): del trainer_config_copy["weights_summary"] - - trainer = Trainer(**trainer_config_copy, callbacks=trainer_callbacks) + trainer = Trainer(**trainer_config_copy, callbacks=trainer_callbacks) # type: ignore[misc] trainer.fit(self) if checkpoint_cb.best_model_path != "": self.checkpoint = os.path.join(self.cwd, checkpoint_cb.best_model_path) - def embed(self, *args, **kwargs): + def embed(self, *args: Any, **kwargs: Any) -> Any: """Must be implemented by classes which inherit from BaseEmbedding.""" raise NotImplementedError() diff --git a/lightly/embedding/callbacks.py b/lightly/embedding/callbacks.py index af586c23a..b04aa0ae3 100644 --- a/lightly/embedding/callbacks.py +++ b/lightly/embedding/callbacks.py @@ -1,4 +1,5 @@ import os +from typing import Optional from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary @@ -7,10 +8,10 @@ def create_checkpoint_callback( - save_last=False, - save_top_k=0, - monitor="loss", - dirpath=None, + save_last: bool = False, + save_top_k: int = 0, + monitor: str = "loss", + dirpath: Optional[str] = None, ) -> ModelCheckpoint: """Initializes the checkpoint callback. @@ -44,7 +45,7 @@ def create_summary_callback( if weights_summary not in [None, "None"]: return _create_summary_callback_deprecated(weights_summary) else: - return _create_summary_callback(**summary_callback_config) + return _create_summary_callback(summary_callback_config["max_depth"]) def _create_summary_callback(max_depth: int) -> ModelSummary: diff --git a/lightly/embedding/embedding.py b/lightly/embedding/embedding.py index 5439fea3f..9a1c704c8 100644 --- a/lightly/embedding/embedding.py +++ b/lightly/embedding/embedding.py @@ -4,14 +4,20 @@ # All Rights Reserved import time -from typing import List, Tuple, Union +from typing import List, Optional, Tuple import numpy as np import torch +from numpy.typing import NDArray +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader from tqdm import tqdm -import lightly +from lightly.data import LightlyDataset from lightly.embedding._base import BaseEmbedding +from lightly.utils.benchmarking import BenchmarkModule from lightly.utils.reordering import sort_items_by_keys @@ -59,19 +65,21 @@ class SelfSupervisedEmbedding(BaseEmbedding): def __init__( self, - model: torch.nn.Module, - criterion: torch.nn.Module, - optimizer: torch.optim.Optimizer, - dataloader: torch.utils.data.DataLoader, - scheduler=None, - ): + model: BenchmarkModule, + criterion: Module, + optimizer: Optimizer, + dataloader: DataLoader[LightlyDataset], + scheduler: Optional[_LRScheduler] = None, + ) -> None: super(SelfSupervisedEmbedding, self).__init__( model, criterion, optimizer, dataloader, scheduler ) def embed( - self, dataloader: torch.utils.data.DataLoader, device: torch.device = None - ) -> Tuple[np.ndarray, List[int], List[str]]: + self, + dataloader: DataLoader[LightlyDataset], + device: Optional[torch.device] = None, + ) -> Tuple[NDArray[np.float_], List[int], List[str]]: """Embeds images in a vector space. Args: @@ -99,15 +107,15 @@ def embed( """ self.model.eval() - embeddings, labels, filenames = None, None, [] + filenames = [] - dataset = dataloader.dataset + dataset: LightlyDataset = dataloader.dataset pbar = tqdm(total=len(dataset), unit="imgs") efficiency = 0.0 - embeddings = [] - labels = [] + embeddings: List[NDArray[np.float_]] = [] + labels: List[int] = [] with torch.no_grad(): start_timepoint = time.time() for image_batch, label_batch, filename_batch in dataloader: @@ -125,8 +133,8 @@ def embed( embedding_batch = self.model.backbone(image_batch) embedding_batch = embedding_batch.detach().reshape(batch_size, -1) - embeddings.append(embedding_batch) - labels.append(label_batch) + embeddings.extend(embedding_batch.cpu().numpy()) + labels.extend(label_batch.cpu().tolist()) finished_timepoint = time.time() @@ -140,16 +148,14 @@ def embed( pbar.update(batch_size) - embeddings = torch.cat(embeddings, 0) - labels = torch.cat(labels, 0) - - embeddings = embeddings.cpu().numpy() - labels = labels.cpu().numpy() - sorted_filenames = dataset.get_filenames() - sorted_embeddings = sort_items_by_keys(filenames, embeddings, sorted_filenames) - sorted_labels = sort_items_by_keys(filenames, labels, sorted_filenames) - embeddings = np.stack(sorted_embeddings) - labels = np.stack(sorted_labels).tolist() + sorted_embeddings = sort_items_by_keys( + keys=filenames, + items=embeddings, + sorted_keys=sorted_filenames, + ) + sorted_labels = sort_items_by_keys( + keys=filenames, items=labels, sorted_keys=sorted_filenames + ) - return embeddings, labels, sorted_filenames + return np.stack(sorted_embeddings), sorted_labels, sorted_filenames diff --git a/lightly/loss/__init__.py b/lightly/loss/__init__.py index 706045fd1..613707abe 100644 --- a/lightly/loss/__init__.py +++ b/lightly/loss/__init__.py @@ -2,7 +2,6 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved - from lightly.loss.barlow_twins_loss import BarlowTwinsLoss from lightly.loss.dcl_loss import DCLLoss, DCLWLoss from lightly.loss.dino_loss import DINOLoss diff --git a/lightly/loss/barlow_twins_loss.py b/lightly/loss/barlow_twins_loss.py index 16a005427..54ee08496 100644 --- a/lightly/loss/barlow_twins_loss.py +++ b/lightly/loss/barlow_twins_loss.py @@ -1,5 +1,8 @@ +from typing import Tuple + import torch import torch.distributed as dist +import torch.nn.functional as F class BarlowTwinsLoss(torch.nn.Module): @@ -48,17 +51,14 @@ def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False) ) def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: - device = z_a.device - # normalize repr. along the batch dimension - z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD - z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD + z_a_norm, z_b_norm = _normalize(z_a, z_b) N = z_a.size(0) - D = z_a.size(1) # cross-correlation matrix - c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD + c = z_a_norm.T @ z_b_norm + c.div_(N) # sum cross-correlation matrix between multiple gpus if self.gather_distributed and dist.is_initialized(): @@ -67,10 +67,31 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: c = c / world_size dist.all_reduce(c) - # loss - c_diff = (c - torch.eye(D, device=device)).pow(2) # DxD - # multiply off-diagonal elems of c_diff by lambda - c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param - loss = c_diff.sum() + invariance_loss = torch.diagonal(c).add_(-1).pow_(2).sum() + redundancy_reduction_loss = _off_diagonal(c).pow_(2).sum() + loss = invariance_loss + self.lambda_param * redundancy_reduction_loss return loss + + +def _normalize( + z_a: torch.Tensor, z_b: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Helper function to normalize tensors along the batch dimension.""" + combined = torch.stack([z_a, z_b], dim=0) # Shape: 2 x N x D + normalized = F.batch_norm( + combined.flatten(0, 1), + running_mean=None, + running_var=None, + weight=None, + bias=None, + training=True, + ).view_as(combined) + return normalized[0], normalized[1] + + +def _off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/lightly/loss/pmsn_loss.py b/lightly/loss/pmsn_loss.py index b8a66b9d2..15e97e2cb 100644 --- a/lightly/loss/pmsn_loss.py +++ b/lightly/loss/pmsn_loss.py @@ -68,7 +68,9 @@ def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: exponent=self.power_law_exponent, device=mean_anchor_probs.device, ) - loss = F.kl_div(input=mean_anchor_probs, target=power_dist, reduction="sum") + loss = F.kl_div( + input=mean_anchor_probs.log(), target=power_dist, reduction="sum" + ) return loss @@ -139,7 +141,9 @@ def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: target_dist = self.target_distribution(mean_anchor_probs).to( mean_anchor_probs.device ) - loss = F.kl_div(input=mean_anchor_probs, target=target_dist, reduction="sum") + loss = F.kl_div( + input=mean_anchor_probs.log(), target=target_dist, reduction="sum" + ) return loss diff --git a/lightly/models/_momentum.py b/lightly/models/_momentum.py index f7cdec507..94192cc24 100644 --- a/lightly/models/_momentum.py +++ b/lightly/models/_momentum.py @@ -4,18 +4,23 @@ # All Rights Reserved import copy +from typing import Iterable, Tuple import torch import torch.nn as nn +from torch import Tensor +from torch.nn.parameter import Parameter -def _deactivate_requires_grad(params): +def _deactivate_requires_grad(params: Iterable[Parameter]) -> None: """Deactivates the requires_grad flag for all parameters.""" for param in params: param.requires_grad = False -def _do_momentum_update(prev_params, params, m): +def _do_momentum_update( + prev_params: Iterable[Parameter], params: Iterable[Parameter], m: float +) -> None: """Updates the weights of the previous parameters.""" for prev_param, param in zip(prev_params, params): prev_param.data = prev_param.data * m + param.data * (1.0 - m) @@ -42,7 +47,7 @@ class _MomentumEncoderMixin: >>> # initialize momentum_backbone and momentum_projection_head >>> self._init_momentum_encoder() >>> - >>> def forward(self, x: torch.Tensor): + >>> def forward(self, x: Tensor): >>> >>> # do the momentum update >>> self._momentum_update(0.999) @@ -59,7 +64,7 @@ class _MomentumEncoderMixin: momentum_backbone: nn.Module momentum_projection_head: nn.Module - def _init_momentum_encoder(self): + def _init_momentum_encoder(self) -> None: """Initializes momentum backbone and a momentum projection head.""" assert self.backbone is not None assert self.projection_head is not None @@ -71,7 +76,7 @@ def _init_momentum_encoder(self): _deactivate_requires_grad(self.momentum_projection_head.parameters()) @torch.no_grad() - def _momentum_update(self, m: float = 0.999): + def _momentum_update(self, m: float = 0.999) -> None: """Performs the momentum update for the backbone and projection head.""" _do_momentum_update( self.momentum_backbone.parameters(), @@ -85,14 +90,14 @@ def _momentum_update(self, m: float = 0.999): ) @torch.no_grad() - def _batch_shuffle(self, batch: torch.Tensor): + def _batch_shuffle(self, batch: Tensor) -> Tuple[Tensor, Tensor]: """Returns the shuffled batch and the indices to undo.""" batch_size = batch.shape[0] shuffle = torch.randperm(batch_size, device=batch.device) return batch[shuffle], shuffle @torch.no_grad() - def _batch_unshuffle(self, batch: torch.Tensor, shuffle: torch.Tensor): + def _batch_unshuffle(self, batch: Tensor, shuffle: Tensor) -> Tensor: """Returns the unshuffled batch.""" unshuffle = torch.argsort(shuffle) return batch[unshuffle] diff --git a/lightly/models/batchnorm.py b/lightly/models/batchnorm.py index e4424bbad..7f05fe48e 100644 --- a/lightly/models/batchnorm.py +++ b/lightly/models/batchnorm.py @@ -3,8 +3,13 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from __future__ import annotations + +from typing import Any + import torch import torch.nn as nn +from torch import Tensor class SplitBatchNorm(nn.BatchNorm2d): @@ -21,27 +26,30 @@ class SplitBatchNorm(nn.BatchNorm2d): """ - def __init__(self, num_features, num_splits, **kw): + def __init__(self, num_features: int, num_splits: int, **kw: Any) -> None: super().__init__(num_features, **kw) self.num_splits = num_splits + # Register buffers self.register_buffer( "running_mean", torch.zeros(num_features * self.num_splits) ) self.register_buffer("running_var", torch.ones(num_features * self.num_splits)) - def train(self, mode=True): + def train(self, mode: bool = True) -> SplitBatchNorm: # lazily collate stats when we are going to use them if (self.training is True) and (mode is False): + assert self.running_mean is not None self.running_mean = torch.mean( self.running_mean.view(self.num_splits, self.num_features), dim=0 ).repeat(self.num_splits) + assert self.running_var is not None self.running_var = torch.mean( self.running_var.view(self.num_splits, self.num_features), dim=0 ).repeat(self.num_splits) return super().train(mode) - def forward(self, input): + def forward(self, input: Tensor) -> Tensor: """Computes the SplitBatchNorm on the input.""" # get input shape N, C, H, W = input.shape @@ -60,10 +68,12 @@ def forward(self, input): self.eps, ).view(N, C, H, W) else: + # We have to ignore the type errors here, because we know that running_mean + # and running_var are not None, but the type checker does not. result = nn.functional.batch_norm( input, - self.running_mean[: self.num_features], - self.running_var[: self.num_features], + self.running_mean[: self.num_features], # type: ignore[index] + self.running_var[: self.num_features], # type: ignore[index] self.weight, self.bias, False, @@ -74,7 +84,7 @@ def forward(self, input): return result -def get_norm_layer(num_features: int, num_splits: int, **kw): +def get_norm_layer(num_features: int, num_splits: int, **kw: Any) -> nn.Module: """Utility to switch between BatchNorm2d and SplitBatchNorm.""" if num_splits > 0: return SplitBatchNorm(num_features, num_splits) diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index a6ab268d7..541b49a29 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from torch import Tensor from lightly.models import utils @@ -33,10 +34,10 @@ class ProjectionHead(nn.Module): def __init__( self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] - ): + ) -> None: super(ProjectionHead, self).__init__() - layers = [] + layers: List[nn.Module] = [] for input_dim, output_dim, batch_norm, non_linearity in blocks: use_bias = not bool(batch_norm) layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) @@ -46,7 +47,7 @@ def __init__( layers.append(non_linearity) self.layers = nn.Sequential(*layers) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor) -> Tensor: """Computes one forward pass through the projection head. Args: @@ -54,7 +55,8 @@ def forward(self, x: torch.Tensor): Input of shape bsz x num_ftrs. """ - return self.layers(x) + projection: Tensor = self.layers(x) + return projection class BarlowTwinsProjectionHead(ProjectionHead): @@ -324,7 +326,7 @@ class SMoGPrototypes(nn.Module): def __init__( self, - group_features: torch.Tensor, + group_features: Tensor, beta: float, ): super(SMoGPrototypes, self).__init__() @@ -332,8 +334,8 @@ def __init__( self.beta = beta def forward( - self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1 - ) -> torch.Tensor: + self, x: Tensor, group_features: Tensor, temperature: float = 0.1 + ) -> Tensor: """Computes the logits for given model outputs and group features. Args: @@ -353,7 +355,7 @@ def forward( logits = torch.mm(x, group_features.t()) return logits / temperature - def get_updated_group_features(self, x: torch.Tensor) -> None: + def get_updated_group_features(self, x: Tensor) -> Tensor: """Performs the synchronous momentum update of the group vectors. Args: @@ -370,23 +372,23 @@ def get_updated_group_features(self, x: torch.Tensor) -> None: mask = assignments == assigned_class group_features[assigned_class] = self.beta * self.group_features[ assigned_class - ] + (1 - self.beta) * x[mask].mean(axis=0) + ] + (1 - self.beta) * x[mask].mean(dim=0) return group_features - def set_group_features(self, x: torch.Tensor) -> None: + def set_group_features(self, x: Tensor) -> None: """Sets the group features and asserts they don't require gradient.""" self.group_features.data = x.to(self.group_features.device) @torch.no_grad() - def assign_groups(self, x: torch.Tensor) -> torch.LongTensor: + def assign_groups(self, x: Tensor) -> Tensor: """Assigns each representation in x to a group based on cosine similarity. Args: Tensor of shape bsz x dim. Returns: - LongTensor of shape bsz indicating group assignments. + Tensor of shape bsz indicating group assignments. """ return torch.argmax(self.forward(x, self.group_features), dim=-1) @@ -524,19 +526,21 @@ def __init__( ) self.n_steps_frozen_prototypes = n_steps_frozen_prototypes - def forward(self, x, step=None) -> Union[torch.Tensor, List[torch.Tensor]]: + def forward( + self, x: Tensor, step: Optional[int] = None + ) -> Union[Tensor, List[Tensor]]: self._freeze_prototypes_if_required(step) out = [] for layer in self.heads: out.append(layer(x)) return out[0] if self._is_single_prototype else out - def normalize(self): + def normalize(self) -> None: """Normalizes the prototypes so that they are on the unit sphere.""" for layer in self.heads: utils.normalize_weight(layer.weight) - def _freeze_prototypes_if_required(self, step): + def _freeze_prototypes_if_required(self, step: Optional[int] = None) -> None: if self.n_steps_frozen_prototypes > 0: if step is None: raise ValueError( @@ -601,22 +605,23 @@ def __init__( ) self.apply(self._init_weights) self.freeze_last_layer = freeze_last_layer - self.last_layer = nn.utils.weight_norm( - nn.Linear(bottleneck_dim, output_dim, bias=False) - ) - self.last_layer.weight_g.data.fill_(1) + self.last_layer = nn.Linear(bottleneck_dim, output_dim, bias=False) + self.last_layer = nn.utils.weight_norm(self.last_layer) + # Tell mypy this is ok because fill_ is overloaded. + self.last_layer.weight_g.data.fill_(1) # type: ignore + # Option to normalize last layer. if norm_last_layer: self.last_layer.weight_g.requires_grad = False - def cancel_last_layer_gradients(self, current_epoch: int): + def cancel_last_layer_gradients(self, current_epoch: int) -> None: """Cancel last layer gradients to stabilize the training.""" if current_epoch >= self.freeze_last_layer: return for param in self.last_layer.parameters(): param.grad = None - def _init_weights(self, module): + def _init_weights(self, module: nn.Module) -> None: """Initializes layers with a truncated normal distribution.""" if isinstance(module, nn.Linear): utils._no_grad_trunc_normal( @@ -629,7 +634,7 @@ def _init_weights(self, module): if module.bias is not None: nn.init.constant_(module.bias, 0) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """Computes one forward pass through the head.""" x = self.layers(x) # l2 normalization @@ -694,6 +699,37 @@ def __init__( ) +class VICRegProjectionHead(ProjectionHead): + """Projection head used for VICReg. + + "The projector network has three linear layers, each with 8192 output + units. The first two layers of the projector are followed by a batch + normalization layer and rectified linear units." [0] + + [0]: 2022, VICReg, https://arxiv.org/pdf/2105.04906.pdf + + """ + + def __init__( + self, + input_dim: int = 2048, + hidden_dim: int = 8192, + output_dim: int = 8192, + num_layers: int = 3, + ): + hidden_layers = [ + (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()) + for _ in range(num_layers - 2) # Exclude first and last layer. + ] + super(VICRegProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + *hidden_layers, + (hidden_dim, output_dim, None, None), + ] + ) + + class VicRegLLocalProjectionHead(ProjectionHead): """Projection head used for the local head of VICRegL. diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index 3eb14a247..47bfca7ad 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -72,7 +72,7 @@ def __init__( self.predictor_pos_embed = nn.Parameter( torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False ) - predictor_pos_embed = _get_2d_sincos_pos_embed( + predictor_pos_embed = utils.get_2d_sincos_pos_embed( self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False ) self.predictor_pos_embed.data.copy_( @@ -431,66 +431,3 @@ def images_to_tokens( if prepend_class_token: tokens = utils.prepend_class_token(tokens, self.class_token) return tokens - - -def _get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=float) - grid_w = np.arange(grid_size, dtype=float) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def _get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid length - return: - pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid = np.arange(grid_size, dtype=float) - pos_embed = _get_1d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb diff --git a/lightly/models/modules/memory_bank.py b/lightly/models/modules/memory_bank.py index 255a38308..7eba9a4b5 100644 --- a/lightly/models/modules/memory_bank.py +++ b/lightly/models/modules/memory_bank.py @@ -3,7 +3,10 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import Optional, Tuple, Union + import torch +from torch import Tensor from lightly.models import utils @@ -32,8 +35,8 @@ class MemoryBankModule(torch.nn.Module): >>> def __init__(self, memory_bank_size: int = 2 ** 16): >>> super(MyLossFunction, self).__init__(memory_bank_size) >>> - >>> def forward(self, output: torch.Tensor, - >>> labels: torch.Tensor = None): + >>> def forward(self, output: Tensor, + >>> labels: Tensor = None): >>> >>> output, negatives = super( >>> MyLossFunction, self).forward(output) @@ -62,7 +65,7 @@ def __init__(self, size: int = 65536, gather_distributed: bool = False): ) @torch.no_grad() - def _init_memory_bank(self, dim: int): + def _init_memory_bank(self, dim: int) -> None: """Initialize the memory bank if it's empty Args: @@ -74,12 +77,12 @@ def _init_memory_bank(self, dim: int): # we could use register buffers like in the moco repo # https://github.com/facebookresearch/moco but we don't # want to pollute our checkpoints - self.bank = torch.randn(dim, self.size).type_as(self.bank) - self.bank = torch.nn.functional.normalize(self.bank, dim=0) - self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr) + bank: Tensor = torch.randn(dim, self.size).type_as(self.bank) + self.bank: Tensor = torch.nn.functional.normalize(bank, dim=0) + self.bank_ptr: Tensor = torch.zeros(1).type_as(self.bank_ptr) @torch.no_grad() - def _dequeue_and_enqueue(self, batch: torch.Tensor): + def _dequeue_and_enqueue(self, batch: Tensor) -> None: """Dequeue the oldest batch and add the latest one Args: @@ -100,8 +103,11 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor): self.bank_ptr[0] = ptr + batch_size def forward( - self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False - ): + self, + output: Tensor, + labels: Optional[Tensor] = None, + update: bool = False, + ) -> Union[Tuple[Tensor, Optional[Tensor]], Tensor]: """Query memory bank for additional negative samples Args: diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py index 624c3cda4..777fcee0a 100644 --- a/lightly/models/modules/nn_memory_bank.py +++ b/lightly/models/modules/nn_memory_bank.py @@ -3,7 +3,10 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved +from typing import Optional + import torch +from torch import Tensor from lightly.models.modules.memory_bank import MemoryBankModule @@ -19,8 +22,7 @@ class NNMemoryBankModule(MemoryBankModule): Attributes: size: - Number of keys the memory bank can store. If set to 0, - memory bank is not used. + Number of keys the memory bank can store. Examples: >>> model = NNCLR(backbone) @@ -38,9 +40,15 @@ class NNMemoryBankModule(MemoryBankModule): """ def __init__(self, size: int = 2**16): + if size <= 0: + raise ValueError(f"Memory bank size must be positive, got {size}.") super(NNMemoryBankModule, self).__init__(size) - def forward(self, output: torch.Tensor, update: bool = False): + def forward( # type: ignore[override] # TODO(Philipp, 11/23): Fix signature to match parent class. + self, + output: Tensor, + update: bool = False, + ) -> Tensor: """Returns nearest neighbour of output tensor from memory bank Args: @@ -50,6 +58,7 @@ def forward(self, output: torch.Tensor, update: bool = False): """ output, bank = super(NNMemoryBankModule, self).forward(output, update=update) + assert bank is not None bank = bank.to(output.device).t() output_normed = torch.nn.functional.normalize(output, dim=1) diff --git a/lightly/models/resnet.py b/lightly/models/resnet.py index 44c6c5829..6bff53626 100644 --- a/lightly/models/resnet.py +++ b/lightly/models/resnet.py @@ -10,12 +10,13 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from __future__ import annotations from typing import List -import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from lightly.models.batchnorm import get_norm_layer @@ -62,7 +63,7 @@ def __init__( get_norm_layer(self.expansion * planes, num_splits), ) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor) -> Tensor: """Forward pass through basic ResNet block. Args: @@ -73,7 +74,7 @@ def forward(self, x: torch.Tensor): Tensor of shape bsz x channels x W x H """ - out = self.conv1(x) + out: Tensor = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -132,7 +133,7 @@ def __init__( get_norm_layer(self.expansion * planes, num_splits), ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Forward pass through bottleneck ResNet block. Args: @@ -143,7 +144,7 @@ def forward(self, x): Tensor of shape bsz x channels x W x H """ - out = self.conv1(x) + out: Tensor = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -179,7 +180,7 @@ class ResNet(nn.Module): def __init__( self, - block: nn.Module = BasicBlock, + block: type[BasicBlock] = BasicBlock, layers: List[int] = [2, 2, 2, 2], num_classes: int = 10, width: float = 1.0, @@ -208,15 +209,22 @@ def __init__( ) self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes) - def _make_layer(self, block, planes, layers, stride, num_splits): - strides = [stride] + [1] * (layers - 1) + def _make_layer( + self, + block: type[BasicBlock], + planes: int, + num_layers: int, + stride: int, + num_splits: int, + ) -> nn.Sequential: + strides = [stride] + [1] * (num_layers - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride, num_splits)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor) -> Tensor: """Forward pass through ResNet. Args: @@ -243,7 +251,7 @@ def ResNetGenerator( width: float = 1, num_classes: int = 10, num_splits: int = 0, -): +) -> ResNet: """Builds and returns the specified ResNet. Args: @@ -286,8 +294,8 @@ def ResNetGenerator( ) return ResNet( - **model_params[name], + **model_params[name], # type: ignore # Cannot unpack dict to type "ResNet". width=width, num_classes=num_classes, - num_splits=num_splits + num_splits=num_splits, ) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 7215395a3..bd7c4e26d 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -7,9 +7,11 @@ import warnings from typing import Iterable, List, Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist import torch.nn as nn +from numpy.typing import NDArray from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parameter import Parameter @@ -596,3 +598,119 @@ def repeat_interleave_batch(x, B, repeat): dim=0, ) return x + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: int, cls_token: bool = False +) -> NDArray[np.float_]: + """Returns 2D sin-cos embeddings. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid_size: + Grid height and width. Should usually be set to sqrt(sequence length). + cls_token: + If True, a positional embedding for the class token is prepended to the returned + embeddings. + + Returns: + Positional embeddings array with size (grid_size * grid_size, embed_dim) if cls_token is False. + If cls_token is True, a (1 + grid_size * grid_size, embed_dim) array is returned. + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: NDArray[np.int_] +) -> NDArray[np.float_]: + """Returns 2D sin-cos embeddings grid. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid: + 2-dimensional grid to embed. + + Returns: + Positional embeddings array with size (grid_size * grid_size, embed_dim). + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim: int, grid_size: int, cls_token: bool = False +) -> NDArray[np.float_]: + """Returns 1D sin-cos embeddings. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid_size: + Grid height and width. Should usually be set to sqrt(sequence length). + cls_token: + If True, a positional embedding for the class token is prepended to the returned + embeddings. + + Returns: + Positional embeddings array with size (grid_size, embed_dim) if cls_token is False. + If cls_token is True, a (1 + grid_size, embed_dim) array is returned. + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: NDArray[np.int_] +) -> NDArray[np.float_]: + """Returns 1D sin-cos embeddings grid. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + pos: + 1-dimensional grid to embed. + + Returns: + Positional embeddings array with size (grid_size, embed_dim). + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/lightly/models/zoo.py b/lightly/models/zoo.py index 34d7dbede..34c6e1868 100644 --- a/lightly/models/zoo.py +++ b/lightly/models/zoo.py @@ -3,6 +3,8 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import List + ZOO = { "resnet-9/simclr/d16/w0.0625": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth", "resnet-9/simclr/d16/w0.125": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth", @@ -13,7 +15,7 @@ } -def checkpoints(): +def checkpoints() -> List[str]: """Returns the Lightly model zoo as a list of checkpoints. Checkpoints: diff --git a/lightly/openapi_generated/swagger_client/__init__.py b/lightly/openapi_generated/swagger_client/__init__.py index a03a7a3b8..8d3328d74 100644 --- a/lightly/openapi_generated/swagger_client/__init__.py +++ b/lightly/openapi_generated/swagger_client/__init__.py @@ -50,6 +50,8 @@ # import models into sdk package from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData +from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data +from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData @@ -75,10 +77,12 @@ from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase +from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL +from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3 @@ -213,13 +217,23 @@ from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod from lightly.openapi_generated.swagger_client.models.sector import Sector from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig +from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3 +from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType +from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3 from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest diff --git a/lightly/openapi_generated/swagger_client/__init__.py-e b/lightly/openapi_generated/swagger_client/__init__.py-e new file mode 100644 index 000000000..c4f85819a --- /dev/null +++ b/lightly/openapi_generated/swagger_client/__init__.py-e @@ -0,0 +1,275 @@ +# coding: utf-8 + +# flake8: noqa + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +__version__ = "1.0.0" + +# import apis into sdk package +from lightly.openapi_generated.swagger_client.api.collaboration_api import CollaborationApi +from lightly.openapi_generated.swagger_client.api.datasets_api import DatasetsApi +from lightly.openapi_generated.swagger_client.api.datasources_api import DatasourcesApi +from lightly.openapi_generated.swagger_client.api.docker_api import DockerApi +from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi +from lightly.openapi_generated.swagger_client.api.embeddings2d_api import Embeddings2dApi +from lightly.openapi_generated.swagger_client.api.jobs_api import JobsApi +from lightly.openapi_generated.swagger_client.api.mappings_api import MappingsApi +from lightly.openapi_generated.swagger_client.api.meta_data_configurations_api import MetaDataConfigurationsApi +from lightly.openapi_generated.swagger_client.api.predictions_api import PredictionsApi +from lightly.openapi_generated.swagger_client.api.profiles_api import ProfilesApi +from lightly.openapi_generated.swagger_client.api.quota_api import QuotaApi +from lightly.openapi_generated.swagger_client.api.samples_api import SamplesApi +from lightly.openapi_generated.swagger_client.api.samplings_api import SamplingsApi +from lightly.openapi_generated.swagger_client.api.scores_api import ScoresApi +from lightly.openapi_generated.swagger_client.api.tags_api import TagsApi +from lightly.openapi_generated.swagger_client.api.teams_api import TeamsApi +from lightly.openapi_generated.swagger_client.api.versioning_api import VersioningApi + +# import ApiClient +from lightly.openapi_generated.swagger_client.api_response import ApiResponse +from lightly.openapi_generated.swagger_client.api_client import ApiClient +from lightly.openapi_generated.swagger_client.configuration import Configuration +from lightly.openapi_generated.swagger_client.exceptions import OpenApiException +from lightly.openapi_generated.swagger_client.exceptions import ApiTypeError +from lightly.openapi_generated.swagger_client.exceptions import ApiValueError +from lightly.openapi_generated.swagger_client.exceptions import ApiKeyError +from lightly.openapi_generated.swagger_client.exceptions import ApiAttributeError +from lightly.openapi_generated.swagger_client.exceptions import ApiException + +# import models into sdk package +from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest +from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData +from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data +from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data +from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode +from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse +from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData +from lightly.openapi_generated.swagger_client.models.configuration_data import ConfigurationData +from lightly.openapi_generated.swagger_client.models.configuration_entry import ConfigurationEntry +from lightly.openapi_generated.swagger_client.models.configuration_set_request import ConfigurationSetRequest +from lightly.openapi_generated.swagger_client.models.configuration_value_data_type import ConfigurationValueDataType +from lightly.openapi_generated.swagger_client.models.create_cf_bucket_activity_request import CreateCFBucketActivityRequest +from lightly.openapi_generated.swagger_client.models.create_docker_worker_registry_entry_request import CreateDockerWorkerRegistryEntryRequest +from lightly.openapi_generated.swagger_client.models.create_entity_response import CreateEntityResponse +from lightly.openapi_generated.swagger_client.models.create_sample_with_write_urls_response import CreateSampleWithWriteUrlsResponse +from lightly.openapi_generated.swagger_client.models.create_team_membership_request import CreateTeamMembershipRequest +from lightly.openapi_generated.swagger_client.models.creator import Creator +from lightly.openapi_generated.swagger_client.models.crop_data import CropData +from lightly.openapi_generated.swagger_client.models.dataset_create_request import DatasetCreateRequest +from lightly.openapi_generated.swagger_client.models.dataset_creator import DatasetCreator +from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData +from lightly.openapi_generated.swagger_client.models.dataset_data_enriched import DatasetDataEnriched +from lightly.openapi_generated.swagger_client.models.dataset_embedding_data import DatasetEmbeddingData +from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType +from lightly.openapi_generated.swagger_client.models.dataset_update_request import DatasetUpdateRequest +from lightly.openapi_generated.swagger_client.models.datasource_config import DatasourceConfig +from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure +from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase +from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath +from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS +from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY +from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL +from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS +from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3 +from lightly.openapi_generated.swagger_client.models.datasource_config_s3_all_of import DatasourceConfigS3AllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_s3_delegated_access import DatasourceConfigS3DelegatedAccess +from lightly.openapi_generated.swagger_client.models.datasource_config_s3_delegated_access_all_of import DatasourceConfigS3DelegatedAccessAllOf +from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data import DatasourceConfigVerifyData +from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data_errors import DatasourceConfigVerifyDataErrors +from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_request import DatasourceProcessedUntilTimestampRequest +from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import DatasourceProcessedUntilTimestampResponse +from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import DatasourceRawSamplesData +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import DatasourceRawSamplesDataRow +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_metadata_data import DatasourceRawSamplesMetadataData +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_metadata_data_row import DatasourceRawSamplesMetadataDataRow +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data import DatasourceRawSamplesPredictionsData +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data_row import DatasourceRawSamplesPredictionsDataRow +from lightly.openapi_generated.swagger_client.models.dimensionality_reduction_method import DimensionalityReductionMethod +from lightly.openapi_generated.swagger_client.models.docker_license_information import DockerLicenseInformation +from lightly.openapi_generated.swagger_client.models.docker_run_artifact_create_request import DockerRunArtifactCreateRequest +from lightly.openapi_generated.swagger_client.models.docker_run_artifact_created_data import DockerRunArtifactCreatedData +from lightly.openapi_generated.swagger_client.models.docker_run_artifact_data import DockerRunArtifactData +from lightly.openapi_generated.swagger_client.models.docker_run_artifact_storage_location import DockerRunArtifactStorageLocation +from lightly.openapi_generated.swagger_client.models.docker_run_artifact_type import DockerRunArtifactType +from lightly.openapi_generated.swagger_client.models.docker_run_create_request import DockerRunCreateRequest +from lightly.openapi_generated.swagger_client.models.docker_run_data import DockerRunData +from lightly.openapi_generated.swagger_client.models.docker_run_log_data import DockerRunLogData +from lightly.openapi_generated.swagger_client.models.docker_run_log_entry_data import DockerRunLogEntryData +from lightly.openapi_generated.swagger_client.models.docker_run_log_level import DockerRunLogLevel +from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_create_request import DockerRunScheduledCreateRequest +from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_data import DockerRunScheduledData +from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_priority import DockerRunScheduledPriority +from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_state import DockerRunScheduledState +from lightly.openapi_generated.swagger_client.models.docker_run_scheduled_update_request import DockerRunScheduledUpdateRequest +from lightly.openapi_generated.swagger_client.models.docker_run_state import DockerRunState +from lightly.openapi_generated.swagger_client.models.docker_run_update_request import DockerRunUpdateRequest +from lightly.openapi_generated.swagger_client.models.docker_task_description import DockerTaskDescription +from lightly.openapi_generated.swagger_client.models.docker_user_stats import DockerUserStats +from lightly.openapi_generated.swagger_client.models.docker_worker_config import DockerWorkerConfig +from lightly.openapi_generated.swagger_client.models.docker_worker_config_create_request import DockerWorkerConfigCreateRequest +from lightly.openapi_generated.swagger_client.models.docker_worker_config_data import DockerWorkerConfigData +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2 import DockerWorkerConfigV2 +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_create_request import DockerWorkerConfigV2CreateRequest +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_data import DockerWorkerConfigV2Data +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker import DockerWorkerConfigV2Docker +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker_object_level import DockerWorkerConfigV2DockerObjectLevel +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_docker_stopping_condition import DockerWorkerConfigV2DockerStoppingCondition +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly import DockerWorkerConfigV2Lightly +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_collate import DockerWorkerConfigV2LightlyCollate +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_model import DockerWorkerConfigV2LightlyModel +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v2_lightly_trainer import DockerWorkerConfigV2LightlyTrainer +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3 import DockerWorkerConfigV3 +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_create_request import DockerWorkerConfigV3CreateRequest +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_data import DockerWorkerConfigV3Data +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker import DockerWorkerConfigV3Docker +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_corruptness_check import DockerWorkerConfigV3DockerCorruptnessCheck +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_datasource import DockerWorkerConfigV3DockerDatasource +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_training import DockerWorkerConfigV3DockerTraining +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly import DockerWorkerConfigV3Lightly +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_checkpoint_callback import DockerWorkerConfigV3LightlyCheckpointCallback +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_collate import DockerWorkerConfigV3LightlyCollate +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_criterion import DockerWorkerConfigV3LightlyCriterion +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_loader import DockerWorkerConfigV3LightlyLoader +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_model import DockerWorkerConfigV3LightlyModel +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_optimizer import DockerWorkerConfigV3LightlyOptimizer +from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly_trainer import DockerWorkerConfigV3LightlyTrainer +from lightly.openapi_generated.swagger_client.models.docker_worker_registry_entry_data import DockerWorkerRegistryEntryData +from lightly.openapi_generated.swagger_client.models.docker_worker_state import DockerWorkerState +from lightly.openapi_generated.swagger_client.models.docker_worker_type import DockerWorkerType +from lightly.openapi_generated.swagger_client.models.embedding2d_create_request import Embedding2dCreateRequest +from lightly.openapi_generated.swagger_client.models.embedding2d_data import Embedding2dData +from lightly.openapi_generated.swagger_client.models.embedding_data import EmbeddingData +from lightly.openapi_generated.swagger_client.models.file_name_format import FileNameFormat +from lightly.openapi_generated.swagger_client.models.file_output_format import FileOutputFormat +from lightly.openapi_generated.swagger_client.models.filename_and_read_url import FilenameAndReadUrl +from lightly.openapi_generated.swagger_client.models.image_type import ImageType +from lightly.openapi_generated.swagger_client.models.initial_tag_create_request import InitialTagCreateRequest +from lightly.openapi_generated.swagger_client.models.job_result_type import JobResultType +from lightly.openapi_generated.swagger_client.models.job_state import JobState +from lightly.openapi_generated.swagger_client.models.job_status_data import JobStatusData +from lightly.openapi_generated.swagger_client.models.job_status_data_result import JobStatusDataResult +from lightly.openapi_generated.swagger_client.models.job_status_meta import JobStatusMeta +from lightly.openapi_generated.swagger_client.models.job_status_upload_method import JobStatusUploadMethod +from lightly.openapi_generated.swagger_client.models.jobs_data import JobsData +from lightly.openapi_generated.swagger_client.models.label_box_data_row import LabelBoxDataRow +from lightly.openapi_generated.swagger_client.models.label_box_v4_data_row import LabelBoxV4DataRow +from lightly.openapi_generated.swagger_client.models.label_studio_task import LabelStudioTask +from lightly.openapi_generated.swagger_client.models.label_studio_task_data import LabelStudioTaskData +from lightly.openapi_generated.swagger_client.models.lightly_docker_selection_method import LightlyDockerSelectionMethod +from lightly.openapi_generated.swagger_client.models.lightly_model_v2 import LightlyModelV2 +from lightly.openapi_generated.swagger_client.models.lightly_model_v3 import LightlyModelV3 +from lightly.openapi_generated.swagger_client.models.lightly_trainer_precision_v2 import LightlyTrainerPrecisionV2 +from lightly.openapi_generated.swagger_client.models.lightly_trainer_precision_v3 import LightlyTrainerPrecisionV3 +from lightly.openapi_generated.swagger_client.models.prediction_singleton import PredictionSingleton +from lightly.openapi_generated.swagger_client.models.prediction_singleton_base import PredictionSingletonBase +from lightly.openapi_generated.swagger_client.models.prediction_singleton_classification import PredictionSingletonClassification +from lightly.openapi_generated.swagger_client.models.prediction_singleton_classification_all_of import PredictionSingletonClassificationAllOf +from lightly.openapi_generated.swagger_client.models.prediction_singleton_instance_segmentation import PredictionSingletonInstanceSegmentation +from lightly.openapi_generated.swagger_client.models.prediction_singleton_instance_segmentation_all_of import PredictionSingletonInstanceSegmentationAllOf +from lightly.openapi_generated.swagger_client.models.prediction_singleton_keypoint_detection import PredictionSingletonKeypointDetection +from lightly.openapi_generated.swagger_client.models.prediction_singleton_keypoint_detection_all_of import PredictionSingletonKeypointDetectionAllOf +from lightly.openapi_generated.swagger_client.models.prediction_singleton_object_detection import PredictionSingletonObjectDetection +from lightly.openapi_generated.swagger_client.models.prediction_singleton_object_detection_all_of import PredictionSingletonObjectDetectionAllOf +from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation import PredictionSingletonSemanticSegmentation +from lightly.openapi_generated.swagger_client.models.prediction_singleton_semantic_segmentation_all_of import PredictionSingletonSemanticSegmentationAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema import PredictionTaskSchema +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_base import PredictionTaskSchemaBase +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category import PredictionTaskSchemaCategory +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints import PredictionTaskSchemaCategoryKeypoints +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_category_keypoints_all_of import PredictionTaskSchemaCategoryKeypointsAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint import PredictionTaskSchemaKeypoint +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_keypoint_all_of import PredictionTaskSchemaKeypointAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple import PredictionTaskSchemaSimple +from lightly.openapi_generated.swagger_client.models.prediction_task_schema_simple_all_of import PredictionTaskSchemaSimpleAllOf +from lightly.openapi_generated.swagger_client.models.prediction_task_schemas import PredictionTaskSchemas +from lightly.openapi_generated.swagger_client.models.profile_basic_data import ProfileBasicData +from lightly.openapi_generated.swagger_client.models.profile_me_data import ProfileMeData +from lightly.openapi_generated.swagger_client.models.profile_me_data_settings import ProfileMeDataSettings +from lightly.openapi_generated.swagger_client.models.questionnaire_data import QuestionnaireData +from lightly.openapi_generated.swagger_client.models.s3_region import S3Region +from lightly.openapi_generated.swagger_client.models.sama_task import SamaTask +from lightly.openapi_generated.swagger_client.models.sama_task_data import SamaTaskData +from lightly.openapi_generated.swagger_client.models.sample_create_request import SampleCreateRequest +from lightly.openapi_generated.swagger_client.models.sample_data import SampleData +from lightly.openapi_generated.swagger_client.models.sample_data_modes import SampleDataModes +from lightly.openapi_generated.swagger_client.models.sample_meta_data import SampleMetaData +from lightly.openapi_generated.swagger_client.models.sample_partial_mode import SamplePartialMode +from lightly.openapi_generated.swagger_client.models.sample_sort_by import SampleSortBy +from lightly.openapi_generated.swagger_client.models.sample_type import SampleType +from lightly.openapi_generated.swagger_client.models.sample_update_request import SampleUpdateRequest +from lightly.openapi_generated.swagger_client.models.sample_write_urls import SampleWriteUrls +from lightly.openapi_generated.swagger_client.models.sampling_config import SamplingConfig +from lightly.openapi_generated.swagger_client.models.sampling_config_stopping_condition import SamplingConfigStoppingCondition +from lightly.openapi_generated.swagger_client.models.sampling_create_request import SamplingCreateRequest +from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod +from lightly.openapi_generated.swagger_client.models.sector import Sector +from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig +from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase +from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry +from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput +from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3 +from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange +from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName +from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType +from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation +from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType +from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData +from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest +from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest +from lightly.openapi_generated.swagger_client.models.shared_access_config_data import SharedAccessConfigData +from lightly.openapi_generated.swagger_client.models.shared_access_type import SharedAccessType +from lightly.openapi_generated.swagger_client.models.tag_active_learning_scores_data import TagActiveLearningScoresData +from lightly.openapi_generated.swagger_client.models.tag_arithmetics_operation import TagArithmeticsOperation +from lightly.openapi_generated.swagger_client.models.tag_arithmetics_request import TagArithmeticsRequest +from lightly.openapi_generated.swagger_client.models.tag_arithmetics_response import TagArithmeticsResponse +from lightly.openapi_generated.swagger_client.models.tag_bit_mask_response import TagBitMaskResponse +from lightly.openapi_generated.swagger_client.models.tag_change_data import TagChangeData +from lightly.openapi_generated.swagger_client.models.tag_change_data_arithmetics import TagChangeDataArithmetics +from lightly.openapi_generated.swagger_client.models.tag_change_data_initial import TagChangeDataInitial +from lightly.openapi_generated.swagger_client.models.tag_change_data_metadata import TagChangeDataMetadata +from lightly.openapi_generated.swagger_client.models.tag_change_data_operation_method import TagChangeDataOperationMethod +from lightly.openapi_generated.swagger_client.models.tag_change_data_rename import TagChangeDataRename +from lightly.openapi_generated.swagger_client.models.tag_change_data_sampler import TagChangeDataSampler +from lightly.openapi_generated.swagger_client.models.tag_change_data_samples import TagChangeDataSamples +from lightly.openapi_generated.swagger_client.models.tag_change_data_scatterplot import TagChangeDataScatterplot +from lightly.openapi_generated.swagger_client.models.tag_change_data_upsize import TagChangeDataUpsize +from lightly.openapi_generated.swagger_client.models.tag_change_entry import TagChangeEntry +from lightly.openapi_generated.swagger_client.models.tag_create_request import TagCreateRequest +from lightly.openapi_generated.swagger_client.models.tag_creator import TagCreator +from lightly.openapi_generated.swagger_client.models.tag_data import TagData +from lightly.openapi_generated.swagger_client.models.tag_update_request import TagUpdateRequest +from lightly.openapi_generated.swagger_client.models.tag_upsize_request import TagUpsizeRequest +from lightly.openapi_generated.swagger_client.models.task_type import TaskType +from lightly.openapi_generated.swagger_client.models.team_basic_data import TeamBasicData +from lightly.openapi_generated.swagger_client.models.team_data import TeamData +from lightly.openapi_generated.swagger_client.models.team_role import TeamRole +from lightly.openapi_generated.swagger_client.models.trigger2d_embedding_job_request import Trigger2dEmbeddingJobRequest +from lightly.openapi_generated.swagger_client.models.update_docker_worker_registry_entry_request import UpdateDockerWorkerRegistryEntryRequest +from lightly.openapi_generated.swagger_client.models.update_team_membership_request import UpdateTeamMembershipRequest +from lightly.openapi_generated.swagger_client.models.user_type import UserType +from lightly.openapi_generated.swagger_client.models.video_frame_data import VideoFrameData +from lightly.openapi_generated.swagger_client.models.write_csv_url_data import WriteCSVUrlData diff --git a/lightly/openapi_generated/swagger_client/api/datasources_api.py b/lightly/openapi_generated/swagger_client/api/datasources_api.py index d409937b4..63d82eac1 100644 --- a/lightly/openapi_generated/swagger_client/api/datasources_api.py +++ b/lightly/openapi_generated/swagger_client/api/datasources_api.py @@ -54,10 +54,10 @@ def __init__(self, api_client=None): self.api_client = api_client @validate_arguments - def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the readUrl for")], **kwargs) -> str: # noqa: E501 + def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501 """get_custom_embedding_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501 + Get the GET ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -66,7 +66,7 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, datas :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the csv file within the embeddings folder to get the readUrl for (required) + :param file_name: The name of the csv file within the embeddings folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -85,10 +85,10 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id(self, datas return self.get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501 @validate_arguments - def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 + def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the csv file within the embeddings folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 """get_custom_embedding_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501 + Get the GET ReadURL of a custom embedding csv file within the embeddings folder (e.g myCustomEmbedding.csv) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -97,7 +97,7 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_i :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the csv file within the embeddings folder to get the readUrl for (required) + :param file_name: The name of the csv file within the embeddings folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -209,9 +209,9 @@ def get_custom_embedding_file_read_url_from_datasource_by_dataset_id_with_http_i @validate_arguments def get_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], purpose : Annotated[Optional[DatasourcePurpose], Field(description="Which datasource with which purpose we want to get. Defaults to INPUT_OUTPUT")] = None, **kwargs) -> DatasourceConfig: # noqa: E501 - """get_datasource_by_dataset_id # noqa: E501 + """(Deprecated) get_datasource_by_dataset_id # noqa: E501 - Get the datasource of a dataset # noqa: E501 + DEPRECATED - use getDatasourcesByDatasetId. Get the datasource of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -240,9 +240,9 @@ def get_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True @validate_arguments def get_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], purpose : Annotated[Optional[DatasourcePurpose], Field(description="Which datasource with which purpose we want to get. Defaults to INPUT_OUTPUT")] = None, **kwargs) -> ApiResponse: # noqa: E501 - """get_datasource_by_dataset_id # noqa: E501 + """(Deprecated) get_datasource_by_dataset_id # noqa: E501 - Get the datasource of a dataset # noqa: E501 + DEPRECATED - use getDatasourcesByDatasetId. Get the datasource of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -278,6 +278,8 @@ def get_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[con :rtype: tuple(DatasourceConfig, status_code(int), headers(HTTPHeaderDict)) """ + warnings.warn("GET /v1/datasets/{datasetId}/datasource is deprecated.", DeprecationWarning) + _params = locals() _all_params = [ @@ -647,6 +649,160 @@ def get_datasources_by_dataset_id_with_http_info(self, dataset_id : Annotated[co collection_formats=_collection_formats, _request_auth=_params.get('_request_auth')) + @validate_arguments + def get_head_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The name of the file within the the datasource to get a HEAD readUrl or GET readURL")], **kwargs) -> str: # noqa: E501 + """get_head_file_read_url_from_datasource_by_dataset_id # noqa: E501 + + Get a HEAD ReadURL of a file within datasources. Can only be used for HEAD request, no GET requests. # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_head_file_read_url_from_datasource_by_dataset_id(dataset_id, file_name, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param file_name: The name of the file within the the datasource to get a HEAD readUrl or GET readURL (required) + :type file_name: str + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: str + """ + kwargs['_return_http_data_only'] = True + if '_preload_content' in kwargs: + raise ValueError("Error! Please call the get_head_file_read_url_from_datasource_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501 + + @validate_arguments + def get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The name of the file within the the datasource to get a HEAD readUrl or GET readURL")], **kwargs) -> ApiResponse: # noqa: E501 + """get_head_file_read_url_from_datasource_by_dataset_id # noqa: E501 + + Get a HEAD ReadURL of a file within datasources. Can only be used for HEAD request, no GET requests. # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_head_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param file_name: The name of the file within the the datasource to get a HEAD readUrl or GET readURL (required) + :type file_name: str + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _preload_content: if False, the ApiResponse.data will + be set to none and raw_data will store the + HTTP response body without reading/decoding. + Default is True. + :type _preload_content: bool, optional + :param _return_http_data_only: response data instead of ApiResponse + object with status code, headers, etc + :type _return_http_data_only: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the authentication + in the spec for a single request. + :type _request_auth: dict, optional + :type _content_type: string, optional: force content-type for the request + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: tuple(str, status_code(int), headers(HTTPHeaderDict)) + """ + + _params = locals() + + _all_params = [ + 'dataset_id', + 'file_name' + ] + _all_params.extend( + [ + 'async_req', + '_return_http_data_only', + '_preload_content', + '_request_timeout', + '_request_auth', + '_content_type', + '_headers' + ] + ) + + # validate the arguments + for _key, _val in _params['kwargs'].items(): + if _key not in _all_params: + raise ApiTypeError( + "Got an unexpected keyword argument '%s'" + " to method get_head_file_read_url_from_datasource_by_dataset_id" % _key + ) + _params[_key] = _val + del _params['kwargs'] + + _collection_formats = {} + + # process the path parameters + _path_params = {} + if _params['dataset_id']: + _path_params['datasetId'] = _params['dataset_id'] + + + # process the query parameters + _query_params = [] + if _params.get('file_name') is not None: # noqa: E501 + _query_params.append(( + 'fileName', + _params['file_name'].value if hasattr(_params['file_name'], 'value') else _params['file_name'] + )) + + # process the header parameters + _header_params = dict(_params.get('_headers', {})) + # process the form parameters + _form_params = [] + _files = {} + # process the body parameter + _body_params = None + # set the HTTP header `Accept` + _header_params['Accept'] = self.api_client.select_header_accept( + ['application/json']) # noqa: E501 + + # authentication setting + _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 + + _response_types_map = { + '200': "str", + '400': "ApiErrorResponse", + '401': "ApiErrorResponse", + '403': "ApiErrorResponse", + '404': "ApiErrorResponse", + } + + return self.api_client.call_api( + '/v1/datasets/{datasetId}/datasource/fileHEAD', 'GET', + _path_params, + _query_params, + _header_params, + body=_body_params, + post_params=_form_params, + files=_files, + response_types_map=_response_types_map, + auth_settings=_auth_settings, + async_req=_params.get('async_req'), + _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501 + _preload_content=_params.get('_preload_content', True), + _request_timeout=_params.get('_request_timeout'), + collection_formats=_collection_formats, + _request_auth=_params.get('_request_auth')) + @validate_arguments def get_list_of_raw_samples_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], var_from : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Unix timestamp, only samples with a creation date after `from` will be returned. This parameter is ignored if `cursor` is specified. ")] = None, to : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Unix timestamp, only samples with a creation date before `to` will be returned. This parameter is ignored if `cursor` is specified. ")] = None, cursor : Annotated[Optional[StrictStr], Field(description="Cursor from previous request, encodes `from` and `to` parameters. Specify to continue reading samples from the list. ")] = None, use_redirected_read_url : Annotated[Optional[StrictBool], Field(description="By default this is set to false unless a S3DelegatedAccess is configured in which case its always true and this param has no effect. When true this will return RedirectedReadUrls instead of ReadUrls meaning that returned URLs allow for unlimited access to the file ")] = None, relevant_filenames_file_name : Annotated[Optional[constr(strict=True, min_length=4)], Field(description="The name of the file within your datasource which contains a list of relevant filenames to list. See https://docs.lightly.ai/docker/getting_started/first_steps.html#specify-relevant-files for more details ")] = None, **kwargs) -> DatasourceRawSamplesData: # noqa: E501 """get_list_of_raw_samples_from_datasource_by_dataset_id # noqa: E501 @@ -1297,10 +1453,10 @@ def get_list_of_raw_samples_predictions_from_datasource_by_dataset_id_with_http_ _request_auth=_params.get('_request_auth')) @validate_arguments - def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the readUrl for")], **kwargs) -> str: # noqa: E501 + def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501 """get_metadata_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501 + Get the GET ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -1309,7 +1465,7 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the metadata folder to get the readUrl for (required) + :param file_name: The name of the file within the metadata folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -1328,10 +1484,10 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id(self, dataset_id : return self.get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501 @validate_arguments - def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 + def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=5), Field(..., description="The name of the file within the metadata folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 """get_metadata_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501 + Get the GET ReadURL of a file within the metadata folder (e.g. my_image.json or my_video-099-mp4.json) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -1340,7 +1496,7 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the metadata folder to get the readUrl for (required) + :param file_name: The name of the file within the metadata folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -1451,10 +1607,10 @@ def get_metadata_file_read_url_from_datasource_by_dataset_id_with_http_info(self _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> str: # noqa: E501 + def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501 """get_prediction_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 + Get the GET ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -1463,7 +1619,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the prediction folder to get the readUrl for (required) + :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -1482,10 +1638,10 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id(self, dataset_id return self.get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501 @validate_arguments - def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 + def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 """get_prediction_file_read_url_from_datasource_by_dataset_id # noqa: E501 - Get the ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 + Get the GET ReadURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -1494,7 +1650,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(se :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the prediction folder to get the readUrl for (required) + :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -1605,7 +1761,7 @@ def get_prediction_file_read_url_from_datasource_by_dataset_id_with_http_info(se _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> str: # noqa: E501 + def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> str: # noqa: E501 """get_prediction_file_write_url_from_datasource_by_dataset_id # noqa: E501 Get the WriteURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 @@ -1617,7 +1773,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the prediction folder to get the readUrl for (required) + :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -1636,7 +1792,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id(self, dataset_id return self.get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(dataset_id, file_name, **kwargs) # noqa: E501 @validate_arguments - def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 + def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[constr(strict=True, min_length=4), Field(..., description="The name of the file within the prediction folder to get the GET readUrl for")], **kwargs) -> ApiResponse: # noqa: E501 """get_prediction_file_write_url_from_datasource_by_dataset_id # noqa: E501 Get the WriteURL of a file within the predictions folder (e.g tasks.json or my_classification_task/schema.json) # noqa: E501 @@ -1648,7 +1804,7 @@ def get_prediction_file_write_url_from_datasource_by_dataset_id_with_http_info(s :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: The name of the file within the prediction folder to get the readUrl for (required) + :param file_name: The name of the file within the prediction folder to get the GET readUrl for (required) :type file_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional diff --git a/lightly/openapi_generated/swagger_client/api/docker_api.py b/lightly/openapi_generated/swagger_client/api/docker_api.py index 633383cbc..107f16037 100644 --- a/lightly/openapi_generated/swagger_client/api/docker_api.py +++ b/lightly/openapi_generated/swagger_client/api/docker_api.py @@ -2150,7 +2150,7 @@ def get_docker_run_logs_by_id_with_http_info(self, run_id : Annotated[constr(str def get_docker_run_report_read_url_by_id(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> str: # noqa: E501 """(Deprecated) get_docker_run_report_read_url_by_id # noqa: E501 - Get the url of a specific docker runs report # noqa: E501 + DEPRECATED, use getDockerRunArtifactReadUrlById - Get the url of a specific docker runs report # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2179,7 +2179,7 @@ def get_docker_run_report_read_url_by_id(self, run_id : Annotated[constr(strict= def get_docker_run_report_read_url_by_id_with_http_info(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> ApiResponse: # noqa: E501 """(Deprecated) get_docker_run_report_read_url_by_id # noqa: E501 - Get the url of a specific docker runs report # noqa: E501 + DEPRECATED, use getDockerRunArtifactReadUrlById - Get the url of a specific docker runs report # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2295,7 +2295,7 @@ def get_docker_run_report_read_url_by_id_with_http_info(self, run_id : Annotated def get_docker_run_report_write_url_by_id(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> str: # noqa: E501 """(Deprecated) get_docker_run_report_write_url_by_id # noqa: E501 - Get the signed url to upload a report of a docker run # noqa: E501 + DEPRECATED, use createDockerRunArtifact - Get the signed url to upload a report of a docker run # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2324,7 +2324,7 @@ def get_docker_run_report_write_url_by_id(self, run_id : Annotated[constr(strict def get_docker_run_report_write_url_by_id_with_http_info(self, run_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the docker run")], **kwargs) -> ApiResponse: # noqa: E501 """(Deprecated) get_docker_run_report_write_url_by_id # noqa: E501 - Get the signed url to upload a report of a docker run # noqa: E501 + DEPRECATED, use createDockerRunArtifact - Get the signed url to upload a report of a docker run # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2580,14 +2580,14 @@ def get_docker_run_tags_with_http_info(self, run_id : Annotated[constr(strict=Tr _request_auth=_params.get('_request_auth')) @validate_arguments - def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> List[DockerRunData]: # noqa: E501 + def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> List[DockerRunData]: # noqa: E501 """get_docker_runs # noqa: E501 Gets all docker runs for a user. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_docker_runs(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True) + >>> thread = api.get_docker_runs(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True) >>> result = thread.get() :param page_size: pagination size/limit of the number of samples to return @@ -2598,6 +2598,8 @@ def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge= :type get_assets_of_team: bool :param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user :type get_assets_of_team_inclusive_self: bool + :param show_archived: if this flag is true, we also get the archived assets + :type show_archived: bool :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -2612,17 +2614,17 @@ def get_docker_runs(self, page_size : Annotated[Optional[conint(strict=True, ge= kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: raise ValueError("Error! Please call the get_docker_runs_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, **kwargs) # noqa: E501 + return self.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, **kwargs) # noqa: E501 @validate_arguments - def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> ApiResponse: # noqa: E501 + def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> ApiResponse: # noqa: E501 """get_docker_runs # noqa: E501 Gets all docker runs for a user. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True) + >>> thread = api.get_docker_runs_with_http_info(page_size, page_offset, get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True) >>> result = thread.get() :param page_size: pagination size/limit of the number of samples to return @@ -2633,6 +2635,8 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s :type get_assets_of_team: bool :param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user :type get_assets_of_team_inclusive_self: bool + :param show_archived: if this flag is true, we also get the archived assets + :type show_archived: bool :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -2664,7 +2668,8 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s 'page_size', 'page_offset', 'get_assets_of_team', - 'get_assets_of_team_inclusive_self' + 'get_assets_of_team_inclusive_self', + 'show_archived' ] _all_params.extend( [ @@ -2719,6 +2724,12 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s _params['get_assets_of_team_inclusive_self'].value if hasattr(_params['get_assets_of_team_inclusive_self'], 'value') else _params['get_assets_of_team_inclusive_self'] )) + if _params.get('show_archived') is not None: # noqa: E501 + _query_params.append(( + 'showArchived', + _params['show_archived'].value if hasattr(_params['show_archived'], 'value') else _params['show_archived'] + )) + # process the header parameters _header_params = dict(_params.get('_headers', {})) # process the form parameters @@ -2759,20 +2770,22 @@ def get_docker_runs_with_http_info(self, page_size : Annotated[Optional[conint(s _request_auth=_params.get('_request_auth')) @validate_arguments - def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> str: # noqa: E501 + def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> str: # noqa: E501 """get_docker_runs_count # noqa: E501 Gets the total count of the amount of runs existing for a user # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_docker_runs_count(get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True) + >>> thread = api.get_docker_runs_count(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True) >>> result = thread.get() :param get_assets_of_team: if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user :type get_assets_of_team: bool :param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user :type get_assets_of_team_inclusive_self: bool + :param show_archived: if this flag is true, we also get the archived assets + :type show_archived: bool :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -2787,23 +2800,25 @@ def get_docker_runs_count(self, get_assets_of_team : Annotated[Optional[StrictBo kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: raise ValueError("Error! Please call the get_docker_runs_count_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, **kwargs) # noqa: E501 + return self.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, **kwargs) # noqa: E501 @validate_arguments - def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, **kwargs) -> ApiResponse: # noqa: E501 + def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user")] = None, get_assets_of_team_inclusive_self : Annotated[Optional[StrictBool], Field(description="if this flag is true, we get the relevant asset of the team of the user including the assets of the user")] = None, show_archived : Annotated[Optional[StrictBool], Field(description="if this flag is true, we also get the archived assets")] = None, **kwargs) -> ApiResponse: # noqa: E501 """get_docker_runs_count # noqa: E501 Gets the total count of the amount of runs existing for a user # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, async_req=True) + >>> thread = api.get_docker_runs_count_with_http_info(get_assets_of_team, get_assets_of_team_inclusive_self, show_archived, async_req=True) >>> result = thread.get() :param get_assets_of_team: if this flag is true, we get the relevant asset of the team of the user rather than the assets of the user :type get_assets_of_team: bool :param get_assets_of_team_inclusive_self: if this flag is true, we get the relevant asset of the team of the user including the assets of the user :type get_assets_of_team_inclusive_self: bool + :param show_archived: if this flag is true, we also get the archived assets + :type show_archived: bool :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -2833,7 +2848,8 @@ def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Op _all_params = [ 'get_assets_of_team', - 'get_assets_of_team_inclusive_self' + 'get_assets_of_team_inclusive_self', + 'show_archived' ] _all_params.extend( [ @@ -2876,6 +2892,12 @@ def get_docker_runs_count_with_http_info(self, get_assets_of_team : Annotated[Op _params['get_assets_of_team_inclusive_self'].value if hasattr(_params['get_assets_of_team_inclusive_self'], 'value') else _params['get_assets_of_team_inclusive_self'] )) + if _params.get('show_archived') is not None: # noqa: E501 + _query_params.append(( + 'showArchived', + _params['show_archived'].value if hasattr(_params['show_archived'], 'value') else _params['show_archived'] + )) + # process the header parameters _header_params = dict(_params.get('_headers', {})) # process the form parameters diff --git a/lightly/openapi_generated/swagger_client/api/embeddings_api.py b/lightly/openapi_generated/swagger_client/api/embeddings_api.py index 8f710136d..c38036884 100644 --- a/lightly/openapi_generated/swagger_client/api/embeddings_api.py +++ b/lightly/openapi_generated/swagger_client/api/embeddings_api.py @@ -199,7 +199,7 @@ def delete_embedding_by_id_with_http_info(self, dataset_id : Annotated[constr(st def get_embeddings_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], **kwargs) -> List[DatasetEmbeddingData]: # noqa: E501 """get_embeddings_by_dataset_id # noqa: E501 - Get all annotations of a dataset # noqa: E501 + Get all embeddings of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -228,7 +228,7 @@ def get_embeddings_by_dataset_id(self, dataset_id : Annotated[constr(strict=True def get_embeddings_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], **kwargs) -> ApiResponse: # noqa: E501 """get_embeddings_by_dataset_id # noqa: E501 - Get all annotations of a dataset # noqa: E501 + Get all embeddings of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True diff --git a/lightly/openapi_generated/swagger_client/api/predictions_api.py b/lightly/openapi_generated/swagger_client/api/predictions_api.py index 057cbdcb7..c188f45fc 100644 --- a/lightly/openapi_generated/swagger_client/api/predictions_api.py +++ b/lightly/openapi_generated/swagger_client/api/predictions_api.py @@ -50,24 +50,24 @@ def __init__(self, api_client=None): self.api_client = api_client @validate_arguments - def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_singleton : conlist(PredictionSingleton), **kwargs) -> CreateEntityResponse: # noqa: E501 + def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_singleton : conlist(PredictionSingleton), prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501 """create_or_update_prediction_by_sample_id # noqa: E501 - Create/Update all the prediction singletons for a sampleId in the order/index of them being discovered # noqa: E501 + Create/Update all the prediction singletons per taskName for a sampleId in the order/index of them being discovered # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_or_update_prediction_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, async_req=True) + >>> thread = api.create_or_update_prediction_by_sample_id(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str :param sample_id: ObjectId of the sample (required) :type sample_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) - :type prediction_uuid_timestamp: int :param prediction_singleton: (required) :type prediction_singleton: List[PredictionSingleton] + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -82,27 +82,27 @@ def create_or_update_prediction_by_sample_id(self, dataset_id : Annotated[constr kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: raise ValueError("Error! Please call the create_or_update_prediction_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, **kwargs) # noqa: E501 + return self.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, **kwargs) # noqa: E501 @validate_arguments - def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_singleton : conlist(PredictionSingleton), **kwargs) -> ApiResponse: # noqa: E501 + def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_singleton : conlist(PredictionSingleton), prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 """create_or_update_prediction_by_sample_id # noqa: E501 - Create/Update all the prediction singletons for a sampleId in the order/index of them being discovered # noqa: E501 + Create/Update all the prediction singletons per taskName for a sampleId in the order/index of them being discovered # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, prediction_singleton, async_req=True) + >>> thread = api.create_or_update_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_singleton, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str :param sample_id: ObjectId of the sample (required) :type sample_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) - :type prediction_uuid_timestamp: int :param prediction_singleton: (required) :type prediction_singleton: List[PredictionSingleton] + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -133,8 +133,8 @@ def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : A _all_params = [ 'dataset_id', 'sample_id', - 'prediction_uuid_timestamp', - 'prediction_singleton' + 'prediction_singleton', + 'prediction_uuid_timestamp' ] _all_params.extend( [ @@ -227,22 +227,22 @@ def create_or_update_prediction_by_sample_id_with_http_info(self, dataset_id : A _request_auth=_params.get('_request_auth')) @validate_arguments - def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_task_schema : PredictionTaskSchema, **kwargs) -> CreateEntityResponse: # noqa: E501 + def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_task_schema : PredictionTaskSchema, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501 """create_or_update_prediction_task_schema_by_dataset_id # noqa: E501 Creates/updates a prediction task schema with the task name # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id(dataset_id, prediction_uuid_timestamp, prediction_task_schema, async_req=True) + >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id(dataset_id, prediction_task_schema, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) - :type prediction_uuid_timestamp: int :param prediction_task_schema: (required) :type prediction_task_schema: PredictionTaskSchema + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -257,25 +257,25 @@ def create_or_update_prediction_task_schema_by_dataset_id(self, dataset_id : Ann kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: raise ValueError("Error! Please call the create_or_update_prediction_task_schema_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, prediction_task_schema, **kwargs) # noqa: E501 + return self.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_task_schema, prediction_uuid_timestamp, **kwargs) # noqa: E501 @validate_arguments - def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], prediction_task_schema : PredictionTaskSchema, **kwargs) -> ApiResponse: # noqa: E501 + def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_task_schema : PredictionTaskSchema, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 """create_or_update_prediction_task_schema_by_dataset_id # noqa: E501 Creates/updates a prediction task schema with the task name # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, prediction_task_schema, async_req=True) + >>> thread = api.create_or_update_prediction_task_schema_by_dataset_id_with_http_info(dataset_id, prediction_task_schema, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) - :type prediction_uuid_timestamp: int :param prediction_task_schema: (required) :type prediction_task_schema: PredictionTaskSchema + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -305,8 +305,8 @@ def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, d _all_params = [ 'dataset_id', - 'prediction_uuid_timestamp', - 'prediction_task_schema' + 'prediction_task_schema', + 'prediction_uuid_timestamp' ] _all_params.extend( [ @@ -395,21 +395,21 @@ def create_or_update_prediction_task_schema_by_dataset_id_with_http_info(self, d _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], **kwargs) -> List[PredictionSingleton]: # noqa: E501 - """get_prediction_by_sample_id # noqa: E501 + def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchema: # noqa: E501 + """get_prediction_task_schema_by_task_name # noqa: E501 - Get all prediction singletons of a specific sample of a dataset # noqa: E501 + Get a prediction task schemas named taskName for a datasetId # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True) + >>> thread = api.get_prediction_task_schema_by_task_name(dataset_id, task_name, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param sample_id: ObjectId of the sample (required) - :type sample_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param task_name: The prediction task name for which one wants to list the predictions (required) + :type task_name: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -420,29 +420,29 @@ def get_prediction_by_sample_id(self, dataset_id : Annotated[constr(strict=True) :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: List[PredictionSingleton] + :rtype: PredictionTaskSchema """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: - raise ValueError("Error! Please call the get_prediction_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, **kwargs) # noqa: E501 + raise ValueError("Error! Please call the get_prediction_task_schema_by_task_name_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, task_name, prediction_uuid_timestamp, **kwargs) # noqa: E501 @validate_arguments - def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], **kwargs) -> ApiResponse: # noqa: E501 - """get_prediction_by_sample_id # noqa: E501 + def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """get_prediction_task_schema_by_task_name # noqa: E501 - Get all prediction singletons of a specific sample of a dataset # noqa: E501 + Get a prediction task schemas named taskName for a datasetId # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True) + >>> thread = api.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, task_name, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param sample_id: ObjectId of the sample (required) - :type sample_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param task_name: The prediction task name for which one wants to list the predictions (required) + :type task_name: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional @@ -466,14 +466,14 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: tuple(List[PredictionSingleton], status_code(int), headers(HTTPHeaderDict)) + :rtype: tuple(PredictionTaskSchema, status_code(int), headers(HTTPHeaderDict)) """ _params = locals() _all_params = [ 'dataset_id', - 'sample_id', + 'task_name', 'prediction_uuid_timestamp' ] _all_params.extend( @@ -493,7 +493,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons if _key not in _all_params: raise ApiTypeError( "Got an unexpected keyword argument '%s'" - " to method get_prediction_by_sample_id" % _key + " to method get_prediction_task_schema_by_task_name" % _key ) _params[_key] = _val del _params['kwargs'] @@ -505,8 +505,8 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons if _params['dataset_id']: _path_params['datasetId'] = _params['dataset_id'] - if _params['sample_id']: - _path_params['sampleId'] = _params['sample_id'] + if _params['task_name']: + _path_params['taskName'] = _params['task_name'] # process the query parameters @@ -532,7 +532,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 _response_types_map = { - '200': "List[PredictionSingleton]", + '200': "PredictionTaskSchema", '400': "ApiErrorResponse", '401': "ApiErrorResponse", '403': "ApiErrorResponse", @@ -540,7 +540,7 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons } return self.api_client.call_api( - '/v1/datasets/{datasetId}/predictions/samples/{sampleId}', 'GET', + '/v1/datasets/{datasetId}/predictions/tasks/{taskName}', 'GET', _path_params, _query_params, _header_params, @@ -557,22 +557,20 @@ def get_prediction_by_sample_id_with_http_info(self, dataset_id : Annotated[cons _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], **kwargs) -> PredictionTaskSchema: # noqa: E501 - """get_prediction_task_schema_by_task_name # noqa: E501 + def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchemas: # noqa: E501 + """get_prediction_task_schemas_by_dataset_id # noqa: E501 - Get a prediction task schemas named taskName for a datasetId # noqa: E501 + Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_task_schema_by_task_name(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) + >>> thread = api.get_prediction_task_schemas_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int - :param task_name: The prediction task name for which one wants to list the predictions (required) - :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -582,30 +580,28 @@ def get_prediction_task_schema_by_task_name(self, dataset_id : Annotated[constr( :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: PredictionTaskSchema + :rtype: PredictionTaskSchemas """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: - raise ValueError("Error! Please call the get_prediction_task_schema_by_task_name_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501 + raise ValueError("Error! Please call the get_prediction_task_schemas_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501 @validate_arguments - def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], **kwargs) -> ApiResponse: # noqa: E501 - """get_prediction_task_schema_by_task_name # noqa: E501 + def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """get_prediction_task_schemas_by_dataset_id # noqa: E501 - Get a prediction task schemas named taskName for a datasetId # noqa: E501 + Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_task_schema_by_task_name_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) + >>> thread = api.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int - :param task_name: The prediction task name for which one wants to list the predictions (required) - :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -628,15 +624,14 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: tuple(PredictionTaskSchema, status_code(int), headers(HTTPHeaderDict)) + :rtype: tuple(PredictionTaskSchemas, status_code(int), headers(HTTPHeaderDict)) """ _params = locals() _all_params = [ 'dataset_id', - 'prediction_uuid_timestamp', - 'task_name' + 'prediction_uuid_timestamp' ] _all_params.extend( [ @@ -655,7 +650,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An if _key not in _all_params: raise ApiTypeError( "Got an unexpected keyword argument '%s'" - " to method get_prediction_task_schema_by_task_name" % _key + " to method get_prediction_task_schemas_by_dataset_id" % _key ) _params[_key] = _val del _params['kwargs'] @@ -667,9 +662,6 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An if _params['dataset_id']: _path_params['datasetId'] = _params['dataset_id'] - if _params['task_name']: - _path_params['taskName'] = _params['task_name'] - # process the query parameters _query_params = [] @@ -694,7 +686,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 _response_types_map = { - '200': "PredictionTaskSchema", + '200': "PredictionTaskSchemas", '400': "ApiErrorResponse", '401': "ApiErrorResponse", '403': "ApiErrorResponse", @@ -702,7 +694,7 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An } return self.api_client.call_api( - '/v1/datasets/{datasetId}/predictions/tasks/{taskName}', 'GET', + '/v1/datasets/{datasetId}/predictions/tasks', 'GET', _path_params, _query_params, _header_params, @@ -719,20 +711,22 @@ def get_prediction_task_schema_by_task_name_with_http_info(self, dataset_id : An _request_auth=_params.get('_request_auth')) @validate_arguments - def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> PredictionTaskSchemas: # noqa: E501 - """get_prediction_task_schemas_by_dataset_id # noqa: E501 + def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> List[List]: # noqa: E501 + """get_predictions_by_dataset_id # noqa: E501 - Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501 + Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_task_schemas_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True) + >>> thread = api.get_predictions_by_dataset_id(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int + :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name + :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -742,28 +736,30 @@ def get_prediction_task_schemas_by_dataset_id(self, dataset_id : Annotated[const :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: PredictionTaskSchemas + :rtype: List[List] """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: - raise ValueError("Error! Please call the get_prediction_task_schemas_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501 + raise ValueError("Error! Please call the get_predictions_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501 @validate_arguments - def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 - """get_prediction_task_schemas_by_dataset_id # noqa: E501 + def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """get_predictions_by_dataset_id # noqa: E501 - Get list of all the prediction task schemas for a datasetId at a specific predictionUUIDTimestamp. If no predictionUUIDTimestamp is set, it defaults to the newest # noqa: E501 + Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_prediction_task_schemas_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True) + >>> thread = api.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int + :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name + :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -786,14 +782,15 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: tuple(PredictionTaskSchemas, status_code(int), headers(HTTPHeaderDict)) + :rtype: tuple(List[List], status_code(int), headers(HTTPHeaderDict)) """ _params = locals() _all_params = [ 'dataset_id', - 'prediction_uuid_timestamp' + 'prediction_uuid_timestamp', + 'task_name' ] _all_params.extend( [ @@ -812,7 +809,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : if _key not in _all_params: raise ApiTypeError( "Got an unexpected keyword argument '%s'" - " to method get_prediction_task_schemas_by_dataset_id" % _key + " to method get_predictions_by_dataset_id" % _key ) _params[_key] = _val del _params['kwargs'] @@ -833,6 +830,12 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp'] )) + if _params.get('task_name') is not None: # noqa: E501 + _query_params.append(( + 'taskName', + _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name'] + )) + # process the header parameters _header_params = dict(_params.get('_headers', {})) # process the form parameters @@ -848,7 +851,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 _response_types_map = { - '200': "PredictionTaskSchemas", + '200': "List[List]", '400': "ApiErrorResponse", '401': "ApiErrorResponse", '403': "ApiErrorResponse", @@ -856,7 +859,7 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : } return self.api_client.call_api( - '/v1/datasets/{datasetId}/predictions/tasks', 'GET', + '/v1/datasets/{datasetId}/predictions/samples', 'GET', _path_params, _query_params, _header_params, @@ -873,22 +876,22 @@ def get_prediction_task_schemas_by_dataset_id_with_http_info(self, dataset_id : _request_auth=_params.get('_request_auth')) @validate_arguments - def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> List[List]: # noqa: E501 - """get_predictions_by_dataset_id # noqa: E501 + def get_predictions_by_sample_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> List[PredictionSingleton]: # noqa: E501 + """get_predictions_by_sample_id # noqa: E501 - Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501 + Get all prediction singletons of all tasks for a specific sample of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_predictions_by_dataset_id(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) + >>> thread = api.get_predictions_by_sample_id(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param sample_id: ObjectId of the sample (required) + :type sample_id: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int - :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name - :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _request_timeout: timeout setting for this request. If one @@ -898,30 +901,30 @@ def get_predictions_by_dataset_id(self, dataset_id : Annotated[constr(strict=Tru :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: List[List] + :rtype: List[PredictionSingleton] """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: - raise ValueError("Error! Please call the get_predictions_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") - return self.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, **kwargs) # noqa: E501 + raise ValueError("Error! Please call the get_predictions_by_sample_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_predictions_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, **kwargs) # noqa: E501 @validate_arguments - def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[conint(strict=True, ge=0), Field(..., description="The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")], task_name : Annotated[Optional[constr(strict=True, min_length=1)], Field(description="If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name")] = None, **kwargs) -> ApiResponse: # noqa: E501 - """get_predictions_by_dataset_id # noqa: E501 + def get_predictions_by_sample_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], sample_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the sample")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """get_predictions_by_sample_id # noqa: E501 - Get all prediction singletons of all samples of a dataset ordered by the sample mapping # noqa: E501 + Get all prediction singletons of all tasks for a specific sample of a dataset # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True - >>> thread = api.get_predictions_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, task_name, async_req=True) + >>> thread = api.get_predictions_by_sample_id_with_http_info(dataset_id, sample_id, prediction_uuid_timestamp, async_req=True) >>> result = thread.get() :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param prediction_uuid_timestamp: The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. (required) + :param sample_id: ObjectId of the sample (required) + :type sample_id: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. :type prediction_uuid_timestamp: int - :param task_name: If provided, only gets all prediction singletons of all samples of a dataset that were yielded by a specific prediction task name - :type task_name: str :param async_req: Whether to execute the request asynchronously. :type async_req: bool, optional :param _preload_content: if False, the ApiResponse.data will @@ -944,15 +947,15 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: tuple(List[List], status_code(int), headers(HTTPHeaderDict)) + :rtype: tuple(List[PredictionSingleton], status_code(int), headers(HTTPHeaderDict)) """ _params = locals() _all_params = [ 'dataset_id', - 'prediction_uuid_timestamp', - 'task_name' + 'sample_id', + 'prediction_uuid_timestamp' ] _all_params.extend( [ @@ -971,7 +974,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co if _key not in _all_params: raise ApiTypeError( "Got an unexpected keyword argument '%s'" - " to method get_predictions_by_dataset_id" % _key + " to method get_predictions_by_sample_id" % _key ) _params[_key] = _val del _params['kwargs'] @@ -983,6 +986,9 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co if _params['dataset_id']: _path_params['datasetId'] = _params['dataset_id'] + if _params['sample_id']: + _path_params['sampleId'] = _params['sample_id'] + # process the query parameters _query_params = [] @@ -992,12 +998,6 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp'] )) - if _params.get('task_name') is not None: # noqa: E501 - _query_params.append(( - 'taskName', - _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name'] - )) - # process the header parameters _header_params = dict(_params.get('_headers', {})) # process the form parameters @@ -1013,7 +1013,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 _response_types_map = { - '200': "List[List]", + '200': "List[PredictionSingleton]", '400': "ApiErrorResponse", '401': "ApiErrorResponse", '403': "ApiErrorResponse", @@ -1021,7 +1021,7 @@ def get_predictions_by_dataset_id_with_http_info(self, dataset_id : Annotated[co } return self.api_client.call_api( - '/v1/datasets/{datasetId}/predictions/samples', 'GET', + '/v1/datasets/{datasetId}/predictions/samples/{sampleId}', 'GET', _path_params, _query_params, _header_params, diff --git a/lightly/openapi_generated/swagger_client/api/samples_api.py b/lightly/openapi_generated/swagger_client/api/samples_api.py index 53f3ab74e..85be18c03 100644 --- a/lightly/openapi_generated/swagger_client/api/samples_api.py +++ b/lightly/openapi_generated/swagger_client/api/samples_api.py @@ -1153,7 +1153,7 @@ def get_sample_image_write_urls_by_id_with_http_info(self, dataset_id : Annotate _request_auth=_params.get('_request_auth')) @validate_arguments - def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[SampleData]: # noqa: E501 + def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="DEPRECATED, use without and filter yourself - Filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[SampleData]: # noqa: E501 """get_samples_by_dataset_id # noqa: E501 Get all samples of a dataset # noqa: E501 @@ -1165,7 +1165,7 @@ def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: filter the samples by filename + :param file_name: DEPRECATED, use without and filter yourself - Filter the samples by filename :type file_name: str :param sort_by: sort the samples :type sort_by: SampleSortBy @@ -1190,7 +1190,7 @@ def get_samples_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), return self.get_samples_by_dataset_id_with_http_info(dataset_id, file_name, sort_by, page_size, page_offset, **kwargs) # noqa: E501 @validate_arguments - def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501 + def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], file_name : Annotated[Optional[StrictStr], Field(description="DEPRECATED, use without and filter yourself - Filter the samples by filename")] = None, sort_by : Annotated[Optional[SampleSortBy], Field(description="sort the samples")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501 """get_samples_by_dataset_id # noqa: E501 Get all samples of a dataset # noqa: E501 @@ -1202,7 +1202,7 @@ def get_samples_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr :param dataset_id: ObjectId of the dataset (required) :type dataset_id: str - :param file_name: filter the samples by filename + :param file_name: DEPRECATED, use without and filter yourself - Filter the samples by filename :type file_name: str :param sort_by: sort the samples :type sort_by: SampleSortBy diff --git a/lightly/openapi_generated/swagger_client/api/scores_api.py b/lightly/openapi_generated/swagger_client/api/scores_api.py index 7c43e49cf..fadd3efcc 100644 --- a/lightly/openapi_generated/swagger_client/api/scores_api.py +++ b/lightly/openapi_generated/swagger_client/api/scores_api.py @@ -20,12 +20,14 @@ from pydantic import validate_arguments, ValidationError from typing_extensions import Annotated -from pydantic import Field, constr, validator +from pydantic import Field, conint, constr, validator -from typing import List +from typing import List, Optional from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData +from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data +from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data from lightly.openapi_generated.swagger_client.models.create_entity_response import CreateEntityResponse from lightly.openapi_generated.swagger_client.models.tag_active_learning_scores_data import TagActiveLearningScoresData @@ -51,7 +53,7 @@ def __init__(self, api_client=None): @validate_arguments def create_or_update_active_learning_score_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, **kwargs) -> CreateEntityResponse: # noqa: E501 - """create_or_update_active_learning_score_by_tag_id # noqa: E501 + """(Deprecated) create_or_update_active_learning_score_by_tag_id # noqa: E501 Create or update active learning score object by tag id # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -84,7 +86,7 @@ def create_or_update_active_learning_score_by_tag_id(self, dataset_id : Annotate @validate_arguments def create_or_update_active_learning_score_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, **kwargs) -> ApiResponse: # noqa: E501 - """create_or_update_active_learning_score_by_tag_id # noqa: E501 + """(Deprecated) create_or_update_active_learning_score_by_tag_id # noqa: E501 Create or update active learning score object by tag id # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -124,6 +126,8 @@ def create_or_update_active_learning_score_by_tag_id_with_http_info(self, datase :rtype: tuple(CreateEntityResponse, status_code(int), headers(HTTPHeaderDict)) """ + warnings.warn("POST /v1/datasets/{datasetId}/tags/{tagId}/scores is deprecated.", DeprecationWarning) + _params = locals() _all_params = [ @@ -215,9 +219,189 @@ def create_or_update_active_learning_score_by_tag_id_with_http_info(self, datase collection_formats=_collection_formats, _request_auth=_params.get('_request_auth')) + @validate_arguments + def create_or_update_active_learning_v2_score_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> CreateEntityResponse: # noqa: E501 + """create_or_update_active_learning_v2_score_by_dataset_id # noqa: E501 + + Create or update active learning score object for a dataset, taskName, predictionUUIDTimestamp # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.create_or_update_active_learning_v2_score_by_dataset_id(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param task_name: The prediction task name for which one wants to list the predictions (required) + :type task_name: str + :param active_learning_score_create_request: (required) + :type active_learning_score_create_request: ActiveLearningScoreCreateRequest + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: CreateEntityResponse + """ + kwargs['_return_http_data_only'] = True + if '_preload_content' in kwargs: + raise ValueError("Error! Please call the create_or_update_active_learning_v2_score_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, **kwargs) # noqa: E501 + + @validate_arguments + def create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], task_name : Annotated[constr(strict=True, min_length=1), Field(..., description="The prediction task name for which one wants to list the predictions")], active_learning_score_create_request : ActiveLearningScoreCreateRequest, prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """create_or_update_active_learning_v2_score_by_dataset_id # noqa: E501 + + Create or update active learning score object for a dataset, taskName, predictionUUIDTimestamp # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.create_or_update_active_learning_v2_score_by_dataset_id_with_http_info(dataset_id, task_name, active_learning_score_create_request, prediction_uuid_timestamp, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param task_name: The prediction task name for which one wants to list the predictions (required) + :type task_name: str + :param active_learning_score_create_request: (required) + :type active_learning_score_create_request: ActiveLearningScoreCreateRequest + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _preload_content: if False, the ApiResponse.data will + be set to none and raw_data will store the + HTTP response body without reading/decoding. + Default is True. + :type _preload_content: bool, optional + :param _return_http_data_only: response data instead of ApiResponse + object with status code, headers, etc + :type _return_http_data_only: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the authentication + in the spec for a single request. + :type _request_auth: dict, optional + :type _content_type: string, optional: force content-type for the request + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: tuple(CreateEntityResponse, status_code(int), headers(HTTPHeaderDict)) + """ + + _params = locals() + + _all_params = [ + 'dataset_id', + 'task_name', + 'active_learning_score_create_request', + 'prediction_uuid_timestamp' + ] + _all_params.extend( + [ + 'async_req', + '_return_http_data_only', + '_preload_content', + '_request_timeout', + '_request_auth', + '_content_type', + '_headers' + ] + ) + + # validate the arguments + for _key, _val in _params['kwargs'].items(): + if _key not in _all_params: + raise ApiTypeError( + "Got an unexpected keyword argument '%s'" + " to method create_or_update_active_learning_v2_score_by_dataset_id" % _key + ) + _params[_key] = _val + del _params['kwargs'] + + _collection_formats = {} + + # process the path parameters + _path_params = {} + if _params['dataset_id']: + _path_params['datasetId'] = _params['dataset_id'] + + + # process the query parameters + _query_params = [] + if _params.get('task_name') is not None: # noqa: E501 + _query_params.append(( + 'taskName', + _params['task_name'].value if hasattr(_params['task_name'], 'value') else _params['task_name'] + )) + + if _params.get('prediction_uuid_timestamp') is not None: # noqa: E501 + _query_params.append(( + 'predictionUUIDTimestamp', + _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp'] + )) + + # process the header parameters + _header_params = dict(_params.get('_headers', {})) + # process the form parameters + _form_params = [] + _files = {} + # process the body parameter + _body_params = None + if _params['active_learning_score_create_request'] is not None: + _body_params = _params['active_learning_score_create_request'] + + # set the HTTP header `Accept` + _header_params['Accept'] = self.api_client.select_header_accept( + ['application/json']) # noqa: E501 + + # set the HTTP header `Content-Type` + _content_types_list = _params.get('_content_type', + self.api_client.select_header_content_type( + ['application/json'])) + if _content_types_list: + _header_params['Content-Type'] = _content_types_list + + # authentication setting + _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 + + _response_types_map = { + '201': "CreateEntityResponse", + '400': "ApiErrorResponse", + '401': "ApiErrorResponse", + '403': "ApiErrorResponse", + '404': "ApiErrorResponse", + } + + return self.api_client.call_api( + '/v1/datasets/{datasetId}/predictions/scores', 'POST', + _path_params, + _query_params, + _header_params, + body=_body_params, + post_params=_form_params, + files=_files, + response_types_map=_response_types_map, + auth_settings=_auth_settings, + async_req=_params.get('async_req'), + _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501 + _preload_content=_params.get('_preload_content', True), + _request_timeout=_params.get('_request_timeout'), + collection_formats=_collection_formats, + _request_auth=_params.get('_request_auth')) + @validate_arguments def get_active_learning_score_by_score_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ActiveLearningScoreData: # noqa: E501 - """get_active_learning_score_by_score_id # noqa: E501 + """(Deprecated) get_active_learning_score_by_score_id # noqa: E501 Get active learning score object by id # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -250,7 +434,7 @@ def get_active_learning_score_by_score_id(self, dataset_id : Annotated[constr(st @validate_arguments def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ApiResponse: # noqa: E501 - """get_active_learning_score_by_score_id # noqa: E501 + """(Deprecated) get_active_learning_score_by_score_id # noqa: E501 Get active learning score object by id # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -290,6 +474,8 @@ def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Anno :rtype: tuple(ActiveLearningScoreData, status_code(int), headers(HTTPHeaderDict)) """ + warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/scores/{scoreId} is deprecated.", DeprecationWarning) + _params = locals() _all_params = [ @@ -376,7 +562,7 @@ def get_active_learning_score_by_score_id_with_http_info(self, dataset_id : Anno @validate_arguments def get_active_learning_scores_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> List[TagActiveLearningScoresData]: # noqa: E501 - """get_active_learning_scores_by_tag_id # noqa: E501 + """(Deprecated) get_active_learning_scores_by_tag_id # noqa: E501 Get all scoreIds for the given tag # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -407,7 +593,7 @@ def get_active_learning_scores_by_tag_id(self, dataset_id : Annotated[constr(str @validate_arguments def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501 - """get_active_learning_scores_by_tag_id # noqa: E501 + """(Deprecated) get_active_learning_scores_by_tag_id # noqa: E501 Get all scoreIds for the given tag # noqa: E501 This method makes a synchronous HTTP request by default. To make an @@ -445,6 +631,8 @@ def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annot :rtype: tuple(List[TagActiveLearningScoresData], status_code(int), headers(HTTPHeaderDict)) """ + warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/scores is deprecated.", DeprecationWarning) + _params = locals() _all_params = [ @@ -524,3 +712,308 @@ def get_active_learning_scores_by_tag_id_with_http_info(self, dataset_id : Annot _request_timeout=_params.get('_request_timeout'), collection_formats=_collection_formats, _request_auth=_params.get('_request_auth')) + + @validate_arguments + def get_active_learning_v2_score_by_dataset_and_score_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ActiveLearningScoreV2Data: # noqa: E501 + """get_active_learning_v2_score_by_dataset_and_score_id # noqa: E501 + + Get active learning scores by scoreId # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_active_learning_v2_score_by_dataset_and_score_id(dataset_id, score_id, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param score_id: ObjectId of the scores (required) + :type score_id: str + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: ActiveLearningScoreV2Data + """ + kwargs['_return_http_data_only'] = True + if '_preload_content' in kwargs: + raise ValueError("Error! Please call the get_active_learning_v2_score_by_dataset_and_score_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(dataset_id, score_id, **kwargs) # noqa: E501 + + @validate_arguments + def get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], score_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the scores")], **kwargs) -> ApiResponse: # noqa: E501 + """get_active_learning_v2_score_by_dataset_and_score_id # noqa: E501 + + Get active learning scores by scoreId # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_active_learning_v2_score_by_dataset_and_score_id_with_http_info(dataset_id, score_id, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param score_id: ObjectId of the scores (required) + :type score_id: str + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _preload_content: if False, the ApiResponse.data will + be set to none and raw_data will store the + HTTP response body without reading/decoding. + Default is True. + :type _preload_content: bool, optional + :param _return_http_data_only: response data instead of ApiResponse + object with status code, headers, etc + :type _return_http_data_only: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the authentication + in the spec for a single request. + :type _request_auth: dict, optional + :type _content_type: string, optional: force content-type for the request + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: tuple(ActiveLearningScoreV2Data, status_code(int), headers(HTTPHeaderDict)) + """ + + _params = locals() + + _all_params = [ + 'dataset_id', + 'score_id' + ] + _all_params.extend( + [ + 'async_req', + '_return_http_data_only', + '_preload_content', + '_request_timeout', + '_request_auth', + '_content_type', + '_headers' + ] + ) + + # validate the arguments + for _key, _val in _params['kwargs'].items(): + if _key not in _all_params: + raise ApiTypeError( + "Got an unexpected keyword argument '%s'" + " to method get_active_learning_v2_score_by_dataset_and_score_id" % _key + ) + _params[_key] = _val + del _params['kwargs'] + + _collection_formats = {} + + # process the path parameters + _path_params = {} + if _params['dataset_id']: + _path_params['datasetId'] = _params['dataset_id'] + + if _params['score_id']: + _path_params['scoreId'] = _params['score_id'] + + + # process the query parameters + _query_params = [] + # process the header parameters + _header_params = dict(_params.get('_headers', {})) + # process the form parameters + _form_params = [] + _files = {} + # process the body parameter + _body_params = None + # set the HTTP header `Accept` + _header_params['Accept'] = self.api_client.select_header_accept( + ['application/json']) # noqa: E501 + + # authentication setting + _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 + + _response_types_map = { + '200': "ActiveLearningScoreV2Data", + '400': "ApiErrorResponse", + '401': "ApiErrorResponse", + '403': "ApiErrorResponse", + '404': "ApiErrorResponse", + } + + return self.api_client.call_api( + '/v1/datasets/{datasetId}/predictions/scores/{scoreId}', 'GET', + _path_params, + _query_params, + _header_params, + body=_body_params, + post_params=_form_params, + files=_files, + response_types_map=_response_types_map, + auth_settings=_auth_settings, + async_req=_params.get('async_req'), + _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501 + _preload_content=_params.get('_preload_content', True), + _request_timeout=_params.get('_request_timeout'), + collection_formats=_collection_formats, + _request_auth=_params.get('_request_auth')) + + @validate_arguments + def get_active_learning_v2_scores_by_dataset_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> List[ActiveLearningScoreTypesV2Data]: # noqa: E501 + """get_active_learning_v2_scores_by_dataset_id # noqa: E501 + + Get all AL score types by datasetId and predictionUUIDTimestamp # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_active_learning_v2_scores_by_dataset_id(dataset_id, prediction_uuid_timestamp, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: List[ActiveLearningScoreTypesV2Data] + """ + kwargs['_return_http_data_only'] = True + if '_preload_content' in kwargs: + raise ValueError("Error! Please call the get_active_learning_v2_scores_by_dataset_id_with_http_info method with `_preload_content` instead and obtain raw data from ApiResponse.raw_data") + return self.get_active_learning_v2_scores_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, **kwargs) # noqa: E501 + + @validate_arguments + def get_active_learning_v2_scores_by_dataset_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], prediction_uuid_timestamp : Annotated[Optional[conint(strict=True, ge=0)], Field(description="Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. ")] = None, **kwargs) -> ApiResponse: # noqa: E501 + """get_active_learning_v2_scores_by_dataset_id # noqa: E501 + + Get all AL score types by datasetId and predictionUUIDTimestamp # noqa: E501 + This method makes a synchronous HTTP request by default. To make an + asynchronous HTTP request, please pass async_req=True + + >>> thread = api.get_active_learning_v2_scores_by_dataset_id_with_http_info(dataset_id, prediction_uuid_timestamp, async_req=True) + >>> result = thread.get() + + :param dataset_id: ObjectId of the dataset (required) + :type dataset_id: str + :param prediction_uuid_timestamp: Deprecated, currently ignored. The timestamp of when the actual predictions were created. This is used as a peg to version predictions. E.g one could upload predictions on day 1 and then create new predictions with an improved model on day 30. One can then upload the new predictions to the same dataset. + :type prediction_uuid_timestamp: int + :param async_req: Whether to execute the request asynchronously. + :type async_req: bool, optional + :param _preload_content: if False, the ApiResponse.data will + be set to none and raw_data will store the + HTTP response body without reading/decoding. + Default is True. + :type _preload_content: bool, optional + :param _return_http_data_only: response data instead of ApiResponse + object with status code, headers, etc + :type _return_http_data_only: bool, optional + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the authentication + in the spec for a single request. + :type _request_auth: dict, optional + :type _content_type: string, optional: force content-type for the request + :return: Returns the result object. + If the method is called asynchronously, + returns the request thread. + :rtype: tuple(List[ActiveLearningScoreTypesV2Data], status_code(int), headers(HTTPHeaderDict)) + """ + + _params = locals() + + _all_params = [ + 'dataset_id', + 'prediction_uuid_timestamp' + ] + _all_params.extend( + [ + 'async_req', + '_return_http_data_only', + '_preload_content', + '_request_timeout', + '_request_auth', + '_content_type', + '_headers' + ] + ) + + # validate the arguments + for _key, _val in _params['kwargs'].items(): + if _key not in _all_params: + raise ApiTypeError( + "Got an unexpected keyword argument '%s'" + " to method get_active_learning_v2_scores_by_dataset_id" % _key + ) + _params[_key] = _val + del _params['kwargs'] + + _collection_formats = {} + + # process the path parameters + _path_params = {} + if _params['dataset_id']: + _path_params['datasetId'] = _params['dataset_id'] + + + # process the query parameters + _query_params = [] + if _params.get('prediction_uuid_timestamp') is not None: # noqa: E501 + _query_params.append(( + 'predictionUUIDTimestamp', + _params['prediction_uuid_timestamp'].value if hasattr(_params['prediction_uuid_timestamp'], 'value') else _params['prediction_uuid_timestamp'] + )) + + # process the header parameters + _header_params = dict(_params.get('_headers', {})) + # process the form parameters + _form_params = [] + _files = {} + # process the body parameter + _body_params = None + # set the HTTP header `Accept` + _header_params['Accept'] = self.api_client.select_header_accept( + ['application/json']) # noqa: E501 + + # authentication setting + _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 + + _response_types_map = { + '200': "List[ActiveLearningScoreTypesV2Data]", + '400': "ApiErrorResponse", + '401': "ApiErrorResponse", + '403': "ApiErrorResponse", + '404': "ApiErrorResponse", + } + + return self.api_client.call_api( + '/v1/datasets/{datasetId}/predictions/scores', 'GET', + _path_params, + _query_params, + _header_params, + body=_body_params, + post_params=_form_params, + files=_files, + response_types_map=_response_types_map, + auth_settings=_auth_settings, + async_req=_params.get('async_req'), + _return_http_data_only=_params.get('_return_http_data_only'), # noqa: E501 + _preload_content=_params.get('_preload_content', True), + _request_timeout=_params.get('_request_timeout'), + collection_formats=_collection_formats, + _request_auth=_params.get('_request_auth')) diff --git a/lightly/openapi_generated/swagger_client/api/tags_api.py b/lightly/openapi_generated/swagger_client/api/tags_api.py index 4affcbccd..428c23c98 100644 --- a/lightly/openapi_generated/swagger_client/api/tags_api.py +++ b/lightly/openapi_generated/swagger_client/api/tags_api.py @@ -524,9 +524,9 @@ def delete_tag_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(stri @validate_arguments def download_zip_of_samples_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> bytearray: # noqa: E501 - """download_zip_of_samples_by_tag_id # noqa: E501 + """(Deprecated) download_zip_of_samples_by_tag_id # noqa: E501 - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501 + DEPRECATED - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -555,9 +555,9 @@ def download_zip_of_samples_by_tag_id(self, dataset_id : Annotated[constr(strict @validate_arguments def download_zip_of_samples_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501 - """download_zip_of_samples_by_tag_id # noqa: E501 + """(Deprecated) download_zip_of_samples_by_tag_id # noqa: E501 - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501 + DEPRECATED - Download a zip file of the samples of a tag. Limited to 1000 images # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -593,6 +593,8 @@ def download_zip_of_samples_by_tag_id_with_http_info(self, dataset_id : Annotate :rtype: tuple(bytearray, status_code(int), headers(HTTPHeaderDict)) """ + warnings.warn("GET /v1/datasets/{datasetId}/tags/{tagId}/export/zip is deprecated.", DeprecationWarning) + _params = locals() _all_params = [ @@ -1112,7 +1114,7 @@ def export_tag_to_basic_filenames_and_read_urls_with_http_info(self, dataset_id def export_tag_to_label_box_data_rows(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], expires_in : Annotated[Optional[StrictInt], Field(description="If defined, the URLs provided will only be valid for amount of seconds from time of issuence. If not defined, the URls will be valid indefinitely. ")] = None, access_control : Annotated[Optional[StrictStr], Field(description="which access control name to be used")] = None, file_name_format : Optional[FileNameFormat] = None, include_meta_data : Annotated[Optional[StrictBool], Field(description="if true, will also include metadata")] = None, format : Optional[FileOutputFormat] = None, preview_example : Annotated[Optional[StrictBool], Field(description="if true, will generate a preview example of how the structure will look")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> List[LabelBoxDataRow]: # noqa: E501 """(Deprecated) export_tag_to_label_box_data_rows # noqa: E501 - Deprecated. Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501 + DEPRECATED - Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -1159,7 +1161,7 @@ def export_tag_to_label_box_data_rows(self, dataset_id : Annotated[constr(strict def export_tag_to_label_box_data_rows_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], expires_in : Annotated[Optional[StrictInt], Field(description="If defined, the URLs provided will only be valid for amount of seconds from time of issuence. If not defined, the URls will be valid indefinitely. ")] = None, access_control : Annotated[Optional[StrictStr], Field(description="which access control name to be used")] = None, file_name_format : Optional[FileNameFormat] = None, include_meta_data : Annotated[Optional[StrictBool], Field(description="if true, will also include metadata")] = None, format : Optional[FileOutputFormat] = None, preview_example : Annotated[Optional[StrictBool], Field(description="if true, will generate a preview example of how the structure will look")] = None, page_size : Annotated[Optional[conint(strict=True, ge=1)], Field(description="pagination size/limit of the number of samples to return")] = None, page_offset : Annotated[Optional[conint(strict=True, ge=0)], Field(description="pagination offset")] = None, **kwargs) -> ApiResponse: # noqa: E501 """(Deprecated) export_tag_to_label_box_data_rows # noqa: E501 - Deprecated. Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501 + DEPRECATED - Please use V4 unless there is a specific need to use the LabelBox V3 API. Export samples of a tag as a json for importing into LabelBox as outlined here; https://docs.labelbox.com/v3/reference/image ```openapi\\+warning The image URLs are special in that the resource can be accessed by anyone in possession of said URL for the time specified by the expiresIn query param ``` # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2070,7 +2072,7 @@ def export_tag_to_sama_tasks_with_http_info(self, dataset_id : Annotated[constr( def get_filenames_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> List[str]: # noqa: E501 """(Deprecated) get_filenames_by_tag_id # noqa: E501 - Get list of filenames by tag. Deprecated, please use # noqa: E501 + DEPRECATED, please use exportTagToBasicFilenames - Get list of filenames by tag. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2101,7 +2103,7 @@ def get_filenames_by_tag_id(self, dataset_id : Annotated[constr(strict=True), Fi def get_filenames_by_tag_id_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the tag")], **kwargs) -> ApiResponse: # noqa: E501 """(Deprecated) get_filenames_by_tag_id # noqa: E501 - Get list of filenames by tag. Deprecated, please use # noqa: E501 + DEPRECATED, please use exportTagToBasicFilenames - Get list of filenames by tag. # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2675,7 +2677,7 @@ def perform_tag_arithmetics_with_http_info(self, dataset_id : Annotated[constr(s def perform_tag_arithmetics_bitmask(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_arithmetics_request : TagArithmeticsRequest, **kwargs) -> TagBitMaskResponse: # noqa: E501 """(Deprecated) perform_tag_arithmetics_bitmask # noqa: E501 - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501 + DEPRECATED, use performTagArithmetics - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True @@ -2706,7 +2708,7 @@ def perform_tag_arithmetics_bitmask(self, dataset_id : Annotated[constr(strict=T def perform_tag_arithmetics_bitmask_with_http_info(self, dataset_id : Annotated[constr(strict=True), Field(..., description="ObjectId of the dataset")], tag_arithmetics_request : TagArithmeticsRequest, **kwargs) -> ApiResponse: # noqa: E501 """(Deprecated) perform_tag_arithmetics_bitmask # noqa: E501 - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501 + DEPRECATED, use performTagArithmetics - Performs tag arithmetics to compute a new bitmask out of two existing tags. Does not create a new tag regardless if newTagName is provided # noqa: E501 This method makes a synchronous HTTP request by default. To make an asynchronous HTTP request, please pass async_req=True diff --git a/lightly/openapi_generated/swagger_client/api/teams_api.py b/lightly/openapi_generated/swagger_client/api/teams_api.py index 811099f83..7eed9b158 100644 --- a/lightly/openapi_generated/swagger_client/api/teams_api.py +++ b/lightly/openapi_generated/swagger_client/api/teams_api.py @@ -51,7 +51,7 @@ def __init__(self, api_client=None): self.api_client = api_client @validate_arguments - def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., description="id of the team")], create_team_membership_request : CreateTeamMembershipRequest, **kwargs) -> None: # noqa: E501 + def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., description="id of the team")], create_team_membership_request : CreateTeamMembershipRequest, **kwargs) -> str: # noqa: E501 """add_team_member # noqa: E501 Add a team member. One needs to be part of the team to do so. # noqa: E501 @@ -74,7 +74,7 @@ def add_team_member(self, team_id : Annotated[constr(strict=True), Field(..., de :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: None + :rtype: str """ kwargs['_return_http_data_only'] = True if '_preload_content' in kwargs: @@ -118,7 +118,7 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True) :return: Returns the result object. If the method is called asynchronously, returns the request thread. - :rtype: None + :rtype: tuple(str, status_code(int), headers(HTTPHeaderDict)) """ _params = locals() @@ -171,7 +171,7 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True) # set the HTTP header `Accept` _header_params['Accept'] = self.api_client.select_header_accept( - ['application/json']) # noqa: E501 + ['text/plain', 'application/json']) # noqa: E501 # set the HTTP header `Content-Type` _content_types_list = _params.get('_content_type', @@ -183,7 +183,13 @@ def add_team_member_with_http_info(self, team_id : Annotated[constr(strict=True) # authentication setting _auth_settings = ['auth0Bearer', 'ApiKeyAuth'] # noqa: E501 - _response_types_map = {} + _response_types_map = { + '200': "str", + '400': "ApiErrorResponse", + '401': "ApiErrorResponse", + '403': "ApiErrorResponse", + '404': "ApiErrorResponse", + } return self.api_client.call_api( '/v1/teams/{teamId}/members', 'POST', diff --git a/lightly/openapi_generated/swagger_client/models/__init__.py b/lightly/openapi_generated/swagger_client/models/__init__.py index 1c47427c7..fd20c0edc 100644 --- a/lightly/openapi_generated/swagger_client/models/__init__.py +++ b/lightly/openapi_generated/swagger_client/models/__init__.py @@ -17,6 +17,8 @@ # import models into model package from lightly.openapi_generated.swagger_client.models.active_learning_score_create_request import ActiveLearningScoreCreateRequest from lightly.openapi_generated.swagger_client.models.active_learning_score_data import ActiveLearningScoreData +from lightly.openapi_generated.swagger_client.models.active_learning_score_types_v2_data import ActiveLearningScoreTypesV2Data +from lightly.openapi_generated.swagger_client.models.active_learning_score_v2_data import ActiveLearningScoreV2Data from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode from lightly.openapi_generated.swagger_client.models.api_error_response import ApiErrorResponse from lightly.openapi_generated.swagger_client.models.async_task_data import AsyncTaskData @@ -42,10 +44,12 @@ from lightly.openapi_generated.swagger_client.models.datasource_config_azure import DatasourceConfigAzure from lightly.openapi_generated.swagger_client.models.datasource_config_azure_all_of import DatasourceConfigAzureAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase +from lightly.openapi_generated.swagger_client.models.datasource_config_base_full_path import DatasourceConfigBaseFullPath from lightly.openapi_generated.swagger_client.models.datasource_config_gcs import DatasourceConfigGCS from lightly.openapi_generated.swagger_client.models.datasource_config_gcs_all_of import DatasourceConfigGCSAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_lightly import DatasourceConfigLIGHTLY from lightly.openapi_generated.swagger_client.models.datasource_config_local import DatasourceConfigLOCAL +from lightly.openapi_generated.swagger_client.models.datasource_config_local_all_of import DatasourceConfigLOCALAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_obs import DatasourceConfigOBS from lightly.openapi_generated.swagger_client.models.datasource_config_obs_all_of import DatasourceConfigOBSAllOf from lightly.openapi_generated.swagger_client.models.datasource_config_s3 import DatasourceConfigS3 @@ -180,13 +184,23 @@ from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod from lightly.openapi_generated.swagger_client.models.sector import Sector from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig +from lightly.openapi_generated.swagger_client.models.selection_config_all_of import SelectionConfigAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_base import SelectionConfigBase from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry from lightly.openapi_generated.swagger_client.models.selection_config_entry_input import SelectionConfigEntryInput from lightly.openapi_generated.swagger_client.models.selection_config_entry_strategy import SelectionConfigEntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3 +from lightly.openapi_generated.swagger_client.models.selection_config_v3_all_of import SelectionConfigV3AllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of import SelectionConfigV3EntryStrategyAllOf +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation from lightly.openapi_generated.swagger_client.models.selection_strategy_type import SelectionStrategyType +from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3 from lightly.openapi_generated.swagger_client.models.service_account_basic_data import ServiceAccountBasicData from lightly.openapi_generated.swagger_client.models.set_embeddings_is_processed_flag_by_id_body_request import SetEmbeddingsIsProcessedFlagByIdBodyRequest from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest diff --git a/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py b/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py new file mode 100644 index 000000000..5455a44d3 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/active_learning_score_types_v2_data.py @@ -0,0 +1,116 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + + +from pydantic import Extra, BaseModel, Field, conint, constr, validator + +class ActiveLearningScoreTypesV2Data(BaseModel): + """ + ActiveLearningScoreTypesV2Data + """ + id: constr(strict=True) = Field(..., description="MongoDB ObjectId") + dataset_id: constr(strict=True) = Field(..., alias="datasetId", description="MongoDB ObjectId") + prediction_uuid_timestamp: conint(strict=True, ge=0) = Field(..., alias="predictionUUIDTimestamp", description="unix timestamp in milliseconds") + task_name: constr(strict=True, min_length=1) = Field(..., alias="taskName", description="A name which is safe to have as a file/folder name in a file system") + score_type: constr(strict=True, min_length=1) = Field(..., alias="scoreType", description="Type of active learning score") + created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds") + __properties = ["id", "datasetId", "predictionUUIDTimestamp", "taskName", "scoreType", "createdAt"] + + @validator('id') + def id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value + + @validator('dataset_id') + def dataset_id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value + + @validator('task_name') + def task_name_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/") + return value + + @validator('score_type') + def score_type_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/") + return value + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> ActiveLearningScoreTypesV2Data: + """Create an instance of ActiveLearningScoreTypesV2Data from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> ActiveLearningScoreTypesV2Data: + """Create an instance of ActiveLearningScoreTypesV2Data from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return ActiveLearningScoreTypesV2Data.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in ActiveLearningScoreTypesV2Data) in the input: " + str(obj)) + + _obj = ActiveLearningScoreTypesV2Data.parse_obj({ + "id": obj.get("id"), + "dataset_id": obj.get("datasetId"), + "prediction_uuid_timestamp": obj.get("predictionUUIDTimestamp"), + "task_name": obj.get("taskName"), + "score_type": obj.get("scoreType"), + "created_at": obj.get("createdAt") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py b/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py new file mode 100644 index 000000000..1970b7c58 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/active_learning_score_v2_data.py @@ -0,0 +1,118 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List, Union +from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, conint, conlist, constr, validator + +class ActiveLearningScoreV2Data(BaseModel): + """ + ActiveLearningScoreV2Data + """ + id: constr(strict=True) = Field(..., description="MongoDB ObjectId") + dataset_id: constr(strict=True) = Field(..., alias="datasetId", description="MongoDB ObjectId") + prediction_uuid_timestamp: conint(strict=True, ge=0) = Field(..., alias="predictionUUIDTimestamp", description="unix timestamp in milliseconds") + task_name: constr(strict=True, min_length=1) = Field(..., alias="taskName", description="A name which is safe to have as a file/folder name in a file system") + score_type: constr(strict=True, min_length=1) = Field(..., alias="scoreType", description="Type of active learning score") + scores: conlist(Union[StrictFloat, StrictInt], min_items=1) = Field(..., description="Array of active learning scores") + created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds") + __properties = ["id", "datasetId", "predictionUUIDTimestamp", "taskName", "scoreType", "scores", "createdAt"] + + @validator('id') + def id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value + + @validator('dataset_id') + def dataset_id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value + + @validator('task_name') + def task_name_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 ._-]+$/") + return value + + @validator('score_type') + def score_type_validate_regular_expression(cls, value): + """Validates the regular expression""" + if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/") + return value + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> ActiveLearningScoreV2Data: + """Create an instance of ActiveLearningScoreV2Data from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> ActiveLearningScoreV2Data: + """Create an instance of ActiveLearningScoreV2Data from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return ActiveLearningScoreV2Data.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in ActiveLearningScoreV2Data) in the input: " + str(obj)) + + _obj = ActiveLearningScoreV2Data.parse_obj({ + "id": obj.get("id"), + "dataset_id": obj.get("datasetId"), + "prediction_uuid_timestamp": obj.get("predictionUUIDTimestamp"), + "task_name": obj.get("taskName"), + "score_type": obj.get("scoreType"), + "scores": obj.get("scores"), + "created_at": obj.get("createdAt") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/annotation_data.py b/lightly/openapi_generated/swagger_client/models/annotation_data.py deleted file mode 100644 index 673e75aec..000000000 --- a/lightly/openapi_generated/swagger_client/models/annotation_data.py +++ /dev/null @@ -1,103 +0,0 @@ -# coding: utf-8 - -""" - Lightly API - - Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: support@lightly.ai - Generated by OpenAPI Generator (https://openapi-generator.tech) - - Do not edit the class manually. -""" - - -from __future__ import annotations -import pprint -import re # noqa: F401 -import json - - -from typing import Optional -from pydantic import Extra, BaseModel, Field, StrictStr, conint -from lightly.openapi_generated.swagger_client.models.annotation_meta_data import AnnotationMetaData -from lightly.openapi_generated.swagger_client.models.annotation_offer_data import AnnotationOfferData -from lightly.openapi_generated.swagger_client.models.annotation_state import AnnotationState - -class AnnotationData(BaseModel): - """ - AnnotationData - """ - id: StrictStr = Field(..., alias="_id") - state: AnnotationState = Field(...) - dataset_id: StrictStr = Field(..., alias="datasetId") - tag_id: StrictStr = Field(..., alias="tagId") - partner_id: Optional[StrictStr] = Field(None, alias="partnerId") - created_at: conint(strict=True, ge=0) = Field(..., alias="createdAt", description="unix timestamp in milliseconds") - last_modified_at: conint(strict=True, ge=0) = Field(..., alias="lastModifiedAt", description="unix timestamp in milliseconds") - meta: AnnotationMetaData = Field(...) - offer: Optional[AnnotationOfferData] = None - __properties = ["_id", "state", "datasetId", "tagId", "partnerId", "createdAt", "lastModifiedAt", "meta", "offer"] - - class Config: - """Pydantic configuration""" - allow_population_by_field_name = True - validate_assignment = True - use_enum_values = True - extra = Extra.forbid - - def to_str(self, by_alias: bool = False) -> str: - """Returns the string representation of the model""" - return pprint.pformat(self.dict(by_alias=by_alias)) - - def to_json(self, by_alias: bool = False) -> str: - """Returns the JSON representation of the model""" - return json.dumps(self.to_dict(by_alias=by_alias)) - - @classmethod - def from_json(cls, json_str: str) -> AnnotationData: - """Create an instance of AnnotationData from a JSON string""" - return cls.from_dict(json.loads(json_str)) - - def to_dict(self, by_alias: bool = False): - """Returns the dictionary representation of the model""" - _dict = self.dict(by_alias=by_alias, - exclude={ - }, - exclude_none=True) - # override the default output from pydantic by calling `to_dict()` of meta - if self.meta: - _dict['meta' if by_alias else 'meta'] = self.meta.to_dict(by_alias=by_alias) - # override the default output from pydantic by calling `to_dict()` of offer - if self.offer: - _dict['offer' if by_alias else 'offer'] = self.offer.to_dict(by_alias=by_alias) - return _dict - - @classmethod - def from_dict(cls, obj: dict) -> AnnotationData: - """Create an instance of AnnotationData from a dict""" - if obj is None: - return None - - if not isinstance(obj, dict): - return AnnotationData.parse_obj(obj) - - # raise errors for additional fields in the input - for _key in obj.keys(): - if _key not in cls.__properties: - raise ValueError("Error due to additional fields (not defined in AnnotationData) in the input: " + str(obj)) - - _obj = AnnotationData.parse_obj({ - "id": obj.get("_id"), - "state": obj.get("state"), - "dataset_id": obj.get("datasetId"), - "tag_id": obj.get("tagId"), - "partner_id": obj.get("partnerId"), - "created_at": obj.get("createdAt"), - "last_modified_at": obj.get("lastModifiedAt"), - "meta": AnnotationMetaData.from_dict(obj.get("meta")) if obj.get("meta") is not None else None, - "offer": AnnotationOfferData.from_dict(obj.get("offer")) if obj.get("offer") is not None else None - }) - return _obj - diff --git a/lightly/openapi_generated/swagger_client/models/api_error_code.py b/lightly/openapi_generated/swagger_client/models/api_error_code.py index 4876a203d..6e2b5dd56 100644 --- a/lightly/openapi_generated/swagger_client/models/api_error_code.py +++ b/lightly/openapi_generated/swagger_client/models/api_error_code.py @@ -37,6 +37,7 @@ class ApiErrorCode(str, Enum): UNAUTHORIZED = 'UNAUTHORIZED' NOT_FOUND = 'NOT_FOUND' NOT_MODIFIED = 'NOT_MODIFIED' + CONFLICT = 'CONFLICT' MALFORMED_REQUEST = 'MALFORMED_REQUEST' MALFORMED_RESPONSE = 'MALFORMED_RESPONSE' PAYLOAD_TOO_LARGE = 'PAYLOAD_TOO_LARGE' @@ -104,6 +105,7 @@ class ApiErrorCode(str, Enum): DOCKER_WORKER_SCHEDULE_UPDATE_FAILED = 'DOCKER_WORKER_SCHEDULE_UPDATE_FAILED' METADATA_CONFIGURATION_UNKNOWN = 'METADATA_CONFIGURATION_UNKNOWN' CUSTOM_METADATA_AT_MAX_SIZE = 'CUSTOM_METADATA_AT_MAX_SIZE' + ONPREM_SUBSCRIPTION_INSUFFICIENT = 'ONPREM_SUBSCRIPTION_INSUFFICIENT' ACCOUNT_SUBSCRIPTION_INSUFFICIENT = 'ACCOUNT_SUBSCRIPTION_INSUFFICIENT' TEAM_UNKNOWN = 'TEAM_UNKNOWN' diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py b/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py index 9f9bf4dec..6214b0725 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_azure.py @@ -20,16 +20,17 @@ -from pydantic import Extra, BaseModel, Field, constr +from pydantic import Extra, BaseModel, Field, StrictStr, constr from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase class DatasourceConfigAzure(DatasourceConfigBase): """ DatasourceConfigAzure """ + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") account_name: constr(strict=True, min_length=1) = Field(..., alias="accountName", description="name of the Azure Storage Account") account_key: constr(strict=True, min_length=1) = Field(..., alias="accountKey", description="key of the Azure Storage Account") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "accountName", "accountKey"] + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "accountName", "accountKey"] class Config: """Pydantic configuration""" @@ -77,8 +78,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigAzure: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath"), "account_name": obj.get("accountName"), "account_key": obj.get("accountKey") }) diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_base.py b/lightly/openapi_generated/swagger_client/models/datasource_config_base.py index d3a5f3983..41c71cce3 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_base.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_base.py @@ -31,9 +31,8 @@ class DatasourceConfigBase(BaseModel): id: Optional[constr(strict=True)] = Field(None, description="MongoDB ObjectId") purpose: DatasourcePurpose = Field(...) type: StrictStr = Field(...) - full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") thumb_suffix: Optional[StrictStr] = Field(None, alias="thumbSuffix", description="the suffix of where to find the thumbnail image. If none is provided, the full image will be loaded where thumbnails would be loaded otherwise. - [filename]: represents the filename without the extension - [extension]: represents the files extension (e.g jpg, png, webp) ") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"] + __properties = ["id", "purpose", "type", "thumbSuffix"] @validator('id') def id_validate_regular_expression(cls, value): diff --git a/lightly/openapi_generated/swagger_client/models/annotation_meta_data.py b/lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py similarity index 67% rename from lightly/openapi_generated/swagger_client/models/annotation_meta_data.py rename to lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py index eb0070b1e..eace63453 100644 --- a/lightly/openapi_generated/swagger_client/models/annotation_meta_data.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_base_full_path.py @@ -19,15 +19,15 @@ import json -from typing import Optional -from pydantic import Extra, BaseModel, StrictStr -class AnnotationMetaData(BaseModel): +from pydantic import Extra, BaseModel, Field, StrictStr + +class DatasourceConfigBaseFullPath(BaseModel): """ - AnnotationMetaData + DatasourceConfigBaseFullPath """ - description: Optional[StrictStr] = None - __properties = ["description"] + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") + __properties = ["fullPath"] class Config: """Pydantic configuration""" @@ -45,8 +45,8 @@ def to_json(self, by_alias: bool = False) -> str: return json.dumps(self.to_dict(by_alias=by_alias)) @classmethod - def from_json(cls, json_str: str) -> AnnotationMetaData: - """Create an instance of AnnotationMetaData from a JSON string""" + def from_json(cls, json_str: str) -> DatasourceConfigBaseFullPath: + """Create an instance of DatasourceConfigBaseFullPath from a JSON string""" return cls.from_dict(json.loads(json_str)) def to_dict(self, by_alias: bool = False): @@ -58,21 +58,21 @@ def to_dict(self, by_alias: bool = False): return _dict @classmethod - def from_dict(cls, obj: dict) -> AnnotationMetaData: - """Create an instance of AnnotationMetaData from a dict""" + def from_dict(cls, obj: dict) -> DatasourceConfigBaseFullPath: + """Create an instance of DatasourceConfigBaseFullPath from a dict""" if obj is None: return None if not isinstance(obj, dict): - return AnnotationMetaData.parse_obj(obj) + return DatasourceConfigBaseFullPath.parse_obj(obj) # raise errors for additional fields in the input for _key in obj.keys(): if _key not in cls.__properties: - raise ValueError("Error due to additional fields (not defined in AnnotationMetaData) in the input: " + str(obj)) + raise ValueError("Error due to additional fields (not defined in DatasourceConfigBaseFullPath) in the input: " + str(obj)) - _obj = AnnotationMetaData.parse_obj({ - "description": obj.get("description") + _obj = DatasourceConfigBaseFullPath.parse_obj({ + "full_path": obj.get("fullPath") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py b/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py index 027ea1fc5..d9b150eb6 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_gcs.py @@ -27,9 +27,10 @@ class DatasourceConfigGCS(DatasourceConfigBase): """ DatasourceConfigGCS """ + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") gcs_project_id: constr(strict=True, min_length=1) = Field(..., alias="gcsProjectId", description="The projectId where you have your bucket configured") gcs_credentials: StrictStr = Field(..., alias="gcsCredentials", description="this is the content of the credentials JSON file stringified which you downloaded from Google Cloud Platform") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "gcsProjectId", "gcsCredentials"] + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "gcsProjectId", "gcsCredentials"] class Config: """Pydantic configuration""" @@ -77,8 +78,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigGCS: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath"), "gcs_project_id": obj.get("gcsProjectId"), "gcs_credentials": obj.get("gcsCredentials") }) diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py b/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py index 7355ed5d0..2617eae53 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_lightly.py @@ -20,14 +20,15 @@ -from pydantic import Extra, BaseModel +from pydantic import Extra, BaseModel, Field, StrictStr from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase class DatasourceConfigLIGHTLY(DatasourceConfigBase): """ DatasourceConfigLIGHTLY """ - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"] + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath"] class Config: """Pydantic configuration""" @@ -75,8 +76,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigLIGHTLY: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), - "thumb_suffix": obj.get("thumbSuffix") + "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_local.py b/lightly/openapi_generated/swagger_client/models/datasource_config_local.py index 6a9170d63..e2f819723 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_local.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_local.py @@ -19,15 +19,27 @@ import json - -from pydantic import Extra, BaseModel +from typing import Optional +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase class DatasourceConfigLOCAL(DatasourceConfigBase): """ DatasourceConfigLOCAL """ - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix"] + full_path: StrictStr = Field(..., alias="fullPath", description="Relative path from the mount point. Not allowed to start with \"/\", contain \"://\" or contain \".\" or \"..\" directory parts.") + web_server_location: Optional[constr(strict=True)] = Field(None, alias="webServerLocation", description="The webserver location where your local webserver is running to use for viewing images in the webapp when using the local datasource workflow. Defaults to http://localhost:3456 ") + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "webServerLocation"] + + @validator('web_server_location') + def web_server_location_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^https?:\/\/.+$", value): + raise ValueError(r"must validate the regular expression /^https?:\/\/.+$/") + return value class Config: """Pydantic configuration""" @@ -75,8 +87,9 @@ def from_dict(cls, obj: dict) -> DatasourceConfigLOCAL: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), + "thumb_suffix": obj.get("thumbSuffix"), "full_path": obj.get("fullPath"), - "thumb_suffix": obj.get("thumbSuffix") + "web_server_location": obj.get("webServerLocation") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py b/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py new file mode 100644 index 000000000..247074c37 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_local_all_of.py @@ -0,0 +1,90 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import Optional +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator + +class DatasourceConfigLOCALAllOf(BaseModel): + """ + DatasourceConfigLOCALAllOf + """ + full_path: StrictStr = Field(..., alias="fullPath", description="Relative path from the mount point. Not allowed to start with \"/\", contain \"://\" or contain \".\" or \"..\" directory parts.") + web_server_location: Optional[constr(strict=True)] = Field(None, alias="webServerLocation", description="The webserver location where your local webserver is running to use for viewing images in the webapp when using the local datasource workflow. Defaults to http://localhost:3456 ") + __properties = ["fullPath", "webServerLocation"] + + @validator('web_server_location') + def web_server_location_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^https?:\/\/.+$", value): + raise ValueError(r"must validate the regular expression /^https?:\/\/.+$/") + return value + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> DatasourceConfigLOCALAllOf: + """Create an instance of DatasourceConfigLOCALAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> DatasourceConfigLOCALAllOf: + """Create an instance of DatasourceConfigLOCALAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return DatasourceConfigLOCALAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in DatasourceConfigLOCALAllOf) in the input: " + str(obj)) + + _obj = DatasourceConfigLOCALAllOf.parse_obj({ + "full_path": obj.get("fullPath"), + "web_server_location": obj.get("webServerLocation") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py b/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py index 3dbafebbf..f50bab55b 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_obs.py @@ -20,17 +20,18 @@ -from pydantic import Extra, BaseModel, Field, constr, validator +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase class DatasourceConfigOBS(DatasourceConfigBase): """ DatasourceConfigOBS """ + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") obs_endpoint: constr(strict=True, min_length=1) = Field(..., alias="obsEndpoint", description="The Object Storage Service (OBS) endpoint to use of your S3 compatible cloud storage provider") obs_access_key_id: constr(strict=True, min_length=1) = Field(..., alias="obsAccessKeyId", description="The Access Key Id of the credential you are providing Lightly to use") obs_secret_access_key: constr(strict=True, min_length=1) = Field(..., alias="obsSecretAccessKey", description="The Secret Access Key of the credential you are providing Lightly to use") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "obsEndpoint", "obsAccessKeyId", "obsSecretAccessKey"] + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "obsEndpoint", "obsAccessKeyId", "obsSecretAccessKey"] @validator('obs_endpoint') def obs_endpoint_validate_regular_expression(cls, value): @@ -85,8 +86,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigOBS: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath"), "obs_endpoint": obj.get("obsEndpoint"), "obs_access_key_id": obj.get("obsAccessKeyId"), "obs_secret_access_key": obj.get("obsSecretAccessKey") diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py b/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py index 2d25ebde2..638018727 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_s3.py @@ -20,7 +20,7 @@ from typing import Optional -from pydantic import Extra, BaseModel, Field, constr, validator +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase from lightly.openapi_generated.swagger_client.models.s3_region import S3Region @@ -28,11 +28,12 @@ class DatasourceConfigS3(DatasourceConfigBase): """ DatasourceConfigS3 """ + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") s3_region: S3Region = Field(..., alias="s3Region") s3_access_key_id: constr(strict=True, min_length=1) = Field(..., alias="s3AccessKeyId", description="The accessKeyId of the credential you are providing Lightly to use") s3_secret_access_key: constr(strict=True, min_length=1) = Field(..., alias="s3SecretAccessKey", description="The secretAccessKey of the credential you are providing Lightly to use") s3_server_side_encryption_kms_key: Optional[constr(strict=True, min_length=1)] = Field(None, alias="s3ServerSideEncryptionKMSKey", description="If set, Lightly Worker will automatically set the headers to use server side encryption https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html with this value as the appropriate KMS key arn. This will encrypt the files created by Lightly (crops, frames, thumbnails) in the S3 bucket. ") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "s3Region", "s3AccessKeyId", "s3SecretAccessKey", "s3ServerSideEncryptionKMSKey"] + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "s3Region", "s3AccessKeyId", "s3SecretAccessKey", "s3ServerSideEncryptionKMSKey"] @validator('s3_server_side_encryption_kms_key') def s3_server_side_encryption_kms_key_validate_regular_expression(cls, value): @@ -90,8 +91,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigS3: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath"), "s3_region": obj.get("s3Region"), "s3_access_key_id": obj.get("s3AccessKeyId"), "s3_secret_access_key": obj.get("s3SecretAccessKey"), diff --git a/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py b/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py index a5ef73962..3ea071e56 100644 --- a/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py +++ b/lightly/openapi_generated/swagger_client/models/datasource_config_s3_delegated_access.py @@ -20,7 +20,7 @@ from typing import Optional -from pydantic import Extra, BaseModel, Field, constr, validator +from pydantic import Extra, BaseModel, Field, StrictStr, constr, validator from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase from lightly.openapi_generated.swagger_client.models.s3_region import S3Region @@ -28,11 +28,12 @@ class DatasourceConfigS3DelegatedAccess(DatasourceConfigBase): """ DatasourceConfigS3DelegatedAccess """ + full_path: StrictStr = Field(..., alias="fullPath", description="path includes the bucket name and the path within the bucket where you have stored your information") s3_region: S3Region = Field(..., alias="s3Region") s3_external_id: constr(strict=True, min_length=10) = Field(..., alias="s3ExternalId", description="The external ID specified when creating the role.") s3_arn: constr(strict=True, min_length=12) = Field(..., alias="s3ARN", description="The ARN of the role you created") s3_server_side_encryption_kms_key: Optional[constr(strict=True, min_length=1)] = Field(None, alias="s3ServerSideEncryptionKMSKey", description="If set, Lightly Worker will automatically set the headers to use server side encryption https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html with this value as the appropriate KMS key arn. This will encrypt the files created by Lightly (crops, frames, thumbnails) in the S3 bucket. ") - __properties = ["id", "purpose", "type", "fullPath", "thumbSuffix", "s3Region", "s3ExternalId", "s3ARN", "s3ServerSideEncryptionKMSKey"] + __properties = ["id", "purpose", "type", "thumbSuffix", "fullPath", "s3Region", "s3ExternalId", "s3ARN", "s3ServerSideEncryptionKMSKey"] @validator('s3_external_id') def s3_external_id_validate_regular_expression(cls, value): @@ -104,8 +105,8 @@ def from_dict(cls, obj: dict) -> DatasourceConfigS3DelegatedAccess: "id": obj.get("id"), "purpose": obj.get("purpose"), "type": obj.get("type"), - "full_path": obj.get("fullPath"), "thumb_suffix": obj.get("thumbSuffix"), + "full_path": obj.get("fullPath"), "s3_region": obj.get("s3Region"), "s3_external_id": obj.get("s3ExternalId"), "s3_arn": obj.get("s3ARN"), diff --git a/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py b/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py index e4fa545e2..02d827c70 100644 --- a/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py +++ b/lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py @@ -36,6 +36,7 @@ class DockerRunArtifactType(str, Enum): CHECKPOINT = 'CHECKPOINT' REPORT_PDF = 'REPORT_PDF' REPORT_JSON = 'REPORT_JSON' + REPORT_V2_JSON = 'REPORT_V2_JSON' CORRUPTNESS_CHECK_INFORMATION = 'CORRUPTNESS_CHECK_INFORMATION' SEQUENCE_INFORMATION = 'SEQUENCE_INFORMATION' RELEVANT_FILENAMES = 'RELEVANT_FILENAMES' diff --git a/lightly/openapi_generated/swagger_client/models/docker_run_data.py b/lightly/openapi_generated/swagger_client/models/docker_run_data.py index 4b8004f93..5502a6a2c 100644 --- a/lightly/openapi_generated/swagger_client/models/docker_run_data.py +++ b/lightly/openapi_generated/swagger_client/models/docker_run_data.py @@ -20,7 +20,7 @@ from typing import List, Optional -from pydantic import Extra, BaseModel, Field, StrictStr, conint, conlist, constr, validator +from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint, conlist, constr, validator from lightly.openapi_generated.swagger_client.models.docker_run_artifact_data import DockerRunArtifactData from lightly.openapi_generated.swagger_client.models.docker_run_state import DockerRunState @@ -32,6 +32,7 @@ class DockerRunData(BaseModel): user_id: StrictStr = Field(..., alias="userId") docker_version: StrictStr = Field(..., alias="dockerVersion") state: DockerRunState = Field(...) + archived: Optional[StrictBool] = Field(None, description="if the run is archived") dataset_id: Optional[constr(strict=True)] = Field(None, alias="datasetId", description="MongoDB ObjectId") config_id: Optional[constr(strict=True)] = Field(None, alias="configId", description="MongoDB ObjectId") scheduled_id: Optional[constr(strict=True)] = Field(None, alias="scheduledId", description="MongoDB ObjectId") @@ -39,7 +40,7 @@ class DockerRunData(BaseModel): last_modified_at: conint(strict=True, ge=0) = Field(..., alias="lastModifiedAt", description="unix timestamp in milliseconds") message: Optional[StrictStr] = Field(None, description="last message sent to the docker run") artifacts: Optional[conlist(DockerRunArtifactData)] = Field(None, description="list of artifacts that were created for a run") - __properties = ["id", "userId", "dockerVersion", "state", "datasetId", "configId", "scheduledId", "createdAt", "lastModifiedAt", "message", "artifacts"] + __properties = ["id", "userId", "dockerVersion", "state", "archived", "datasetId", "configId", "scheduledId", "createdAt", "lastModifiedAt", "message", "artifacts"] @validator('id') def id_validate_regular_expression(cls, value): @@ -132,6 +133,7 @@ def from_dict(cls, obj: dict) -> DockerRunData: "user_id": obj.get("userId"), "docker_version": obj.get("dockerVersion"), "state": obj.get("state"), + "archived": obj.get("archived"), "dataset_id": obj.get("datasetId"), "config_id": obj.get("configId"), "scheduled_id": obj.get("scheduledId"), diff --git a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py index a59c3f34d..4e5d1d8db 100644 --- a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py +++ b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3.py @@ -24,7 +24,7 @@ from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker import DockerWorkerConfigV3Docker from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_lightly import DockerWorkerConfigV3Lightly from lightly.openapi_generated.swagger_client.models.docker_worker_type import DockerWorkerType -from lightly.openapi_generated.swagger_client.models.selection_config import SelectionConfig +from lightly.openapi_generated.swagger_client.models.selection_config_v3 import SelectionConfigV3 class DockerWorkerConfigV3(BaseModel): """ @@ -33,7 +33,7 @@ class DockerWorkerConfigV3(BaseModel): worker_type: DockerWorkerType = Field(..., alias="workerType") docker: Optional[DockerWorkerConfigV3Docker] = None lightly: Optional[DockerWorkerConfigV3Lightly] = None - selection: Optional[SelectionConfig] = None + selection: Optional[SelectionConfigV3] = None __properties = ["workerType", "docker", "lightly", "selection"] class Config: @@ -91,7 +91,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3: "worker_type": obj.get("workerType"), "docker": DockerWorkerConfigV3Docker.from_dict(obj.get("docker")) if obj.get("docker") is not None else None, "lightly": DockerWorkerConfigV3Lightly.from_dict(obj.get("lightly")) if obj.get("lightly") is not None else None, - "selection": SelectionConfig.from_dict(obj.get("selection")) if obj.get("selection") is not None else None + "selection": SelectionConfigV3.from_dict(obj.get("selection")) if obj.get("selection") is not None else None }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py index 962263917..7a66eab2e 100644 --- a/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py +++ b/lightly/openapi_generated/swagger_client/models/docker_worker_config_v3_docker.py @@ -20,7 +20,7 @@ from typing import Optional -from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint +from pydantic import Extra, BaseModel, Field, StrictBool, StrictStr, conint, constr, validator from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_corruptness_check import DockerWorkerConfigV3DockerCorruptnessCheck from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_datasource import DockerWorkerConfigV3DockerDatasource from lightly.openapi_generated.swagger_client.models.docker_worker_config_v3_docker_training import DockerWorkerConfigV3DockerTraining @@ -30,6 +30,7 @@ class DockerWorkerConfigV3Docker(BaseModel): docker run configurations, keys should match the structure of https://github.com/lightly-ai/lightly-core/blob/develop/onprem-docker/lightly_worker/src/lightly_worker/resources/docker/docker.yaml """ checkpoint: Optional[StrictStr] = None + checkpoint_run_id: Optional[constr(strict=True)] = Field(None, alias="checkpointRunId", description="MongoDB ObjectId") corruptness_check: Optional[DockerWorkerConfigV3DockerCorruptnessCheck] = Field(None, alias="corruptnessCheck") datasource: Optional[DockerWorkerConfigV3DockerDatasource] = None embeddings: Optional[StrictStr] = None @@ -44,8 +45,18 @@ class DockerWorkerConfigV3Docker(BaseModel): relevant_filenames_file: Optional[StrictStr] = Field(None, alias="relevantFilenamesFile") selected_sequence_length: Optional[conint(strict=True, ge=1)] = Field(None, alias="selectedSequenceLength") upload_report: Optional[StrictBool] = Field(None, alias="uploadReport") - use_datapool: Optional[StrictBool] = Field(None, alias="useDatapool") - __properties = ["checkpoint", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport", "useDatapool"] + shutdown_when_job_finished: Optional[StrictBool] = Field(None, alias="shutdownWhenJobFinished") + __properties = ["checkpoint", "checkpointRunId", "corruptnessCheck", "datasource", "embeddings", "enableTraining", "training", "normalizeEmbeddings", "numProcesses", "numThreads", "outputImageFormat", "pretagging", "pretaggingUpload", "relevantFilenamesFile", "selectedSequenceLength", "uploadReport", "shutdownWhenJobFinished"] + + @validator('checkpoint_run_id') + def checkpoint_run_id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value class Config: """Pydantic configuration""" @@ -100,6 +111,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3Docker: _obj = DockerWorkerConfigV3Docker.parse_obj({ "checkpoint": obj.get("checkpoint"), + "checkpoint_run_id": obj.get("checkpointRunId"), "corruptness_check": DockerWorkerConfigV3DockerCorruptnessCheck.from_dict(obj.get("corruptnessCheck")) if obj.get("corruptnessCheck") is not None else None, "datasource": DockerWorkerConfigV3DockerDatasource.from_dict(obj.get("datasource")) if obj.get("datasource") is not None else None, "embeddings": obj.get("embeddings"), @@ -114,7 +126,7 @@ def from_dict(cls, obj: dict) -> DockerWorkerConfigV3Docker: "relevant_filenames_file": obj.get("relevantFilenamesFile"), "selected_sequence_length": obj.get("selectedSequenceLength"), "upload_report": obj.get("uploadReport"), - "use_datapool": obj.get("useDatapool") + "shutdown_when_job_finished": obj.get("shutdownWhenJobFinished") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py index c0def6930..be66eae4a 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection.py @@ -28,8 +28,9 @@ class PredictionSingletonKeypointDetection(PredictionSingletonBase): PredictionSingletonKeypointDetection """ keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ") + bbox: Optional[conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], max_items=4, min_items=4)] = Field(None, description="The bbox of where a prediction task yielded a finding. [x, y, width, height]") probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.") - __properties = ["type", "taskName", "cropDatasetId", "cropSampleId", "categoryId", "score", "keypoints", "probabilities"] + __properties = ["type", "taskName", "cropDatasetId", "cropSampleId", "categoryId", "score", "keypoints", "bbox", "probabilities"] class Config: """Pydantic configuration""" @@ -81,6 +82,7 @@ def from_dict(cls, obj: dict) -> PredictionSingletonKeypointDetection: "category_id": obj.get("categoryId"), "score": obj.get("score"), "keypoints": obj.get("keypoints"), + "bbox": obj.get("bbox"), "probabilities": obj.get("probabilities") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py index abeadb9dc..d5cfa8795 100644 --- a/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py +++ b/lightly/openapi_generated/swagger_client/models/prediction_singleton_keypoint_detection_all_of.py @@ -27,8 +27,9 @@ class PredictionSingletonKeypointDetectionAllOf(BaseModel): PredictionSingletonKeypointDetectionAllOf """ keypoints: conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], min_items=3) = Field(..., description="[x1, y2, s1, ..., xk, yk, sk] as outlined by https://docs.lightly.ai/docs/prediction-format#keypoint-detection ") + bbox: Optional[conlist(Union[confloat(ge=0, strict=True), conint(ge=0, strict=True)], max_items=4, min_items=4)] = Field(None, description="The bbox of where a prediction task yielded a finding. [x, y, width, height]") probabilities: Optional[conlist(Union[confloat(le=1, ge=0, strict=True), conint(le=1, ge=0, strict=True)])] = Field(None, description="The probabilities of it being a certain category other than the one which was selected. The sum of all probabilities should equal 1.") - __properties = ["keypoints", "probabilities"] + __properties = ["keypoints", "bbox", "probabilities"] class Config: """Pydantic configuration""" @@ -74,6 +75,7 @@ def from_dict(cls, obj: dict) -> PredictionSingletonKeypointDetectionAllOf: _obj = PredictionSingletonKeypointDetectionAllOf.parse_obj({ "keypoints": obj.get("keypoints"), + "bbox": obj.get("bbox"), "probabilities": obj.get("probabilities") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/sample_create_request.py b/lightly/openapi_generated/swagger_client/models/sample_create_request.py index ab3c6c34d..3c292a76a 100644 --- a/lightly/openapi_generated/swagger_client/models/sample_create_request.py +++ b/lightly/openapi_generated/swagger_client/models/sample_create_request.py @@ -78,6 +78,16 @@ def to_dict(self, by_alias: bool = False): if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__: _dict['customMetaData' if by_alias else 'custom_meta_data'] = None + # set to None if video_frame_data (nullable) is None + # and __fields_set__ contains the field + if self.video_frame_data is None and "video_frame_data" in self.__fields_set__: + _dict['videoFrameData' if by_alias else 'video_frame_data'] = None + + # set to None if crop_data (nullable) is None + # and __fields_set__ contains the field + if self.crop_data is None and "crop_data" in self.__fields_set__: + _dict['cropData' if by_alias else 'crop_data'] = None + return _dict @classmethod diff --git a/lightly/openapi_generated/swagger_client/models/sample_data.py b/lightly/openapi_generated/swagger_client/models/sample_data.py index 42ff509a7..84d3c39fb 100644 --- a/lightly/openapi_generated/swagger_client/models/sample_data.py +++ b/lightly/openapi_generated/swagger_client/models/sample_data.py @@ -102,11 +102,26 @@ def to_dict(self, by_alias: bool = False): if self.thumb_name is None and "thumb_name" in self.__fields_set__: _dict['thumbName' if by_alias else 'thumb_name'] = None + # set to None if exif (nullable) is None + # and __fields_set__ contains the field + if self.exif is None and "exif" in self.__fields_set__: + _dict['exif' if by_alias else 'exif'] = None + # set to None if custom_meta_data (nullable) is None # and __fields_set__ contains the field if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__: _dict['customMetaData' if by_alias else 'custom_meta_data'] = None + # set to None if video_frame_data (nullable) is None + # and __fields_set__ contains the field + if self.video_frame_data is None and "video_frame_data" in self.__fields_set__: + _dict['videoFrameData' if by_alias else 'video_frame_data'] = None + + # set to None if crop_data (nullable) is None + # and __fields_set__ contains the field + if self.crop_data is None and "crop_data" in self.__fields_set__: + _dict['cropData' if by_alias else 'crop_data'] = None + return _dict @classmethod diff --git a/lightly/openapi_generated/swagger_client/models/sample_data_modes.py b/lightly/openapi_generated/swagger_client/models/sample_data_modes.py index 426589a33..736754e2e 100644 --- a/lightly/openapi_generated/swagger_client/models/sample_data_modes.py +++ b/lightly/openapi_generated/swagger_client/models/sample_data_modes.py @@ -102,11 +102,26 @@ def to_dict(self, by_alias: bool = False): if self.thumb_name is None and "thumb_name" in self.__fields_set__: _dict['thumbName' if by_alias else 'thumb_name'] = None + # set to None if exif (nullable) is None + # and __fields_set__ contains the field + if self.exif is None and "exif" in self.__fields_set__: + _dict['exif' if by_alias else 'exif'] = None + # set to None if custom_meta_data (nullable) is None # and __fields_set__ contains the field if self.custom_meta_data is None and "custom_meta_data" in self.__fields_set__: _dict['customMetaData' if by_alias else 'custom_meta_data'] = None + # set to None if video_frame_data (nullable) is None + # and __fields_set__ contains the field + if self.video_frame_data is None and "video_frame_data" in self.__fields_set__: + _dict['videoFrameData' if by_alias else 'video_frame_data'] = None + + # set to None if crop_data (nullable) is None + # and __fields_set__ contains the field + if self.crop_data is None and "crop_data" in self.__fields_set__: + _dict['cropData' if by_alias else 'crop_data'] = None + return _dict @classmethod diff --git a/lightly/openapi_generated/swagger_client/models/sample_meta_data.py b/lightly/openapi_generated/swagger_client/models/sample_meta_data.py index 705c8b391..2153a9487 100644 --- a/lightly/openapi_generated/swagger_client/models/sample_meta_data.py +++ b/lightly/openapi_generated/swagger_client/models/sample_meta_data.py @@ -66,6 +66,16 @@ def to_dict(self, by_alias: bool = False): exclude={ }, exclude_none=True) + # set to None if custom (nullable) is None + # and __fields_set__ contains the field + if self.custom is None and "custom" in self.__fields_set__: + _dict['custom' if by_alias else 'custom'] = None + + # set to None if dynamic (nullable) is None + # and __fields_set__ contains the field + if self.dynamic is None and "dynamic" in self.__fields_set__: + _dict['dynamic' if by_alias else 'dynamic'] = None + return _dict @classmethod diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py new file mode 100644 index 000000000..3d04824fa --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_all_of.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.selection_config_entry import SelectionConfigEntry + +class SelectionConfigAllOf(BaseModel): + """ + SelectionConfigAllOf + """ + strategies: conlist(SelectionConfigEntry, min_items=1) = Field(...) + __properties = ["strategies"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigAllOf: + """Create an instance of SelectionConfigAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in strategies (list) + _items = [] + if self.strategies: + for _item in self.strategies: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['strategies' if by_alias else 'strategies'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigAllOf: + """Create an instance of SelectionConfigAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigAllOf) in the input: " + str(obj)) + + _obj = SelectionConfigAllOf.parse_obj({ + "strategies": [SelectionConfigEntry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/annotation_offer_data.py b/lightly/openapi_generated/swagger_client/models/selection_config_base.py similarity index 66% rename from lightly/openapi_generated/swagger_client/models/annotation_offer_data.py rename to lightly/openapi_generated/swagger_client/models/selection_config_base.py index e12f63aa5..51b358cf1 100644 --- a/lightly/openapi_generated/swagger_client/models/annotation_offer_data.py +++ b/lightly/openapi_generated/swagger_client/models/selection_config_base.py @@ -20,15 +20,15 @@ from typing import Optional, Union -from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, conint +from pydantic import Extra, BaseModel, Field, confloat, conint -class AnnotationOfferData(BaseModel): +class SelectionConfigBase(BaseModel): """ - AnnotationOfferData + SelectionConfigBase """ - cost: Optional[Union[StrictFloat, StrictInt]] = None - completed_by: Optional[conint(strict=True, ge=0)] = Field(None, alias="completedBy", description="unix timestamp in milliseconds") - __properties = ["cost", "completedBy"] + n_samples: Optional[conint(strict=True, ge=-1)] = Field(None, alias="nSamples") + proportion_samples: Optional[Union[confloat(le=1.0, ge=0.0, strict=True), conint(le=1, ge=0, strict=True)]] = Field(None, alias="proportionSamples") + __properties = ["nSamples", "proportionSamples"] class Config: """Pydantic configuration""" @@ -46,8 +46,8 @@ def to_json(self, by_alias: bool = False) -> str: return json.dumps(self.to_dict(by_alias=by_alias)) @classmethod - def from_json(cls, json_str: str) -> AnnotationOfferData: - """Create an instance of AnnotationOfferData from a JSON string""" + def from_json(cls, json_str: str) -> SelectionConfigBase: + """Create an instance of SelectionConfigBase from a JSON string""" return cls.from_dict(json.loads(json_str)) def to_dict(self, by_alias: bool = False): @@ -59,22 +59,22 @@ def to_dict(self, by_alias: bool = False): return _dict @classmethod - def from_dict(cls, obj: dict) -> AnnotationOfferData: - """Create an instance of AnnotationOfferData from a dict""" + def from_dict(cls, obj: dict) -> SelectionConfigBase: + """Create an instance of SelectionConfigBase from a dict""" if obj is None: return None if not isinstance(obj, dict): - return AnnotationOfferData.parse_obj(obj) + return SelectionConfigBase.parse_obj(obj) # raise errors for additional fields in the input for _key in obj.keys(): if _key not in cls.__properties: - raise ValueError("Error due to additional fields (not defined in AnnotationOfferData) in the input: " + str(obj)) + raise ValueError("Error due to additional fields (not defined in SelectionConfigBase) in the input: " + str(obj)) - _obj = AnnotationOfferData.parse_obj({ - "cost": obj.get("cost"), - "completed_by": obj.get("completedBy") + _obj = SelectionConfigBase.parse_obj({ + "n_samples": obj.get("nSamples"), + "proportion_samples": obj.get("proportionSamples") }) return _obj diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3.py new file mode 100644 index 000000000..80aed78fb --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3.py @@ -0,0 +1,90 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List, Optional, Union +from pydantic import Extra, BaseModel, Field, confloat, conint, conlist +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry + +class SelectionConfigV3(BaseModel): + """ + SelectionConfigV3 + """ + n_samples: Optional[conint(strict=True, ge=-1)] = Field(None, alias="nSamples") + proportion_samples: Optional[Union[confloat(le=1.0, ge=0.0, strict=True), conint(le=1, ge=0, strict=True)]] = Field(None, alias="proportionSamples") + strategies: conlist(SelectionConfigV3Entry, min_items=1) = Field(...) + __properties = ["nSamples", "proportionSamples", "strategies"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3: + """Create an instance of SelectionConfigV3 from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in strategies (list) + _items = [] + if self.strategies: + for _item in self.strategies: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['strategies' if by_alias else 'strategies'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3: + """Create an instance of SelectionConfigV3 from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3) in the input: " + str(obj)) + + _obj = SelectionConfigV3.parse_obj({ + "n_samples": obj.get("nSamples"), + "proportion_samples": obj.get("proportionSamples"), + "strategies": [SelectionConfigV3Entry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py new file mode 100644 index 000000000..84770bdd9 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_all_of.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List +from pydantic import Extra, BaseModel, Field, conlist +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import SelectionConfigV3Entry + +class SelectionConfigV3AllOf(BaseModel): + """ + SelectionConfigV3AllOf + """ + strategies: conlist(SelectionConfigV3Entry, min_items=1) = Field(...) + __properties = ["strategies"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3AllOf: + """Create an instance of SelectionConfigV3AllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of each item in strategies (list) + _items = [] + if self.strategies: + for _item in self.strategies: + if _item: + _items.append(_item.to_dict(by_alias=by_alias)) + _dict['strategies' if by_alias else 'strategies'] = _items + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3AllOf: + """Create an instance of SelectionConfigV3AllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3AllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3AllOf) in the input: " + str(obj)) + + _obj = SelectionConfigV3AllOf.parse_obj({ + "strategies": [SelectionConfigV3Entry.from_dict(_item) for _item in obj.get("strategies")] if obj.get("strategies") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py new file mode 100644 index 000000000..834cbd317 --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry.py @@ -0,0 +1,88 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + + +from pydantic import Extra, BaseModel, Field +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import SelectionConfigV3EntryInput +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import SelectionConfigV3EntryStrategy + +class SelectionConfigV3Entry(BaseModel): + """ + SelectionConfigV3Entry + """ + input: SelectionConfigV3EntryInput = Field(...) + strategy: SelectionConfigV3EntryStrategy = Field(...) + __properties = ["input", "strategy"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3Entry: + """Create an instance of SelectionConfigV3Entry from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of input + if self.input: + _dict['input' if by_alias else 'input'] = self.input.to_dict(by_alias=by_alias) + # override the default output from pydantic by calling `to_dict()` of strategy + if self.strategy: + _dict['strategy' if by_alias else 'strategy'] = self.strategy.to_dict(by_alias=by_alias) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3Entry: + """Create an instance of SelectionConfigV3Entry from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3Entry.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3Entry) in the input: " + str(obj)) + + _obj = SelectionConfigV3Entry.parse_obj({ + "input": SelectionConfigV3EntryInput.from_dict(obj.get("input")) if obj.get("input") is not None else None, + "strategy": SelectionConfigV3EntryStrategy.from_dict(obj.get("strategy")) if obj.get("strategy") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py new file mode 100644 index 000000000..aac4c59da --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_input.py @@ -0,0 +1,136 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import List, Optional +from pydantic import Extra, BaseModel, Field, StrictInt, conlist, constr, validator +from lightly.openapi_generated.swagger_client.models.selection_input_predictions_name import SelectionInputPredictionsName +from lightly.openapi_generated.swagger_client.models.selection_input_type import SelectionInputType + +class SelectionConfigV3EntryInput(BaseModel): + """ + SelectionConfigV3EntryInput + """ + type: SelectionInputType = Field(...) + task: Optional[constr(strict=True)] = Field(None, description="Since we sometimes stitch together SelectionInputTask+ActiveLearningScoreType, they need to follow the same specs of ActiveLearningScoreType. However, this can be an empty string due to internal logic. ") + score: Optional[constr(strict=True, min_length=1)] = Field(None, description="Type of active learning score") + key: Optional[constr(strict=True, min_length=1)] = None + name: Optional[SelectionInputPredictionsName] = None + dataset_id: Optional[constr(strict=True)] = Field(None, alias="datasetId", description="MongoDB ObjectId") + tag_name: Optional[constr(strict=True, min_length=3)] = Field(None, alias="tagName", description="The name of the tag") + random_seed: Optional[StrictInt] = Field(None, alias="randomSeed") + categories: Optional[conlist(constr(strict=True, min_length=1), min_items=1, unique_items=True)] = None + __properties = ["type", "task", "score", "key", "name", "datasetId", "tagName", "randomSeed", "categories"] + + @validator('task') + def task_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/") + return value + + @validator('score') + def score_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^[a-zA-Z0-9_+=,.@:\/-]*$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9_+=,.@:\/-]*$/") + return value + + @validator('dataset_id') + def dataset_id_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^[a-f0-9]{24}$", value): + raise ValueError(r"must validate the regular expression /^[a-f0-9]{24}$/") + return value + + @validator('tag_name') + def tag_name_validate_regular_expression(cls, value): + """Validates the regular expression""" + if value is None: + return value + + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9 .:;=@_-]+$", value): + raise ValueError(r"must validate the regular expression /^[a-zA-Z0-9][a-zA-Z0-9 .:;=@_-]+$/") + return value + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3EntryInput: + """Create an instance of SelectionConfigV3EntryInput from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3EntryInput: + """Create an instance of SelectionConfigV3EntryInput from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3EntryInput.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryInput) in the input: " + str(obj)) + + _obj = SelectionConfigV3EntryInput.parse_obj({ + "type": obj.get("type"), + "task": obj.get("task"), + "score": obj.get("score"), + "key": obj.get("key"), + "name": obj.get("name"), + "dataset_id": obj.get("datasetId"), + "tag_name": obj.get("tagName"), + "random_seed": obj.get("randomSeed"), + "categories": obj.get("categories") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py new file mode 100644 index 000000000..a9e00e1ea --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy.py @@ -0,0 +1,102 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import Any, Dict, Optional, Union +from pydantic import Extra, BaseModel, Field, StrictFloat, StrictInt, confloat, conint +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange +from lightly.openapi_generated.swagger_client.models.selection_strategy_threshold_operation import SelectionStrategyThresholdOperation +from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3 + +class SelectionConfigV3EntryStrategy(BaseModel): + """ + SelectionConfigV3EntryStrategy + """ + type: SelectionStrategyTypeV3 = Field(...) + stopping_condition_minimum_distance: Optional[Union[StrictFloat, StrictInt]] = None + threshold: Optional[Union[StrictFloat, StrictInt]] = None + operation: Optional[SelectionStrategyThresholdOperation] = None + target: Optional[Dict[str, Any]] = None + num_nearest_neighbors: Optional[Union[confloat(ge=2, strict=True), conint(ge=2, strict=True)]] = Field(None, alias="numNearestNeighbors", description="It is the number of nearest datapoints used to compute the typicality of each sample. ") + stopping_condition_minimum_typicality: Optional[Union[confloat(gt=0, strict=True), conint(gt=0, strict=True)]] = Field(None, alias="stoppingConditionMinimumTypicality", description="It is the minimal allowed typicality of the selected samples. When the typicality of the selected samples reaches this, the selection stops. It should be a number between 0 and 1. ") + strength: Optional[Union[confloat(le=1000000000, ge=-1000000000, strict=True), conint(le=1000000000, ge=-1000000000, strict=True)]] = Field(None, description="The relative strength of this strategy compared to other strategies. The default value is 1.0, which is set in the worker for backwards compatibility. The minimum and maximum values of +-10^9 are used to prevent numerical issues. ") + stopping_condition_max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="stoppingConditionMaxSum", description="When the sum of inputs reaches this, the selection stops. Only compatible with the WEIGHTS strategy. Similar to the stopping_condition_minimum_distance for the DIVERSITY strategy. ") + target_range: Optional[SelectionConfigV3EntryStrategyAllOfTargetRange] = Field(None, alias="targetRange") + __properties = ["type", "stopping_condition_minimum_distance", "threshold", "operation", "target", "numNearestNeighbors", "stoppingConditionMinimumTypicality", "strength", "stoppingConditionMaxSum", "targetRange"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategy: + """Create an instance of SelectionConfigV3EntryStrategy from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of target_range + if self.target_range: + _dict['targetRange' if by_alias else 'target_range'] = self.target_range.to_dict(by_alias=by_alias) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategy: + """Create an instance of SelectionConfigV3EntryStrategy from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3EntryStrategy.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategy) in the input: " + str(obj)) + + _obj = SelectionConfigV3EntryStrategy.parse_obj({ + "type": obj.get("type"), + "stopping_condition_minimum_distance": obj.get("stopping_condition_minimum_distance"), + "threshold": obj.get("threshold"), + "operation": obj.get("operation"), + "target": obj.get("target"), + "num_nearest_neighbors": obj.get("numNearestNeighbors"), + "stopping_condition_minimum_typicality": obj.get("stoppingConditionMinimumTypicality"), + "strength": obj.get("strength"), + "stopping_condition_max_sum": obj.get("stoppingConditionMaxSum"), + "target_range": SelectionConfigV3EntryStrategyAllOfTargetRange.from_dict(obj.get("targetRange")) if obj.get("targetRange") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py new file mode 100644 index 000000000..30adac0db --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of.py @@ -0,0 +1,93 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import Optional, Union +from pydantic import Extra, BaseModel, Field, confloat, conint +from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy_all_of_target_range import SelectionConfigV3EntryStrategyAllOfTargetRange +from lightly.openapi_generated.swagger_client.models.selection_strategy_type_v3 import SelectionStrategyTypeV3 + +class SelectionConfigV3EntryStrategyAllOf(BaseModel): + """ + SelectionConfigV3EntryStrategyAllOf + """ + type: SelectionStrategyTypeV3 = Field(...) + num_nearest_neighbors: Optional[Union[confloat(ge=2, strict=True), conint(ge=2, strict=True)]] = Field(None, alias="numNearestNeighbors", description="It is the number of nearest datapoints used to compute the typicality of each sample. ") + stopping_condition_minimum_typicality: Optional[Union[confloat(gt=0, strict=True), conint(gt=0, strict=True)]] = Field(None, alias="stoppingConditionMinimumTypicality", description="It is the minimal allowed typicality of the selected samples. When the typicality of the selected samples reaches this, the selection stops. It should be a number between 0 and 1. ") + strength: Optional[Union[confloat(le=1000000000, ge=-1000000000, strict=True), conint(le=1000000000, ge=-1000000000, strict=True)]] = Field(None, description="The relative strength of this strategy compared to other strategies. The default value is 1.0, which is set in the worker for backwards compatibility. The minimum and maximum values of +-10^9 are used to prevent numerical issues. ") + stopping_condition_max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="stoppingConditionMaxSum", description="When the sum of inputs reaches this, the selection stops. Only compatible with the WEIGHTS strategy. Similar to the stopping_condition_minimum_distance for the DIVERSITY strategy. ") + target_range: Optional[SelectionConfigV3EntryStrategyAllOfTargetRange] = Field(None, alias="targetRange") + __properties = ["type", "numNearestNeighbors", "stoppingConditionMinimumTypicality", "strength", "stoppingConditionMaxSum", "targetRange"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategyAllOf: + """Create an instance of SelectionConfigV3EntryStrategyAllOf from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + # override the default output from pydantic by calling `to_dict()` of target_range + if self.target_range: + _dict['targetRange' if by_alias else 'target_range'] = self.target_range.to_dict(by_alias=by_alias) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategyAllOf: + """Create an instance of SelectionConfigV3EntryStrategyAllOf from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3EntryStrategyAllOf.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategyAllOf) in the input: " + str(obj)) + + _obj = SelectionConfigV3EntryStrategyAllOf.parse_obj({ + "type": obj.get("type"), + "num_nearest_neighbors": obj.get("numNearestNeighbors"), + "stopping_condition_minimum_typicality": obj.get("stoppingConditionMinimumTypicality"), + "strength": obj.get("strength"), + "stopping_condition_max_sum": obj.get("stoppingConditionMaxSum"), + "target_range": SelectionConfigV3EntryStrategyAllOfTargetRange.from_dict(obj.get("targetRange")) if obj.get("targetRange") is not None else None + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py new file mode 100644 index 000000000..73495ae9c --- /dev/null +++ b/lightly/openapi_generated/swagger_client/models/selection_config_v3_entry_strategy_all_of_target_range.py @@ -0,0 +1,80 @@ +# coding: utf-8 + +""" + Lightly API + + Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 + + The version of the OpenAPI document: 1.0.0 + Contact: support@lightly.ai + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + + +from typing import Optional, Union +from pydantic import Extra, BaseModel, Field, confloat, conint + +class SelectionConfigV3EntryStrategyAllOfTargetRange(BaseModel): + """ + If specified, it tries to select samples such that their sum of inputs is >= min_sum and <= max_sum. Only compatible with the WEIGHTS strategy. + """ + min_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="minSum", description="Target minimum sum of inputs. ") + max_sum: Optional[Union[confloat(ge=0.0, strict=True), conint(ge=0, strict=True)]] = Field(None, alias="maxSum", description="Target maximum sum of inputs. Must be >= min_sum. ") + __properties = ["minSum", "maxSum"] + + class Config: + """Pydantic configuration""" + allow_population_by_field_name = True + validate_assignment = True + use_enum_values = True + extra = Extra.forbid + + def to_str(self, by_alias: bool = False) -> str: + """Returns the string representation of the model""" + return pprint.pformat(self.dict(by_alias=by_alias)) + + def to_json(self, by_alias: bool = False) -> str: + """Returns the JSON representation of the model""" + return json.dumps(self.to_dict(by_alias=by_alias)) + + @classmethod + def from_json(cls, json_str: str) -> SelectionConfigV3EntryStrategyAllOfTargetRange: + """Create an instance of SelectionConfigV3EntryStrategyAllOfTargetRange from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self, by_alias: bool = False): + """Returns the dictionary representation of the model""" + _dict = self.dict(by_alias=by_alias, + exclude={ + }, + exclude_none=True) + return _dict + + @classmethod + def from_dict(cls, obj: dict) -> SelectionConfigV3EntryStrategyAllOfTargetRange: + """Create an instance of SelectionConfigV3EntryStrategyAllOfTargetRange from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return SelectionConfigV3EntryStrategyAllOfTargetRange.parse_obj(obj) + + # raise errors for additional fields in the input + for _key in obj.keys(): + if _key not in cls.__properties: + raise ValueError("Error due to additional fields (not defined in SelectionConfigV3EntryStrategyAllOfTargetRange) in the input: " + str(obj)) + + _obj = SelectionConfigV3EntryStrategyAllOfTargetRange.parse_obj({ + "min_sum": obj.get("minSum"), + "max_sum": obj.get("maxSum") + }) + return _obj + diff --git a/lightly/openapi_generated/swagger_client/models/annotation_state.py b/lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py similarity index 59% rename from lightly/openapi_generated/swagger_client/models/annotation_state.py rename to lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py index 48f5388ea..4d1984560 100644 --- a/lightly/openapi_generated/swagger_client/models/annotation_state.py +++ b/lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py @@ -23,24 +23,24 @@ -class AnnotationState(str, Enum): +class SelectionStrategyTypeV3(str, Enum): """ - AnnotationState + SelectionStrategyTypeV3 """ """ allowed enum values """ - DRAFT = 'DRAFT' - OFFER_REQUESTED = 'OFFER_REQUESTED' - OFFER_RETURNED = 'OFFER_RETURNED' - ACCEPTED = 'ACCEPTED' - ACTIVE = 'ACTIVE' - COMPLETED = 'COMPLETED' + DIVERSITY = 'DIVERSITY' + WEIGHTS = 'WEIGHTS' + THRESHOLD = 'THRESHOLD' + BALANCE = 'BALANCE' + SIMILARITY = 'SIMILARITY' + TYPICALITY = 'TYPICALITY' @classmethod - def from_json(cls, json_str: str) -> 'AnnotationState': - """Create an instance of AnnotationState from a JSON string""" - return AnnotationState(json.loads(json_str)) + def from_json(cls, json_str: str) -> 'SelectionStrategyTypeV3': + """Create an instance of SelectionStrategyTypeV3 from a JSON string""" + return SelectionStrategyTypeV3(json.loads(json_str)) diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 59fafe34d..33c70bdad 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -8,6 +8,11 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur diff --git a/lightly/transforms/byol_transform.py b/lightly/transforms/byol_transform.py new file mode 100644 index 000000000..26d9f1360 --- /dev/null +++ b/lightly/transforms/byol_transform.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur +from lightly.transforms.multi_view_transform import MultiViewTransform +from lightly.transforms.rotation import random_rotation_transform +from lightly.transforms.solarize import RandomSolarization +from lightly.transforms.utils import IMAGENET_NORMALIZE + + +class BYOLView1Transform: + def __init__( + self, + input_size: int = 224, + cj_prob: float = 0.8, + cj_strength: float = 1.0, + cj_bright: float = 0.4, + cj_contrast: float = 0.4, + cj_sat: float = 0.2, + cj_hue: float = 0.1, + min_scale: float = 0.08, + random_gray_scale: float = 0.2, + gaussian_blur: float = 1.0, + solarization_prob: float = 0.0, + kernel_size: Optional[float] = None, + sigmas: Tuple[float, float] = (0.1, 2), + vf_prob: float = 0.0, + hf_prob: float = 0.5, + rr_prob: float = 0.0, + rr_degrees: Union[None, float, Tuple[float, float]] = None, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, + ): + color_jitter = T.ColorJitter( + brightness=cj_strength * cj_bright, + contrast=cj_strength * cj_contrast, + saturation=cj_strength * cj_sat, + hue=cj_strength * cj_hue, + ) + + transform = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees), + T.RandomHorizontalFlip(p=hf_prob), + T.RandomVerticalFlip(p=vf_prob), + T.RandomApply([color_jitter], p=cj_prob), + T.RandomGrayscale(p=random_gray_scale), + GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur), + RandomSolarization(prob=solarization_prob), + T.ToTensor(), + ] + if normalize: + transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + self.transform = T.Compose(transform) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """ + Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + transformed: Tensor = self.transform(image) + return transformed + + +class BYOLView2Transform: + def __init__( + self, + input_size: int = 224, + cj_prob: float = 0.8, + cj_strength: float = 1.0, + cj_bright: float = 0.4, + cj_contrast: float = 0.4, + cj_sat: float = 0.2, + cj_hue: float = 0.1, + min_scale: float = 0.08, + random_gray_scale: float = 0.2, + gaussian_blur: float = 0.1, + solarization_prob: float = 0.2, + kernel_size: Optional[float] = None, + sigmas: Tuple[float, float] = (0.1, 2), + vf_prob: float = 0.0, + hf_prob: float = 0.5, + rr_prob: float = 0.0, + rr_degrees: Union[None, float, Tuple[float, float]] = None, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, + ): + color_jitter = T.ColorJitter( + brightness=cj_strength * cj_bright, + contrast=cj_strength * cj_contrast, + saturation=cj_strength * cj_sat, + hue=cj_strength * cj_hue, + ) + + transform = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees), + T.RandomHorizontalFlip(p=hf_prob), + T.RandomVerticalFlip(p=vf_prob), + T.RandomApply([color_jitter], p=cj_prob), + T.RandomGrayscale(p=random_gray_scale), + GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur), + RandomSolarization(prob=solarization_prob), + T.ToTensor(), + ] + if normalize: + transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + self.transform = T.Compose(transform) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """ + Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + transformed: Tensor = self.transform(image) + return transformed + + +class BYOLTransform(MultiViewTransform): + """Implements the transformations for BYOL[0]. + + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - Solarization + - ImageNet normalization + + Note that SimCLR v1 and v2 use similar augmentations. In detail, BYOL has + asymmetric gaussian blur and solarization. Furthermore, BYOL has weaker + color jitter compared to SimCLR. + + - [0]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf + + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of [tensor, tensor]. + + Attributes: + view_1_transform: The transform for the first view. + view_2_transform: The transform for the second view. + """ + + def __init__( + self, + view_1_transform: Union[BYOLView1Transform, None] = None, + view_2_transform: Union[BYOLView2Transform, None] = None, + ): + # We need to initialize the transforms here + view_1_transform = view_1_transform or BYOLView1Transform() + view_2_transform = view_2_transform or BYOLView2Transform() + super().__init__(transforms=[view_1_transform, view_2_transform]) diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py index de6a18657..38e719fbd 100644 --- a/lightly/transforms/dino_transform.py +++ b/lightly/transforms/dino_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import PIL import torchvision.transforms as T @@ -117,7 +117,7 @@ def __init__( kernel_scale: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), solarization_prob: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): # first global crop global_transform_0 = DINOViewTransform( @@ -213,7 +213,7 @@ def __init__( kernel_scale: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), solarization_prob: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transform = [ T.RandomResizedCrop( @@ -262,4 +262,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/fast_siam_transform.py b/lightly/transforms/fast_siam_transform.py index 7b560591f..c70b77de1 100644 --- a/lightly/transforms/fast_siam_transform.py +++ b/lightly/transforms/fast_siam_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.simsiam_transform import SimSiamViewTransform @@ -89,7 +89,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = [ SimSiamViewTransform( diff --git a/lightly/transforms/gaussian_blur.py b/lightly/transforms/gaussian_blur.py index 180e6ff08..3f8f9e2d7 100644 --- a/lightly/transforms/gaussian_blur.py +++ b/lightly/transforms/gaussian_blur.py @@ -6,6 +6,7 @@ import numpy as np from PIL import ImageFilter +from PIL.Image import Image class GaussianBlur: @@ -47,7 +48,7 @@ def __init__( self.prob = prob self.sigmas = sigmas - def __call__(self, sample): + def __call__(self, sample: Image) -> Image: """Blurs the image with a given probability. Args: diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py index 321dba66a..6bd002446 100644 --- a/lightly/transforms/ijepa_transform.py +++ b/lightly/transforms/ijepa_transform.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -30,7 +30,7 @@ def __init__( self, input_size: Union[int, Tuple[int, int]] = 224, min_scale: float = 0.2, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): transforms = [ T.RandomResizedCrop( @@ -55,4 +55,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/image_grid_transform.py b/lightly/transforms/image_grid_transform.py index 30854c910..1822761a6 100644 --- a/lightly/transforms/image_grid_transform.py +++ b/lightly/transforms/image_grid_transform.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Sequence, Union import torchvision.transforms as T from PIL.Image import Image @@ -19,7 +19,7 @@ class ImageGridTransform: grids. """ - def __init__(self, transforms): + def __init__(self, transforms: Sequence[T.Compose]): self.transforms = transforms def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]: diff --git a/lightly/transforms/jigsaw.py b/lightly/transforms/jigsaw.py index db9cbc37b..adebb808b 100644 --- a/lightly/transforms/jigsaw.py +++ b/lightly/transforms/jigsaw.py @@ -1,10 +1,14 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved +from typing import List + import numpy as np import torch -from PIL import Image -from torchvision import transforms +from PIL import Image as Image +from PIL.Image import Image as PILImage +from torch import Tensor +from torchvision import transforms as T class Jigsaw(object): @@ -34,7 +38,11 @@ class Jigsaw(object): """ def __init__( - self, n_grid=3, img_size=255, crop_size=64, transform=transforms.ToTensor() + self, + n_grid: int = 3, + img_size: int = 255, + crop_size: int = 64, + transform: T.Compose = T.ToTensor(), ): self.n_grid = n_grid self.img_size = img_size @@ -47,7 +55,7 @@ def __init__( self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,)) self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,)) - def __call__(self, img): + def __call__(self, img: PILImage) -> Tensor: """Performs the Jigsaw augmentation Args: img: @@ -59,7 +67,7 @@ def __call__(self, img): r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) img = np.asarray(img, np.uint8) - crops = [] + crops: List[PILImage] = [] for i in range(self.n_grid * self.n_grid): crops.append( img[ @@ -68,7 +76,9 @@ def __call__(self, img): :, ] ) - crops = [Image.fromarray(crop) for crop in crops] - crops = torch.stack([self.transform(crop) for crop in crops]) - crops = crops[np.random.permutation(self.n_grid**2)] - return crops + crop_images = [Image.fromarray(crop) for crop in crops] + crop_tensors: Tensor = torch.stack( + [self.transform(crop) for crop in crop_images] + ) + permutation: List[int] = np.random.permutation(self.n_grid**2).tolist() + return crop_tensors[permutation] diff --git a/lightly/transforms/mae_transform.py b/lightly/transforms/mae_transform.py index 3176f084e..50f9dd9f7 100644 --- a/lightly/transforms/mae_transform.py +++ b/lightly/transforms/mae_transform.py @@ -1,10 +1,9 @@ -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T from PIL.Image import Image from torch import Tensor -from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.utils import IMAGENET_NORMALIZE @@ -37,7 +36,7 @@ def __init__( self, input_size: Union[int, Tuple[int, int]] = 224, min_scale: float = 0.2, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): transforms = [ T.RandomResizedCrop( diff --git a/lightly/transforms/moco_transform.py b/lightly/transforms/moco_transform.py index 8ec8ade55..3f5f728fe 100644 --- a/lightly/transforms/moco_transform.py +++ b/lightly/transforms/moco_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from lightly.transforms.simclr_transform import SimCLRTransform from lightly.transforms.utils import IMAGENET_NORMALIZE @@ -83,7 +83,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): super().__init__( input_size=input_size, diff --git a/lightly/transforms/msn_transform.py b/lightly/transforms/msn_transform.py index d07130732..b8a78ac1c 100644 --- a/lightly/transforms/msn_transform.py +++ b/lightly/transforms/msn_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -96,7 +96,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, vf_prob: float = 0.0, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): random_view_transform = MSNViewTransform( crop_size=random_size, @@ -150,7 +150,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, vf_prob: float = 0.0, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -183,4 +183,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/multi_crop_transform.py b/lightly/transforms/multi_crop_transform.py index 307a507c6..2e546b22f 100644 --- a/lightly/transforms/multi_crop_transform.py +++ b/lightly/transforms/multi_crop_transform.py @@ -34,11 +34,11 @@ class MultiCropTranform(MultiViewTransform): def __init__( self, - crop_sizes: Tuple[int], - crop_counts: Tuple[int], - crop_min_scales: Tuple[float], - crop_max_scales: Tuple[float], - transforms, + crop_sizes: Tuple[int, ...], + crop_counts: Tuple[int, ...], + crop_min_scales: Tuple[float, ...], + crop_max_scales: Tuple[float, ...], + transforms: T.Compose, ): if len(crop_sizes) != len(crop_counts): raise ValueError( diff --git a/lightly/transforms/multi_view_transform.py b/lightly/transforms/multi_view_transform.py index a00f4fc00..62c9f2cb5 100644 --- a/lightly/transforms/multi_view_transform.py +++ b/lightly/transforms/multi_view_transform.py @@ -1,7 +1,8 @@ -from typing import List, Union +from typing import List, Sequence, Union from PIL.Image import Image from torch import Tensor +from torchvision import transforms as T class MultiViewTransform: @@ -13,7 +14,7 @@ class MultiViewTransform: """ - def __init__(self, transforms): + def __init__(self, transforms: Sequence[T.Compose]): self.transforms = transforms def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]: diff --git a/lightly/transforms/pirl_transform.py b/lightly/transforms/pirl_transform.py index b67e451c7..1d5ec7d57 100644 --- a/lightly/transforms/pirl_transform.py +++ b/lightly/transforms/pirl_transform.py @@ -1,8 +1,6 @@ -from typing import Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T -from PIL.Image import Image -from torch import Tensor from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.multi_view_transform import MultiViewTransform @@ -71,7 +69,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, n_grid: int = 3, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): if isinstance(input_size, tuple): input_size_ = max(input_size) @@ -79,13 +77,17 @@ def __init__( input_size_ = input_size # Cropping and normalisation for non-transformed image - no_augment = T.Compose( - [ - T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), - T.ToTensor(), - T.Normalize(mean=normalize["mean"], std=normalize["std"]), - ] - ) + transforms_no_augment = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + T.ToTensor(), + ] + + if normalize is not None: + transforms_no_augment.append( + T.Normalize(mean=normalize["mean"], std=normalize["std"]) + ) + + no_augment = T.Compose(transforms_no_augment) color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -95,21 +97,21 @@ def __init__( ) # Transform for transformed jigsaw image - transform = [ + transforms = [ T.RandomHorizontalFlip(p=hf_prob), T.RandomApply([color_jitter], p=cj_prob), T.RandomGrayscale(p=random_gray_scale), T.ToTensor(), ] - if normalize: - transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + if normalize is not None: + transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"])) jigsaw = Jigsaw( n_grid=n_grid, img_size=input_size_, crop_size=int(input_size_ // n_grid), - transform=T.Compose(transform), + transform=T.Compose(transforms), ) super().__init__([no_augment, jigsaw]) diff --git a/lightly/transforms/py.typed b/lightly/transforms/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/lightly/transforms/random_crop_and_flip_with_grid.py b/lightly/transforms/random_crop_and_flip_with_grid.py index 6fbb92a33..60a66226f 100644 --- a/lightly/transforms/random_crop_and_flip_with_grid.py +++ b/lightly/transforms/random_crop_and_flip_with_grid.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Tuple import torch import torchvision.transforms as T @@ -28,7 +28,7 @@ class Location: vertical_flip: bool = False -class RandomResizedCropWithLocation(T.RandomResizedCrop): +class RandomResizedCropWithLocation(T.RandomResizedCrop): # type: ignore[misc] # Class cannot subclass "RandomResizedCrop" (has type "Any") """ Do a random resized crop and return both the resulting image and the location. See base class. @@ -59,7 +59,7 @@ def forward(self, img: Image.Image) -> Tuple[Image.Image, Location]: return img, location -class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip): +class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip): # type: ignore[misc] # Class cannot subclass "RandomHorizontalFlip" (has type "Any") """See base class.""" def forward( @@ -84,7 +84,7 @@ def forward( return img, location -class RandomVerticalFlipWithLocation(T.RandomVerticalFlip): +class RandomVerticalFlipWithLocation(T.RandomVerticalFlip): # type: ignore[misc] # Class cannot subclass "RandomVerticalFlip" (has type "Any") """See base class.""" def forward( diff --git a/lightly/transforms/simclr_transform.py b/lightly/transforms/simclr_transform.py index 8c39591c7..d975b0771 100644 --- a/lightly/transforms/simclr_transform.py +++ b/lightly/transforms/simclr_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -102,7 +102,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = SimCLRViewTransform( input_size=input_size, @@ -145,7 +145,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -180,4 +180,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/simsiam_transform.py b/lightly/transforms/simsiam_transform.py index 379d55242..4f2607480 100644 --- a/lightly/transforms/simsiam_transform.py +++ b/lightly/transforms/simsiam_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -91,7 +91,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = SimSiamViewTransform( input_size=input_size, @@ -134,7 +134,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -169,4 +169,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/smog_transform.py b/lightly/transforms/smog_transform.py index 6f8958377..69d8de64c 100644 --- a/lightly/transforms/smog_transform.py +++ b/lightly/transforms/smog_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -88,7 +88,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = [] for i in range(len(crop_sizes)): @@ -137,7 +137,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -175,4 +175,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/solarize.py b/lightly/transforms/solarize.py index bbd460899..3e640a404 100644 --- a/lightly/transforms/solarize.py +++ b/lightly/transforms/solarize.py @@ -3,6 +3,7 @@ import numpy as np from PIL import ImageOps +from PIL.Image import Image as PILImage class RandomSolarization(object): @@ -22,7 +23,7 @@ def __init__(self, prob: float = 0.5, threshold: int = 128): self.prob = prob self.threshold = threshold - def __call__(self, sample): + def __call__(self, sample: PILImage) -> PILImage: """Solarizes the given input image Args: diff --git a/lightly/transforms/swav_transform.py b/lightly/transforms/swav_transform.py index 2cbabf94f..f7000f945 100644 --- a/lightly/transforms/swav_transform.py +++ b/lightly/transforms/swav_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -96,7 +96,7 @@ def __init__( gaussian_blur: float = 0.5, kernel_size: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = SwaVViewTransform( hf_prob=hf_prob, @@ -142,7 +142,7 @@ def __init__( gaussian_blur: float = 0.5, kernel_size: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/vicreg_transform.py b/lightly/transforms/vicreg_transform.py index e2234f6c4..c7cc0b270 100644 --- a/lightly/transforms/vicreg_transform.py +++ b/lightly/transforms/vicreg_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -97,7 +97,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = VICRegViewTransform( input_size=input_size, @@ -142,7 +142,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/vicregl_transform.py b/lightly/transforms/vicregl_transform.py index 70baf139d..335c03c34 100644 --- a/lightly/transforms/vicregl_transform.py +++ b/lightly/transforms/vicregl_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -123,7 +123,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): global_transform = ( RandomResizedCropAndFlip( @@ -189,7 +189,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -223,4 +223,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Returns: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/utils/benchmarking/linear_classifier.py b/lightly/utils/benchmarking/linear_classifier.py index b47c94560..3fb76c6e6 100644 --- a/lightly/utils/benchmarking/linear_classifier.py +++ b/lightly/utils/benchmarking/linear_classifier.py @@ -1,9 +1,9 @@ -from typing import Dict, Tuple +from typing import Any, Dict, List, Tuple, Union from pytorch_lightning import LightningModule from torch import Tensor from torch.nn import CrossEntropyLoss, Linear, Module -from torch.optim import SGD +from torch.optim import SGD, Optimizer from lightly.models.utils import activate_requires_grad, deactivate_requires_grad from lightly.utils.benchmarking.topk import mean_topk_accuracy @@ -94,9 +94,12 @@ def __init__( def forward(self, images: Tensor) -> Tensor: features = self.model.forward(images).flatten(start_dim=1) - return self.classification_head(features) + output: Tensor = self.classification_head(features) + return output - def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]: + def shared_step( + self, batch: Tuple[Tensor, ...], batch_idx: int + ) -> Tuple[Tensor, Dict[int, Tensor]]: images, targets = batch[0], batch[1] predictions = self.forward(images) loss = self.criterion(predictions, targets) @@ -104,7 +107,7 @@ def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]: topk = mean_topk_accuracy(predicted_labels, targets, k=self.topk) return loss, topk - def training_step(self, batch, batch_idx) -> Tensor: + def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor: loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx) batch_size = len(batch[1]) log_dict = {f"train_top{k}": acc for k, acc in topk.items()} @@ -114,7 +117,7 @@ def training_step(self, batch, batch_idx) -> Tensor: self.log_dict(log_dict, sync_dist=True, batch_size=batch_size) return loss - def validation_step(self, batch, batch_idx) -> Tensor: + def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor: loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx) batch_size = len(batch[1]) log_dict = {f"val_top{k}": acc for k, acc in topk.items()} @@ -122,7 +125,9 @@ def validation_step(self, batch, batch_idx) -> Tensor: self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size) return loss - def configure_optimizers(self): + def configure_optimizers( + self, + ) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]: parameters = list(self.classification_head.parameters()) if not self.freeze_model: parameters += self.model.parameters() @@ -136,7 +141,7 @@ def configure_optimizers(self): "scheduler": CosineWarmupScheduler( optimizer=optimizer, warmup_epochs=0, - max_epochs=self.trainer.estimated_stepping_batches, + max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } diff --git a/lightly/utils/benchmarking/metric_callback.py b/lightly/utils/benchmarking/metric_callback.py index f632f3d2a..b635b2fbb 100644 --- a/lightly/utils/benchmarking/metric_callback.py +++ b/lightly/utils/benchmarking/metric_callback.py @@ -1,9 +1,11 @@ -from typing import Dict, List +from typing import Dict, List, Union from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from torch import Tensor +MetricValue = Union[Tensor, float] + class MetricCallback(Callback): """Callback that collects log metrics from the LightningModule and stores them after @@ -39,7 +41,7 @@ class MetricCallback(Callback): >>> max_val_acc = max(metric_callback.val_metrics["val_acc"]) """ - def __init__(self): + def __init__(self) -> None: super().__init__() self.train_metrics: Dict[str, List[float]] = {} self.val_metrics: Dict[str, List[float]] = {} @@ -58,7 +60,6 @@ def _append_metrics( self, metrics_dict: Dict[str, List[float]], trainer: Trainer ) -> None: for name, value in trainer.callback_metrics.items(): - # Only store scalar values. if isinstance(value, float) or ( isinstance(value, Tensor) and value.numel() == 1 ): diff --git a/lightly/utils/bounding_box.py b/lightly/utils/bounding_box.py index 7f22657a0..322b13b16 100644 --- a/lightly/utils/bounding_box.py +++ b/lightly/utils/bounding_box.py @@ -1,5 +1,6 @@ """ Bounding Box Utils """ +from __future__ import annotations # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved @@ -43,7 +44,7 @@ def __init__( if clip_values: - def clip_to_0_1(value): + def clip_to_0_1(value: float) -> float: return min(1, max(0, value)) x0 = clip_to_0_1(x0) @@ -75,7 +76,7 @@ def clip_to_0_1(value): self.y1 = y1 @classmethod - def from_x_y_w_h(cls, x: float, y: float, w: float, h: float): + def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox: """Helper to convert from bounding box format with width and height. Examples: @@ -85,7 +86,9 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float): return cls(x, y, x + w, y + h) @classmethod - def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float): + def from_yolo_label( + cls, x_center: float, y_center: float, w: float, h: float + ) -> BoundingBox: """Helper to convert from yolo label format x_center, y_center, w, h --> x0, y0, x1, y1 @@ -102,16 +105,16 @@ def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float): ) @property - def width(self): + def width(self) -> float: """Returns the width of the bounding box relative to the image size.""" return self.x1 - self.x0 @property - def height(self): + def height(self) -> float: """Returns the height of the bounding box relative to the image size.""" return self.y1 - self.y0 @property - def area(self): + def area(self) -> float: """Returns the area of the bounding box relative to the area of the image.""" return self.width * self.height diff --git a/lightly/utils/cropping/__init__.py b/lightly/utils/cropping/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightly/utils/dist.py b/lightly/utils/dist.py index 64f9ed351..2c20750c6 100644 --- a/lightly/utils/dist.py +++ b/lightly/utils/dist.py @@ -68,3 +68,28 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: diag_mask = torch.zeros((n, n * world_size()), dtype=torch.bool) diag_mask[(rows, cols)] = True return diag_mask + + +def rank_zero_only(fn): + """Decorator that only runs the function on the process with rank 0. + + Example: + >>> @rank_zero_only + >>> def print_rank_zero(message: str): + >>> print(message) + >>> + >>> print_rank_zero("Hello from rank 0!") + + """ + + def wrapped(*args, **kwargs): + if rank() == 0: + return fn(*args, **kwargs) + + return wrapped + + +@rank_zero_only +def print_rank_zero(*args, **kwargs) -> None: + """Equivalent to print, but only runs on the process with rank 0.""" + print(*args, **kwargs) diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py index 0b40a4bac..046f2d874 100644 --- a/lightly/utils/embeddings_2d.py +++ b/lightly/utils/embeddings_2d.py @@ -3,7 +3,12 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from __future__ import annotations + +from typing import Optional, Tuple + import numpy as np +from numpy.typing import NDArray class PCA(object): @@ -18,11 +23,11 @@ class PCA(object): def __init__(self, n_components: int = 2, eps: float = 1e-10): self.n_components = n_components - self.mean = None - self.w = None + self.mean: Optional[NDArray[np.float32]] = None + self.w: Optional[NDArray[np.float32]] = None self.eps = eps - def fit(self, X: np.ndarray): + def fit(self, X: NDArray[np.float32]) -> PCA: """Fits PCA to data in X. Args: @@ -35,6 +40,7 @@ def fit(self, X: np.ndarray): """ X = X.astype(np.float32) self.mean = X.mean(axis=0) + assert self.mean is not None X = X - self.mean + self.eps cov = np.cov(X.T) / X.shape[0] v, w = np.linalg.eig(cov) @@ -43,7 +49,7 @@ def fit(self, X: np.ndarray): self.w = w return self - def transform(self, X: np.ndarray): + def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: """Uses PCA to transform data in X. Args: @@ -53,13 +59,22 @@ def transform(self, X: np.ndarray): Returns: Numpy array of n x p datapoints where p <= d. + Raises: + ValueError: If PCA was not fitted before. """ + if self.mean is None or self.w is None: + raise ValueError("PCA not fitted yet. Call fit() before transform().") X = X.astype(np.float32) X = X - self.mean + self.eps - return X.dot(self.w)[:, : self.n_components] + transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components] + return np.asarray(transformed) -def fit_pca(embeddings: np.ndarray, n_components: int = 2, fraction: float = None): +def fit_pca( + embeddings: NDArray[np.float32], + n_components: int = 2, + fraction: Optional[float] = None, +) -> PCA: """Fits PCA to randomly selected subset of embeddings. For large datasets, it can be unfeasible to perform PCA on the whole data. diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py index fd2d7097b..44dc32aed 100644 --- a/lightly/utils/hipify.py +++ b/lightly/utils/hipify.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Type +from typing import Type, Union class bcolors: @@ -14,15 +14,19 @@ class bcolors: UNDERLINE = "\033[4m" -def _custom_formatwarning(msg, *args, **kwargs): +def _custom_formatwarning( + message: Union[str, Warning], + category: Type[Warning], + filename: str, + lineno: int, + line: Union[str, None] = None, +) -> str: # ignore everything except the message - return f"{bcolors.WARNING}{msg}{bcolors.WARNING}\n" + return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n" -def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning): +def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None: old_format = copy.copy(warnings.formatwarning) - warnings.formatwarning = _custom_formatwarning warnings.warn(message, warning_class) - warnings.formatwarning = old_format diff --git a/lightly/utils/io.py b/lightly/utils/io.py index 00ca63a1c..74240bf6e 100644 --- a/lightly/utils/io.py +++ b/lightly/utils/io.py @@ -7,35 +7,13 @@ import json import re from itertools import compress -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union import numpy as np +from numpy.typing import NDArray -INVALID_FILENAME_CHARACTERS = [","] - -def _is_valid_filename(filename: str) -> bool: - """Returns False if the filename is misformatted.""" - for character in INVALID_FILENAME_CHARACTERS: - if character in filename: - return False - return True - - -def check_filenames(filenames: List[str]): - """Raises an error if one of the filenames is misformatted - - Args: - filenames: - A list of string being filenames - - """ - invalid_filenames = [f for f in filenames if not _is_valid_filename(f)] - if len(invalid_filenames) > 0: - raise ValueError(f"Invalid filename(s): {invalid_filenames}") - - -def check_embeddings(path: str, remove_additional_columns: bool = False): +def check_embeddings(path: str, remove_additional_columns: bool = False) -> None: """Raises an error if the embeddings csv file has not the correct format Use this check whenever you want to upload an embedding to the Lightly @@ -118,8 +96,8 @@ def check_embeddings(path: str, remove_additional_columns: bool = False): def save_embeddings( - path: str, embeddings: np.ndarray, labels: List[int], filenames: List[str] -): + path: str, embeddings: NDArray[np.float64], labels: List[int], filenames: List[str] +) -> None: """Saves embeddings in a csv file in a Lightly compatible format. Creates a csv file at the location specified by path and saves embeddings, @@ -146,7 +124,6 @@ def save_embeddings( >>> labels, >>> filenames) """ - check_filenames(filenames) n_embeddings = len(embeddings) n_filenames = len(filenames) @@ -167,7 +144,7 @@ def save_embeddings( writer.writerow([filename] + list(embedding) + [str(label)]) -def load_embeddings(path: str): +def load_embeddings(path: str) -> Tuple[NDArray[np.float64], List[int], List[str]]: """Loads embeddings from a csv file in a Lightly compatible format. Args: @@ -202,15 +179,13 @@ def load_embeddings(path: str): # read embeddings embeddings.append(row[1:-1]) - check_filenames(filenames) - - embeddings = np.array(embeddings).astype(np.float32) - return embeddings, labels, filenames + embedding_array = np.array(embeddings).astype(np.float64) + return embedding_array, labels, filenames def load_embeddings_as_dict( path: str, embedding_name: str = "default", return_all: bool = False -): +) -> Union[Any, Tuple[Any, NDArray[np.float64], List[int], List[str]]]: """Loads embeddings from csv and store it in a dictionary for transfer. Loads embeddings to a dictionary which can be serialized and sent to the @@ -245,10 +220,13 @@ def load_embeddings_as_dict( embeddings, labels, filenames = load_embeddings(path) # build dictionary - data = {"embeddingName": embedding_name, "embeddings": []} - for embedding, filename, label in zip(embeddings, filenames, labels): - item = {"fileName": filename, "value": embedding.tolist(), "label": label} - data["embeddings"].append(item) + data = { + "embeddingName": embedding_name, + "embeddings": [ + {"fileName": filename, "value": embedding.tolist(), "label": label} + for embedding, filename, label in zip(embeddings, filenames, labels) + ], + } # return embeddings along with dictionary if return_all: @@ -258,7 +236,10 @@ def load_embeddings_as_dict( class COCO_ANNOTATION_KEYS: - """Enum of coco annotation keys complemented with a key for custom metadata.""" + """Enum of coco annotation keys complemented with a key for custom metadata. + + :meta private: # Skip docstring generation + """ # image keys images: str = "images" @@ -270,7 +251,9 @@ class COCO_ANNOTATION_KEYS: custom_metadata_image_id: str = "image_id" -def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): +def format_custom_metadata( + custom_metadata: List[Tuple[str, Any]] +) -> Dict[str, List[Any]]: """Transforms custom metadata into a format which can be handled by Lightly. Args: @@ -292,8 +275,9 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): >>> > 'metadata': [{'image_id': 0, 'number_of_people': 1}, {'image_id': 1, 'number_of_people': 3}] >>> > } + :meta private: # Skip docstring generation """ - formatted = { + formatted: Dict[str, List[Any]] = { COCO_ANNOTATION_KEYS.images: [], COCO_ANNOTATION_KEYS.custom_metadata: [], } @@ -315,7 +299,7 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): return formatted -def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]): +def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Any]]) -> None: """Saves custom metadata in a .json. Args: @@ -324,6 +308,7 @@ def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]): custom_metadata: List of tuples (filename, metadata) where metadata is a dictionary. + :meta private: # Skip docstring generation """ formatted = format_custom_metadata(custom_metadata) with open(path, "w") as f: @@ -333,7 +318,7 @@ def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]): def save_tasks( path: str, tasks: List[str], -): +) -> None: """Saves a list of prediction task names in the right format. Args: @@ -347,7 +332,7 @@ def save_tasks( json.dump(tasks, f) -def save_schema(path: str, task_type: str, ids: List[int], names: List[str]): +def save_schema(path: str, task_type: str, ids: List[int], names: List[str]) -> None: """Saves a prediction schema in the right format. Args: diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py index af14ff028..063149d36 100644 --- a/lightly/utils/lars.py +++ b/lightly/utils/lars.py @@ -1,5 +1,8 @@ +from typing import Any, Callable, Dict, Optional, Union + import torch -from torch.optim.optimizer import Optimizer, required +from torch import Tensor +from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined] class LARS(Optimizer): @@ -65,7 +68,7 @@ class LARS(Optimizer): def __init__( self, - params, + params: Any, lr: float = required, momentum: float = 0, dampening: float = 0, @@ -95,14 +98,14 @@ def __init__( super().__init__(params, defaults) - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) @torch.no_grad() - def step(self, closure=None): + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. Args: diff --git a/lightly/utils/reordering.py b/lightly/utils/reordering.py index 22311d8b6..baaf286d3 100644 --- a/lightly/utils/reordering.py +++ b/lightly/utils/reordering.py @@ -1,7 +1,12 @@ -from typing import List, Sized +from typing import List, Sequence, TypeVar +_K = TypeVar("_K") +_V = TypeVar("_V") -def sort_items_by_keys(keys: List[any], items: List[any], sorted_keys: List[any]): + +def sort_items_by_keys( + keys: Sequence[_K], items: Sequence[_V], sorted_keys: Sequence[_K] +) -> List[_V]: """Sorts the items in the same order as the sorted keys. Args: diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index 0f200146b..a698e09ac 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -1,14 +1,19 @@ import warnings +from typing import Optional import numpy as np import torch def cosine_schedule( - step: float, max_steps: float, start_value: float, end_value: float + step: int, + max_steps: int, + start_value: float, + end_value: float, + period: Optional[int] = None, ) -> float: - """ - Use cosine decay to gradually modify start_value to reach target end_value during iterations. + """Use cosine decay to gradually modify start_value to reach target end_value during + iterations. Args: step: @@ -19,6 +24,9 @@ def cosine_schedule( Starting value. end_value: Target value. + period (optional): + The number of steps over which the cosine function completes a full cycle. + If not provided, it defaults to max_steps. Returns: Cosine decay value. @@ -28,13 +36,21 @@ def cosine_schedule( raise ValueError("Current step number can't be negative") if max_steps < 1: raise ValueError("Total step number must be >= 1") - if step > max_steps: + if period is None and step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", category=RuntimeWarning, ) + if period is not None and period <= 0: + raise ValueError("Period must be >= 1") - if max_steps == 1: + decay: float + if period is not None: # "cycle" based on period, if provided + decay = ( + end_value + - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2 + ) + elif max_steps == 1: # Avoid division by zero decay = end_value elif step == max_steps: @@ -52,8 +68,7 @@ def cosine_schedule( class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): - """ - Cosine warmup scheduler for learning rate. + """Cosine warmup scheduler for learning rate. Args: optimizer: @@ -70,22 +85,28 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): Target learning rate scale. Default: 0.001 verbose: If True, prints a message to stdout for each update. Default: False. + + Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index + can be used. The naming follows the Pytorch convention to use `epoch` for the steps + in the scheduler. """ def __init__( self, optimizer: torch.optim.Optimizer, - warmup_epochs: float, - max_epochs: float, - last_epoch: float = -1, + warmup_epochs: int, + max_epochs: int, + last_epoch: int = -1, start_value: float = 1.0, end_value: float = 0.001, + period: Optional[int] = None, verbose: bool = False, ) -> None: self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.start_value = start_value self.end_value = end_value + self.period = period super().__init__( optimizer=optimizer, lr_lambda=self.scale_lr, @@ -106,7 +127,15 @@ def scale_lr(self, epoch: int) -> float: """ if epoch < self.warmup_epochs: - return (epoch + 1) / self.warmup_epochs + return self.start_value * (epoch + 1) / self.warmup_epochs + elif self.period is not None: + return cosine_schedule( + step=epoch - self.warmup_epochs, + max_steps=1, + start_value=self.start_value, + end_value=self.end_value, + period=self.period, + ) else: return cosine_schedule( step=epoch - self.warmup_epochs, diff --git a/lightly/utils/version_compare.py b/lightly/utils/version_compare.py index 898407cfb..05c7093f5 100644 --- a/lightly/utils/version_compare.py +++ b/lightly/utils/version_compare.py @@ -4,7 +4,7 @@ # All Rights Reserved -def version_compare(v0: str, v1: str): +def version_compare(v0: str, v1: str) -> int: """Returns 1 if version of v0 is larger than v1 and -1 otherwise Use this method to compare Python package versions and see which one is @@ -16,14 +16,14 @@ def version_compare(v0: str, v1: str): >>> version_compare('1.2.0', '1.1.2') >>> 1 """ - v0 = [int(n) for n in v0.split(".")][::-1] - v1 = [int(n) for n in v1.split(".")][::-1] - if len(v0) != 3 or len(v1) != 3: + v0_parsed = [int(n) for n in v0.split(".")][::-1] + v1_parsed = [int(n) for n in v1.split(".")][::-1] + if len(v0_parsed) != 3 or len(v1_parsed) != 3: raise ValueError( f"Length of version strings is not 3 (expected pattern `x.y.z`) but is " - f"{v0} and {v1}." + f"{v0_parsed} and {v1_parsed}." ) - pairs = list(zip(v0, v1))[::-1] + pairs = list(zip(v0_parsed, v1_parsed))[::-1] for x, y in pairs: if x < y: return -1 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..a4c18e97f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,219 @@ +# Global options: + +[mypy] +ignore_missing_imports = True +python_version = 3.10 +warn_unused_configs = True +strict_equality = True + +# Disallow dynamic typing +disallow_any_decorated = True +# TODO(Philipp, 09/23): Remove me! +# disallow_any_explicit = True +disallow_any_generics = True +disallow_subclassing_any = True + +# Disallow untyped definitions +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True + +# None and optional handling +no_implicit_optional = True +strict_optional = True + +# Configuring warnings +warn_unused_ignores = True +warn_no_return = True +warn_return_any = True +warn_redundant_casts = True +warn_unreachable = True + +# Print format +show_error_codes = True +show_error_context = True + +# Plugins +plugins = numpy.typing.mypy_plugin + +# Exludes +# TODO(Philipp, 09/23): Remove these one by one (start with 300 files). +exclude = (?x)( + lightly/cli/version_cli.py | + lightly/cli/crop_cli.py | + lightly/cli/serve_cli.py | + lightly/cli/embed_cli.py | + lightly/cli/lightly_cli.py | + lightly/cli/download_cli.py | + lightly/cli/config/get_config.py | + lightly/cli/train_cli.py | + lightly/cli/_cli_simclr.py | + lightly/cli/_helpers.py | + lightly/loss/ntx_ent_loss.py | + lightly/loss/vicreg_loss.py | + lightly/loss/tico_loss.py | + lightly/loss/pmsn_loss.py | + lightly/loss/swav_loss.py | + lightly/loss/negative_cosine_similarity.py | + lightly/loss/hypersphere_loss.py | + lightly/loss/msn_loss.py | + lightly/loss/dino_loss.py | + lightly/loss/sym_neg_cos_sim_loss.py | + lightly/loss/vicregl_loss.py | + lightly/loss/dcl_loss.py | + lightly/loss/regularizer/co2.py | + lightly/loss/barlow_twins_loss.py | + lightly/data/lightly_subset.py | + lightly/data/dataset.py | + lightly/data/collate.py | + lightly/data/_image.py | + lightly/data/_helpers.py | + lightly/data/_image_loaders.py | + lightly/data/_video.py | + lightly/data/_utils.py | + lightly/data/multi_view_collate.py | + lightly/core.py | + lightly/api/api_workflow_compute_worker.py | + lightly/api/api_workflow_predictions.py | + lightly/api/download.py | + lightly/api/api_workflow_export.py | + lightly/api/api_workflow_download_dataset.py | + lightly/api/bitmask.py | + lightly/api/_version_checking.py | + lightly/api/serve.py | + lightly/api/patch.py | + lightly/api/swagger_api_client.py | + lightly/api/api_workflow_collaboration.py | + lightly/api/utils.py | + lightly/api/api_workflow_datasets.py | + lightly/api/api_workflow_selection.py | + lightly/api/swagger_rest_client.py | + lightly/api/api_workflow_datasources.py | + lightly/api/api_workflow_upload_embeddings.py | + lightly/api/api_workflow_client.py | + lightly/api/api_workflow_upload_metadata.py | + lightly/api/api_workflow_tags.py | + lightly/api/api_workflow_artifacts.py | + lightly/utils/cropping/crop_image_by_bounding_boxes.py | + lightly/utils/cropping/read_yolo_label_file.py | + lightly/utils/debug.py | + lightly/utils/dist.py | + lightly/utils/benchmarking/benchmark_module.py | + lightly/utils/benchmarking/knn_classifier.py | + lightly/utils/benchmarking/online_linear_classifier.py | + lightly/models/modules/masked_autoencoder.py | + lightly/models/modules/ijepa.py | + lightly/models/utils.py | + tests/cli/test_cli_version.py | + tests/cli/test_cli_magic.py | + tests/cli/test_cli_crop.py | + tests/cli/test_cli_download.py | + tests/cli/test_cli_train.py | + tests/cli/test_cli_get_lighty_config.py | + tests/cli/test_cli_embed.py | + tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py | + tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py | + tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py | + tests/loss/test_NegativeCosineSimilarity.py | + tests/loss/test_MSNLoss.py | + tests/loss/test_DINOLoss.py | + tests/loss/test_VICRegLLoss.py | + tests/loss/test_CO2Regularizer.py | + tests/loss/test_DCLLoss.py | + tests/loss/test_barlow_twins_loss.py | + tests/loss/test_SymNegCosineSimilarityLoss.py | + tests/loss/test_NTXentLoss.py | + tests/loss/test_MemoryBank.py | + tests/loss/test_TicoLoss.py | + tests/loss/test_VICRegLoss.py | + tests/loss/test_PMSNLoss.py | + tests/loss/test_HyperSphere.py | + tests/loss/test_SwaVLoss.py | + tests/core/test_Core.py | + tests/data/test_multi_view_collate.py | + tests/data/test_data_collate.py | + tests/data/test_VideoDataset.py | + tests/data/test_LightlySubset.py | + tests/data/test_LightlyDataset.py | + tests/embedding/test_callbacks.py | + tests/embedding/test_embedding.py | + tests/api/test_serve.py | + tests/api/test_swagger_rest_client.py | + tests/api/test_rest_parser.py | + tests/api/test_utils.py | + tests/api/benchmark_video_download.py | + tests/api/test_BitMask.py | + tests/api/test_patch.py | + tests/api/test_download.py | + tests/api/test_version_checking.py | + tests/api/test_swagger_api_client.py | + tests/utils/test_debug.py | + tests/utils/benchmarking/test_benchmark_module.py | + tests/utils/benchmarking/test_topk.py | + tests/utils/benchmarking/test_online_linear_classifier.py | + tests/utils/benchmarking/test_knn_classifier.py | + tests/utils/benchmarking/test_knn.py | + tests/utils/benchmarking/test_linear_classifier.py | + tests/utils/benchmarking/test_metric_callback.py | + tests/utils/test_dist.py | + tests/models/test_ModelsSimSiam.py | + tests/models/modules/test_masked_autoencoder.py | + tests/models/test_ModelsSimCLR.py | + tests/models/test_ModelUtils.py | + tests/models/test_ModelsNNCLR.py | + tests/models/test_ModelsMoCo.py | + tests/models/test_ProjectionHeads.py | + tests/models/test_ModelsBYOL.py | + tests/conftest.py | + tests/api_workflow/test_api_workflow_selection.py | + tests/api_workflow/test_api_workflow_datasets.py | + tests/api_workflow/mocked_api_workflow_client.py | + tests/api_workflow/test_api_workflow_compute_worker.py | + tests/api_workflow/test_api_workflow_artifacts.py | + tests/api_workflow/test_api_workflow_download_dataset.py | + tests/api_workflow/utils.py | + tests/api_workflow/test_api_workflow_client.py | + tests/api_workflow/test_api_workflow_export.py | + tests/api_workflow/test_api_workflow_datasources.py | + tests/api_workflow/test_api_workflow_tags.py | + tests/api_workflow/test_api_workflow_upload_custom_metadata.py | + tests/api_workflow/test_api_workflow_upload_embeddings.py | + tests/api_workflow/test_api_workflow_collaboration.py | + tests/api_workflow/test_api_workflow_predictions.py | + tests/api_workflow/test_api_workflow.py | + # Let's not type check deprecated active learning: + lightly/active_learning | + # Let's not type deprecated models: + lightly/models/simclr.py | + lightly/models/moco.py | + lightly/models/barlowtwins.py | + lightly/models/nnclr.py | + lightly/models/simsiam.py | + lightly/models/byol.py ) + +# Ignore imports from untyped modules. +[mypy-lightly.api.*] +follow_imports = skip + +[mypy-lightly.cli.*] +follow_imports = skip + +[mypy-lightly.data.*] +follow_imports = skip + +[mypy-lightly.loss.*] +follow_imports = skip + +[mypy-lightly.models.*] +follow_imports = skip + +[mypy-lightly.utils.benchmarking.*] +follow_imports = skip + +[mypy-tests.api_workflow.*] +follow_imports = skip + +# Ignore errors in auto generated code. +[mypy-lightly.openapi_generated.*] +ignore_errors = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 18cdd7197..904100d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +# pyproject.toml is currently only used to configure developement tools. +# Configurations for the lightly package are in setup.py. + [tool.black] extend-exclude = "lightly/openapi_generated/.*" diff --git a/requirements/dev.txt b/requirements/dev.txt index dc74efe15..d5bc31f7b 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -22,3 +22,4 @@ torchmetrics lightning-bolts # for LARS optimizer black==23.1.0 # frozen version to avoid differences between CI and local dev machines isort==5.11.5 # frozen version to avoid differences between CI and local dev machines +mypy==1.4.1 # frozen version to avoid differences between CI and local dev machines diff --git a/requirements/openapi.txt b/requirements/openapi.txt index d52723c7a..74ede174a 100644 --- a/requirements/openapi.txt +++ b/requirements/openapi.txt @@ -1,4 +1,5 @@ python_dateutil >= 2.5.3 +setuptools >= 21.0.0 urllib3 >= 1.25.3 pydantic >= 1.10.5, < 2 aenum >= 3.1.11 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a684c9b94..000000000 --- a/setup.cfg +++ /dev/null @@ -1,3 +0,0 @@ -[metadata] -# This includes the license file(s) in the wheel. -license_files = LICENSE.txt \ No newline at end of file diff --git a/setup.py b/setup.py index 94d96ef1f..7caad8ce9 100644 --- a/setup.py +++ b/setup.py @@ -1,36 +1,19 @@ import os -import sys +from pathlib import Path +from typing import List import setuptools -try: - import builtins -except ImportError: - import __builtin__ as builtins +_PATH_ROOT = Path(os.path.dirname(__file__)) -PATH_ROOT = PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTLY_SETUP__ = True -import lightly - - -def load_description(path_dir=PATH_ROOT, filename="DOCS.md"): - """Load long description from readme in the path_dir/ directory""" - with open(os.path.join(path_dir, filename)) as f: - long_description = f.read() - return long_description - - -def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#"): - """From pytorch-lightning repo: https://github.com/PyTorchLightning/pytorch-lightning. - Load requirements from text file in the path_dir/requirements/ directory. - - """ - with open(os.path.join(path_dir, "requirements", filename), "r") as file: +def load_requirements(filename: str, comment_char: str = "#") -> List[str]: + """Load requirements from text file in the requirements directory.""" + with (_PATH_ROOT / "requirements" / filename).open() as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: - # filer all comments + # filter all comments if comment_char in ln: ln = ln[: ln.index(comment_char)].strip() # skip directly installed dependencies @@ -41,28 +24,44 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#") return reqs +def load_version() -> str: + """Load version from the lightly/__init__.py file. + + Note: We do not want to get the version by accessing `lightly.__version__` because + it would require importing `lightly`. Importing `lightly` in setup.py breaks the + installation process as the import has side effects and requires dependencies to be + installed. As dependencies are not yet available during installation, the `lightly` + import fails. + """ + version_filepath = _PATH_ROOT / "lightly" / "__init__.py" + with version_filepath.open() as file: + for line in file.readlines(): + if line.startswith("__version__"): + version = line.split("=")[-1].strip().strip('"') + return version + raise RuntimeError("Unable to find version string in '{version_filepath}'.") + + if __name__ == "__main__": name = "lightly" - version = lightly.__version__ - description = lightly.__doc__ - - author = "Philipp Wirth & Igor Susmelj" - author_email = "philipp@lightly.ai" + version = load_version() + author = "Lightly Team" + author_email = "team@lightly.ai" description = "A deep learning package for self-supervised learning" + long_description = (_PATH_ROOT / "README.md").read_text() entry_points = { "console_scripts": [ "lightly-crop = lightly.cli.crop_cli:entry", - "lightly-train = lightly.cli.train_cli:entry", + "lightly-download = lightly.cli.download_cli:entry", "lightly-embed = lightly.cli.embed_cli:entry", "lightly-magic = lightly.cli.lightly_cli:entry", - "lightly-download = lightly.cli.download_cli:entry", + "lightly-serve = lightly.cli.serve_cli:entry", + "lightly-train = lightly.cli.train_cli:entry", "lightly-version = lightly.cli.version_cli:entry", ] } - long_description = load_description() - python_requires = ">=3.6" base_requires = load_requirements(filename="base.txt") openapi_requires = load_requirements(filename="openapi.txt") @@ -78,28 +77,7 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#") "all": dev_requires + video_requires, } - packages = [ - "lightly", - "lightly.api", - "lightly.cli", - "lightly.cli.config", - "lightly.data", - "lightly.embedding", - "lightly.loss", - "lightly.loss.regularizer", - "lightly.models", - "lightly.models.modules", - "lightly.transforms", - "lightly.utils", - "lightly.utils.benchmarking", - "lightly.utils.cropping", - "lightly.active_learning", - "lightly.active_learning.config", - "lightly.openapi_generated", - "lightly.openapi_generated.swagger_client", - "lightly.openapi_generated.swagger_client.api", - "lightly.openapi_generated.swagger_client.models", - ] + packages = setuptools.find_packages(include=["lightly*"]) project_urls = { "Homepage": "https://www.lightly.ai", @@ -115,8 +93,9 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#") "Intended Audience :: Education", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", + "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Software Development", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", @@ -125,6 +104,8 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#") "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", ] @@ -136,6 +117,7 @@ def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#") description=description, entry_points=entry_points, license="MIT", + license_files=["LICENSE.txt"], long_description=long_description, long_description_content_type="text/markdown", setup_requires=setup_requires, diff --git a/tests/api/test_serve.py b/tests/api/test_serve.py new file mode 100644 index 000000000..1d056199d --- /dev/null +++ b/tests/api/test_serve.py @@ -0,0 +1,18 @@ +from pathlib import Path + +from lightly.api import serve + + +def test__translate_path(tmp_path: Path) -> None: + tmp_file = tmp_path / "hello/world.txt" + assert serve._translate_path(path="/hello/world.txt", directories=[]) == "" + assert serve._translate_path(path="/hello/world.txt", directories=[tmp_path]) == "" + tmp_file.mkdir(parents=True, exist_ok=True) + tmp_file.touch() + assert serve._translate_path( + path="/hello/world.txt", directories=[tmp_path] + ) == str(tmp_file) + assert serve._translate_path( + path="/world.txt", + directories=[tmp_path / "hi", tmp_path / "hello"], + ) == str(tmp_file) diff --git a/tests/api/test_version_checking.py b/tests/api/test_version_checking.py index df885c785..5a72d8344 100644 --- a/tests/api/test_version_checking.py +++ b/tests/api/test_version_checking.py @@ -1,75 +1,87 @@ -import sys +import os import time -import unittest - -import lightly -from lightly.api.version_checking import ( - LightlyAPITimeoutException, - get_latest_version, - get_minimum_compatible_version, - is_compatible_version, - is_latest_version, - pretty_print_latest_version, -) -from tests.api_workflow.mocked_api_workflow_client import MockedVersioningApi - - -class TestVersionChecking(unittest.TestCase): - def setUp(self) -> None: - lightly.api.version_checking.VersioningApi = MockedVersioningApi - - def test_get_latest_version(self): - get_latest_version("1.2.3") - - def test_get_minimum_compatible_version(self): - get_minimum_compatible_version() - - def test_is_latest_version(self) -> None: - assert is_latest_version("1.2.8") - assert not is_latest_version("1.2.7") - assert not is_latest_version("1.1.8") - assert not is_latest_version("0.2.8") - - def test_is_compatible_version(self) -> None: - assert is_compatible_version("1.2.1") - assert not is_compatible_version("1.2.0") - assert not is_compatible_version("1.1.9") - assert not is_compatible_version("0.2.1") - - def test_pretty_print(self): - pretty_print_latest_version(current_version="curr", latest_version="1.1.1") - - def test_version_check_timout_mocked(self): - """ - We cannot check for other errors as we don't know whether the - current LIGHTLY_SERVER_URL is - - unreachable (error in < 1 second) - - causing a timeout and thus raising a LightlyAPITimeoutException - - reachable (success in < 1 second - - Thus this only checks that the actual lightly.do_version_check() - with needing >1s internally causes a LightlyAPITimeoutException - """ - try: - old_get_versioning_api = lightly.api.version_checking.get_versioning_api - - def mocked_get_versioning_api_timeout(): - time.sleep(10) - print("This line should never be reached, calling sys.exit()") - sys.exit() - - lightly.api.version_checking.get_versioning_api = ( - mocked_get_versioning_api_timeout - ) - - start_time = time.time() - - with self.assertRaises(LightlyAPITimeoutException): - is_latest_version(lightly.__version__) - - duration = time.time() - start_time - - self.assertLess(duration, 1.5) - - finally: - lightly.api.version_checking.get_versioning_api = old_get_versioning_api + +import pytest +from pytest_mock import MockerFixture +from urllib3.exceptions import MaxRetryError + +from lightly.api import _version_checking +from lightly.openapi_generated.swagger_client.api import VersioningApi + + +# Overwrite the mock_versioning_api fixture from conftest.py that is applied by default +# for all tests as we want to test the functionality of the versioning api. +@pytest.fixture(autouse=True) +def mock_versioning_api(): + return + + +@pytest.mark.disable_mock_versioning_api +def test_is_latest_version(mocker: MockerFixture) -> None: + mocker.patch.object( + _version_checking.VersioningApi, "get_latest_pip_version", return_value="1.2.8" + ) + assert _version_checking.is_latest_version("1.2.8") + assert not _version_checking.is_latest_version("1.2.7") + assert not _version_checking.is_latest_version("1.1.8") + assert not _version_checking.is_latest_version("0.2.8") + + +def test_is_compatible_version(mocker: MockerFixture) -> None: + mocker.patch.object( + _version_checking.VersioningApi, + "get_minimum_compatible_pip_version", + return_value="1.2.8", + ) + assert _version_checking.is_compatible_version("1.2.8") + assert not _version_checking.is_compatible_version("1.2.7") + assert not _version_checking.is_compatible_version("1.1.8") + assert not _version_checking.is_compatible_version("0.2.8") + + +def test_get_latest_version(mocker: MockerFixture) -> None: + mocker.patch.object( + _version_checking.VersioningApi, "get_latest_pip_version", return_value="1.2.8" + ) + assert _version_checking.get_latest_version("1.2.8") == "1.2.8" + + +def test_get_latest_version__timeout(mocker: MockerFixture) -> None: + mocker.patch.dict(os.environ, {"LIGHTLY_SERVER_LOCATION": "invalid-url"}) + start = time.perf_counter() + with pytest.raises(MaxRetryError): + # Urllib3 raises a timeout error (connection refused) for invalid URLs. + _version_checking.get_latest_version("1.2.8", timeout_sec=0.1) + end = time.perf_counter() + assert end - start < 0.2 # Give some slack for timeout. + + +def test_get_minimum_compatible_version(mocker: MockerFixture) -> None: + mocker.patch.object( + _version_checking.VersioningApi, + "get_minimum_compatible_pip_version", + return_value="1.2.8", + ) + + assert _version_checking.get_minimum_compatible_version() == "1.2.8" + + +def test_get_minimum_compatible_version__timeout(mocker: MockerFixture) -> None: + mocker.patch.dict(os.environ, {"LIGHTLY_SERVER_LOCATION": "invalid-url"}) + start = time.perf_counter() + with pytest.raises(MaxRetryError): + # Urllib3 raises a timeout error (connection refused) for invalid URLs. + _version_checking.get_minimum_compatible_version(timeout_sec=0.1) + end = time.perf_counter() + assert end - start < 0.2 # Give some slack for timeout. + + +def test_check_is_latest_version_in_background(mocker: MockerFixture) -> None: + spy_is_latest_version = mocker.spy(_version_checking, "is_latest_version") + _version_checking.check_is_latest_version_in_background("1.2.8") + time.sleep(0.1) # Wait for thread to run. + spy_is_latest_version.assert_called_once_with(current_version="1.2.8") + + +def test__get_versioning_api() -> None: + assert isinstance(_version_checking._get_versioning_api(), VersioningApi) diff --git a/tests/api_workflow/mocked_api_workflow_client.py b/tests/api_workflow/mocked_api_workflow_client.py index f61d9d75e..700f1522a 100644 --- a/tests/api_workflow/mocked_api_workflow_client.py +++ b/tests/api_workflow/mocked_api_workflow_client.py @@ -36,7 +36,9 @@ DatasetData, DatasetEmbeddingData, DatasourceConfig, + DatasourceConfigAzure, DatasourceConfigBase, + DatasourceConfigLOCAL, DatasourceProcessedUntilTimestampRequest, DatasourceProcessedUntilTimestampResponse, DatasourceRawSamplesData, @@ -657,11 +659,14 @@ def __init__(self, api_client=None): self.reset() def reset(self): - local_datasource = DatasourceConfigBase( - type="LOCAL", full_path="", purpose="INPUT_OUTPUT" + local_datasource = DatasourceConfigLOCAL( + type="LOCAL", + full_path="", + web_server_location="https://localhost:1234", + purpose="INPUT_OUTPUT", ).to_dict() azure_datasource = DatasourceConfigBase( - type="AZURE", full_path="", purpose="INPUT_OUTPUT" + type="AZURE", purpose="INPUT_OUTPUT" ).to_dict() self._datasources = { @@ -982,14 +987,6 @@ def update_scheduled_docker_run_state_by_id( raise NotImplementedError() -class MockedVersioningApi(VersioningApi): - def get_latest_pip_version(self, **kwargs): - return "1.2.8" - - def get_minimum_compatible_pip_version(self, **kwargs): - return "1.2.1" - - class MockedQuotaApi(QuotaApi): def get_quota_maximum_dataset_size(self, **kwargs): return "60000" @@ -1032,7 +1029,6 @@ class MockedApiWorkflowClient(ApiWorkflowClient): n_embedding_rows_on_server = N_FILES_ON_SERVER def __init__(self, *args, **kwargs): - lightly.api.version_checking.VersioningApi = MockedVersioningApi ApiWorkflowClient.__init__(self, *args, **kwargs) self._selection_api = MockedSamplingsApi(api_client=self.api_client) diff --git a/tests/api_workflow/test_api_workflow_client.py b/tests/api_workflow/test_api_workflow_client.py index 48920077e..7ed256d19 100644 --- a/tests/api_workflow/test_api_workflow_client.py +++ b/tests/api_workflow/test_api_workflow_client.py @@ -85,11 +85,6 @@ def raise_connection_error(*args, **kwargs): def test_user_agent_header(mocker: MockerFixture) -> None: mocker.patch.object(lightly.api.api_workflow_client, "__version__", new="VERSION") - mocker.patch.object( - lightly.api.api_workflow_client.version_checking, - "is_compatible_version", - new=lambda _: True, - ) mocked_platform = mocker.patch.object( lightly.api.api_workflow_client, "platform", spec_set=platform ) diff --git a/tests/api_workflow/test_api_workflow_compute_worker.py b/tests/api_workflow/test_api_workflow_compute_worker.py index 09efdd8a2..b0eb80625 100644 --- a/tests/api_workflow/test_api_workflow_compute_worker.py +++ b/tests/api_workflow/test_api_workflow_compute_worker.py @@ -33,14 +33,14 @@ DockerWorkerConfigV3LightlyLoader, DockerWorkerState, DockerWorkerType, - SelectionConfig, - SelectionConfigEntry, - SelectionConfigEntryInput, - SelectionConfigEntryStrategy, + SelectionConfigV3, + SelectionConfigV3Entry, + SelectionConfigV3EntryInput, + SelectionConfigV3EntryStrategy, SelectionInputPredictionsName, SelectionInputType, SelectionStrategyThresholdOperation, - SelectionStrategyType, + SelectionStrategyTypeV3, TagData, ) from lightly.openapi_generated.swagger_client.rest import ApiException @@ -101,17 +101,17 @@ def test_create_compute_worker_config__selection_config_is_class(self) -> None: "batch_size": 64, }, }, - selection_config=SelectionConfig( + selection_config=SelectionConfigV3( n_samples=20, strategies=[ - SelectionConfigEntry( - input=SelectionConfigEntryInput( + SelectionConfigV3Entry( + input=SelectionConfigV3EntryInput( type=SelectionInputType.EMBEDDINGS, dataset_id=utils.generate_id(), tag_name="some-tag-name", ), - strategy=SelectionConfigEntryStrategy( - type=SelectionStrategyType.SIMILARITY, + strategy=SelectionConfigV3EntryStrategy( + type=SelectionStrategyTypeV3.SIMILARITY, ), ) ], @@ -203,44 +203,46 @@ def _check_if_openapi_generated_obj_is_valid(self, obj) -> Any: return obj_api def test_selection_config(self): - selection_config = SelectionConfig( + selection_config = SelectionConfigV3( n_samples=1, strategies=[ - SelectionConfigEntry( - input=SelectionConfigEntryInput(type=SelectionInputType.EMBEDDINGS), - strategy=SelectionConfigEntryStrategy( - type=SelectionStrategyType.DIVERSITY, + SelectionConfigV3Entry( + input=SelectionConfigV3EntryInput( + type=SelectionInputType.EMBEDDINGS + ), + strategy=SelectionConfigV3EntryStrategy( + type=SelectionStrategyTypeV3.DIVERSITY, stopping_condition_minimum_distance=-1, ), ), - SelectionConfigEntry( - input=SelectionConfigEntryInput( + SelectionConfigV3Entry( + input=SelectionConfigV3EntryInput( type=SelectionInputType.SCORES, task="my-classification-task", score="uncertainty_margin", ), - strategy=SelectionConfigEntryStrategy( - type=SelectionStrategyType.WEIGHTS + strategy=SelectionConfigV3EntryStrategy( + type=SelectionStrategyTypeV3.WEIGHTS ), ), - SelectionConfigEntry( - input=SelectionConfigEntryInput( + SelectionConfigV3Entry( + input=SelectionConfigV3EntryInput( type=SelectionInputType.METADATA, key="lightly.sharpness" ), - strategy=SelectionConfigEntryStrategy( - type=SelectionStrategyType.THRESHOLD, + strategy=SelectionConfigV3EntryStrategy( + type=SelectionStrategyTypeV3.THRESHOLD, threshold=20, operation=SelectionStrategyThresholdOperation.BIGGER_EQUAL, ), ), - SelectionConfigEntry( - input=SelectionConfigEntryInput( + SelectionConfigV3Entry( + input=SelectionConfigV3EntryInput( type=SelectionInputType.PREDICTIONS, task="my_object_detection_task", name=SelectionInputPredictionsName.CLASS_DISTRIBUTION, ), - strategy=SelectionConfigEntryStrategy( - type=SelectionStrategyType.BALANCE, + strategy=SelectionConfigV3EntryStrategy( + type=SelectionStrategyTypeV3.BALANCE, target={"Ambulance": 0.2, "Bus": 0.4}, ), ), diff --git a/tests/api_workflow/test_api_workflow_datasources.py b/tests/api_workflow/test_api_workflow_datasources.py index b83910975..cbaa46bb7 100644 --- a/tests/api_workflow/test_api_workflow_datasources.py +++ b/tests/api_workflow/test_api_workflow_datasources.py @@ -1,13 +1,15 @@ import pytest +import tqdm from pytest_mock import MockerFixture -from lightly.api import ApiWorkflowClient +from lightly.api import ApiWorkflowClient, api_workflow_datasources from lightly.openapi_generated.swagger_client.models import ( DatasourceConfigAzure, DatasourceConfigGCS, DatasourceConfigLOCAL, DatasourceConfigS3, DatasourceConfigS3DelegatedAccess, + DatasourcePurpose, DatasourceRawSamplesDataRow, ) from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data import ( @@ -16,327 +18,627 @@ from lightly.openapi_generated.swagger_client.models.datasource_config_verify_data_errors import ( DatasourceConfigVerifyDataErrors, ) +from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import ( + DatasourceProcessedUntilTimestampResponse, +) +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import ( + DatasourceRawSamplesData, +) -def test__download_raw_files(mocker: MockerFixture) -> None: - mock_response_1 = mocker.MagicMock() - mock_response_1.has_more = True - mock_response_1.data = [ - DatasourceRawSamplesDataRow(file_name="/file1", read_url="url1"), - DatasourceRawSamplesDataRow(file_name="file2", read_url="url2"), - ] - - mock_response_2 = mocker.MagicMock() - mock_response_2.has_more = False - mock_response_2.data = [ - DatasourceRawSamplesDataRow(file_name="./file3", read_url="url3"), - DatasourceRawSamplesDataRow(file_name="file2", read_url="url2"), - ] - - mocked_method = mocker.MagicMock(side_effect=[mock_response_1, mock_response_2]) - mocked_pbar = mocker.MagicMock() - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_warning = mocker.patch("warnings.warn") - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - result = client._download_raw_files( - download_function=mocked_method, - progress_bar=mocked_pbar, - ) - kwargs = mocked_method.call_args[1] - assert "relevant_filenames_file_name" not in kwargs - assert mocked_pbar.update.call_count == 2 - assert mocked_warning.call_count == 3 - warning_text = [str(call_args[0][0]) for call_args in mocked_warning.call_args_list] - assert warning_text == [ - ( - "Absolute file paths like /file1 are not supported" - " in relevant filenames file None due to blob storage" - ), - ( - "Using dot notation ('./', '../') like in ./file3 is not supported" - " in relevant filenames file None due to blob storage" - ), - ("Duplicate filename file2 in relevant filenames file None"), - ] - assert len(result) == 1 - assert result[0][0] == "file2" - - -def test_get_prediction_read_url(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - client._datasources_api = mocked_api - client.get_prediction_read_url("test.json") - mocked_method = ( - mocked_api.get_prediction_file_read_url_from_datasource_by_dataset_id - ) - mocked_method.assert_called_once_with( - dataset_id="dataset-id", file_name="test.json" - ) +class TestDatasourcesMixin: + def test_download_raw_samples(self, mocker: MockerFixture) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_from_datasource_by_dataset_id", + side_effect=[response], + ) + assert client.download_raw_samples() == [("file1", "url1"), ("file2", "url2")] + + def test_download_raw_predictions(self, mocker: MockerFixture) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", + side_effect=[response], + ) + assert client.download_raw_predictions(task_name="task") == [ + ("file1", "url1"), + ("file2", "url2"), + ] + + def test_download_raw_predictions_iter(self, mocker: MockerFixture) -> None: + response_1 = DatasourceRawSamplesData( + hasMore=True, + cursor="cursor1", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + response_2 = DatasourceRawSamplesData( + hasMore=False, + cursor="cursor2", + data=[ + DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"), + DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", + side_effect=[response_1, response_2], + ) + assert list(client.download_raw_predictions_iter(task_name="task")) == [ + ("file1", "url1"), + ("file2", "url2"), + ("file3", "url3"), + ("file4", "url4"), + ] + client._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id.assert_has_calls( + [ + mocker.call( + dataset_id="dataset-id", + task_name="task", + var_from=0, + to=mocker.ANY, + use_redirected_read_url=False, + ), + mocker.call( + dataset_id="dataset-id", + task_name="task", + cursor="cursor1", + use_redirected_read_url=False, + ), + ] + ) + def test_download_raw_predictions_iter__relevant_filenames_artifact_id( + self, + mocker: MockerFixture, + ) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_predictions_from_datasource_by_dataset_id", + side_effect=[response], + ) + assert list( + client.download_raw_predictions_iter( + task_name="task", + run_id="run-id", + relevant_filenames_artifact_id="relevant-filenames", + ) + ) == [ + ("file1", "url1"), + ("file2", "url2"), + ] + client._datasources_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id.assert_called_once_with( + dataset_id="dataset-id", + task_name="task", + var_from=0, + to=mocker.ANY, + relevant_filenames_run_id="run-id", + relevant_filenames_artifact_id="relevant-filenames", + use_redirected_read_url=False, + ) -def test_download_new_raw_samples(mocker: MockerFixture) -> None: - from_timestamp = 2 - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocker.patch.object( - ApiWorkflowClient, "get_processed_until_timestamp", return_value=from_timestamp - ) - current_time = 5 - mocker.patch("time.time", return_value=current_time) - mocked_download = mocker.patch.object(ApiWorkflowClient, "download_raw_samples") - mocked_update_timestamp = mocker.patch.object( - ApiWorkflowClient, "update_processed_until_timestamp" - ) - client = ApiWorkflowClient() - client.download_new_raw_samples() - mocked_download.assert_called_once_with( - from_=from_timestamp + 1, - to=current_time, - relevant_filenames_file_name=None, - use_redirected_read_url=False, - ) - mocked_update_timestamp.assert_called_once_with(timestamp=current_time) + # should raise ValueError when only run_id is given + with pytest.raises(ValueError): + next( + client.download_raw_predictions_iter(task_name="task", run_id="run-id") + ) + + # should raise ValueError when only relevant_filenames_artifact_id is given + with pytest.raises(ValueError): + next( + client.download_raw_predictions_iter( + task_name="task", + relevant_filenames_artifact_id="relevant-filenames", + ) + ) + + def test_download_raw_metadata(self, mocker: MockerFixture) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", + side_effect=[response], + ) + assert client.download_raw_metadata() == [ + ("file1", "url1"), + ("file2", "url2"), + ] + + def test_download_raw_metadata_iter(self, mocker: MockerFixture) -> None: + response_1 = DatasourceRawSamplesData( + hasMore=True, + cursor="cursor1", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + response_2 = DatasourceRawSamplesData( + hasMore=False, + cursor="cursor2", + data=[ + DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"), + DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", + side_effect=[response_1, response_2], + ) + assert list(client.download_raw_metadata_iter()) == [ + ("file1", "url1"), + ("file2", "url2"), + ("file3", "url3"), + ("file4", "url4"), + ] + client._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id.assert_has_calls( + [ + mocker.call( + dataset_id="dataset-id", + var_from=0, + to=mocker.ANY, + use_redirected_read_url=False, + ), + mocker.call( + dataset_id="dataset-id", + cursor="cursor1", + use_redirected_read_url=False, + ), + ] + ) + def test_download_raw_metadata_iter__relevant_filenames_artifact_id( + self, mocker: MockerFixture + ) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_list_of_raw_samples_metadata_from_datasource_by_dataset_id", + side_effect=[response], + ) + assert list( + client.download_raw_metadata_iter( + run_id="run-id", + relevant_filenames_artifact_id="relevant-filenames", + ) + ) == [ + ("file1", "url1"), + ("file2", "url2"), + ] + client._datasources_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id.assert_called_once_with( + dataset_id="dataset-id", + var_from=0, + to=mocker.ANY, + relevant_filenames_run_id="run-id", + relevant_filenames_artifact_id="relevant-filenames", + use_redirected_read_url=False, + ) -def test_download_new_raw_samples__from_beginning(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocker.patch.object( - ApiWorkflowClient, "get_processed_until_timestamp", return_value=0 - ) - current_time = 5 - mocker.patch("time.time", return_value=current_time) - mocked_download = mocker.patch.object(ApiWorkflowClient, "download_raw_samples") - mocked_update_timestamp = mocker.patch.object( - ApiWorkflowClient, "update_processed_until_timestamp" - ) - client = ApiWorkflowClient() - client.download_new_raw_samples() - mocked_download.assert_called_once_with( - from_=0, - to=current_time, - relevant_filenames_file_name=None, - use_redirected_read_url=False, - ) - mocked_update_timestamp.assert_called_once_with(timestamp=current_time) - - -def test_download_raw_samples_predictions__relevant_filenames_artifact_id( - mocker: MockerFixture, -) -> None: - mock_response = mocker.MagicMock() - mock_response.has_more = False - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_api = mocker.MagicMock() - mocked_method = mocker.MagicMock(return_value=mock_response) - mocked_api.get_list_of_raw_samples_predictions_from_datasource_by_dataset_id = ( - mocked_method - ) - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - client._datasources_api = mocked_api - client.download_raw_predictions( - task_name="task", run_id="foo", relevant_filenames_artifact_id="bar" - ) - kwargs = mocked_method.call_args[1] - assert kwargs.get("relevant_filenames_run_id") == "foo" - assert kwargs.get("relevant_filenames_artifact_id") == "bar" - - # should raise ValueError when only run_id is given - with pytest.raises(ValueError): - client.download_raw_predictions(task_name="foobar", run_id="foo") - # should raise ValueError when only relevant_filenames_artifact_id is given - with pytest.raises(ValueError): - client.download_raw_predictions( - task_name="foobar", relevant_filenames_artifact_id="bar" - ) - - -def test_download_raw_samples_metadata__relevant_filenames_artifact_id( - mocker: MockerFixture, -) -> None: - mock_response = mocker.MagicMock() - mock_response.has_more = False - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_api = mocker.MagicMock() - mocked_method = mocker.MagicMock(return_value=mock_response) - mocked_api.get_list_of_raw_samples_metadata_from_datasource_by_dataset_id = ( - mocked_method - ) - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - client._datasources_api = mocked_api - client.download_raw_metadata(run_id="foo", relevant_filenames_artifact_id="bar") - kwargs = mocked_method.call_args[1] - assert kwargs.get("relevant_filenames_run_id") == "foo" - assert kwargs.get("relevant_filenames_artifact_id") == "bar" - - # should raise ValueError when only run_id is given - with pytest.raises(ValueError): - client.download_raw_metadata(run_id="foo") - # should raise ValueError when only relevant_filenames_artifact_id is given - with pytest.raises(ValueError): - client.download_raw_metadata(relevant_filenames_artifact_id="bar") - - -def test_get_processed_until_timestamp(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - client._datasources_api = mocked_datasources_api - client.get_processed_until_timestamp() - mocked_method = ( - mocked_datasources_api.get_datasource_processed_until_timestamp_by_dataset_id - ) - mocked_method.assert_called_once_with(dataset_id="dataset-id") - - -def test_set_azure_config(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._datasources_api = mocked_datasources_api - client._dataset_id = "dataset-id" - client.set_azure_config( - container_name="my-container/name", - account_name="my-account-name", - sas_token="my-sas-token", - thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", - ) - kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1] - assert isinstance( - kwargs["datasource_config"].actual_instance, DatasourceConfigAzure - ) + # should raise ValueError when only run_id is given + with pytest.raises(ValueError): + next(client.download_raw_metadata_iter(run_id="run-id")) + + # should raise ValueError when only relevant_filenames_artifact_id is given + with pytest.raises(ValueError): + next( + client.download_raw_metadata_iter( + relevant_filenames_artifact_id="relevant-filenames", + ) + ) + + def test_download_new_raw_samples(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + client.get_processed_until_timestamp = mocker.MagicMock(return_value=2) + mocker.patch("time.time", return_value=5) + mocker.patch.object(client, "download_raw_samples") + mocker.patch.object(client, "update_processed_until_timestamp") + client.download_new_raw_samples() + client.download_raw_samples.assert_called_once_with( + from_=2 + 1, + to=5, + relevant_filenames_file_name=None, + use_redirected_read_url=False, + ) + client.update_processed_until_timestamp.assert_called_once_with(timestamp=5) + + def test_download_new_raw_samples__from_beginning( + self, mocker: MockerFixture + ) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + client.get_processed_until_timestamp = mocker.MagicMock(return_value=2) + mocker.patch("time.time", return_value=5) + mocker.patch.object(client, "download_raw_samples") + mocker.patch.object(client, "update_processed_until_timestamp") + client.download_new_raw_samples() + client.download_raw_samples.assert_called_once_with( + from_=3, + to=5, + relevant_filenames_file_name=None, + use_redirected_read_url=False, + ) + client.update_processed_until_timestamp.assert_called_once_with(timestamp=5) + + def test_get_processed_until_timestamp(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_datasource_processed_until_timestamp_by_dataset_id", + return_value=DatasourceProcessedUntilTimestampResponse( + processedUntilTimestamp=5 + ), + ) + assert client.get_processed_until_timestamp() == 5 + client._datasources_api.get_datasource_processed_until_timestamp_by_dataset_id.assert_called_once_with( + dataset_id="dataset-id" + ) + def test_update_processed_until_timestamp(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_processed_until_timestamp_by_dataset_id", + ) + client.update_processed_until_timestamp(timestamp=10) + kwargs = client._datasources_api.update_datasource_processed_until_timestamp_by_dataset_id.call_args[ + 1 + ] + assert kwargs["dataset_id"] == "dataset-id" + assert ( + kwargs[ + "datasource_processed_until_timestamp_request" + ].processed_until_timestamp + == 10 + ) -def test_set_gcs_config(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._datasources_api = mocked_datasources_api - client._dataset_id = "dataset-id" - client.set_gcs_config( - resource_path="gs://my-bucket/my-dataset", - project_id="my-project-id", - credentials="my-credentials", - thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", - ) - kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1] - assert isinstance(kwargs["datasource_config"].actual_instance, DatasourceConfigGCS) - - -def test_set_local_config(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._datasources_api = mocked_datasources_api - client._dataset_id = "dataset-id" - client.set_local_config( - resource_path="http://localhost:1234/path/to/my/data", - thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", - ) - kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1] - assert isinstance( - kwargs["datasource_config"].actual_instance, DatasourceConfigLOCAL - ) + def test_set_azure_config(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_by_dataset_id", + ) + client.set_azure_config( + container_name="my-container/name", + account_name="my-account-name", + sas_token="my-sas-token", + thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", + ) + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + assert isinstance( + kwargs["datasource_config"].actual_instance, DatasourceConfigAzure + ) + def test_set_gcs_config(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_by_dataset_id", + ) + client.set_gcs_config( + resource_path="gs://my-bucket/my-dataset", + project_id="my-project-id", + credentials="my-credentials", + thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", + ) + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + assert isinstance( + kwargs["datasource_config"].actual_instance, DatasourceConfigGCS + ) -def test_set_s3_config(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._datasources_api = mocked_datasources_api - client._dataset_id = "dataset-id" - client.set_s3_config( - resource_path="s3://my-bucket/my-dataset", - thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", - region="eu-central-1", - access_key="my-access-key", - secret_access_key="my-secret-access-key", - ) - kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1] - assert isinstance(kwargs["datasource_config"].actual_instance, DatasourceConfigS3) - - -def test_set_s3_delegated_access_config(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._datasources_api = mocked_datasources_api - client._dataset_id = "dataset-id" - client.set_s3_delegated_access_config( - resource_path="s3://my-bucket/my-dataset", - thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", - region="eu-central-1", - role_arn="arn:aws:iam::000000000000:role.test", - external_id="my-external-id", - ) - kwargs = mocked_datasources_api.update_datasource_by_dataset_id.call_args[1] - assert isinstance( - kwargs["datasource_config"].actual_instance, DatasourceConfigS3DelegatedAccess - ) + def test_set_local_config(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_by_dataset_id", + ) + client.set_local_config( + web_server_location="http://localhost:1234", + relative_path="path/to/my/data", + thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", + purpose=DatasourcePurpose.INPUT, + ) + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + datasource_config = kwargs["datasource_config"].actual_instance + assert isinstance(datasource_config, DatasourceConfigLOCAL) + assert datasource_config.type == "LOCAL" + assert datasource_config.web_server_location == "http://localhost:1234" + assert datasource_config.full_path == "path/to/my/data" + assert ( + datasource_config.thumb_suffix + == ".lightly/thumbnails/[filename]-thumb-[extension]" + ) + assert datasource_config.purpose == DatasourcePurpose.INPUT + + # Test defaults + client.set_local_config() + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + datasource_config = kwargs["datasource_config"].actual_instance + assert isinstance(datasource_config, DatasourceConfigLOCAL) + assert datasource_config.type == "LOCAL" + assert datasource_config.web_server_location == "http://localhost:3456" + assert datasource_config.full_path == "" + assert ( + datasource_config.thumb_suffix + == ".lightly/thumbnails/[filename]_thumb.[extension]" + ) + assert datasource_config.purpose == DatasourcePurpose.INPUT_OUTPUT + def test_set_s3_config(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_by_dataset_id", + ) + client.set_s3_config( + resource_path="s3://my-bucket/my-dataset", + thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", + region="eu-central-1", + access_key="my-access-key", + secret_access_key="my-secret-access-key", + ) + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + assert isinstance( + kwargs["datasource_config"].actual_instance, DatasourceConfigS3 + ) -def test_update_processed_until_timestamp(mocker: MockerFixture) -> None: - mocker.patch.object(ApiWorkflowClient, "__init__", return_value=None) - mocked_datasources_api = mocker.MagicMock() - client = ApiWorkflowClient() - client._dataset_id = "dataset-id" - client._datasources_api = mocked_datasources_api - client.update_processed_until_timestamp(10) - kwargs = mocked_datasources_api.update_datasource_processed_until_timestamp_by_dataset_id.call_args[ - 1 - ] - assert kwargs["dataset_id"] == "dataset-id" - assert ( - kwargs["datasource_processed_until_timestamp_request"].processed_until_timestamp - == 10 - ) + def test_set_s3_delegated_access_config(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "update_datasource_by_dataset_id", + ) + client.set_s3_delegated_access_config( + resource_path="s3://my-bucket/my-dataset", + thumbnail_suffix=".lightly/thumbnails/[filename]-thumb-[extension]", + region="eu-central-1", + role_arn="arn:aws:iam::000000000000:role.test", + external_id="my-external-id", + ) + kwargs = client._datasources_api.update_datasource_by_dataset_id.call_args[1] + assert isinstance( + kwargs["datasource_config"].actual_instance, + DatasourceConfigS3DelegatedAccess, + ) + def test_get_prediction_read_url(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_prediction_file_read_url_from_datasource_by_dataset_id", + return_value="read-url", + ) + assert client.get_prediction_read_url(filename="test.json") == "read-url" + client._datasources_api.get_prediction_file_read_url_from_datasource_by_dataset_id.assert_called_once_with( + dataset_id="dataset-id", file_name="test.json" + ) -def test_list_datasource_permissions(mocker: MockerFixture) -> None: - client = ApiWorkflowClient(token="abc") - client._dataset_id = "dataset-id" - client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( - return_value=DatasourceConfigVerifyData( - canRead=True, - canWrite=True, - canList=False, - canOverwrite=True, - errors=None, - ), - ) - assert client.list_datasource_permissions() == { - "can_read": True, - "can_write": True, - "can_list": False, - "can_overwrite": True, - } - - -def test_list_datasource_permissions__error(mocker: MockerFixture) -> None: - client = ApiWorkflowClient(token="abc") - client._dataset_id = "dataset-id" - client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( - return_value=DatasourceConfigVerifyData( - canRead=True, - canWrite=True, - canList=False, - canOverwrite=True, - errors=DatasourceConfigVerifyDataErrors( - canRead=None, canWrite=None, canList="error message", canOverwrite=None + def test_get_custom_embedding_read_url(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + mocker.patch.object( + client._datasources_api, + "get_custom_embedding_file_read_url_from_datasource_by_dataset_id", + return_value="read-url", + ) + assert ( + client.get_custom_embedding_read_url(filename="embeddings.csv") + == "read-url" + ) + client._datasources_api.get_custom_embedding_file_read_url_from_datasource_by_dataset_id.assert_called_once_with( + dataset_id="dataset-id", file_name="embeddings.csv" + ) + + def test_list_datasource_permissions(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( + return_value=DatasourceConfigVerifyData( + canRead=True, + canWrite=True, + canList=False, + canOverwrite=True, + errors=None, + ), + ) + assert client.list_datasource_permissions() == { + "can_read": True, + "can_write": True, + "can_list": False, + "can_overwrite": True, + } + + def test_list_datasource_permissions__error(self, mocker: MockerFixture) -> None: + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + client._datasources_api.verify_datasource_by_dataset_id = mocker.MagicMock( + return_value=DatasourceConfigVerifyData( + canRead=True, + canWrite=True, + canList=False, + canOverwrite=True, + errors=DatasourceConfigVerifyDataErrors( + canRead=None, + canWrite=None, + canList="error message", + canOverwrite=None, + ), ), - ), + ) + assert client.list_datasource_permissions() == { + "can_read": True, + "can_write": True, + "can_list": False, + "can_overwrite": True, + "errors": { + "can_list": "error message", + }, + } + + def test__download_raw_files(self, mocker: MockerFixture) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + download_function = mocker.MagicMock(side_effect=[response]) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + assert client._download_raw_files( + download_function=download_function, + ) == [("file1", "url1"), ("file2", "url2")] + + def test__download_raw_files_iter(self, mocker: MockerFixture) -> None: + response_1 = DatasourceRawSamplesData( + hasMore=True, + cursor="cursor1", + data=[ + DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + DatasourceRawSamplesDataRow(fileName="file2", readUrl="url2"), + ], + ) + response_2 = DatasourceRawSamplesData( + hasMore=False, + cursor="cursor2", + data=[ + DatasourceRawSamplesDataRow(fileName="file3", readUrl="url3"), + DatasourceRawSamplesDataRow(fileName="file4", readUrl="url4"), + ], + ) + download_function = mocker.MagicMock(side_effect=[response_1, response_2]) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + progress_bar = mocker.spy(tqdm, "tqdm") + assert list( + client._download_raw_files_iter( + download_function=download_function, + from_=0, + to=5, + relevant_filenames_file_name="relevant-filenames", + use_redirected_read_url=True, + progress_bar=progress_bar, + foo="bar", + ) + ) == [ + ("file1", "url1"), + ("file2", "url2"), + ("file3", "url3"), + ("file4", "url4"), + ] + download_function.assert_has_calls( + [ + mocker.call( + dataset_id="dataset-id", + var_from=0, + to=5, + relevant_filenames_file_name="relevant-filenames", + use_redirected_read_url=True, + foo="bar", + ), + mocker.call( + dataset_id="dataset-id", + cursor="cursor1", + relevant_filenames_file_name="relevant-filenames", + use_redirected_read_url=True, + foo="bar", + ), + ] + ) + assert progress_bar.update.call_count == 4 + + def test__download_raw_files_iter__no_relevant_filenames( + self, mocker: MockerFixture + ) -> None: + response = DatasourceRawSamplesData(hasMore=False, cursor="", data=[]) + download_function = mocker.MagicMock(side_effect=[response]) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + list(client._download_raw_files_iter(download_function=download_function)) + assert "relevant_filenames_file_name" not in download_function.call_args[1] + + def test__download_raw_files_iter__warning(self, mocker: MockerFixture) -> None: + response = DatasourceRawSamplesData( + hasMore=False, + cursor="", + data=[ + DatasourceRawSamplesDataRow(fileName="/file1", readUrl="url1"), + ], + ) + download_function = mocker.MagicMock(side_effect=[response]) + client = ApiWorkflowClient(token="abc", dataset_id="dataset-id") + with pytest.warns(UserWarning, match="Absolute file paths like /file1"): + list(client._download_raw_files_iter(download_function=download_function)) + + +def test__sample_unseen_and_valid() -> None: + with pytest.warns(UserWarning, match="Absolute file paths like /file1"): + assert not api_workflow_datasources._sample_unseen_and_valid( + sample=DatasourceRawSamplesDataRow(fileName="/file1", readUrl="url1"), + relevant_filenames_file_name=None, + listed_filenames=set(), + ) + + with pytest.warns(UserWarning, match="Using dot notation"): + assert not api_workflow_datasources._sample_unseen_and_valid( + sample=DatasourceRawSamplesDataRow(fileName="./file1", readUrl="url1"), + relevant_filenames_file_name=None, + listed_filenames=set(), + ) + + with pytest.warns(UserWarning, match="Duplicate filename file1"): + assert not api_workflow_datasources._sample_unseen_and_valid( + sample=DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + relevant_filenames_file_name=None, + listed_filenames={"file1"}, + ) + + assert api_workflow_datasources._sample_unseen_and_valid( + sample=DatasourceRawSamplesDataRow(fileName="file1", readUrl="url1"), + relevant_filenames_file_name=None, + listed_filenames=set(), ) - assert client.list_datasource_permissions() == { - "can_read": True, - "can_write": True, - "can_list": False, - "can_overwrite": True, - "errors": { - "can_list": "error message", - }, - } diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py index 4444d11d7..a3bb54c2b 100644 --- a/tests/api_workflow/test_api_workflow_upload_embeddings.py +++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py @@ -4,11 +4,7 @@ import numpy as np from lightly.utils import io as io_utils -from lightly.utils.io import INVALID_FILENAME_CHARACTERS -from tests.api_workflow.mocked_api_workflow_client import ( - N_FILES_ON_SERVER, - MockedApiWorkflowSetup, -) +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup class TestApiWorkflowUploadEmbeddings(MockedApiWorkflowSetup): @@ -80,15 +76,6 @@ def test_upload_wrong_filenames(self): with self.assertRaises(ValueError): self.t_ester_upload_embedding(n_data=n_data, special_name_first_sample=True) - def test_upload_comma_filenames(self): - n_data = len(self.api_workflow_client._mappings_api.sample_names) - for invalid_char in INVALID_FILENAME_CHARACTERS: - with self.subTest(msg=f"invalid_char: {invalid_char}"): - with self.assertRaises(ValueError): - self.t_ester_upload_embedding( - n_data=n_data, special_char_in_first_filename=invalid_char - ) - def test_set_embedding_id_default(self): self.api_workflow_client.set_embedding_id_to_latest() embeddings = ( diff --git a/tests/conftest.py b/tests/conftest.py index 22e256d29..7133ba908 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="module", autouse=True) def mock_versioning_api(): """Fixture that is applied to all tests and mocks the versioning API. @@ -56,17 +56,17 @@ def mock_versioning_api(): should be compatible with all future versions. """ - def mock_get_latest_pip_version(current_version: str) -> str: + def mock_get_latest_pip_version(current_version: str, **kwargs) -> str: return current_version # NOTE(guarin, 2/6/23): Cannot use pytest mocker fixture here because it has not - # a "session" scope and it is not possible to use a fixture that has a tigher scope + # a "module" scope and it is not possible to use a fixture that has a tighter scope # inside a fixture with a wider scope. with mock.patch( - "lightly.api.version_checking.VersioningApi.get_latest_pip_version", + "lightly.api._version_checking.VersioningApi.get_latest_pip_version", new=mock_get_latest_pip_version, ), mock.patch( - "lightly.api.version_checking.VersioningApi.get_minimum_compatible_pip_version", + "lightly.api._version_checking.VersioningApi.get_minimum_compatible_pip_version", return_value="1.0.0", ): yield diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py index c5a1a7c6c..0b4c7b196 100644 --- a/tests/data/test_LightlyDataset.py +++ b/tests/data/test_LightlyDataset.py @@ -1,23 +1,18 @@ import os -import random -import re import shutil import tempfile import unittest -import warnings from typing import List, Tuple import numpy as np -import torch import torchvision from PIL.Image import Image from lightly.data import LightlyDataset from lightly.data._utils import check_images -from lightly.utils.io import INVALID_FILENAME_CHARACTERS try: - import av + import av as _ import cv2 from lightly.data._video import VideoDataset @@ -137,24 +132,6 @@ def test_create_lightly_dataset_from_folder_nosubdir(self): for i in range(n_tot): sample, target, fname = dataset[i] - def test_create_lightly_dataset_with_invalid_char_in_filename(self): - # create a dataset - n_tot = 100 - dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32)) - - for invalid_char in INVALID_FILENAME_CHARACTERS: - with self.subTest(msg=f"invalid_char: {invalid_char}"): - tmp_dir = tempfile.mkdtemp() - sample_names = [f"img_,_{i}.jpg" for i in range(n_tot)] - for sample_idx in range(n_tot): - data = dataset[sample_idx] - path = os.path.join(tmp_dir, sample_names[sample_idx]) - data[0].save(path) - - # create lightly dataset - with self.assertRaises(ValueError): - dataset = LightlyDataset(input_dir=tmp_dir) - def test_check_images(self): # create a dataset tmp_dir = tempfile.mkdtemp() diff --git a/tests/embedding/test_embedding.py b/tests/embedding/test_embedding.py index c6c6d1baa..d6f4bcb8e 100644 --- a/tests/embedding/test_embedding.py +++ b/tests/embedding/test_embedding.py @@ -66,8 +66,8 @@ def test_embed_correct_order(self): device=device, ) - np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, rtol=5e-5) - np.testing.assert_allclose(labels_1_worker, labels_4_worker, rtol=1e-5) + np.testing.assert_allclose(embeddings_1_worker, embeddings_4_worker, atol=5e-4) + np.testing.assert_allclose(labels_1_worker, labels_4_worker, atol=1e-5) self.assertListEqual(filenames_1_worker, filenames_4_worker) self.assertListEqual(filenames_1_worker, dataset.get_filenames()) diff --git a/tests/loss/test_PMSNLoss.py b/tests/loss/test_PMSNLoss.py index 4919e965b..1ddc8b84e 100644 --- a/tests/loss/test_PMSNLoss.py +++ b/tests/loss/test_PMSNLoss.py @@ -12,7 +12,7 @@ class TestPMSNLoss: def test_regularization_loss(self) -> None: criterion = PMSNLoss() - mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]).log() + mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]) loss = criterion.regularization_loss(mean_anchor_probs=mean_anchor_probs) norm = 1 / (1**0.25) + 1 / (2**0.25) + 1 / (3**0.25) t0 = 1 / (1**0.25) / norm @@ -45,7 +45,7 @@ def test_forward_cuda(self) -> None: class TestPMSNCustomLoss: def test_regularization_loss(self) -> None: criterion = PMSNCustomLoss(target_distribution=_uniform_distribution) - mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]).log() + mean_anchor_probs = torch.Tensor([0.1, 0.3, 0.6]) loss = criterion.regularization_loss(mean_anchor_probs=mean_anchor_probs) expected_loss = ( 1 diff --git a/tests/loss/test_barlow_twins_loss.py b/tests/loss/test_barlow_twins_loss.py index f947ee784..c2853042b 100644 --- a/tests/loss/test_barlow_twins_loss.py +++ b/tests/loss/test_barlow_twins_loss.py @@ -1,10 +1,51 @@ import pytest +import torch +import torch.nn as nn from pytest_mock import MockerFixture from torch import distributed as dist from lightly.loss.barlow_twins_loss import BarlowTwinsLoss +class BarlowTwinsLossReference(torch.nn.Module): + def __init__( + self, + projector_dim: int = 8196, + lambda_param: float = 5e-3, + gather_distributed: bool = False, + ): + super(BarlowTwinsLossReference, self).__init__() + # normalization layer for the representations z1 and z2 + self.bn = nn.BatchNorm1d(projector_dim, affine=False) + self.lambda_param = lambda_param + self.gather_distributed = gather_distributed + + def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: + # code from https://github.com/facebookresearch/barlowtwins/blob/main/main.py + + N = z_a.size(0) + + # empirical cross-correlation matrix + c = self.bn(z_a).T @ self.bn(z_b) + + # sum the cross-correlation matrix between all gpus + c.div_(N) + if self.gather_distributed: + torch.distributed.all_reduce(c) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + loss = on_diag + self.lambda_param * off_diag + return loss + + +def off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + class TestBarlowTwinsLoss: def test__gather_distributed(self, mocker: MockerFixture) -> None: mock_is_available = mocker.patch.object(dist, "is_available", return_value=True) @@ -20,3 +61,25 @@ def test__gather_distributed_dist_not_available( with pytest.raises(ValueError): BarlowTwinsLoss(gather_distributed=True) mock_is_available.assert_called_once() + + def test__loss_matches_reference_loss(self) -> None: + batch_size = 32 + projector_dim = 8196 + lambda_param = 5e-3 + gather_distributed = False + loss = BarlowTwinsLoss( + lambda_param=lambda_param, gather_distributed=gather_distributed + ) + loss_ref = BarlowTwinsLossReference( + projector_dim=projector_dim, + lambda_param=lambda_param, + gather_distributed=gather_distributed, + ) + + z_a = torch.randn(batch_size, projector_dim) + z_b = torch.randn(batch_size, projector_dim) + + loss_out = loss(z_a, z_b) + loss_ref_out = loss_ref(z_a, z_b) + + assert torch.allclose(loss_out, loss_ref_out, rtol=1e-3, atol=1e-3) diff --git a/tests/transforms/test_Solarize.py b/tests/transforms/test_Solarize.py index 4a1cda349..1b595caf3 100644 --- a/tests/transforms/test_Solarize.py +++ b/tests/transforms/test_Solarize.py @@ -6,7 +6,7 @@ class TestRandomSolarization(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: for w in [32, 64, 128]: for h in [32, 64, 128]: solarization = RandomSolarization(0.5) diff --git a/tests/transforms/test_byol_transform.py b/tests/transforms/test_byol_transform.py new file mode 100644 index 000000000..ecf72df48 --- /dev/null +++ b/tests/transforms/test_byol_transform.py @@ -0,0 +1,26 @@ +from PIL import Image + +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) + + +def test_view_on_pil_image() -> None: + single_view_transform = BYOLView1Transform(input_size=32) + sample = Image.new("RGB", (100, 100)) + output = single_view_transform(sample) + assert output.shape == (3, 32, 32) + + +def test_multi_view_on_pil_image() -> None: + multi_view_transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32), + view_2_transform=BYOLView2Transform(input_size=32), + ) + sample = Image.new("RGB", (100, 100)) + output = multi_view_transform(sample) + assert len(output) == 2 + assert output[0].shape == (3, 32, 32) + assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_dino_transform.py b/tests/transforms/test_dino_transform.py index 4ca06c721..74bfea478 100644 --- a/tests/transforms/test_dino_transform.py +++ b/tests/transforms/test_dino_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = DINOViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_fastsiam_transform.py b/tests/transforms/test_fastsiam_transform.py index a5f60dfdf..672cf0a41 100644 --- a/tests/transforms/test_fastsiam_transform.py +++ b/tests/transforms/test_fastsiam_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.fast_siam_transform import FastSiamTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = FastSiamTransform(num_views=3, input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_GaussianBlur.py b/tests/transforms/test_gaussian_blur.py similarity index 69% rename from tests/transforms/test_GaussianBlur.py rename to tests/transforms/test_gaussian_blur.py index a778f15e4..fae09c674 100644 --- a/tests/transforms/test_GaussianBlur.py +++ b/tests/transforms/test_gaussian_blur.py @@ -2,21 +2,21 @@ from PIL import Image -from lightly.transforms import GaussianBlur +from lightly.transforms.gaussian_blur import GaussianBlur class TestGaussianBlur(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: for w in range(1, 100): for h in range(1, 100): gaussian_blur = GaussianBlur() sample = Image.new("RGB", (w, h)) gaussian_blur(sample) - def test_raise_kernel_size_deprecation(self): + def test_raise_kernel_size_deprecation(self) -> None: gaussian_blur = GaussianBlur(kernel_size=2) self.assertWarns(DeprecationWarning) - def test_raise_scale_deprecation(self): + def test_raise_scale_deprecation(self) -> None: gaussian_blur = GaussianBlur(scale=0.1) self.assertWarns(DeprecationWarning) diff --git a/tests/transforms/test_Jigsaw.py b/tests/transforms/test_jigsaw.py similarity index 66% rename from tests/transforms/test_Jigsaw.py rename to tests/transforms/test_jigsaw.py index 964e738c5..6482ee109 100644 --- a/tests/transforms/test_Jigsaw.py +++ b/tests/transforms/test_jigsaw.py @@ -2,11 +2,11 @@ from PIL import Image -from lightly.transforms import Jigsaw +from lightly.transforms.jigsaw import Jigsaw class TestJigsaw(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: crop = Jigsaw() sample = Image.new("RGB", (255, 255)) crop(sample) diff --git a/tests/transforms/test_location_to_NxN_grid.py b/tests/transforms/test_location_to_NxN_grid.py index 2ec1beb5b..013d4ab8f 100644 --- a/tests/transforms/test_location_to_NxN_grid.py +++ b/tests/transforms/test_location_to_NxN_grid.py @@ -3,7 +3,7 @@ import lightly.transforms.random_crop_and_flip_with_grid as test_module -def test_location_to_NxN_grid(): +def test_location_to_NxN_grid() -> None: # create a test instance of the Location class test_location = test_module.Location( left=10, diff --git a/tests/transforms/test_mae_transform.py b/tests/transforms/test_mae_transform.py index aafa11cdf..6f9b928c1 100644 --- a/tests/transforms/test_mae_transform.py +++ b/tests/transforms/test_mae_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.mae_transform import MAETransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MAETransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_moco_transform.py b/tests/transforms/test_moco_transform.py index eef1c651f..aa43a216f 100644 --- a/tests/transforms/test_moco_transform.py +++ b/tests/transforms/test_moco_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform -def test_moco_v1_multi_view_on_pil_image(): +def test_moco_v1_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV1Transform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) @@ -12,7 +12,7 @@ def test_moco_v1_multi_view_on_pil_image(): assert output[1].shape == (3, 32, 32) -def test_moco_v2_multi_view_on_pil_image(): +def test_moco_v2_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV2Transform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_msn_transform.py b/tests/transforms/test_msn_transform.py index fd2030bab..4f5be7b53 100644 --- a/tests/transforms/test_msn_transform.py +++ b/tests/transforms/test_msn_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = MSNViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MSNTransform(random_size=32, focal_size=8) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_multi_view_transform.py b/tests/transforms/test_multi_view_transform.py index dddd4db3d..27806b92a 100644 --- a/tests/transforms/test_multi_view_transform.py +++ b/tests/transforms/test_multi_view_transform.py @@ -6,7 +6,7 @@ from lightly.transforms.multi_view_transform import MultiViewTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MultiViewTransform( [ T.RandomHorizontalFlip(p=0.1), diff --git a/tests/transforms/test_pirl_transform.py b/tests/transforms/test_pirl_transform.py index 5042a1e8f..20c7c8705 100644 --- a/tests/transforms/test_pirl_transform.py +++ b/tests/transforms/test_pirl_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.pirl_transform import PIRLTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = PIRLTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_rotation.py b/tests/transforms/test_rotation.py index 448858fdd..065b8dd1d 100644 --- a/tests/transforms/test_rotation.py +++ b/tests/transforms/test_rotation.py @@ -1,3 +1,5 @@ +from typing import List, Tuple, Union + from PIL import Image from lightly.transforms.rotation import ( @@ -7,20 +9,21 @@ ) -def test_RandomRotate_on_pil_image(): +def test_RandomRotate_on_pil_image() -> None: random_rotate = RandomRotate() sample = Image.new("RGB", (100, 100)) random_rotate(sample) -def test_RandomRotateDegrees_on_pil_image(): - for degrees in [0, 1, 45, (0, 0), (-15, 30)]: +def test_RandomRotateDegrees_on_pil_image() -> None: + all_degrees: List[Union[float, Tuple[float, float]]] = [0, 1, 45, (0, 0), (-15, 30)] + for degrees in all_degrees: random_rotate = RandomRotateDegrees(prob=0.5, degrees=degrees) sample = Image.new("RGB", (100, 100)) random_rotate(sample) -def test_random_rotation_transform(): +def test_random_rotation_transform() -> None: transform = random_rotation_transform(rr_prob=1.0, rr_degrees=None) assert isinstance(transform, RandomRotate) transform = random_rotation_transform(rr_prob=1.0, rr_degrees=45) diff --git a/tests/transforms/test_simclr_transform.py b/tests/transforms/test_simclr_transform.py index 70fff7ab4..78a9a5cca 100644 --- a/tests/transforms/test_simclr_transform.py +++ b/tests/transforms/test_simclr_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SimCLRViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SimCLRTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_simsiam_transform.py b/tests/transforms/test_simsiam_transform.py index 2444924ec..39a88721a 100644 --- a/tests/transforms/test_simsiam_transform.py +++ b/tests/transforms/test_simsiam_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SimSiamViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SimSiamTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_smog_transform.py b/tests/transforms/test_smog_transform.py index 49fed878b..042d46f9f 100644 --- a/tests/transforms/test_smog_transform.py +++ b/tests/transforms/test_smog_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SmoGViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SMoGTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_swav_transform.py b/tests/transforms/test_swav_transform.py index 05475ce6a..7c2cdd2c0 100644 --- a/tests/transforms/test_swav_transform.py +++ b/tests/transforms/test_swav_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SwaVViewTransform() sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 100, 100) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SwaVTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_vicreg_transform.py b/tests/transforms/test_vicreg_transform.py index 5a2b0633d..06e710f25 100644 --- a/tests/transforms/test_vicreg_transform.py +++ b/tests/transforms/test_vicreg_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = VICRegViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = VICRegTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_vicregl_transform.py b/tests/transforms/test_vicregl_transform.py index f4696d067..e697807c4 100644 --- a/tests/transforms/test_vicregl_transform.py +++ b/tests/transforms/test_vicregl_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = VICRegLViewTransform() sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 100, 100) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = VICRegLTransform( global_crop_size=32, local_crop_size=8, diff --git a/tests/utils/test_dist.py b/tests/utils/test_dist.py index 79eb16013..8d3a564a9 100644 --- a/tests/utils/test_dist.py +++ b/tests/utils/test_dist.py @@ -2,6 +2,7 @@ from unittest import mock import torch +from pytest import CaptureFixture from lightly.utils import dist @@ -31,3 +32,25 @@ def test_eye_rank_dist(self): expected.append(zeros) expected = torch.cat(expected, dim=1) self.assertTrue(torch.all(dist.eye_rank(n) == expected)) + + +def test_rank_zero_only__rank_0() -> None: + @dist.rank_zero_only + def fn(): + return 0 + + assert fn() == 0 + + +def test_rank_zero_only__rank_1() -> None: + @dist.rank_zero_only + def fn(): + return 0 + + with mock.patch.object(dist, "rank", lambda: 1): + assert fn() is None + + +def test_print_rank_zero(capsys: CaptureFixture[str]) -> None: + dist.print_rank_zero("message") + assert capsys.readouterr().out == "message\n" diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 988b622f1..04086be5f 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,62 +1,35 @@ import csv import json -import sys import tempfile import unittest +from pathlib import Path import numpy as np -from lightly.utils.io import ( - check_embeddings, - check_filenames, - save_custom_metadata, - save_embeddings, - save_schema, - save_tasks, -) -from tests.api_workflow.mocked_api_workflow_client import ( - MockedApiWorkflowClient, - MockedApiWorkflowSetup, -) - - -class TestCLICrop(MockedApiWorkflowSetup): - def test_save_metadata(self): +from lightly.utils import io +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup + + +class TestCLICrop(MockedApiWorkflowSetup): # type: ignore[misc] + def test_save_metadata(self) -> None: metadata = [("filename.jpg", {"random_metadata": 42})] metadata_filepath = tempfile.mktemp(".json", "metadata") - save_custom_metadata(metadata_filepath, metadata) - - def test_valid_filenames(self): - valid = "img.png" - non_valid = "img,1.png" - filenames_list = [ - ([valid], True), - ([valid, valid], True), - ([non_valid], False), - ([valid, non_valid], False), - ] - for filenames, valid in filenames_list: - with self.subTest(msg=f"filenames:{filenames}"): - if valid: - check_filenames(filenames) - else: - with self.assertRaises(ValueError): - check_filenames(filenames) + io.save_custom_metadata(metadata_filepath, metadata) class TestEmbeddingsIO(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # correct embedding file as created through lightly self.embeddings_path = tempfile.mktemp(".csv", "embeddings") embeddings = np.random.rand(32, 2) labels = [0 for i in range(len(embeddings))] filenames = [f"img_{i}.jpg" for i in range(len(embeddings))] - save_embeddings(self.embeddings_path, embeddings, labels, filenames) + io.save_embeddings(self.embeddings_path, embeddings, labels, filenames) - def test_valid_embeddings(self): - check_embeddings(self.embeddings_path) + def test_valid_embeddings(self) -> None: + io.check_embeddings(self.embeddings_path) - def test_whitespace_in_embeddings(self): + def test_whitespace_in_embeddings(self) -> None: # should fail because there whitespaces in the header columns lines = [ "filenames, embedding_0,embedding_1,labels\n", @@ -65,19 +38,19 @@ def test_whitespace_in_embeddings(self): with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("must not contain whitespaces" in str(context.exception)) - def test_no_labels_in_embeddings(self): + def test_no_labels_in_embeddings(self) -> None: # should fail because there is no `labels` column in the header lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"] with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("has no `labels` column" in str(context.exception)) - def test_no_empty_rows_in_embeddings(self): + def test_no_empty_rows_in_embeddings(self) -> None: # should fail because there are empty rows in the embeddings file lines = [ "filenames,embedding_0,embedding_1,labels\n", @@ -86,10 +59,10 @@ def test_no_empty_rows_in_embeddings(self): with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("must not have empty rows" in str(context.exception)) - def test_embeddings_extra_rows(self): + def test_embeddings_extra_rows(self) -> None: rows = [ ["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"], ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], @@ -99,14 +72,14 @@ def test_embeddings_extra_rows(self): csv_writer = csv.writer(f) csv_writer.writerows(rows) - check_embeddings(self.embeddings_path, remove_additional_columns=True) + io.check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, rows): self.assertListEqual(row_read, row_original[:-2]) - def test_embeddings_extra_rows_special_order(self): + def test_embeddings_extra_rows_special_order(self) -> None: input_rows = [ ["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"], ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], @@ -121,26 +94,26 @@ def test_embeddings_extra_rows_special_order(self): csv_writer = csv.writer(f) csv_writer.writerows(input_rows) - check_embeddings(self.embeddings_path, remove_additional_columns=True) + io.check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, correct_output_rows): self.assertListEqual(row_read, row_original) - def test_save_tasks(self): + def test_save_tasks(self) -> None: tasks = [ "task1", "task2", "task3", ] with tempfile.NamedTemporaryFile(suffix=".json") as file: - save_tasks(file.name, tasks) + io.save_tasks(file.name, tasks) with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(tasks, loaded) - def test_save_schema(self): + def test_save_schema(self) -> None: description = "classification" ids = [1, 2, 3, 4] names = ["name1", "name2", "name3", "name4"] @@ -154,16 +127,56 @@ def test_save_schema(self): ], } with tempfile.NamedTemporaryFile(suffix=".json") as file: - save_schema(file.name, description, ids, names) + io.save_schema(file.name, description, ids, names) with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(sorted(expected_format), sorted(loaded)) - def test_save_schema_different(self): + def test_save_schema_different(self) -> None: with self.assertRaises(ValueError): - save_schema( + io.save_schema( "name_doesnt_matter", "description_doesnt_matter", [1, 2], ["name1"], ) + + +def test_save_and_load_embeddings(tmp_path: Path) -> None: + embeddings = np.random.rand(2, 32) + labels = [0, 1] + filenames = ["img_1.jpg", "img_2.jpg"] + + io.save_embeddings( + path=str(tmp_path / "embeddings.csv"), + embeddings=embeddings, + labels=labels, + filenames=filenames, + ) + + loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings( + path=str(tmp_path / "embeddings.csv") + ) + assert np.allclose(embeddings, loaded_embeddings) + assert labels == loaded_labels + assert filenames == loaded_filenames + + +def test_save_and_load_embeddings__filename_with_comma(tmp_path: Path) -> None: + embeddings = np.random.rand(4, 32) + labels = [0, 1, 2, 3] + filenames = ["img,1.jpg", '",img,.jpg', ',"img".jpg', ',"img\n".jpg'] + + io.save_embeddings( + path=str(tmp_path / "embeddings.csv"), + embeddings=embeddings, + labels=labels, + filenames=filenames, + ) + + loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings( + path=str(tmp_path / "embeddings.csv") + ) + assert np.allclose(embeddings, loaded_embeddings) + assert labels == loaded_labels + assert filenames == loaded_filenames diff --git a/tests/utils/test_scheduler.py b/tests/utils/test_scheduler.py index 6a0bd569c..7beb6754f 100644 --- a/tests/utils/test_scheduler.py +++ b/tests/utils/test_scheduler.py @@ -7,7 +7,7 @@ class TestScheduler(unittest.TestCase): - def test_cosine_schedule(self): + def test_cosine_schedule(self) -> None: self.assertAlmostEqual(cosine_schedule(1, 10, 0.99, 1.0), 0.99030154, 6) self.assertAlmostEqual(cosine_schedule(95, 100, 0.7, 2.0), 1.99477063, 6) self.assertAlmostEqual(cosine_schedule(0, 1, 0.996, 1.0), 1.0, 6) @@ -23,7 +23,15 @@ def test_cosine_schedule(self): ): cosine_schedule(11, 10, 0.0, 1.0) - def test_CosineWarmupScheduler(self): + def test_cosine_schedule__period(self) -> None: + self.assertAlmostEqual(cosine_schedule(0, 1, 0, 1.0, period=10), 0.0, 6) + self.assertAlmostEqual(cosine_schedule(3, 1, 0, 2.0, period=10), 1.30901706, 6) + self.assertAlmostEqual(cosine_schedule(10, 1, 0, 1.0, period=10), 0.0, 6) + self.assertAlmostEqual(cosine_schedule(15, 1, 0, 1.0, period=10), 1.0, 6) + with self.assertRaises(ValueError): + cosine_schedule(1, 10, 0.0, 1.0, period=-1) + + def test_CosineWarmupScheduler(self) -> None: model = nn.Linear(10, 1) optimizer = torch.optim.SGD( model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0 @@ -62,3 +70,23 @@ def test_CosineWarmupScheduler(self): RuntimeWarning, msg="Current step number 7 exceeds max_steps 6." ): scheduler.step() + + def test_CosineWarmupScheduler__warmup(self) -> None: + model = nn.Linear(10, 1) + optimizer = torch.optim.SGD( + model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0 + ) + scheduler = CosineWarmupScheduler( + optimizer, + warmup_epochs=3, + max_epochs=6, + start_value=2.0, + end_value=0.0, + ) + # Linear warmup + self.assertAlmostEqual(scheduler.scale_lr(epoch=0), 2.0 * 1 / 3) + self.assertAlmostEqual(scheduler.scale_lr(epoch=1), 2.0 * 2 / 3) + self.assertAlmostEqual(scheduler.scale_lr(epoch=2), 2.0 * 3 / 3) + # Cosine decay + self.assertAlmostEqual(scheduler.scale_lr(epoch=3), 2.0 * 3 / 3) + self.assertLess(scheduler.scale_lr(epoch=4), 2.0) diff --git a/tests/utils/test_version_compare.py b/tests/utils/test_version_compare.py index ce39dbb6a..40d516ea0 100644 --- a/tests/utils/test_version_compare.py +++ b/tests/utils/test_version_compare.py @@ -4,7 +4,7 @@ class TestVersionCompare(unittest.TestCase): - def test_valid_versions(self): + def test_valid_versions(self) -> None: # general test of smaller than version numbers self.assertEqual(version_compare.version_compare("0.1.4", "1.2.0"), -1) self.assertEqual(version_compare.version_compare("1.1.0", "1.2.0"), -1) @@ -16,7 +16,7 @@ def test_valid_versions(self): # test equal self.assertEqual(version_compare.version_compare("1.2.0", "1.2.0"), 0) - def test_invalid_versions(self): + def test_invalid_versions(self) -> None: with self.assertRaises(ValueError): version_compare.version_compare("1.2", "1.1.0")