diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 861925a7d99..cc6b7195fe2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -88,6 +88,8 @@ title: Load document data - local: document_dataset title: Create a document dataset + - local: nifti_dataset + title: Create a medical imaging dataset title: "Vision" - sections: - local: nlp_load diff --git a/docs/source/nifti_dataset.mdx b/docs/source/nifti_dataset.mdx new file mode 100644 index 00000000000..2770460fbaf --- /dev/null +++ b/docs/source/nifti_dataset.mdx @@ -0,0 +1,130 @@ +# Create a NIfTI dataset + +This page shows how to create and share a dataset of medical images in NIfTI format (.nii / .nii.gz) using the `datasets` library. + +You can share a dataset with your team or with anyone in the community by creating a dataset repository on the Hugging Face Hub: + +```py +from datasets import load_dataset + +dataset = load_dataset("/my_nifti_dataset") +``` + +There are two common ways to create a NIfTI dataset: + +- Create a dataset from local NIfTI files in Python and upload it with `Dataset.push_to_hub`. +- Use a folder-based convention (one file per example) and a small helper to convert it into a `Dataset`. + +> [!TIP] +> You can control access to your dataset by requiring users to share their contact information first. Check out the [Gated datasets](https://huggingface.co/docs/hub/datasets-gated) guide for more information. + +## Local files + +If you already have a list of file paths to NIfTI files, the easiest workflow is to create a `Dataset` from that list and cast the column to the `Nifti` feature. + +```py +from datasets import Dataset +from datasets import Nifti + +# simple example: create a dataset from file paths +files = ["/path/to/scan_001.nii.gz", "/path/to/scan_002.nii.gz"] +ds = Dataset.from_dict({"nifti": files}).cast_column("nifti", Nifti()) + +# access a decoded nibabel image (if decode=True) +# ds[0]["nifti"] will be a nibabel.Nifti1Image object when decode=True +# or a dict {'bytes': None, 'path': '...'} when decode=False +``` + +The `Nifti` feature supports a `decode` parameter. When `decode=True` (the default), it loads the NIfTI file into a `nibabel.nifti1.Nifti1Image` object. You can access the image data as a numpy array with `img.get_fdata()`. When `decode=False`, it returns a dict with the file path and bytes. + +```py +from datasets import Dataset, Nifti + +ds = Dataset.from_dict({"nifti": ["/path/to/scan.nii.gz"]}).cast_column("nifti", Nifti(decode=True)) +img = ds[0]["nifti"] # instance of: nibabel.nifti1.Nifti1Image +arr = img.get_fdata() +``` + +After preparing the dataset you can push it to the Hub: + +```py +ds.push_to_hub("/my_nifti_dataset") +``` + +This will create a dataset repository containing your NIfTI dataset with a `data/` folder of parquet shards. + +## Folder conventions and metadata + +If you organize your dataset in folders you can create splits automatically (train/test/validation) by following a structure like: + +``` +dataset/train/scan_0001.nii +dataset/train/scan_0002.nii +dataset/validation/scan_1001.nii +dataset/test/scan_2001.nii +``` + +If you have labels or other metadata, provide a `metadata.csv`, `metadata.jsonl`, or `metadata.parquet` in the folder so files can be linked to metadata rows. The metadata must contain a `file_name` (or `*_file_name`) field with the relative path to the NIfTI file next to the metadata file. + +Example `metadata.csv`: + +```csv +file_name,patient_id,age,diagnosis +scan_0001.nii.gz,P001,45,healthy +scan_0002.nii.gz,P002,59,disease_x +``` + +The `Nifti` feature works with zipped datasets too — each zip can contain NIfTI files and a metadata file. This is useful when uploading large datasets as archives. +This means your dataset structure could look like this (mixed compressed and uncompressed files): +``` +dataset/train/scan_0001.nii.gz +dataset/train/scan_0002.nii +dataset/validation/scan_1001.nii.gz +dataset/test/scan_2001.nii +``` + +## Converting to PyTorch tensors + +Use the [`~Dataset.set_transform`] function to apply the transformation on-the-fly to batches of the dataset: + +```py +import torch +import nibabel +import numpy as np + +def transform_to_pytorch(example): + example["nifti_torch"] = [torch.tensor(ex.get_fdata()) for ex in example["nifti"]] + return example + +ds.set_transform(transform_to_pytorch) + +``` +Accessing elements now (e.g. `ds[0]`) will yield torch tensors in the `"nifti_torch"` key. + + +## Usage of NifTI1Image + +NifTI is a format to store the result of 3 (or even 4) dimensional brain scans. This includes 3 spatial dimensions (x,y,z) +and optionally a time dimension (t). Furthermore, the given positions here are only relative to the scanner, therefore +the dimensions (4, 5, 6) are used to lift this to real world coordinates. + +You can visualize nifti files for instance leveraging `matplotlib` as follows: +```python +import matplotlib.pyplot as plt +from datasets import load_dataset + +def show_slices(slices): + """ Function to display row of image slices """ + fig, axes = plt.subplots(1, len(slices)) + for i, slice in enumerate(slices): + axes[i].imshow(slice.T, cmap="gray", origin="lower") + +nifti_ds = load_dataset("/my_nifti_dataset") +for epi_img in nifti_ds: + nifti_img = epi_img["nifti"].get_fdata() + show_slices([nifti_img[:, :, 16], nifti_img[26, :, :], nifti_img[:, 30, :]]) + plt.show() +``` + +For further reading we refer to the [nibabel documentation](https://nipy.org/nibabel/index.html) and especially [this nibabel tutorial](https://nipy.org/nibabel/coordinate_systems.html) +--- diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index 786679636e7..4792d1b88f7 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -103,6 +103,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.pdffolder.PdfFolder +### Nifti + +[[autodoc]] datasets.packaged_modules.niftifolder.NiftiFolderConfig + +[[autodoc]] datasets.packaged_modules.niftifolder.NiftiFolder + ### WebDataset [[autodoc]] datasets.packaged_modules.webdataset.WebDataset diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 299dd765d13..84e651f9171 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -271,6 +271,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Pdf +### Nifti + +[[autodoc]] datasets.Nifti + ## Filesystems [[autodoc]] datasets.filesystems.is_remote_filesystem diff --git a/setup.py b/setup.py index 8dc56ef518d..06eee6717c8 100644 --- a/setup.py +++ b/setup.py @@ -186,6 +186,7 @@ "polars[timezone]>=0.20.0", "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced "torchcodec>=0.7.0", # minium version to get windows support + "nibabel>=5.3.1", ] NUMPY2_INCOMPATIBLE_LIBRARIES = [ @@ -207,6 +208,8 @@ PDFS_REQUIRE = ["pdfplumber>=0.11.4"] +NIBABEL_REQUIRE = ["nibabel>=5.3.2"] + EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, @@ -224,6 +227,7 @@ "benchmarks": BENCHMARKS_REQUIRE, "docs": DOCS_REQUIRE, "pdfs": PDFS_REQUIRE, + "nibabel": NIBABEL_REQUIRE, } setup( diff --git a/src/datasets/config.py b/src/datasets/config.py index 5e61e7bc015..3d3f12b008d 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -139,6 +139,7 @@ TORCHCODEC_AVAILABLE = importlib.util.find_spec("torchcodec") is not None TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None +NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 36133ce5e5a..40a3568039a 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -15,10 +15,12 @@ "TranslationVariableLanguages", "Video", "Pdf", + "Nifti", ] from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, List, Sequence, Value from .image import Image +from .nifti import Nifti from .pdf import Pdf from .translation import Translation, TranslationVariableLanguages from .video import Video diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 54d84ef33e2..88259767ae0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -42,6 +42,7 @@ from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image +from .nifti import Nifti from .pdf import Pdf, encode_pdfplumber_pdf from .translation import Translation, TranslationVariableLanguages from .video import Video @@ -1270,6 +1271,7 @@ def __repr__(self): Image, Video, Pdf, + Nifti, ] @@ -1428,6 +1430,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni Image.__name__: Image, Video.__name__: Video, Pdf.__name__: Pdf, + Nifti.__name__: Nifti, } @@ -1761,6 +1764,9 @@ class Features(dict): - [`Pdf`] feature to store the absolute path to a PDF file, a `pdfplumber.pdf.PDF` object or a dictionary with the relative path to a PDF file ("path" key) and its bytes content ("bytes" key). This feature loads the PDF lazily with a PDF reader. + - [`Nifti`] feature to store the absolute path to a NIfTI neuroimaging file, a `nibabel.Nifti1Image` object + or a dictionary with the relative path to a NIfTI file ("path" key) and its bytes content ("bytes" key). + This feature loads the NIfTI file lazily with nibabel. - [`Translation`] or [`TranslationVariableLanguages`] feature specific to Machine Translation. """ diff --git a/src/datasets/features/nifti.py b/src/datasets/features/nifti.py new file mode 100644 index 00000000000..bac91e2af4b --- /dev/null +++ b/src/datasets/features/nifti.py @@ -0,0 +1,243 @@ +import os +from dataclasses import dataclass, field +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union + +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import string_to_dict + + +if TYPE_CHECKING: + import nibabel as nib + + from .features import FeatureType + + +@dataclass +class Nifti: + """ + **Experimental.** + Nifti [`Feature`] to read NIfTI neuroimaging files. + + Input: The Nifti feature accepts as input: + - A `str`: Absolute path to the NIfTI file (i.e. random access is allowed). + - A `pathlib.Path`: path to the NIfTI file (i.e. random access is allowed). + - A `dict` with the keys: + - `path`: String with relative path of the NIfTI file in a dataset repository. + - `bytes`: Bytes of the NIfTI file. + This is useful for archived files with sequential access. + + - A `nibabel` image object (e.g., `nibabel.nifti1.Nifti1Image`). + + Args: + decode (`bool`, defaults to `True`): + Whether to decode the NIfTI data. If `False` a string with the bytes is returned. `decode=False` is not supported when decoding examples. + + Examples: + + ```py + >>> from datasets import Dataset, Nifti + >>> ds = Dataset.from_dict({"nifti": ["path/to/file.nii.gz"]}).cast_column("nifti", Nifti()) + >>> ds.features["nifti"] + Nifti(decode=True, id=None) + >>> ds[0]["nifti"] + + >>> ds = ds.cast_column("nifti", Nifti(decode=False)) + >>> ds[0]["nifti"] + {'bytes': None, + 'path': 'path/to/file.nii.gz'} + ``` + """ + + decode: bool = True + id: Optional[str] = field(default=None, repr=False) + + # Automatically constructed + dtype: ClassVar[str] = "nibabel.nifti1.Nifti1Image" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Nifti", init=False, repr=False) + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, bytearray, dict, "nib.Nifti1Image"]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `bytes`, `nibabel.Nifti1Image` or `dict`): + Data passed as input to Nifti feature. + + Returns: + `dict` with "path" and "bytes" fields + """ + if config.NIBABEL_AVAILABLE: + import nibabel as nib + else: + nib = None + + if isinstance(value, str): + return {"path": value, "bytes": None} + elif isinstance(value, Path): + return {"path": str(value.absolute()), "bytes": None} + elif isinstance(value, (bytes, bytearray)): + return {"path": None, "bytes": value} + elif nib is not None and isinstance(value, nib.spatialimages.SpatialImage): + # nibabel image object - try to get path or convert to bytes + return encode_nibabel_image(value) + elif isinstance(value, dict): + if value.get("path") is not None and os.path.isfile(value["path"]): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": value.get("path")} + elif value.get("bytes") is not None or value.get("path") is not None: + # store the nifti bytes, and path is used to infer the format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A nifti sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + else: + raise ValueError( + f"A nifti sample should be a string, bytes, Path, nibabel image, or dict, but got {type(value)}." + ) + + def decode_example(self, value: dict, token_per_repo_id=None) -> "nib.nifti1.Nifti1Image": + """Decode example NIfTI file into nibabel image object. + + Args: + value (`str` or `dict`): + A string with the absolute NIfTI file path, a dictionary with + keys: + + - `path`: String with absolute or relative NIfTI file path. + - `bytes`: The bytes of the NIfTI file. + + token_per_repo_id (`dict`, *optional*): + To access and decode NIfTI files from private repositories on + the Hub, you can pass a dictionary + repo_id (`str`) -> token (`bool` or `str`). + + Returns: + `nibabel.Nifti1Image` objects + """ + if not self.decode: + raise NotImplementedError("Decoding is disabled for this feature. Please use Nifti(decode=True) instead.") + + if config.NIBABEL_AVAILABLE: + import nibabel as nib + else: + raise ImportError("To support decoding NIfTI files, please install 'nibabel'.") + + if token_per_repo_id is None: + token_per_repo_id = {} + + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: + if path is None: + raise ValueError(f"A nifti should have one of 'path' or 'bytes' but both are None in {value}.") + else: + if is_local_path(path): + nifti = nib.load(path) + else: + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL + if source_url.startswith(config.HF_ENDPOINT) + else config.HUB_DATASETS_HFFS_URL + ) + try: + repo_id = string_to_dict(source_url, pattern)["repo_id"] + token = token_per_repo_id.get(repo_id) + except ValueError: + token = None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + nifti = nib.load(f) + else: + import gzip + + if ( + bytes_[:2] == b"\x1f\x8b" + ): # gzip magic number, see https://stackoverflow.com/a/76055284/9534390 or "Magic number" on https://en.wikipedia.org/wiki/Gzip + bytes_ = gzip.decompress(bytes_) + + bio = BytesIO(bytes_) + fh = nib.FileHolder(fileobj=bio) + nifti = nib.Nifti1Image.from_file_map({"header": fh, "image": fh}) + + return nifti + + def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]: + """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary.""" + from .features import Value + + return ( + self + if self.decode + else { + "bytes": Value("binary"), + "path": Value("string"), + } + ) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.BinaryArray]) -> pa.StructArray: + """Cast an Arrow array to the Nifti arrow storage type. + The Arrow types that can be converted to the Nifti pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.binary()` - it must contain the NIfTI bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + + Args: + storage (`Union[pa.StringArray, pa.StructArray, pa.BinaryArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Nifti arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + return array_cast(storage, self.pa_type) + + +def encode_nibabel_image(img: "nib.Nifti1Image") -> dict[str, Optional[Union[str, bytes]]]: + """ + Encode a nibabel image object into a dictionary. + + If the image has an associated file path, returns the path. Otherwise, serializes + the image content into bytes. + + Args: + img: A nibabel image object (e.g., Nifti1Image). + + Returns: + dict: A dictionary with "path" or "bytes" field. + """ + if hasattr(img, "file_map") and img.file_map is not None: + filename = img.file_map["image"].filename + return {"path": filename, "bytes": None} + + bytes_data = img.to_bytes() + return {"path": None, "bytes": bytes_data} diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 515ff147b29..9d076df44b7 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -11,6 +11,7 @@ from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json +from .niftifolder import niftifolder from .pandas import pandas from .parquet import parquet from .pdffolder import pdffolder @@ -46,6 +47,7 @@ def _hash_python_lines(lines: list[str]) -> str: "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())), + "niftifolder": (niftifolder.__name__, _hash_python_lines(inspect.getsource(niftifolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), @@ -89,6 +91,8 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext: ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: dict[str, list[str]] = {} @@ -106,3 +110,4 @@ def _hash_python_lines(lines: list[str]) -> str: _MODULE_TO_METADATA_FILE_NAMES["audiofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES +_MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES diff --git a/src/datasets/packaged_modules/niftifolder/__init__.py b/src/datasets/packaged_modules/niftifolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/niftifolder/niftifolder.py b/src/datasets/packaged_modules/niftifolder/niftifolder.py new file mode 100644 index 00000000000..c6d039419f1 --- /dev/null +++ b/src/datasets/packaged_modules/niftifolder/niftifolder.py @@ -0,0 +1,23 @@ +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class NiftiFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for NiftiFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class NiftiFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Nifti + BASE_COLUMN_NAME = "nifti" + BUILDER_CONFIG_CLASS = NiftiFolderConfig + EXTENSIONS: list[str] = [".nii", ".nii.gz"] diff --git a/tests/features/data/test_nifti.nii b/tests/features/data/test_nifti.nii new file mode 100644 index 00000000000..c1d560658c6 Binary files /dev/null and b/tests/features/data/test_nifti.nii differ diff --git a/tests/features/data/test_nifti.nii.gz b/tests/features/data/test_nifti.nii.gz new file mode 100644 index 00000000000..d5683901665 Binary files /dev/null and b/tests/features/data/test_nifti.nii.gz differ diff --git a/tests/features/test_nifti.py b/tests/features/test_nifti.py new file mode 100644 index 00000000000..077a7519431 --- /dev/null +++ b/tests/features/test_nifti.py @@ -0,0 +1,91 @@ +## taken from: https://github.com/yarikoptic/nitest-balls1/blob/2cd07d86e2cc2d3c612d5d4d659daccd7a58f126/NIFTI/T1.nii.gz + +from pathlib import Path + +import pytest + +from datasets import Dataset, Features, Nifti +from src.datasets.features.nifti import encode_nibabel_image + +from ..utils import require_nibabel + + +@require_nibabel +@pytest.mark.parametrize("nifti_file", ["test_nifti.nii", "test_nifti.nii.gz"]) +@pytest.mark.parametrize( + "build_example", + [ + lambda nifti_path: nifti_path, + lambda nifti_path: Path(nifti_path), + lambda nifti_path: open(nifti_path, "rb").read(), + lambda nifti_path: {"path": nifti_path}, + lambda nifti_path: {"path": nifti_path, "bytes": None}, + lambda nifti_path: {"path": nifti_path, "bytes": open(nifti_path, "rb").read()}, + lambda nifti_path: {"path": None, "bytes": open(nifti_path, "rb").read()}, + lambda nifti_path: {"bytes": open(nifti_path, "rb").read()}, + ], +) +def test_nifti_feature_encode_example(shared_datadir, nifti_file, build_example): + import nibabel + + nifti_path = str(shared_datadir / nifti_file) + nifti = Nifti() + encoded_example = nifti.encode_example(build_example(nifti_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = nifti.decode_example(encoded_example) + assert isinstance(decoded_example, nibabel.nifti1.Nifti1Image) + + +@require_nibabel +@pytest.mark.parametrize("nifti_file", ["test_nifti.nii", "test_nifti.nii.gz"]) +def test_dataset_with_nifti_feature(shared_datadir, nifti_file): + import nibabel + + nifti_path = str(shared_datadir / nifti_file) + data = {"nifti": [nifti_path]} + features = Features({"nifti": Nifti()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"nifti"} + assert isinstance(item["nifti"], nibabel.nifti1.Nifti1Image) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"nifti"} + assert isinstance(batch["nifti"], list) and all( + isinstance(item, nibabel.nifti1.Nifti1Image) for item in batch["nifti"] + ) + column = dset["nifti"] + assert len(column) == 1 + assert all(isinstance(item, nibabel.nifti1.Nifti1Image) for item in column) + + # from bytes + with open(nifti_path, "rb") as f: + data = {"nifti": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"nifti"} + assert isinstance(item["nifti"], nibabel.nifti1.Nifti1Image) + + +@require_nibabel +def test_encode_nibabel_image(shared_datadir): + import nibabel + + nifti_path = str(shared_datadir / "test_nifti.nii") + img = nibabel.load(nifti_path) + encoded_example = encode_nibabel_image(img) + nifti = Nifti() + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["path"] is not None and encoded_example["bytes"] is None + decoded_example = nifti.decode_example(encoded_example) + assert isinstance(decoded_example, nibabel.nifti1.Nifti1Image) + + # test bytes only + img.file_map = None + encoded_example_bytes = encode_nibabel_image(img) + assert isinstance(encoded_example_bytes, dict) + assert encoded_example_bytes["bytes"] is not None and encoded_example_bytes["path"] is None + # this cannot be converted back from bytes (yet) diff --git a/tests/utils.py b/tests/utils.py index 166bd4789c2..b796641a290 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -209,6 +209,18 @@ def require_pdfplumber(test_case): return test_case +def require_nibabel(test_case): + """ + Decorator marking a test that requires nibabel. + + These tests are skipped when nibabel isn't installed. + + """ + if not config.NIBABEL_AVAILABLE: + test_case = unittest.skip("test requires nibabel")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.