diff --git a/nibabel/filebasedimages.py b/nibabel/filebasedimages.py index 21fe754edf..b08f5e74d4 100644 --- a/nibabel/filebasedimages.py +++ b/nibabel/filebasedimages.py @@ -10,6 +10,7 @@ import io from copy import deepcopy +from urllib import request from .fileholders import FileHolder from .filename_parser import (types_filenames, TypesFilenamesError, splitext_addext) @@ -488,7 +489,7 @@ def path_maybe_image(klass, filename, sniff=None, sniff_max=1024): class SerializableImage(FileBasedImage): """ - Abstract image class for (de)serializing images to/from byte strings. + Abstract image class for (de)serializing images to/from byte streams/strings. The class doesn't define any image properties. @@ -501,6 +502,7 @@ class SerializableImage(FileBasedImage): classmethods: * from_bytes(bytestring) - make instance by deserializing a byte string + * from_url(url) - make instance by fetching and deserializing a URL Loading from byte strings should provide round-trip equivalence: @@ -538,7 +540,43 @@ class SerializableImage(FileBasedImage): """ @classmethod - def from_bytes(klass, bytestring): + def _filemap_from_iobase(klass, io_obj: io.IOBase): + """For single-file image types, make a file map with the correct key""" + if len(klass.files_types) > 1: + raise NotImplementedError( + "(de)serialization is undefined for multi-file images" + ) + return klass.make_file_map({klass.files_types[0][0]: io_obj}) + + @classmethod + def from_stream(klass, io_obj: io.IOBase): + """Load image from readable IO stream + + Convert to BytesIO to enable seeking, if input stream is not seekable + + Parameters + ---------- + io_obj : IOBase object + Readable stream + """ + if not io_obj.seekable(): + io_obj = io.BytesIO(io_obj.read()) + return klass.from_file_map(klass._filemap_from_iobase(io_obj)) + + def to_stream(self, io_obj: io.IOBase, **kwargs): + """Save image to writable IO stream + + Parameters + ---------- + io_obj : IOBase object + Writable stream + \*\*kwargs : keyword arguments + Keyword arguments that may be passed to ``img.to_file_map()`` + """ + self.to_file_map(self._filemap_from_iobase(io_obj), **kwargs) + + @classmethod + def from_bytes(klass, bytestring: bytes): """ Construct image from a byte string Class method @@ -548,13 +586,9 @@ def from_bytes(klass, bytestring): bstring : bytes Byte string containing the on-disk representation of an image """ - if len(klass.files_types) > 1: - raise NotImplementedError("from_bytes is undefined for multi-file images") - bio = io.BytesIO(bytestring) - file_map = klass.make_file_map({'image': bio, 'header': bio}) - return klass.from_file_map(file_map) + return klass.from_stream(io.BytesIO(bytestring)) - def to_bytes(self, **kwargs): + def to_bytes(self, **kwargs) -> bytes: r""" Return a ``bytes`` object with the contents of the file that would be written if the image were saved. @@ -568,9 +602,22 @@ def to_bytes(self, **kwargs): bytes Serialized image """ - if len(self.__class__.files_types) > 1: - raise NotImplementedError("to_bytes() is undefined for multi-file images") bio = io.BytesIO() - file_map = self.make_file_map({'image': bio, 'header': bio}) - self.to_file_map(file_map, **kwargs) + self.to_stream(bio, **kwargs) return bio.getvalue() + + @classmethod + def from_url(klass, url, timeout=5): + """Retrieve and load an image from a URL + + Class method + + Parameters + ---------- + url : str or urllib.request.Request object + URL of file to retrieve + timeout : float, optional + Time (in seconds) to wait for a response + """ + response = request.urlopen(url, timeout=timeout) + return klass.from_stream(response) diff --git a/nibabel/tests/test_filebasedimages.py b/nibabel/tests/test_filebasedimages.py index efac76a65a..d01440eb65 100644 --- a/nibabel/tests/test_filebasedimages.py +++ b/nibabel/tests/test_filebasedimages.py @@ -5,6 +5,7 @@ import warnings import numpy as np +import pytest from ..filebasedimages import FileBasedHeader, FileBasedImage, SerializableImage @@ -127,3 +128,24 @@ def __init__(self, seq=None): hdr4 = H.from_header(None) assert isinstance(hdr4, H) assert hdr4.a_list == [] + + +class MultipartNumpyImage(FBNumpyImage): + # We won't actually try to write these out, just need to test an edge case + files_types = (('header', '.hdr'), ('image', '.npy')) + + +class SerializableMPNumpyImage(MultipartNumpyImage, SerializableImage): + pass + + +def test_multifile_stream_failure(): + shape = (2, 3, 4) + arr = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + img = SerializableMPNumpyImage(arr) + with pytest.raises(NotImplementedError): + img.to_bytes() + img = SerializableNumpyImage(arr) + bstr = img.to_bytes() + with pytest.raises(NotImplementedError): + SerializableMPNumpyImage.from_bytes(bstr) diff --git a/nibabel/tests/test_image_api.py b/nibabel/tests/test_image_api.py index afc04d709a..e4287988d7 100644 --- a/nibabel/tests/test_image_api.py +++ b/nibabel/tests/test_image_api.py @@ -26,6 +26,7 @@ import warnings from functools import partial from itertools import product +import io import pathlib import numpy as np @@ -523,34 +524,41 @@ def validate_affine_deprecated(self, imaker, params): img.get_affine() -class SerializeMixin(object): - def validate_to_bytes(self, imaker, params): +class SerializeMixin: + def validate_to_from_stream(self, imaker, params): img = imaker() - serialized = img.to_bytes() - with InTemporaryDirectory(): - fname = 'img' + self.standard_extension - img.to_filename(fname) - with open(fname, 'rb') as fobj: - file_contents = fobj.read() - assert serialized == file_contents + klass = getattr(self, 'klass', img.__class__) + stream = io.BytesIO() + img.to_stream(stream) - def validate_from_bytes(self, imaker, params): + rt_img = klass.from_stream(stream) + assert self._header_eq(img.header, rt_img.header) + assert np.array_equal(img.get_fdata(), rt_img.get_fdata()) + + def validate_file_stream_equivalence(self, imaker, params): img = imaker() klass = getattr(self, 'klass', img.__class__) with InTemporaryDirectory(): fname = 'img' + self.standard_extension img.to_filename(fname) - all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}] - for img_params in all_images: - img_a = klass.from_filename(img_params['fname']) - with open(img_params['fname'], 'rb') as fobj: - img_b = klass.from_bytes(fobj.read()) + with open("stream", "wb") as fobj: + img.to_stream(fobj) - assert self._header_eq(img_a.header, img_b.header) + # Check that writing gets us the same thing + contents1 = pathlib.Path(fname).read_bytes() + contents2 = pathlib.Path("stream").read_bytes() + assert contents1 == contents2 + + # Check that reading gets us the same thing + img_a = klass.from_filename(fname) + with open(fname, "rb") as fobj: + img_b = klass.from_stream(fobj) + # This needs to happen while the filehandle is open assert np.array_equal(img_a.get_fdata(), img_b.get_fdata()) - del img_a - del img_b + assert self._header_eq(img_a.header, img_b.header) + del img_a + del img_b def validate_to_from_bytes(self, imaker, params): img = imaker() @@ -572,6 +580,45 @@ def validate_to_from_bytes(self, imaker, params): del img_a del img_b + @pytest.fixture(autouse=True) + def setup(self, httpserver, tmp_path): + """Make pytest fixtures available to validate functions""" + self.httpserver = httpserver + self.tmp_path = tmp_path + + def validate_from_url(self, imaker, params): + server = self.httpserver + + img = imaker() + img_bytes = img.to_bytes() + + server.expect_oneshot_request("/img").respond_with_data(img_bytes) + url = server.url_for("/img") + assert url.startswith("http://") # Check we'll trigger an HTTP handler + rt_img = img.__class__.from_url(url) + + assert rt_img.to_bytes() == img_bytes + assert self._header_eq(img.header, rt_img.header) + assert np.array_equal(img.get_fdata(), rt_img.get_fdata()) + del img + del rt_img + + def validate_from_file_url(self, imaker, params): + tmp_path = self.tmp_path + + img = imaker() + import uuid + fname = tmp_path / f'img-{uuid.uuid4()}{self.standard_extension}' + img.to_filename(fname) + + rt_img = img.__class__.from_url(f"file:///{fname}") + + assert self._header_eq(img.header, rt_img.header) + assert np.array_equal(img.get_fdata(), rt_img.get_fdata()) + del img + del rt_img + + @staticmethod def _header_eq(header_a, header_b): """ Header equality check that can be overridden by a subclass of this test @@ -583,7 +630,6 @@ def _header_eq(header_a, header_b): return header_a == header_b - class LoadImageAPI(GenericImageAPI, DataInterfaceMixin, AffineMixin, diff --git a/setup.cfg b/setup.cfg index 4defb7eb14..47a7317088 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,6 +61,7 @@ test = pytest !=5.3.4 pytest-cov pytest-doctestplus + pytest-httpserver zstd = pyzstd >= 0.14.3 all =