diff --git a/deeptrack/features.py b/deeptrack/features.py index a4bf2c70..6bc3fcfc 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -218,7 +218,7 @@ def propagate_data_to_dependencies( "OneOfDict", "LoadImage", # TODO ***MG*** "SampleToMasks", # TODO ***MG*** - "AsType", # TODO ***MG*** + "AsType", "ChannelFirst2d", "Upscale", # TODO ***AL*** "NonOverlapping", # TODO ***AL*** @@ -7751,9 +7751,9 @@ def _process_and_get( class AsType(Feature): """Convert the data type of images. - This feature changes the data type (`dtype`) of input images to a specified - type. The accepted types are the same as those used by NumPy arrays, such - as `float64`, `int32`, `uint16`, `int16`, `uint8`, and `int8`. + This feature changes the data type (`dtype`) of input images to a specified + type. The accepted types are standard NumPy or PyTorch data types (e.g., + `"float64"`, `"int32"`, `"uint8"`, `"int8"`, and `"torch.float32"`). Parameters ---------- @@ -7776,7 +7776,7 @@ class AsType(Feature): >>> >>> input_image = np.array([1.5, 2.5, 3.5]) - Apply an AsType feature to convert to `int32`: + Apply an AsType feature to convert to "`int32"`: >>> astype_feature = dt.AsType(dtype="int32") >>> output_image = astype_feature.get(input_image, dtype="int32") >>> output_image @@ -7833,7 +7833,39 @@ def get( """ - return image.astype(dtype) + if apc.is_torch_array(image): + # Mapping from string to torch dtype + torch_dtypes = { + "float64": torch.float64, + "double": torch.float64, + "float32": torch.float32, + "float": torch.float32, + "float16": torch.float16, + "half": torch.float16, + "int64": torch.int64, + "int32": torch.int32, + "int16": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + + # Ensure `"torch.float32"` and `"float32"` are treated the same by + # removing the `torch.` prefix if present + dtype_str = str(dtype).replace("torch.", "") + torch_dtype = torch_dtypes.get(dtype_str) + + if torch_dtype is None: + raise ValueError( + f"Unsupported dtype for torch.Tensor: {dtype}" + ) + + return image.to(dtype=torch_dtype) + + else: + return image.astype(dtype) class ChannelFirst2d(Feature): # DEPRECATED diff --git a/deeptrack/tests/test_features.py b/deeptrack/tests/test_features.py index 591547a6..d46aa101 100644 --- a/deeptrack/tests/test_features.py +++ b/deeptrack/tests/test_features.py @@ -1949,11 +1949,48 @@ def test_AsType(self): np.all(output_image == np.array([1, 2, 3], dtype=dtype)) ) - # Test for Image. - #TODO + ### Test with PyTorch tensor (if available) + if TORCH_AVAILABLE: + input_image_torch = torch.tensor([1.5, 2.5, 3.5]) + + data_types_torch = [ + "float64", + "int32", + "int16", + "uint8", + "int8", + "torch.float64", + "torch.int32", + ] - # Test for PyTorch tensors. - #TODO + torch_dtypes_map = { + "float64": torch.float64, + "int32": torch.int32, + "int16": torch.int16, + "uint8": torch.uint8, + "int8": torch.int8, + "torch.float64": torch.float64, + "torch.int32": torch.int32, + } + + for dtype in data_types_torch: + astype_feature = features.AsType(dtype=dtype) + output_image = astype_feature.get( + input_image_torch, dtype=dtype + ) + expected_dtype = torch_dtypes_map[dtype] + self.assertEqual(output_image.dtype, expected_dtype) + + # Additional check for specific behavior of integers. + if expected_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.uint8, + ]: + # Verify that fractional parts are truncated + expected = torch.tensor([1, 2, 3], dtype=expected_dtype) + self.assertTrue(torch.equal(output_image, expected)) def test_ChannelFirst2d(self):