diff --git a/docs/source/en/model_doc/fuyu.md b/docs/source/en/model_doc/fuyu.md index 34202b022f7e..57f0de1eb244 100644 --- a/docs/source/en/model_doc/fuyu.md +++ b/docs/source/en/model_doc/fuyu.md @@ -75,11 +75,11 @@ A processor requires an image_processor and a tokenizer. Hence, inputs can be lo from PIL import Image from transformers import AutoTokenizer from transformers.models.fuyu.processing_fuyu import FuyuProcessor -from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor +from transformers.models.fuyu.image_processing_fuyu_fast import FuyuImageProcessorFast tokenizer = AutoTokenizer.from_pretrained('adept-hf-collab/fuyu-8b') -image_processor = FuyuImageProcessor() +image_processor = FuyuImageProcessorFast() processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) @@ -118,6 +118,11 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. [[autodoc]] FuyuImageProcessor - __call__ +## FuyuImageProcessor + +[[autodoc]] FuyuImageProcessorFast + - __call__ + ## FuyuProcessor [[autodoc]] FuyuProcessor diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index a145754d3209..b203d65ad7b4 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -227,6 +227,7 @@ def pad( padding_mode: Optional[str] = "constant", return_mask: bool = False, disable_grouping: Optional[bool] = False, + is_nested: Optional[bool] = False, **kwargs, ) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]: """ @@ -257,7 +258,9 @@ def pad( else: pad_size = get_max_height_width(images) - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=is_nested + ) processed_images_grouped = {} processed_masks_grouped = {} for shape, stacked_images in grouped_images.items(): @@ -280,9 +283,9 @@ def pad( stacked_masks[..., : image_size[0], : image_size[1]] = 1 processed_masks_grouped[shape] = stacked_masks - processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=is_nested) if return_mask: - processed_masks = reorder_images(processed_masks_grouped, grouped_images_index) + processed_masks = reorder_images(processed_masks_grouped, grouped_images_index, is_nested=is_nested) return processed_images, processed_masks return processed_images diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 60af0f869bad..a262d69c438d 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -98,7 +98,7 @@ ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")), ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), - ("fuyu", ("FuyuImageProcessor", None)), + ("fuyu", ("FuyuImageProcessor", "FuyuImageProcessorFast")), ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), diff --git a/src/transformers/models/fuyu/__init__.py b/src/transformers/models/fuyu/__init__.py index c2a7d252010e..eca3cf7c411b 100644 --- a/src/transformers/models/fuyu/__init__.py +++ b/src/transformers/models/fuyu/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_fuyu import * from .image_processing_fuyu import * + from .image_processing_fuyu_fast import * from .modeling_fuyu import * from .processing_fuyu import * else: diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 76c3b6130653..e86352af1bf5 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -29,6 +29,7 @@ ChannelDimension, ImageInput, PILImageResampling, + SizeDict, get_image_size, infer_channel_dimension_format, is_scaled_image, @@ -37,6 +38,7 @@ to_numpy_array, validate_preprocess_arguments, ) +from ...processing_utils import ImagesKwargs from ...utils import ( TensorType, filter_out_non_signature_kwargs, @@ -70,6 +72,21 @@ def make_list_of_list_of_images( raise ValueError("images must be a list of list of images or a list of images or an image.") +class FuyuImagesKwargs(ImagesKwargs, total=False): + r""" + patch_size (`dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. + padding_value (`float`, *optional*, defaults to 1.0): + The value to pad the image with. + padding_mode (`str`, *optional*, defaults to "constant"): + The padding mode to use when padding the image. + """ + + patch_size: Optional[SizeDict] + padding_value: float + padding_mode: str + + class FuyuBatchFeature(BatchFeature): """ BatchFeature class for Fuyu image processor and processor. @@ -232,6 +249,7 @@ class FuyuImageProcessor(BaseImageProcessor): "image_patch_indices_per_batch", "image_patch_indices_per_subsequence", ] + valid_kwargs = FuyuImagesKwargs def __init__( self, diff --git a/src/transformers/models/fuyu/image_processing_fuyu_fast.py b/src/transformers/models/fuyu/image_processing_fuyu_fast.py new file mode 100644 index 000000000000..4c9c2802e8df --- /dev/null +++ b/src/transformers/models/fuyu/image_processing_fuyu_fast.py @@ -0,0 +1,382 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Fuyu.""" + +import math +from typing import Optional, Union + +import torch + +from ...image_processing_utils import get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + ImageInput, + PILImageResampling, + SizeDict, +) +from ...utils import ( + TensorType, + auto_docstring, + is_torchvision_available, + logging, + requires_backends, +) +from .image_processing_fuyu import FuyuBatchFeature, FuyuImagesKwargs, make_list_of_list_of_images + + +if is_torchvision_available(): + from torchvision.transforms.v2 import functional as F + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class FuyuImageProcessorFast(BaseImageProcessorFast): + do_resize = True + size = {"height": 1080, "width": 1920} + resample = PILImageResampling.BILINEAR + do_pad = True + padding_value = 1.0 + padding_mode = "constant" + do_normalize = True + image_mean = 0.5 + image_std = 0.5 + do_rescale = True + rescale_factor = 1 / 255 + model_input_names = [ + "images", + "image_input_ids", + "image_patches", + "image_patch_indices_per_batch", + "image_patch_indices_per_subsequence", + ] + valid_kwargs = FuyuImagesKwargs + + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_list_of_list_of_images(images) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image to fit within `(size["height"], size["width"])` while maintaining aspect ratio. + Only resizes if the image is larger than the target size. + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the max size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BILINEAR`. + antialias (`bool`, *optional*, defaults to `True`): + Whether to apply antialiasing when resizing. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + image_height, image_width = image.shape[-2:] + target_height, target_width = size.height, size.width + # Only resize if image is larger than target + if image_width <= target_width and image_height <= target_height: + return image + # Calculate optimal scale factor to fit within target size + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + new_height = int(image_height * optimal_scale_factor) + new_width = int(image_width * optimal_scale_factor) + + return super().resize( + image, SizeDict(height=new_height, width=new_width), interpolation=interpolation, antialias=antialias + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_pad: Optional[bool], + padding_value: Optional[float], + padding_mode: Optional[str], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> FuyuBatchFeature: + # Group images by size for batched resizing + original_image_sizes = [batch_image[0].shape[-2:] for batch_image in images if batch_image] + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True) + + image_sizes = [batch_image[0].shape[-2:] for batch_image in resized_images if batch_image] + image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] + image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] + image_scale_factors = [ + [resized_size[0] / original_size[0]] + for original_size, resized_size in zip(original_image_sizes, image_sizes) + ] + if do_pad: + resized_images = self.pad( + resized_images, + pad_size=size, + fill_value=padding_value, + padding_mode=padding_mode, + disable_grouping=disable_grouping, + is_nested=True, + ) + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape( + resized_images, disable_grouping=disable_grouping, is_nested=True + ) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True) + + return FuyuBatchFeature( + data={ + "images": processed_images, + "image_unpadded_heights": image_unpadded_heights, + "image_unpadded_widths": image_unpadded_widths, + "image_scale_factors": image_scale_factors, + }, + tensor_type=return_tensors, + ) + + def get_num_patches(self, image_height: int, image_width: int, patch_size: Optional[SizeDict] = None) -> int: + """ + Calculate number of patches required to encode an image. + Args: + image_height (`int`): + Height of the image. + image_width (`int`): + Width of the image. + patch_size (`SizeDict`, *optional*): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. + """ + if patch_size is None: + patch_size = SizeDict(**self.patch_size) + patch_height, patch_width = patch_size.height, patch_size.width + if image_height % patch_height != 0: + raise ValueError(f"{image_height=} must be divisible by {patch_height}") + if image_width % patch_width != 0: + raise ValueError(f"{image_width=} must be divisible by {patch_width}") + num_patches_per_dim_h = image_height // patch_height + num_patches_per_dim_w = image_width // patch_width + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + return num_patches + + def patchify_image(self, image: torch.Tensor, patch_size: Optional[SizeDict] = None) -> torch.Tensor: + """ + Convert an image into a tensor of patches using PyTorch's unfold operation. + Args: + image (`torch.Tensor`): + Image to convert. Shape: [batch, channels, height, width] + patch_size (`SizeDict`, *optional*): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. + """ + requires_backends(self, ["torch"]) + if patch_size is None: + patch_size = SizeDict(**self.patch_size) + patch_height, patch_width = patch_size.height, patch_size.width + batch_size, channels, _, _ = image.shape + # Use unfold to extract patches + unfolded_along_height = image.unfold(2, patch_height, patch_height) + patches = unfolded_along_height.unfold(3, patch_width, patch_width) + patches = patches.contiguous() + # Reshape to [batch, num_patches, channels * patch_h * patch_w] + patches = patches.view(batch_size, channels, -1, patch_height, patch_width) + patches = patches.permute(0, 2, 3, 4, 1) + patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width) + return patches + + def preprocess_with_tokenizer_info( + self, + image_input: torch.Tensor, + image_present: torch.Tensor, + image_unpadded_h: torch.Tensor, + image_unpadded_w: torch.Tensor, + image_placeholder_id: int, + image_newline_id: int, + variable_sized: bool, + patch_size: Optional[dict[str, int]] = None, + ) -> FuyuBatchFeature: + """ + Process images for model input. In particular, variable-sized images are handled here. + + Args: + image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]): + Tensor of images padded to model input size. + image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]): + Tensor of 1s and 0s indicating whether an image is present. + image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]): + Tensor of unpadded image heights. + image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]): + Tensor of unpadded image widths. + image_placeholder_id (int): + The id of the image placeholder token. Comes from an associated tokenizer. + image_newline_id (int): + The id of the image newline token. Comes from an associated tokenizer. + variable_sized (bool): + Whether to process images as variable-sized. + patch_size (`dict[str, int]`, *optional*): + Size of the patches. + """ + requires_backends(self, ["torch"]) + + if patch_size is None: + patch_size = SizeDict(**self.patch_size) + else: + patch_size = SizeDict(**patch_size) + patch_height, patch_width = patch_size.height, patch_size.width + # Only images that are present + images: list[list[torch.Tensor]] = [] + batch_image_patches: list[list[torch.Tensor]] = [] + # Image input ids for every subsequence, including ones with no image present + batch_image_input_ids: list[list[torch.Tensor]] = [] + for batch_index in range(image_input.shape[0]): + image_input_ids = [] + image_patches = [] + for subseq_index in range(image_input.shape[1]): + if image_present[batch_index, subseq_index]: + image = image_input[batch_index, subseq_index] + image_height, image_width = image.shape[1], image.shape[2] + if variable_sized: + # Calculate new dimensions based on unpadded size + # The min() is required here due to floating point issues + new_h = min( + image_height, + math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height, + ) + new_w = min( + image_width, + math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width, + ) + image = image[:, :new_h, :new_w] + image_height, image_width = new_h, new_w + num_patches = self.get_num_patches( + image_height=image_height, image_width=image_width, patch_size=patch_size + ) + # Create tensor of placeholder IDs + tensor_of_image_ids = torch.full( + [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device + ) + # Patchify the image + patches = self.patchify_image(image=image.unsqueeze(0), patch_size=patch_size).squeeze(0) + assert num_patches == patches.shape[0] + if variable_sized: + # Terminate each line with newline ID + tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width) + newline_ids = torch.full( + [tensor_of_image_ids.shape[0], 1], + image_newline_id, + dtype=torch.int32, + device=image_input.device, + ) + tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1) + tensor_of_image_ids = tensor_of_image_ids.reshape(-1) + images.append([image]) + image_input_ids.append(tensor_of_image_ids) + image_patches.append(patches) + else: + image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device)) + batch_image_input_ids.append(image_input_ids) + batch_image_patches.append(image_patches) + # Create image patch indices + image_patch_indices_per_batch: list[list[torch.Tensor]] = [] + image_patch_indices_per_subsequence: list[list[torch.Tensor]] = [] + + for sample_image_input_ids in batch_image_input_ids: + index_offset = 0 + per_batch_indices = [] + per_subsequence_indices = [] + for subseq_image_input_ids in sample_image_input_ids: + # Indices of image patches + patches_mask = subseq_image_input_ids == image_placeholder_id + num_patches = torch.count_nonzero(patches_mask) + indices = torch.arange(num_patches, dtype=torch.int64, device=subseq_image_input_ids.device).type_as( + subseq_image_input_ids + ) + # Place those indices in the image input ids token stream, with -1 representing non-index tokens + indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1) + indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1) + patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0] + + indices_in_stream_per_batch[patches_inds] = indices + index_offset + indices_in_stream_per_subsequence[patches_inds] = indices + + per_batch_indices.append(indices_in_stream_per_batch) + per_subsequence_indices.append(indices_in_stream_per_subsequence) + index_offset += num_patches + + image_patch_indices_per_batch.append(per_batch_indices) + image_patch_indices_per_subsequence.append(per_subsequence_indices) + return FuyuBatchFeature( + data={ + "images": images, + "image_input_ids": batch_image_input_ids, + "image_patches": batch_image_patches, + "image_patch_indices_per_batch": image_patch_indices_per_batch, + "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence, + } + ) + + def _further_process_kwargs( + self, + patch_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> dict: + """ + Process Fuyu-specific kwargs before validation. + """ + kwargs = super()._further_process_kwargs(**kwargs) + if patch_size is not None: + patch_size = SizeDict(**get_size_dict(patch_size, param_name="patch_size")) + kwargs["patch_size"] = patch_size + return kwargs + + +__all__ = ["FuyuImageProcessorFast"] diff --git a/tests/models/fuyu/test_image_processing_fuyu.py b/tests/models/fuyu/test_image_processing_fuyu.py index fd9fea1f741a..24b19b01a029 100644 --- a/tests/models/fuyu/test_image_processing_fuyu.py +++ b/tests/models/fuyu/test_image_processing_fuyu.py @@ -1,63 +1,466 @@ +import io import unittest +import httpx import numpy as np +import pytest +from packaging import version -from transformers import is_torch_available, is_vision_available +from transformers.image_utils import SizeDict from transformers.testing_utils import ( require_torch, + require_torch_accelerator, require_torchvision, require_vision, + slow, + torch_device, ) +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin if is_torch_available() and is_vision_available(): import torch - from transformers import FuyuImageProcessor + from transformers import FuyuImageProcessor, FuyuImageProcessorFast if is_vision_available(): from PIL import Image +class FuyuImageProcessingTester: + def __init__( + self, + parent, + batch_size=3, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_pad=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_rescale=True, + rescale_factor=1 / 255, + patch_size=None, + ): + size = size if size is not None else {"height": 180, "width": 360} + patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = 30 + self.max_resolution = 360 + self.do_resize = do_resize + self.size = size + self.do_pad = do_pad + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.patch_size = patch_size + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_pad": self.do_pad, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "patch_size": self.patch_size, + } + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + """Prepares a batch of images for testing""" + if equal_resolution: + image_inputs = [ + np.random.randint( + 0, 256, (self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8 + ) + for _ in range(self.batch_size) + ] + else: + heights = [ + h - (h % 30) for h in np.random.randint(self.min_resolution, self.max_resolution, self.batch_size) + ] + widths = [ + w - (w % 30) for w in np.random.randint(self.min_resolution, self.max_resolution, self.batch_size) + ] + + image_inputs = [ + np.random.randint(0, 256, (self.num_channels, height, width), dtype=np.uint8) + for height, width in zip(heights, widths) + ] + + if not numpify and not torchify: + image_inputs = [Image.fromarray(np.moveaxis(img, 0, -1)) for img in image_inputs] + + if torchify: + image_inputs = [torch.from_numpy(img) for img in image_inputs] + + return image_inputs + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + @require_torch @require_vision @require_torchvision -class TestFuyuImageProcessor(unittest.TestCase): +class FuyuImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = FuyuImageProcessor + fast_image_processing_class = FuyuImageProcessorFast + + # Skip tests that expect pixel_values output + test_cast_dtype = None + def setUp(self): - self.size = {"height": 160, "width": 320} - self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0) - self.batch_size = 3 - self.channels = 3 - self.height = 300 - self.width = 300 + self.image_processor_tester = FuyuImageProcessingTester(self) + self.image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() - self.image_input = torch.rand(self.batch_size, self.channels, self.height, self.width) + # Initialize image_processor_list (from ImageProcessingTestMixin) + image_processor_list = [] + if self.test_slow_image_processor and self.image_processing_class: + image_processor_list.append(self.image_processing_class) + if self.test_fast_image_processor and self.fast_image_processing_class: + image_processor_list.append(self.fast_image_processing_class) + self.image_processor_list = image_processor_list - self.image_patch_dim_h = 30 - self.image_patch_dim_w = 30 - self.sample_image = np.zeros((450, 210, 3), dtype=np.uint8) - self.sample_image_pil = Image.fromarray(self.sample_image) + def test_call_pil(self): + """Override to handle Fuyu's custom output structure""" + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) - def test_patches(self): - expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width) + encoded_images = image_processing(image_inputs[0], return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), 1) + + encoded_images = image_processing(image_inputs, return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), self.image_processor_tester.batch_size) + + def test_call_numpy(self): + """Override to handle Fuyu's custom output structure""" + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), 1) + + encoded_images = image_processing(image_inputs, return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), self.image_processor_tester.batch_size) + + def test_call_pytorch(self): + """Override to handle Fuyu's custom output structure""" + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), 1) - patches_final = self.processor.patchify_image(image=self.image_input) - assert patches_final.shape[1] == expected_num_patches, ( - f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}." + encoded_images = image_processing(image_inputs, return_tensors="pt") + self.assertIn("images", encoded_images) + self.assertEqual(len(encoded_images.images), self.image_processor_tester.batch_size) + + def test_call_numpy_4_channels(self): + """Skip this test as Fuyu doesn't support arbitrary channels""" + self.skipTest("Fuyu processor is designed for 3-channel RGB images") + + def test_slow_fast_equivalence(self): + """Override to handle Fuyu's custom output structure""" + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + dummy_image = Image.open( + io.BytesIO( + httpx.get("http://images.cocodataset.org/val2017/000000039769.jpg", follow_redirects=True).content + ) ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + + self._assert_slow_fast_tensors_equivalence(encoding_slow.images[0][0], encoding_fast.images[0][0]) + + def test_slow_fast_equivalence_batched(self): + """Override to handle Fuyu's custom output structure""" + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + # Compare each image tensor + for slow_img, fast_img in zip(encoding_slow.images, encoding_fast.images): + self._assert_slow_fast_tensors_equivalence(slow_img[0], fast_img[0]) + + @slow + @require_torch_accelerator + @require_vision + @pytest.mark.torch_compile_test + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence( + output_eager.images[0][0], output_compiled.images[0][0], atol=1e-4, rtol=1e-4, mean_atol=1e-5 + ) + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_pad")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_rescale")) + self.assertTrue(hasattr(image_processor, "rescale_factor")) + self.assertTrue(hasattr(image_processor, "patch_size")) + + def test_patches(self): + """Test that patchify_image produces the expected number of patches.""" + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + batch_size = 3 + channels = 3 + height = 300 + width = 300 + image_input = torch.rand(batch_size, channels, height, width) + + expected_num_patches = image_processor.get_num_patches(image_height=height, image_width=width) + patches_final = image_processor.patchify_image(image=image_input) + + self.assertEqual(patches_final.shape[1], expected_num_patches) + + def test_patches_match_slow_fast(self): + """Test that fast processor produces same patches as slow processor.""" + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast patch equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest( + reason="Skipping slow/fast patch equivalence test as one of the image processors is not defined" + ) + + batch_size = 3 + channels = 3 + height = 300 + width = 300 + image_input = torch.rand(batch_size, channels, height, width) + + processor_slow = self.image_processing_class(**self.image_processor_dict) + processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + patches_fast = processor_fast.patchify_image(image=image_input) + patches_slow = processor_slow.patchify_image(image=image_input) + + self.assertEqual(patches_fast.shape, patches_slow.shape) + torch.testing.assert_close(patches_fast, patches_slow, rtol=1e-4, atol=1e-4) def test_scale_to_target_aspect_ratio(self): - # (h:450, w:210) fitting (160, 320) -> (160, 210*160/450) - scaled_image = self.processor.resize(self.sample_image, size=self.size) - self.assertEqual(scaled_image.shape[0], 160) - self.assertEqual(scaled_image.shape[1], 74) + """Test that resize maintains aspect ratio correctly.""" + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + + if self.test_slow_image_processor and self.image_processing_class: + image_processor = self.image_processing_class(**self.image_processor_dict) + scaled_image = image_processor.resize(sample_image, size=self.image_processor_dict["size"]) + self.assertEqual(scaled_image.shape[0], 180) + self.assertEqual(scaled_image.shape[1], 84) + + if self.test_fast_image_processor and self.fast_image_processing_class: + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + sample_tensor = torch.from_numpy(sample_image).permute(2, 0, 1).float() + + size_dict = SizeDict( + height=self.image_processor_dict["size"]["height"], width=self.image_processor_dict["size"]["width"] + ) + scaled_image = image_processor_fast.resize(sample_tensor, size=size_dict) + + self.assertEqual(scaled_image.shape[1], 180) + self.assertEqual(scaled_image.shape[2], 84) def test_apply_transformation_numpy(self): - transformed_image = self.processor.preprocess(self.sample_image).images[0][0] - self.assertEqual(transformed_image.shape[1], 160) - self.assertEqual(transformed_image.shape[2], 320) + """Test preprocessing with numpy input.""" + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + transformed_image = image_processor.preprocess(sample_image).images[0][0] + self.assertEqual(transformed_image.shape[1], 180) + self.assertEqual(transformed_image.shape[2], 360) def test_apply_transformation_pil(self): - transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0] - self.assertEqual(transformed_image.shape[1], 160) - self.assertEqual(transformed_image.shape[2], 320) + """Test preprocessing with PIL input.""" + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + sample_image_pil = Image.fromarray(sample_image) + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + transformed_image = image_processor.preprocess(sample_image_pil).images[0][0] + self.assertEqual(transformed_image.shape[1], 180) + self.assertEqual(transformed_image.shape[2], 360) + + def test_preprocess_output_structure(self): + """Test that preprocess returns correct output structure.""" + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + result = image_processor.preprocess(sample_image) + + self.assertIn("images", result) + self.assertIn("image_unpadded_heights", result) + self.assertIn("image_unpadded_widths", result) + self.assertIn("image_scale_factors", result) + + self.assertEqual(len(result.images), 1) + self.assertEqual(len(result.images[0]), 1) + self.assertEqual(len(result.image_unpadded_heights), 1) + self.assertEqual(len(result.image_unpadded_widths), 1) + self.assertEqual(len(result.image_scale_factors), 1) + + def test_batch_processing(self): + """Test processing multiple images.""" + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + sample_image_pil = Image.fromarray(sample_image) + images = [sample_image, sample_image_pil] + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + result = image_processor.preprocess(images) + + self.assertEqual(len(result.images), 2) + for img in result.images: + self.assertEqual(len(img), 1) + if hasattr(img[0], "shape"): + if len(img[0].shape) == 3: + self.assertEqual(img[0].shape[1], 180) + self.assertEqual(img[0].shape[2], 360) + + def test_pad_image_fast(self): + """Test that padding works correctly for fast processor.""" + if not self.test_fast_image_processor or self.fast_image_processing_class is None: + self.skipTest(reason="Fast processor not available") + + from transformers.image_utils import SizeDict + + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + small_image = torch.rand(3, 100, 100) + size_dict = SizeDict(height=180, width=360) + + padded = image_processor_fast.pad([small_image], pad_size=size_dict, fill_value=1.0)[0] + self.assertEqual(padded.shape[1], 180) + self.assertEqual(padded.shape[2], 360) + + self.assertTrue(torch.allclose(padded[:, 100:, :], torch.ones_like(padded[:, 100:, :]))) + self.assertTrue(torch.allclose(padded[:, :, 100:], torch.ones_like(padded[:, :, 100:]))) + + def test_preprocess_with_tokenizer_info(self): + """Test preprocess_with_tokenizer_info functionality.""" + batch_size = 2 + subseq_size = 1 + channels = 3 + image_input = torch.rand(batch_size, subseq_size, channels, 180, 360) + image_present = torch.ones(batch_size, subseq_size, dtype=torch.bool) + image_unpadded_h = torch.tensor([[180], [180]]) + image_unpadded_w = torch.tensor([[360], [360]]) + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + + result = image_processor.preprocess_with_tokenizer_info( + image_input=image_input, + image_present=image_present, + image_unpadded_h=image_unpadded_h, + image_unpadded_w=image_unpadded_w, + image_placeholder_id=100, + image_newline_id=101, + variable_sized=True, + ) + + # Check output structure + self.assertIn("images", result) + self.assertIn("image_input_ids", result) + self.assertIn("image_patches", result) + self.assertIn("image_patch_indices_per_batch", result) + self.assertIn("image_patch_indices_per_subsequence", result) + + # Check batch structure + self.assertEqual(len(result.images), batch_size) + self.assertEqual(len(result.image_input_ids), batch_size) + self.assertEqual(len(result.image_patches), batch_size) + + def test_device_handling_fast(self): + """Test that fast processor can handle device placement.""" + if not self.test_fast_image_processor or self.fast_image_processing_class is None: + self.skipTest(reason="Fast processor not available") + + sample_image = np.zeros((450, 210, 3), dtype=np.uint8) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + if torch.cuda.is_available(): + result_cuda = image_processor_fast.preprocess(sample_image, device="cuda") + self.assertEqual(result_cuda.images[0][0].device.type, "cuda") + + result_cpu = image_processor_fast.preprocess(sample_image, device="cpu") + self.assertEqual(result_cpu.images[0][0].device.type, "cpu") + + def test_do_not_resize_if_smaller(self): + """Test that images smaller than target size are not resized.""" + if not self.test_fast_image_processor or self.fast_image_processing_class is None: + self.skipTest(reason="Fast processor not available") + + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + small_image = torch.rand(3, 100, 150) + size_dict = SizeDict(height=180, width=360) + + resized = image_processor_fast.resize(small_image, size=size_dict) + + self.assertEqual(resized.shape[1], 100) + self.assertEqual(resized.shape[2], 150)