diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..810f8fa921 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_hub import export as export from keras_hub import layers as layers from keras_hub import metrics as metrics from keras_hub import models as models diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py new file mode 100644 index 0000000000..fccc068e3d --- /dev/null +++ b/keras_hub/api/export/__init__.py @@ -0,0 +1,25 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_hub.src.export.configs import ( + CausalLMExporterConfig as CausalLMExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageClassifierExporterConfig as ImageClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageSegmenterExporterConfig as ImageSegmenterExporterConfig, +) +from keras_hub.src.export.configs import ( + ObjectDetectorExporterConfig as ObjectDetectorExporterConfig, +) +from keras_hub.src.export.configs import ( + Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, +) +from keras_hub.src.export.configs import ( + TextClassifierExporterConfig as TextClassifierExporterConfig, +) +from keras_hub.src.export.litert import LiteRTExporter as LiteRTExporter diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py new file mode 100644 index 0000000000..25d8d27f36 --- /dev/null +++ b/keras_hub/src/export/__init__.py @@ -0,0 +1,9 @@ +# Export base classes and configurations for advanced usage +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.export.base import KerasHubExporterConfig +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import get_exporter_config +from keras_hub.src.export.litert import LiteRTExporter +from keras_hub.src.export.litert import export_litert diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py new file mode 100644 index 0000000000..da6634c14c --- /dev/null +++ b/keras_hub/src/export/base.py @@ -0,0 +1,120 @@ +"""Base classes for Keras-Hub model exporters. + +This module provides the foundation for exporting Keras-Hub models to various +formats. It defines the abstract base classes that all exporters must implement. +""" + +from abc import ABC +from abc import abstractmethod + + +class KerasHubExporterConfig(ABC): + """Base configuration class for Keras-Hub model exporters. + + This class defines the interface for exporter configurations that specify + how different types of Keras-Hub models should be exported. + """ + + # Model type this exporter handles (e.g., "causal_lm", "text_classifier") + MODEL_TYPE = None + + # Expected input structure for this model type + EXPECTED_INPUTS = [] + + def __init__(self, model, **kwargs): + """Initialize the exporter configuration. + + Args: + model: `keras.Model`. The Keras-Hub model to export. + **kwargs: Additional configuration parameters. + """ + self.model = model + self.config_kwargs = kwargs + self._validate_model() + + def _validate_model(self): + """Validate that the model is compatible with this exporter.""" + if not self._is_model_compatible(): + raise ValueError( + f"Model {self.model.__class__.__name__} is not compatible " + f"with {self.__class__.__name__}" + ) + + @abstractmethod + def _is_model_compatible(self): + """Check if the model is compatible with this exporter. + + Returns: + `bool`. True if compatible, False otherwise + """ + pass + + @abstractmethod + def get_input_signature(self, sequence_length=None): + """Get the input signature for this model type. + + Args: + sequence_length: `int` or `None`. Optional sequence length for + input tensors. + + Returns: + `dict`. Dictionary mapping input names to tensor specifications. + """ + pass + + +class KerasHubExporter(ABC): + """Base class for Keras-Hub model exporters. + + This class provides the common interface for exporting Keras-Hub models + to different formats (LiteRT, ONNX, etc.). + """ + + def __init__(self, config, **kwargs): + """Initialize the exporter. + + Args: + config: `KerasHubExporterConfig`. Exporter configuration specifying + model type and parameters. + **kwargs: Additional exporter-specific parameters. + """ + self.config = config + self.model = config.model + self.export_kwargs = kwargs + + @abstractmethod + def export(self, filepath): + """Export the model to the specified filepath. + + Args: + filepath: `str`. Path where to save the exported model. + """ + pass + + def _ensure_model_built(self, param=None): + """Ensure the model is properly built with correct input structure. + + This method builds the model using model.build() with input shapes. + This creates the necessary variables and initializes the model structure + for export without needing dummy data. + + Args: + param: `int` or `None`. Optional parameter for input signature + (e.g., sequence_length for text models, image_size for vision + models). + """ + # Get input signature (returns dict of InputSpec objects) + input_signature = self.config.get_input_signature(param) + + # Extract shapes from InputSpec objects + input_shapes = {} + for name, spec in input_signature.items(): + if hasattr(spec, "shape"): + input_shapes[name] = spec.shape + else: + # Fallback for unexpected formats + input_shapes[name] = spec + + # Build the model using shapes only (no actual data allocation) + # This creates variables and initializes the model structure + self.model.build(input_shape=input_shapes) diff --git a/keras_hub/src/export/base_test.py b/keras_hub/src/export/base_test.py new file mode 100644 index 0000000000..02d22b51fd --- /dev/null +++ b/keras_hub/src/export/base_test.py @@ -0,0 +1,197 @@ +"""Tests for base export classes.""" + +import keras + +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.export.base import KerasHubExporterConfig +from keras_hub.src.tests.test_case import TestCase + + +class DummyExporterConfig(KerasHubExporterConfig): + """Dummy configuration for testing.""" + + MODEL_TYPE = "test_model" + EXPECTED_INPUTS = ["input_ids", "attention_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def __init__(self, model, compatible=True, **kwargs): + self.is_compatible = compatible + super().__init__(model, **kwargs) + + def _is_model_compatible(self): + return self.is_compatible + + def get_input_signature(self, sequence_length=None): + seq_len = sequence_length or self.DEFAULT_SEQUENCE_LENGTH + return { + "input_ids": keras.layers.InputSpec( + shape=(None, seq_len), dtype="int32" + ), + "attention_mask": keras.layers.InputSpec( + shape=(None, seq_len), dtype="int32" + ), + } + + +class DummyExporter(KerasHubExporter): + """Dummy exporter for testing.""" + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.exported = False + self.export_path = None + + def export(self, filepath): + self.exported = True + self.export_path = filepath + return filepath + + +class KerasHubExporterConfigTest(TestCase): + """Tests for KerasHubExporterConfig base class.""" + + def test_init_with_compatible_model(self): + """Test initialization with a compatible model.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig(model, compatible=True) + + self.assertEqual(config.model, model) + self.assertEqual(config.MODEL_TYPE, "test_model") + self.assertEqual( + config.EXPECTED_INPUTS, ["input_ids", "attention_mask"] + ) + + def test_init_with_incompatible_model_raises_error(self): + """Test that incompatible model raises ValueError.""" + model = keras.Sequential([keras.layers.Dense(10)]) + + with self.assertRaisesRegex(ValueError, "not compatible"): + DummyExporterConfig(model, compatible=False) + + def test_get_input_signature_default_sequence_length(self): + """Test get_input_signature with default sequence length.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig(model) + + signature = config.get_input_signature() + + self.assertIn("input_ids", signature) + self.assertIn("attention_mask", signature) + self.assertEqual(signature["input_ids"].shape, (None, 128)) + self.assertEqual(signature["attention_mask"].shape, (None, 128)) + + def test_get_input_signature_custom_sequence_length(self): + """Test get_input_signature with custom sequence length.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig(model) + + signature = config.get_input_signature(sequence_length=256) + + self.assertEqual(signature["input_ids"].shape, (None, 256)) + self.assertEqual(signature["attention_mask"].shape, (None, 256)) + + def test_config_kwargs_stored(self): + """Test that additional kwargs are stored.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig( + model, custom_param="value", another_param=42 + ) + + self.assertEqual(config.config_kwargs["custom_param"], "value") + self.assertEqual(config.config_kwargs["another_param"], 42) + + +class KerasHubExporterTest(TestCase): + """Tests for KerasHubExporter base class.""" + + def test_init_stores_config_and_model(self): + """Test that initialization stores config and model.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig(model) + exporter = DummyExporter(config, verbose=True, custom_param="test") + + self.assertEqual(exporter.config, config) + self.assertEqual(exporter.model, model) + self.assertEqual(exporter.export_kwargs["verbose"], True) + self.assertEqual(exporter.export_kwargs["custom_param"], "test") + + def test_export_method_called(self): + """Test that export method can be called.""" + model = keras.Sequential([keras.layers.Dense(10)]) + config = DummyExporterConfig(model) + exporter = DummyExporter(config) + + result = exporter.export("/tmp/test_model") + + self.assertTrue(exporter.exported) + self.assertEqual(exporter.export_path, "/tmp/test_model") + self.assertEqual(result, "/tmp/test_model") + + def test_ensure_model_built(self): + """Test _ensure_model_built method.""" + + class TestModel(keras.Model): + def __init__(self): + super().__init__() + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["input_ids"]) + + model = TestModel() + config = DummyExporterConfig(model) + exporter = DummyExporter(config) + + # Model should not be built initially + self.assertFalse(model.built) + + # Call _ensure_model_built + exporter._ensure_model_built() + + # Model should now be built + self.assertTrue(model.built) + + def test_ensure_model_built_with_custom_param(self): + """Test _ensure_model_built with custom sequence length.""" + + class TestModel(keras.Model): + def __init__(self): + super().__init__() + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["input_ids"]) + + model = TestModel() + config = DummyExporterConfig(model) + exporter = DummyExporter(config) + + # Call with custom sequence length + exporter._ensure_model_built(param=512) + + # Verify model is built + self.assertTrue(model.built) + + def test_ensure_model_built_already_built_model(self): + """Test _ensure_model_built with already built model.""" + + class TestModel(keras.Model): + def __init__(self): + super().__init__() + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["input_ids"]) + + model = TestModel() + # Pre-build the model + model.build(input_shape={"input_ids": (None, 128)}) + + config = DummyExporterConfig(model) + exporter = DummyExporter(config) + + # Should not raise an error for already built model + exporter._ensure_model_built() + + # Model should still be built + self.assertTrue(model.built) diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py new file mode 100644 index 0000000000..706e25e756 --- /dev/null +++ b/keras_hub/src/export/configs.py @@ -0,0 +1,409 @@ +"""Configuration classes for different Keras-Hub model types. + +This module provides specific configurations for exporting different types +of Keras-Hub models, following the Optimum pattern. +""" + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporterConfig +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.text_classifier import TextClassifier + + +def _get_text_input_signature(model, sequence_length=128): + """Get input signature for text models with token_ids and padding_mask. + + Args: + model: The model instance. + sequence_length: `int`. Sequence length (default: 128). + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), dtype="int32", name="token_ids" + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="padding_mask", + ), + } + + +def _get_seq2seq_input_signature(model, sequence_length=128): + """Get input signature for seq2seq models with encoder/decoder tokens. + + Args: + model: The model instance. + sequence_length: `int`. Sequence length (default: 128). + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + return { + "encoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="encoder_token_ids", + ), + "encoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="encoder_padding_mask", + ), + "decoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="decoder_token_ids", + ), + "decoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype="int32", + name="decoder_padding_mask", + ), + } + + +def _infer_sequence_length(model, default_length): + """Infer sequence length from model preprocessor or use default. + + Args: + model: The model instance. + default_length: `int`. Default sequence length to use if not found. + + Returns: + `int`. Sequence length from preprocessor or default. + """ + if hasattr(model, "preprocessor") and model.preprocessor: + return getattr( + model.preprocessor, + "sequence_length", + default_length, + ) + return default_length + + +def _infer_image_size(model): + """Infer image size from model preprocessor or inputs. + + Args: + model: The model instance. + + Returns: + `tuple`. Image size as (height, width). + + Raises: + ValueError: If image_size cannot be determined. + """ + image_size = None + + # Get from preprocessor + if hasattr(model, "preprocessor") and model.preprocessor: + if hasattr(model.preprocessor, "image_size"): + image_size = model.preprocessor.image_size + + # Try to infer from model inputs + if image_size is None and hasattr(model, "inputs") and model.inputs: + input_shape = model.inputs[0].shape + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): + image_size = (input_shape[1], input_shape[2]) + + if image_size is None: + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size " + "attribute, or model inputs should have concrete shapes." + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + return image_size + + +def _infer_image_dtype(model): + """Infer image dtype from model inputs. + + Args: + model: The model instance. + + Returns: + `str`. Image dtype (defaults to "float32"). + """ + if hasattr(model, "inputs") and model.inputs: + model_dtype = model.inputs[0].dtype + return model_dtype.name if hasattr(model_dtype, "name") else model_dtype + return "float32" + + +@keras_hub_export("keras_hub.export.CausalLMExporterConfig") +class CausalLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" + + MODEL_TYPE = "causal_lm" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a causal language model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, CausalLM) + + def get_input_signature(self, sequence_length=None): + """Get input signature for causal LM models. + + Args: + sequence_length: `int` or `None`. Optional sequence length. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if sequence_length is None: + sequence_length = _infer_sequence_length( + self.model, self.DEFAULT_SEQUENCE_LENGTH + ) + + return _get_text_input_signature(self.model, sequence_length) + + +@keras_hub_export("keras_hub.export.TextClassifierExporterConfig") +class TextClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Text Classification models.""" + + MODEL_TYPE = "text_classifier" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a text classifier. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, TextClassifier) + + def get_input_signature(self, sequence_length=None): + """Get input signature for text classifier models. + + Args: + sequence_length: `int` or `None`. Optional sequence length. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if sequence_length is None: + sequence_length = _infer_sequence_length( + self.model, self.DEFAULT_SEQUENCE_LENGTH + ) + + return _get_text_input_signature(self.model, sequence_length) + + +@keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") +class Seq2SeqLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Sequence-to-Sequence Language Models.""" + + MODEL_TYPE = "seq2seq_lm" + EXPECTED_INPUTS = [ + "encoder_token_ids", + "encoder_padding_mask", + "decoder_token_ids", + "decoder_padding_mask", + ] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self): + """Check if model is a seq2seq language model. + + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, Seq2SeqLM) + + def get_input_signature(self, sequence_length=None): + """Get input signature for seq2seq models. + + Args: + sequence_length: `int` or `None`. Optional sequence length. + + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if sequence_length is None: + sequence_length = _infer_sequence_length( + self.model, self.DEFAULT_SEQUENCE_LENGTH + ) + + return _get_seq2seq_input_signature(self.model, sequence_length) + + +@keras_hub_export("keras_hub.export.ImageClassifierExporterConfig") +class ImageClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Classification models.""" + + MODEL_TYPE = "image_classifier" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image classifier. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ImageClassifier) + + def get_input_signature(self, image_size=None): + """Get input signature for image classifier models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + } + + +@keras_hub_export("keras_hub.export.ObjectDetectorExporterConfig") +class ObjectDetectorExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Object Detection models.""" + + MODEL_TYPE = "object_detector" + EXPECTED_INPUTS = ["images", "image_shape"] + + def _is_model_compatible(self): + """Check if model is an object detector. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ObjectDetector) + + def get_input_signature(self, image_size=None): + """Get input signature for object detector models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + "image_shape": keras.layers.InputSpec( + shape=(None, 2), dtype="int32", name="image_shape" + ), + } + + +@keras_hub_export("keras_hub.export.ImageSegmenterExporterConfig") +class ImageSegmenterExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Segmentation models.""" + + MODEL_TYPE = "image_segmenter" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image segmenter. + Returns: + `bool`. True if compatible, False otherwise + """ + return isinstance(self.model, ImageSegmenter) + + def get_input_signature(self, image_size=None): + """Get input signature for image segmenter models. + Args: + image_size: `int`, `tuple` or `None`. Optional image size. + Returns: + `dict`. Dictionary mapping input names to their specifications + """ + if image_size is None: + image_size = _infer_image_size(self.model) + elif isinstance(image_size, int): + image_size = (image_size, image_size) + + dtype = _infer_image_dtype(self.model) + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), + dtype=dtype, + name="images", + ), + } + + +def get_exporter_config(model): + """Get the appropriate exporter configuration for a model instance. + + This function automatically detects the model type and returns the + corresponding exporter configuration. + + Args: + model: A Keras-Hub model instance (e.g., CausalLM, TextClassifier). + + Returns: + An instance of the appropriate KerasHubExporterConfig subclass. + + Raises: + ValueError: If the model type is not supported for export. + """ + # Mapping of model classes to their config classes + # NOTE: Order matters! Seq2SeqLM must be checked before CausalLM + # since Seq2SeqLM is a subclass of CausalLM + _MODEL_TYPE_TO_CONFIG = { + Seq2SeqLM: Seq2SeqLMExporterConfig, + CausalLM: CausalLMExporterConfig, + TextClassifier: TextClassifierExporterConfig, + ImageClassifier: ImageClassifierExporterConfig, + ObjectDetector: ObjectDetectorExporterConfig, + ImageSegmenter: ImageSegmenterExporterConfig, + } + + # Find matching config class + for model_class, config_class in _MODEL_TYPE_TO_CONFIG.items(): + if isinstance(model, model_class): + return config_class(model) + + # Model type not supported + supported_types = ", ".join( + cls.__name__ for cls in _MODEL_TYPE_TO_CONFIG.keys() + ) + raise ValueError( + f"Could not find exporter config for model type " + f"'{model.__class__.__name__}'. " + f"Supported types: {supported_types}" + ) diff --git a/keras_hub/src/export/configs_test.py b/keras_hub/src/export/configs_test.py new file mode 100644 index 0000000000..d618e97c69 --- /dev/null +++ b/keras_hub/src/export/configs_test.py @@ -0,0 +1,279 @@ +"""Tests for export configuration classes.""" + +import keras + +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import ImageClassifierExporterConfig +from keras_hub.src.export.configs import ImageSegmenterExporterConfig +from keras_hub.src.export.configs import ObjectDetectorExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.tests.test_case import TestCase + + +class MockPreprocessor: + """Mock preprocessor for testing.""" + + def __init__(self, sequence_length=None, image_size=None): + if sequence_length is not None: + self.sequence_length = sequence_length + if image_size is not None: + self.image_size = image_size + + +class MockCausalLM(keras.Model): + """Mock Causal LM model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["token_ids"]) + + +class MockTextClassifier(keras.Model): + """Mock Text Classifier model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(5) + + def call(self, inputs): + return self.dense(inputs["token_ids"]) + + +class MockImageClassifier(keras.Model): + """Mock Image Classifier model for testing.""" + + def __init__(self, preprocessor=None): + super().__init__() + self.preprocessor = preprocessor + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + return self.dense(inputs) + + +class CausalLMExporterConfigTest(TestCase): + """Tests for CausalLMExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "causal_lm") + self.assertEqual(config.EXPECTED_INPUTS, ["token_ids", "padding_mask"]) + + def test_get_input_signature_default(self): + """Test get_input_signature with default sequence length.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("token_ids", signature) + self.assertIn("padding_mask", signature) + self.assertEqual(signature["token_ids"].shape, (None, 128)) + self.assertEqual(signature["padding_mask"].shape, (None, 128)) + + def test_get_input_signature_from_preprocessor(self): + """Test get_input_signature infers from preprocessor.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(sequence_length=256) + model = MockCausalLMForTest(preprocessor) + config = CausalLMExporterConfig(model) + signature = config.get_input_signature() + + # Should use preprocessor's sequence length + self.assertEqual(signature["token_ids"].shape, (None, 256)) + self.assertEqual(signature["padding_mask"].shape, (None, 256)) + + def test_get_input_signature_custom_length(self): + """Test get_input_signature with custom sequence length.""" + from keras_hub.src.models.causal_lm import CausalLM + + class MockCausalLMForTest(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockCausalLMForTest() + config = CausalLMExporterConfig(model) + signature = config.get_input_signature(sequence_length=512) + + # Should use provided sequence length + self.assertEqual(signature["token_ids"].shape, (None, 512)) + self.assertEqual(signature["padding_mask"].shape, (None, 512)) + + +class TextClassifierExporterConfigTest(TestCase): + """Tests for TextClassifierExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.text_classifier import TextClassifier + + class MockTextClassifierForTest(TextClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockTextClassifierForTest() + config = TextClassifierExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "text_classifier") + self.assertEqual(config.EXPECTED_INPUTS, ["token_ids", "padding_mask"]) + + def test_get_input_signature_default(self): + """Test get_input_signature with default sequence length.""" + from keras_hub.src.models.text_classifier import TextClassifier + + class MockTextClassifierForTest(TextClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockTextClassifierForTest() + config = TextClassifierExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("token_ids", signature) + self.assertIn("padding_mask", signature) + self.assertEqual(signature["token_ids"].shape, (None, 128)) + + +class ImageClassifierExporterConfigTest(TestCase): + """Tests for ImageClassifierExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.image_classifier import ImageClassifier + + class MockImageClassifierForTest(ImageClassifier): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockImageClassifierForTest() + config = ImageClassifierExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "image_classifier") + self.assertEqual(config.EXPECTED_INPUTS, ["images"]) + + def test_get_input_signature_with_preprocessor(self): + """Test get_input_signature infers image size from preprocessor.""" + from keras_hub.src.models.image_classifier import ImageClassifier + + class MockImageClassifierForTest(ImageClassifier): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(image_size=(224, 224)) + model = MockImageClassifierForTest(preprocessor) + config = ImageClassifierExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("images", signature) + # Image shape should be (batch, height, width, channels) + expected_shape = (None, 224, 224, 3) + self.assertEqual(signature["images"].shape, expected_shape) + + +class Seq2SeqLMExporterConfigTest(TestCase): + """Tests for Seq2SeqLMExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + + class MockSeq2SeqLMForTest(Seq2SeqLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockSeq2SeqLMForTest() + config = Seq2SeqLMExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "seq2seq_lm") + # Seq2Seq models have both encoder and decoder inputs + self.assertIn("encoder_token_ids", config.EXPECTED_INPUTS) + self.assertIn("decoder_token_ids", config.EXPECTED_INPUTS) + + +class ObjectDetectorExporterConfigTest(TestCase): + """Tests for ObjectDetectorExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.object_detector import ObjectDetector + + class MockObjectDetectorForTest(ObjectDetector): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockObjectDetectorForTest() + config = ObjectDetectorExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "object_detector") + self.assertEqual(config.EXPECTED_INPUTS, ["images", "image_shape"]) + + def test_get_input_signature_with_preprocessor(self): + """Test get_input_signature infers from preprocessor.""" + from keras_hub.src.models.object_detector import ObjectDetector + + class MockObjectDetectorForTest(ObjectDetector): + def __init__(self, preprocessor): + keras.Model.__init__(self) + self.preprocessor = preprocessor + + preprocessor = MockPreprocessor(image_size=(512, 512)) + model = MockObjectDetectorForTest(preprocessor) + config = ObjectDetectorExporterConfig(model) + signature = config.get_input_signature() + + self.assertIn("images", signature) + self.assertIn("image_shape", signature) + # Images shape should be (batch, height, width, channels) + self.assertEqual(signature["images"].shape, (None, 512, 512, 3)) + # Image shape is (batch, 2) for (height, width) + self.assertEqual(signature["image_shape"].shape, (None, 2)) + self.assertEqual(signature["image_shape"].dtype, "int32") + + +class ImageSegmenterExporterConfigTest(TestCase): + """Tests for ImageSegmenterExporterConfig class.""" + + def test_model_type_and_expected_inputs(self): + """Test MODEL_TYPE and EXPECTED_INPUTS are correctly set.""" + from keras_hub.src.models.image_segmenter import ImageSegmenter + + class MockImageSegmenterForTest(ImageSegmenter): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + + model = MockImageSegmenterForTest() + config = ImageSegmenterExporterConfig(model) + self.assertEqual(config.MODEL_TYPE, "image_segmenter") + self.assertEqual(config.EXPECTED_INPUTS, ["images"]) diff --git a/keras_hub/src/export/litert.py b/keras_hub/src/export/litert.py new file mode 100644 index 0000000000..1c3f6df5f8 --- /dev/null +++ b/keras_hub/src/export/litert.py @@ -0,0 +1,331 @@ +"""LiteRT exporter for Keras-Hub models. + +This module provides LiteRT export functionality specifically designed for +Keras-Hub models, handling their unique input structures and requirements. +""" + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.text_classifier import TextClassifier + +try: + from keras.src.export.litert import LiteRTExporter as KerasLitertExporter + + KERAS_LITE_RT_AVAILABLE = True +except ImportError: + KERAS_LITE_RT_AVAILABLE = False + KerasLitertExporter = None + + +@keras_hub_export("keras_hub.export.LiteRTExporter") +class LiteRTExporter(KerasHubExporter): + """LiteRT exporter for Keras-Hub models. + + This exporter handles the conversion of Keras-Hub models to TensorFlow Lite + format, properly managing the dictionary input structures that Keras-Hub + models expect. + """ + + def __init__( + self, + config, + max_sequence_length=None, + aot_compile_targets=None, + verbose=None, + **kwargs, + ): + """Initialize the LiteRT exporter. + + Args: + config: `KerasHubExporterConfig`. Exporter configuration. + max_sequence_length: `int` or `None`. Maximum sequence length. + aot_compile_targets: `list` or `None`. AOT compilation targets. + verbose: `bool` or `None`. Whether to print progress. Defaults to + `None`, which will use `True`. + **kwargs: `dict`. Additional arguments passed to exporter. + """ + super().__init__(config, **kwargs) + + if not KERAS_LITE_RT_AVAILABLE: + raise ImportError( + "Keras LiteRT exporter is not available. " + "Make sure you have Keras with LiteRT support installed." + ) + + self.max_sequence_length = max_sequence_length + self.aot_compile_targets = aot_compile_targets + self.verbose = verbose if verbose is not None else True + + def _get_model_adapter_class(self): + """Determine the appropriate adapter class for the model. + + Returns: + `str`. The adapter type to use ("text" or "image"). + + Raises: + ValueError: If the model type is not supported for LiteRT export. + """ + if isinstance(self.model, (CausalLM, TextClassifier, Seq2SeqLM)): + return "text" + elif isinstance( + self.model, (ImageClassifier, ObjectDetector, ImageSegmenter) + ): + return "image" + else: + # For other model types (audio, multimodal, custom, etc.) + raise ValueError( + f"Model type {self.model.__class__.__name__} is not supported " + "for LiteRT export. Currently supported model types are: " + "CausalLM, TextClassifier, Seq2SeqLM, ImageClassifier, " + "ObjectDetector, ImageSegmenter." + ) + + def _get_export_param(self): + """Get the appropriate parameter for export based on model type. + + Returns: + The parameter to use for export (sequence_length for text models, + image_size for image models, or None for other model types). + """ + if isinstance(self.model, (CausalLM, TextClassifier, Seq2SeqLM)): + # For text models, use sequence_length + return self.max_sequence_length + elif isinstance( + self.model, (ImageClassifier, ObjectDetector, ImageSegmenter) + ): + # For image models, get image_size from preprocessor + if hasattr(self.model, "preprocessor") and hasattr( + self.model.preprocessor, "image_size" + ): + return self.model.preprocessor.image_size + else: + return None # Will use default in get_input_signature + else: + # For other model types (audio, multimodal, custom, etc.) + return None + + def export(self, filepath): + """Export the Keras-Hub model to LiteRT format. + + Args: + filepath: `str`. Path where to save the model. If it doesn't end + with '.tflite', the extension will be added automatically. + """ + from keras.src.utils import io_utils + + # Ensure filepath ends with .tflite + if not filepath.endswith(".tflite"): + filepath = filepath + ".tflite" + + if self.verbose: + io_utils.print_msg( + f"Starting LiteRT export for {self.model.__class__.__name__}" + ) + + # Get export parameter based on model type + param = self._get_export_param() + + # Ensure model is built + self._ensure_model_built(param) + + # Get input signature + input_signature = self.config.get_input_signature(param) + + # Get adapter class type for this model + adapter_type = self._get_model_adapter_class() + + # Create a wrapper that adapts the Keras-Hub model to work with Keras + # LiteRT exporter + wrapped_model = self._create_export_wrapper(param, adapter_type) + + # Convert input signature to list format expected by Keras exporter + if isinstance(input_signature, dict): + # Extract specs in the order expected by the model + signature_list = [] + for input_name in self.config.EXPECTED_INPUTS: + if input_name in input_signature: + signature_list.append(input_signature[input_name]) + input_signature = signature_list + + # Create the Keras LiteRT exporter with the wrapped model + keras_exporter = KerasLitertExporter( + wrapped_model, + input_signature=input_signature, + aot_compile_targets=self.aot_compile_targets, + verbose=self.verbose, + **self.export_kwargs, + ) + + try: + # Export using the Keras exporter + keras_exporter.export(filepath) + + if self.verbose: + io_utils.print_msg( + f"Export completed successfully to: {filepath}" + ) + + except Exception as e: + raise RuntimeError(f"LiteRT export failed: {e}") from e + + def _create_export_wrapper(self, param, adapter_type): + """Create a wrapper model that handles the input structure conversion. + + This creates a type-specific adapter that converts between the + list-based inputs that Keras LiteRT exporter provides and the format + expected by Keras-Hub models. + + Args: + param: The parameter for input signature (sequence_length for text + models, image_size for image models, or None for other types). + adapter_type: `str`. The type of adapter to use - "text", "image", + or "base". + """ + + class BaseModelAdapter(keras.Model): + """Base adapter for Keras-Hub models.""" + + def __init__( + self, keras_hub_model, expected_inputs, input_signature + ): + super().__init__() + self.keras_hub_model = keras_hub_model + self.expected_inputs = expected_inputs + self.input_signature = input_signature + + # Create Input layers based on the input signature + self._input_layers = [] + for input_name in expected_inputs: + if input_name in input_signature: + spec = input_signature[input_name] + input_layer = keras.layers.Input( + shape=spec.shape[1:], # Remove batch dimension + dtype=spec.dtype, + name=input_name, + ) + self._input_layers.append(input_layer) + + # Store references to the original model's variables + self._variables = keras_hub_model.variables + self._trainable_variables = keras_hub_model.trainable_variables + self._non_trainable_variables = ( + keras_hub_model.non_trainable_variables + ) + + @property + def variables(self): + return self._variables + + @property + def trainable_variables(self): + return self._trainable_variables + + @property + def non_trainable_variables(self): + return self._non_trainable_variables + + @property + def inputs(self): + """Return the input layers for the Keras exporter to use.""" + return self._input_layers + + def get_config(self): + """Return the configuration of the wrapped model.""" + return self.keras_hub_model.get_config() + + class TextModelAdapter(BaseModelAdapter): + """Adapter for text models (CausalLM, TextClassifier, Seq2SeqLM). + + Text models expect dictionary inputs with keys like 'token_ids' + and 'padding_mask'. + """ + + def call(self, inputs, training=None, mask=None): + """Convert list inputs to dictionary format for text models.""" + if isinstance(inputs, dict): + return self.keras_hub_model(inputs, training=training) + + # Convert to list if needed + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + # Map inputs to expected dictionary keys + input_dict = {} + for i, input_name in enumerate(self.expected_inputs): + if i < len(inputs): + input_dict[input_name] = inputs[i] + + return self.keras_hub_model(input_dict, training=training) + + class ImageModelAdapter(BaseModelAdapter): + """Adapter for image models (ImageClassifier, ObjectDetector, + ImageSegmenter). + + Image models typically expect a single tensor input but may also + accept dictionary format with 'images' key. + """ + + def call(self, inputs, training=None, mask=None): + """Convert list inputs to format expected by image models.""" + if isinstance(inputs, dict): + return self.keras_hub_model(inputs, training=training) + + # Convert to list if needed + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + # Most image models expect a single tensor input + if len(self.expected_inputs) == 1: + return self.keras_hub_model(inputs[0], training=training) + + # If multiple inputs, use dictionary format + input_dict = {} + for i, input_name in enumerate(self.expected_inputs): + if i < len(inputs): + input_dict[input_name] = inputs[i] + + return self.keras_hub_model(input_dict, training=training) + + # Select the appropriate adapter based on adapter_type + if adapter_type == "text": + adapter_class = TextModelAdapter + elif adapter_type == "image": + adapter_class = ImageModelAdapter + else: + # For other model types (audio, multimodal, custom, etc.) + adapter_class = BaseModelAdapter + + return adapter_class( + self.model, + self.config.EXPECTED_INPUTS, + self.config.get_input_signature(param), + ) + + +# Convenience function for direct export +def export_litert(model, filepath, **kwargs): + """Export a Keras-Hub model to Litert format. + + This is a convenience function that automatically detects the model type + and exports it using the appropriate configuration. + + Args: + model: `keras.Model`. The Keras-Hub model to export. + filepath: `str`. Path where to save the model (without extension). + **kwargs: `dict`. Additional arguments passed to exporter. + """ + from keras_hub.src.export.configs import get_exporter_config + + # Get the appropriate configuration for this model + config = get_exporter_config(model) + + # Create and use the LiteRT exporter + exporter = LiteRTExporter(config, **kwargs) + exporter.export(filepath) diff --git a/keras_hub/src/export/litert_models_test.py b/keras_hub/src/export/litert_models_test.py new file mode 100644 index 0000000000..5919b7e717 --- /dev/null +++ b/keras_hub/src/export/litert_models_test.py @@ -0,0 +1,372 @@ +"""Tests for LiteRT export with specific production models. + +This test suite validates LiteRT export functionality for production +model presets including CausalLM, ImageClassifier, ObjectDetector, +and ImageSegmenter models. + +Each test validates export correctness by: +1. Loading a model from preset +2. Exporting it to LiteRT format +3. Running numerical verification to ensure exported model produces + equivalent outputs +4. Comparing outputs statistically against predefined thresholds + +This ensures that exported models maintain functional correctness and +numerical stability. +""" + +import gc + +import keras +import numpy as np +import pytest + +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.tests.test_case import TestCase + +# Model configurations for testing +CAUSAL_LM_MODELS = [ + { + "preset": "llama3.2_1b", + "model_class": Llama3CausalLM, + "sequence_length": 128, + "test_name": "llama3_2_1b", + "output_thresholds": {"*": {"max": 1e-3, "mean": 1e-5}}, + }, + { + "preset": "gemma3_1b", + "model_class": Gemma3CausalLM, + "sequence_length": 128, + "test_name": "gemma3_1b", + "output_thresholds": {"*": {"max": 1e-3, "mean": 3e-5}}, + }, + { + "preset": "gpt2_base_en", + "model_class": GPT2CausalLM, + "sequence_length": 128, + "test_name": "gpt2_base_en", + "output_thresholds": {"*": {"max": 5e-4, "mean": 5e-5}}, + }, +] + +IMAGE_CLASSIFIER_MODELS = [ + { + "preset": "resnet_50_imagenet", + "test_name": "resnet_50", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "efficientnet_b0_ra_imagenet", + "test_name": "efficientnet_b0", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "densenet_121_imagenet", + "test_name": "densenet_121", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, + { + "preset": "mobilenet_v3_small_100_imagenet", + "test_name": "mobilenet_v3_small", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 5e-5, "mean": 1e-5}}, + }, +] + +OBJECT_DETECTOR_MODELS = [ + { + "preset": "dfine_small_coco", + "test_name": "dfine_small", + "input_range": (0.0, 1.0), + "output_thresholds": { + "intermediate_predicted_corners": {"max": 5.0, "mean": 0.05}, + "intermediate_logits": {"max": 5.0, "mean": 0.1}, + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, + { + "preset": "dfine_medium_coco", + "test_name": "dfine_medium", + "input_range": (0.0, 1.0), + "output_thresholds": { + "intermediate_predicted_corners": {"max": 50.0, "mean": 0.15}, + "intermediate_logits": {"max": 5.0, "mean": 0.1}, + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, + { + "preset": "retinanet_resnet50_fpn_coco", + "test_name": "retinanet_resnet50", + "input_range": (0.0, 1.0), + "output_thresholds": { + "enc_topk_logits": {"max": 5.0, "mean": 0.03}, + "logits": {"max": 2.0, "mean": 0.03}, + "*": {"max": 1.0, "mean": 0.03}, + }, + }, +] + +IMAGE_SEGMENTER_MODELS = [ + { + "preset": "deeplab_v3_plus_resnet50_pascalvoc", + "test_name": "deeplab_v3_plus", + "input_range": (0.0, 1.0), + "output_thresholds": {"*": {"max": 1.0, "mean": 1e-2}}, + }, +] + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + CAUSAL_LM_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_causal_lm_litert_export(model_config): + """Test LiteRT export for CausalLM models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + model_class = model_config["model_class"] + sequence_length = model_config["sequence_length"] + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 3e-5, "mean": 3e-6}} + ) + + model = None + try: + # Load model from preset once + model = model_class.from_preset(preset, load_weights=True) + + # Set sequence length before export + model.preprocessor.sequence_length = sequence_length + + # Get vocab_size from the loaded model + vocab_size = model.backbone.vocabulary_size + + # Prepare test inputs with fixed random seed for reproducibility + np.random.seed(42) + input_data = { + "token_ids": np.random.randint( + 1, vocab_size, size=(1, sequence_length), dtype=np.int32 + ), + "padding_mask": np.ones((1, sequence_length), dtype=np.int32), + } + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=input_data, + expected_output_shape=(1, sequence_length, vocab_size), + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + IMAGE_CLASSIFIER_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_image_classifier_litert_export(model_config): + """Test LiteRT export for ImageClassifier models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1e-4, "mean": 4e-5}} + ) + + model = None + try: + # Load model once + model = ImageClassifier.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + input_shape = image_size + (3,) # Add channels + + # Prepare test input + test_image = np.random.uniform( + input_range[0], input_range[1], size=(1,) + input_shape + ).astype(np.float32) + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_image, + expected_output_shape=None, # Output shape varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + OBJECT_DETECTOR_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_object_detector_litert_export(model_config): + """Test LiteRT export for ObjectDetector models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1.0, "mean": 0.02}} + ) + + model = None + try: + # Load model once + model = ObjectDetector.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + # ObjectDetector typically needs images (H, W, 3) and image_shape (H, W) + test_inputs = { + "images": np.random.uniform( + input_range[0], + input_range[1], + size=(1,) + image_size + (3,), + ).astype(np.float32), + "image_shape": np.array([image_size], dtype=np.int32), + } + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_inputs, + expected_output_shape=None, # Output varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +@pytest.mark.parametrize( + "model_config", + IMAGE_SEGMENTER_MODELS, + ids=lambda x: f"{x['test_name']}-{x['preset']}", +) +def test_image_segmenter_litert_export(model_config): + """Test LiteRT export for ImageSegmenter models. + + Validates that the model can be successfully exported to LiteRT format + and produces numerically equivalent outputs. + """ + preset = model_config["preset"] + input_range = model_config.get("input_range", (0.0, 1.0)) + output_thresholds = model_config.get( + "output_thresholds", {"*": {"max": 1.0, "mean": 1e-2}} + ) + + model = None + try: + # Load model once + model = ImageSegmenter.from_preset(preset) + + # Get actual image size from model preprocessor or backbone + image_size = getattr(model.preprocessor, "image_size", None) + if image_size is None and hasattr(model.backbone, "image_shape"): + image_shape = model.backbone.image_shape + if isinstance(image_shape, (list, tuple)) and len(image_shape) >= 2: + image_size = tuple(image_shape[:2]) + elif isinstance(image_shape, int): + image_size = (image_shape, image_shape) + + if image_size is None: + raise ValueError(f"Could not determine image size for {preset}") + + input_shape = image_size + (3,) # Add channels + + # Prepare test input + test_image = np.random.uniform( + input_range[0], input_range[1], size=(1,) + input_shape + ).astype(np.float32) + + # Validate LiteRT export with numerical verification + TestCase().run_litert_export_test( + model=model, + input_data=test_image, + expected_output_shape=None, # Output shape varies by model + comparison_mode="statistical", + output_thresholds=output_thresholds, + ) + + finally: + # Clean up model, free memory + if model is not None: + del model + gc.collect() diff --git a/keras_hub/src/export/litert_test.py b/keras_hub/src/export/litert_test.py new file mode 100644 index 0000000000..47ba7f7b4a --- /dev/null +++ b/keras_hub/src/export/litert_test.py @@ -0,0 +1,448 @@ +"""Tests for LiteRT export functionality.""" + +import os +import shutil +import tempfile + +import keras +import numpy as np +import pytest + +from keras_hub.src.export.litert import LiteRTExporter +from keras_hub.src.tests.test_case import TestCase + +# Lazy import LiteRT interpreter with fallback logic +LITERT_AVAILABLE = False +if keras.backend.backend() == "tensorflow": + try: + from ai_edge_litert.interpreter import Interpreter + + LITERT_AVAILABLE = True + except ImportError: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class LiteRTExporterTest(TestCase): + """Tests for LiteRTExporter class.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + # Clean up temporary files + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_exporter_init_without_litert_available(self): + """Test that LiteRTExporter raises error if Keras LiteRT unavailable.""" + # We can't easily test this without mocking, so we'll skip + self.skipTest("Requires mocking KERAS_LITE_RT_AVAILABLE") + + def test_exporter_init_with_parameters(self): + """Test LiteRTExporter initialization with custom parameters.""" + from keras_hub.src.export.configs import CausalLMExporterConfig + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock model + class MockCausalLM(CausalLM): + def __init__(self): + keras.Model.__init__(self) + self.preprocessor = None + self.dense = keras.layers.Dense(10) + + def call(self, inputs): + return self.dense(inputs["token_ids"]) + + model = MockCausalLM() + config = CausalLMExporterConfig(model) + exporter = LiteRTExporter( + config, + max_sequence_length=256, + verbose=True, + custom_param="test", + ) + + self.assertEqual(exporter.max_sequence_length, 256) + self.assertTrue(exporter.verbose) + self.assertEqual(exporter.export_kwargs["custom_param"], "test") + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class CausalLMExportTest(TestCase): + """Tests for exporting CausalLM models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_causal_lm_mock(self): + """Test exporting a mock CausalLM model.""" + from keras_hub.src.models.causal_lm import CausalLM + + # Create a minimal mock CausalLM + class SimpleCausalLM(CausalLM): + def __init__(self): + super().__init__() + self.preprocessor = None + self.embedding = keras.layers.Embedding(1000, 64) + self.dense = keras.layers.Dense(1000) + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + return self.dense(x) + + model = SimpleCausalLM() + model.build( + input_shape={ + "token_ids": (None, 128), + "padding_mask": (None, 128), + } + ) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "test_causal_lm") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Verify we have the expected inputs + self.assertEqual(len(input_details), 2) + + # Create test inputs with dtypes from the interpreter + test_token_ids = np.random.randint(0, 1000, (1, 128)).astype( + input_details[0]["dtype"] + ) + test_padding_mask = np.ones((1, 128), dtype=input_details[1]["dtype"]) + + # Set inputs and run inference + interpreter.set_tensor(input_details[0]["index"], test_token_ids) + interpreter.set_tensor(input_details[1]["index"], test_padding_mask) + interpreter.invoke() + + # Get output + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], 128) # Sequence length + self.assertEqual(output.shape[2], 1000) # Vocab size + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ImageClassifierExportTest(TestCase): + """Tests for exporting ImageClassifier models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_image_classifier_mock(self): + """Test exporting a mock ImageClassifier model.""" + from keras_hub.src.models.backbone import Backbone + from keras_hub.src.models.image_classifier import ImageClassifier + + # Create a minimal mock Backbone + class SimpleBackbone(Backbone): + def __init__(self): + inputs = keras.layers.Input(shape=(224, 224, 3)) + x = keras.layers.Conv2D(32, 3, padding="same")(inputs) + # Don't reduce dimensions - let ImageClassifier handle pooling + outputs = x + super().__init__(inputs=inputs, outputs=outputs) + + # Create ImageClassifier with the mock backbone + backbone = SimpleBackbone() + model = ImageClassifier(backbone=backbone, num_classes=10) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "test_image_classifier") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Verify we have the expected input + self.assertEqual(len(input_details), 1) + + # Create test input with dtype from the interpreter + test_image = np.random.uniform(0.0, 1.0, (1, 224, 224, 3)).astype( + input_details[0]["dtype"] + ) + + # Set input and run inference + interpreter.set_tensor(input_details[0]["index"], test_image) + interpreter.invoke() + + # Get output + output = interpreter.get_tensor(output_details[0]["index"]) + self.assertEqual(output.shape[0], 1) # Batch size + self.assertEqual(output.shape[1], 10) # Number of classes + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class TextClassifierExportTest(TestCase): + """Tests for exporting TextClassifier models to LiteRT.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_text_classifier_mock(self): + """Test exporting a mock TextClassifier model.""" + from keras_hub.src.models.text_classifier import TextClassifier + + # Create a minimal mock TextClassifier + class SimpleTextClassifier(TextClassifier): + def __init__(self): + super().__init__() + self.preprocessor = None + self.embedding = keras.layers.Embedding(5000, 64) + self.pool = keras.layers.GlobalAveragePooling1D() + self.dense = keras.layers.Dense(5) # 5 classes + + def call(self, inputs): + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + else: + token_ids = inputs + x = self.embedding(token_ids) + x = self.pool(x) + return self.dense(x) + + model = SimpleTextClassifier() + model.build( + input_shape={ + "token_ids": (None, 128), + "padding_mask": (None, 128), + } + ) + + # Export using the model's export method + export_path = os.path.join(self.temp_dir, "test_text_classifier") + model.export(export_path, format="litert") + + # Verify the file was created + tflite_path = export_path + ".tflite" + self.assertTrue(os.path.exists(tflite_path)) + + # Load and verify the exported model + interpreter = Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # Delete the TFLite file after loading to free disk space + if os.path.exists(tflite_path): + os.remove(tflite_path) + + output_details = interpreter.get_output_details() + + # Verify output shape (batch, num_classes) + self.assertEqual(len(output_details), 1) + + # Clean up interpreter, free memory + del interpreter + import gc + + gc.collect() + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ExportNumericalVerificationTest(TestCase): + """Tests for numerical accuracy of exported models.""" + + def test_simple_model_numerical_accuracy(self): + """Test that exported model produces similar outputs to original.""" + # Create a simple sequential model + model = keras.Sequential( + [ + keras.layers.Dense(10, activation="relu", input_shape=(5,)), + keras.layers.Dense(3, activation="softmax"), + ] + ) + + # Prepare test input + test_input = np.random.random((1, 5)).astype(np.float32) + + # Use standardized test from TestCase + # Note: This assumes the model has an export() method + # If not available, the test will be skipped + if not hasattr(model, "export"): + self.skipTest("model.export() not available") + + self.run_litert_export_test( + cls=keras.Sequential, + init_kwargs={ + "layers": [ + keras.layers.Dense(10, activation="relu", input_shape=(5,)), + keras.layers.Dense(3, activation="softmax"), + ] + }, + input_data=test_input, + expected_output_shape=(1, 3), + comparison_mode="strict", + ) + + def test_dict_input_model_numerical_accuracy(self): + """Test numerical accuracy for models with dictionary inputs.""" + + # Define a custom model class for testing + class DictInputModel(keras.Model): + def __init__(self): + super().__init__() + self.concat = keras.layers.Concatenate() + self.dense = keras.layers.Dense(5) + + def call(self, inputs): + x = self.concat([inputs["input1"], inputs["input2"]]) + return self.dense(x) + + # Prepare test inputs + test_inputs = { + "input1": np.random.random((1, 10)).astype(np.float32), + "input2": np.random.random((1, 10)).astype(np.float32), + } + + # Use standardized test from TestCase + self.run_litert_export_test( + cls=DictInputModel, + init_kwargs={}, + input_data=test_inputs, + expected_output_shape=(1, 5), + comparison_mode="strict", + ) + + +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="LiteRT export only supports TensorFlow backend.", +) +class ExportErrorHandlingTest(TestCase): + """Tests for error handling in export process.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up test fixtures.""" + super().tearDown() + import shutil + + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_export_to_invalid_path(self): + """Test that export with invalid path raises appropriate error.""" + if not hasattr(keras.Model, "export"): + self.skipTest("model.export() not available") + + model = keras.Sequential([keras.layers.Dense(10)]) + + # Try to export to a path that doesn't exist and can't be created + invalid_path = "/nonexistent/deeply/nested/path/model" + + with self.assertRaises(Exception): + model.export(invalid_path, format="litert") + + def test_export_unbuilt_model(self): + """Test exporting an unbuilt model.""" + if not hasattr(keras.Model, "export"): + self.skipTest("model.export() not available") + + model = keras.Sequential([keras.layers.Dense(10, input_shape=(5,))]) + + # Model is not built yet (no explicit build() call) + # Export should still work by building the model + export_path = os.path.join(self.temp_dir, "unbuilt_model.tflite") + model.export(export_path, format="litert") + + # Should succeed + self.assertTrue(os.path.exists(export_path)) diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index e69de29bb2..1c02bb93f3 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -0,0 +1,4 @@ +"""Keras-Hub models module. + +This module contains all the task and backbone models available in Keras-Hub. +""" diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 41bccf04b3..3cc9bcda0e 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -289,3 +289,113 @@ def export_to_transformers(self, path): ) export_backbone(self, path) + + def _get_save_spec(self, dynamic_batch=True): + """Compatibility shim for Keras/TensorFlow saving utilities. + + TensorFlow's SavedModel / TFLite export paths expect a + `_get_save_spec` method on subclassed models. In some runtime + combinations this method may not be present on the MRO for + our `Backbone` subclass; add a small shim that first delegates to + the superclass, and falls back to constructing simple + `tf.TensorSpec` objects from the functional `inputs` if needed. + + Args: + dynamic_batch: whether to set the batch dimension to `None`. + + Returns: + A TensorSpec, list or dict mirroring the model inputs, or + `None` when specs cannot be inferred. + """ + # Prefer the base implementation if available. + try: + return super()._get_save_spec(dynamic_batch) + except AttributeError: + # Fall back to building specs from `self.inputs`. + try: + from tensorflow import TensorSpec + except (ImportError, ModuleNotFoundError): + return None + + inputs = getattr(self, "inputs", None) + if inputs is None: + return None + + def _make_spec(t): + # t is a tf.Tensor-like object + shape = list(t.shape) + if dynamic_batch and len(shape) > 0: + shape[0] = None + # Convert to tuple for TensorSpec + try: + name = getattr(t, "name", None) + return TensorSpec( + shape=tuple(shape), dtype=t.dtype, name=name + ) + except (ImportError, ModuleNotFoundError): + return None + + # Handle dict/list/single tensor inputs + if isinstance(inputs, dict): + return {k: _make_spec(v) for k, v in inputs.items()} + if isinstance(inputs, (list, tuple)): + return [_make_spec(t) for t in inputs] + return _make_spec(inputs) + + def _trackable_children(self, save_type=None, **kwargs): + """Override to prevent _DictWrapper issues during TensorFlow export. + + This method filters out problematic _DictWrapper objects that cause + TypeError during SavedModel introspection, while preserving all + essential trackable components. + """ + children = super()._trackable_children(save_type, **kwargs) + + # Import _DictWrapper safely + # WARNING: This uses a private TensorFlow API (_DictWrapper from + # tensorflow.python.trackable.data_structures). This API is not + # guaranteed to be stable and may change in future TensorFlow versions. + # If this breaks, we may need to find an alternative approach or pin + # the TensorFlow version more strictly. + try: + from tensorflow.python.trackable.data_structures import _DictWrapper + except ImportError: + return children + + clean_children = {} + for name, child in children.items(): + # Handle _DictWrapper objects + if isinstance(child, _DictWrapper): + try: + # For list-like _DictWrapper (e.g., transformer_layers) + if hasattr(child, "_data") and isinstance( + child._data, list + ): + # Create a clean list of the trackable items + clean_list = [ + item + for item in child._data + if hasattr(item, "_trackable_children") + ] + if clean_list: + clean_children[name] = clean_list + # For dict-like _DictWrapper + elif hasattr(child, "_data") and isinstance( + child._data, dict + ): + clean_dict = { + k: v + for k, v in child._data.items() + if hasattr(v, "_trackable_children") + } + if clean_dict: + clean_children[name] = clean_dict + # Skip if we can't unwrap safely + except (AttributeError, TypeError): + # Skip problematic _DictWrapper objects + continue + else: + # Keep non-_DictWrapper children as-is + clean_children[name] = child + + return clean_children diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index d273759b46..8fe39b3940 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -369,3 +369,64 @@ def add_layer(layer, info): print_fn=print_fn, **kwargs, ) + + def export(self, filepath, format="litert", verbose=False, **kwargs): + """Export the Keras-Hub model to the specified format. + + This method overrides `keras.Model.export()` to provide specialized + handling for Keras-Hub models with dictionary inputs. + + Args: + filepath: `str`. Path where to save the exported model. + format: `str`. Export format. Currently supports "litert" for + TensorFlow Lite export, as well as other formats supported by + the parent `keras.Model.export()` method (e.g., + "tf_saved_model"). + verbose: `bool`. Whether to print verbose output during export. + Defaults to `False`. + **kwargs: Additional arguments passed to the exporter. For LiteRT + export, common options include: + - `max_sequence_length`: Maximum sequence length for text models + - `optimizations`: List of TFLite optimizations (e.g., + `[tf.lite.Optimize.DEFAULT]`) + - `allow_custom_ops`: Whether to allow custom operations + - `enable_select_tf_ops`: Whether to enable TensorFlow Select + ops + + Examples: + + ```python + # Export a text model to TensorFlow Lite + model = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en") + model.export("gemma_model.tflite", format="litert") + + # Export with custom sequence length + model.export( + "gemma_model.tflite", + format="litert", + max_sequence_length=512 + ) + + # Export with quantization + import tensorflow as tf + model.export( + "gemma_model_quantized.tflite", + format="litert", + optimizations=[tf.lite.Optimize.DEFAULT] + ) + ``` + """ + if format == "litert": + from keras_hub.src.export.configs import get_exporter_config + from keras_hub.src.export.litert import LiteRTExporter + + # Get the appropriate configuration for this model type + config = get_exporter_config(self) + + # Create and use the LiteRT exporter + kwargs["verbose"] = verbose + exporter = LiteRTExporter(config, **kwargs) + exporter.export(filepath) + else: + # Fall back to parent class (keras.Model) export for other formats + super().export(filepath, format=format, verbose=verbose, **kwargs) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 633f32cd5b..5d63e41f3e 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -1,7 +1,9 @@ +import gc import json import os import pathlib import re +import tempfile import keras import numpy as np @@ -433,6 +435,305 @@ def run_model_saving_test( restored_output = restored_model(input_data) self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) + def _prepare_litert_inputs(self, input_data, input_details): + """Prepare input data for LiteRT interpreter.""" + if isinstance(input_data, dict): + input_values = list(input_data.values()) + litert_input_values = [] + for i, detail in enumerate(input_details): + if i < len(input_values): + converted_value = ops.convert_to_numpy( + input_values[i] + ).astype(detail["dtype"]) + litert_input_values.append(converted_value) + return input_data, litert_input_values + else: + litert_input_values = [ + ops.convert_to_numpy(input_data).astype( + input_details[0]["dtype"] + ) + ] + return input_data, litert_input_values + + def _get_litert_output(self, interpreter, output_details): + """Get output from LiteRT interpreter.""" + if len(output_details) == 1: + return interpreter.get_tensor(output_details[0]["index"]) + else: + litert_output = {} + for detail in output_details: + output_tensor = interpreter.get_tensor(detail["index"]) + litert_output[detail["name"]] = output_tensor + return litert_output + + def _verify_outputs( + self, + keras_output, + litert_output, + output_thresholds, + comparison_mode, + ): + """Verify numerical accuracy between Keras and LiteRT outputs. + + This method uses name-based matching with sorted keys to reliably + map LiteRT outputs to Keras outputs, even when LiteRT generates + generic names like "StatefulPartitionedCall:0". This approach: + - Provides better error messages with semantic output names + - Supports per-output threshold configurations + - Is more robust than relying on output ordering + """ + if isinstance(keras_output, dict) and isinstance(litert_output, dict): + # Map LiteRT generic keys to Keras semantic keys if needed + if all( + key.startswith("StatefulPartitionedCall") + for key in litert_output.keys() + ): + litert_keys_sorted = sorted(litert_output.keys()) + keras_keys_sorted = sorted(keras_output.keys()) + if len(litert_keys_sorted) != len(keras_keys_sorted): + self.fail( + f"Different number of outputs:\n" + f"Keras: {len(keras_keys_sorted)} outputs -\n" + f" {keras_keys_sorted}\n" + f"LiteRT: {len(litert_keys_sorted)} outputs -\n" + f" {litert_keys_sorted}" + ) + output_name_mapping = dict( + zip(litert_keys_sorted, keras_keys_sorted) + ) + mapped_litert = { + keras_key: litert_output[litert_key] + for litert_key, keras_key in output_name_mapping.items() + } + litert_output = mapped_litert + + common_keys = set(keras_output.keys()) & set(litert_output.keys()) + if not common_keys: + self.fail( + f"No common keys between Keras and LiteRT outputs.\n" + f"Keras keys: {list(keras_output.keys())}\n" + f"LiteRT keys: {list(litert_output.keys())}" + ) + + for key in sorted(common_keys): + keras_val_np = ops.convert_to_numpy(keras_output[key]) + litert_val = litert_output[key] + output_threshold = output_thresholds.get( + key, output_thresholds.get("*", {"max": 10.0, "mean": 0.1}) + ) + self._compare_outputs( + keras_val_np, + litert_val, + comparison_mode, + key, + output_threshold["max"], + output_threshold["mean"], + ) + elif not isinstance(keras_output, dict) and not isinstance( + litert_output, dict + ): + keras_output_np = ops.convert_to_numpy(keras_output) + output_threshold = output_thresholds.get( + "*", {"max": 10.0, "mean": 0.1} + ) + self._compare_outputs( + keras_output_np, + litert_output, + comparison_mode, + key=None, + max_threshold=output_threshold["max"], + mean_threshold=output_threshold["mean"], + ) + else: + keras_type = type(keras_output).__name__ + litert_type = type(litert_output).__name__ + self.fail( + f"Output structure mismatch: Keras returns " + f"{keras_type}, LiteRT returns {litert_type}" + ) + + def run_litert_export_test( + self, + cls=None, + init_kwargs=None, + input_data=None, + expected_output_shape=None, + model=None, + verify_numerical_accuracy=True, + comparison_mode="strict", + output_thresholds=None, + ): + """Export model to LiteRT format and verify numerical accuracy. + + Args: + cls: Model class to test (optional if model is provided) + init_kwargs: Initialization arguments for the model (optional + if model is provided) + input_data: Input data to test with (dict or tensor) + expected_output_shape: Expected output shape from LiteRT inference + model: Pre-created model instance (optional, if provided cls and + init_kwargs are ignored) + verify_numerical_accuracy: Whether to verify numerical accuracy + between Keras and LiteRT outputs. Set to False for preset + models with load_weights=False where outputs are random. + comparison_mode: "strict" (default) or "statistical". + - "strict": All elements must be within default tolerances + (1e-6) + - "statistical": Check mean/max absolute differences against + provided thresholds + output_thresholds: Dict mapping output names to threshold dicts + with "max" and "mean" keys. Use "*" as wildcard for defaults. + Example: {"output1": {"max": 1e-4, "mean": 1e-5}, + "*": {"max": 1e-3, "mean": 1e-4}} + """ + if keras.backend.backend() != "tensorflow": + self.skipTest("LiteRT export only supports TensorFlow backend") + + try: + from ai_edge_litert.interpreter import Interpreter + except ImportError: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + + if output_thresholds is None: + output_thresholds = {"*": {"max": 10.0, "mean": 0.1}} + + if model is None: + if cls is None or init_kwargs is None: + raise ValueError( + "Either 'model' or 'cls' and 'init_kwargs' must be provided" + ) + model = cls(**init_kwargs) + _ = model(input_data) + + interpreter = None + try: + with tempfile.TemporaryDirectory() as temp_dir: + export_path = os.path.join(temp_dir, "model.tflite") + model.export(export_path, format="litert") + + self.assertTrue(os.path.exists(export_path)) + self.assertGreater(os.path.getsize(export_path), 0) + + interpreter = Interpreter(model_path=export_path) + interpreter.allocate_tensors() + os.remove(export_path) + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + keras_input_data, litert_input_values = ( + self._prepare_litert_inputs(input_data, input_details) + ) + + if verify_numerical_accuracy: + keras_output = model(keras_input_data) + + if isinstance(input_data, dict): + for i, detail in enumerate(input_details): + if i < len(litert_input_values): + interpreter.set_tensor( + detail["index"], litert_input_values[i] + ) + else: + interpreter.set_tensor( + input_details[0]["index"], litert_input_values[0] + ) + + interpreter.invoke() + + litert_output = self._get_litert_output( + interpreter, output_details + ) + + if expected_output_shape is not None: + self.assertEqual(litert_output.shape, expected_output_shape) + + if verify_numerical_accuracy: + self._verify_outputs( + keras_output, + litert_output, + output_thresholds, + comparison_mode, + ) + finally: + if interpreter is not None: + del interpreter + if model is not None and cls is not None: + del model + gc.collect() + + def _compare_outputs( + self, + keras_val, + litert_val, + comparison_mode, + key=None, + max_threshold=10.0, + mean_threshold=0.1, + ): + """Compare Keras and LiteRT outputs using specified comparison mode. + + Args: + keras_val: Keras model output (numpy array) + litert_val: LiteRT model output (numpy array) + comparison_mode: "strict" or "statistical" + key: Output key name for error messages (optional) + max_threshold: Maximum absolute difference threshold for statistical + mode + mean_threshold: Mean absolute difference threshold for statistical + mode + """ + key_msg = f" for output key '{key}'" if key else "" + + # Check if shapes are compatible for comparison + self.assertEqual( + keras_val.shape, + litert_val.shape, + f"Shape mismatch{key_msg}: Keras shape " + f"{keras_val.shape}, LiteRT shape {litert_val.shape}. " + "Numerical comparison cannot proceed due to incompatible shapes.", + ) + + if comparison_mode == "strict": + # Original strict element-wise comparison with default tolerances + self.assertAllClose( + keras_val, + litert_val, + atol=1e-6, + rtol=1e-6, + msg=f"Mismatch{key_msg}", + ) + elif comparison_mode == "statistical": + # Statistical comparison + + # Calculate element-wise absolute differences + abs_diff = np.abs(keras_val - litert_val) + + # Element-wise statistics + mean_abs_diff = np.mean(abs_diff) + max_abs_diff = np.max(abs_diff) + + # Assert reasonable bounds on statistical differences + self.assertLessEqual( + mean_abs_diff, + mean_threshold, + f"Mean absolute difference too high: {mean_abs_diff:.6e}" + f"{key_msg} (threshold: {mean_threshold})", + ) + self.assertLessEqual( + max_abs_diff, + max_threshold, + f"Max absolute difference too high: {max_abs_diff:.6e}" + f"{key_msg} (threshold: {max_threshold})", + ) + else: + raise ValueError( + f"Unknown comparison_mode: {comparison_mode}. Must be " + "'strict' or 'statistical'" + ) + def run_backbone_test( self, cls, diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 94ab86d63f..5b366b8734 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -11,3 +11,6 @@ torchvision>=0.16.0 jax[cpu] -r requirements-common.txt + +# for litert export feature +ai-edge-litert