diff --git a/keras_rs/api/layers/__init__.py b/keras_rs/api/layers/__init__.py index ea781a1..12b717c 100644 --- a/keras_rs/api/layers/__init__.py +++ b/keras_rs/api/layers/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_rs.src.layers.modeling.feature_cross import FeatureCross from keras_rs.src.layers.retrieval.brute_force_retrieval import ( BruteForceRetrieval, ) diff --git a/keras_rs/src/layers/modeling/__init__.py b/keras_rs/src/layers/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/keras_rs/src/layers/modeling/feature_cross.py b/keras_rs/src/layers/modeling/feature_cross.py new file mode 100644 index 0000000..b1f64a1 --- /dev/null +++ b/keras_rs/src/layers/modeling/feature_cross.py @@ -0,0 +1,204 @@ +from typing import Any, Optional, Text, Union + +import keras +from keras import ops + +from keras_rs.src import types +from keras_rs.src.api_export import keras_rs_export +from keras_rs.src.utils.keras_utils import clone_initializer + + +@keras_rs_export("keras_rs.layers.FeatureCross") +class FeatureCross(keras.layers.Layer): + """FeatureCross layer in Deep & Cross Network (DCN). + + A layer that creates explicit and bounded-degree feature interactions + efficiently. The `call` method accepts two inputs: `x0` contains the + original features; the second input `xi` is the output of the previous + `FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer. + For the first `FeatureCross` layer in the stack, `x0 = xi`. + + The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`, + where .* denotes element-wise multiplication. W could be a full-rank + matrix, or a low-rank matrix `U*V` to reduce the computational cost, and + `diag_scale` increases the diagonal of W to improve training stability ( + especially for the low-rank case). + + Args: + projection_dim: int. Dimension for down-projecting the input to reduce + computational cost. If `None` (default), the full matrix, `W` + (with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank + matrix `W = U*V` will be used, where `U` is of shape + `(input_dim, projection_dim)` and `V` is of shape + `(projection_dim, input_dim)`. `projection_dim` need to be smaller + than `input_dim//2` to improve the model efficiency. In practice, + we've observed that `projection_dim = input_dim//4` consistently + preserved the accuracy of a full-rank version. + diag_scale: non-negative float. Used to increase the diagonal of the + kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the + identity matrix. Defaults to `None`. + use_bias: bool. Whether to add a bias term for this layer. Defaults to + `True`. + pre_activation: string or `keras.activations`. Activation applied to + output matrix of the layer, before multiplication with the input. + Can be used to control the scale of the layer's outputs and + improve stability. Defaults to `None`. + kernel_initializer: string or `keras.initializers` initializer. + Initializer to use for the kernel matrix. Defaults to + `"glorot_uniform"`. + bias_initializer: string or `keras.initializers` initializer. + Initializer to use for the bias vector. Defaults to `"ones"`. + kernel_regularizer: string or `keras.regularizer` regularizer. + Regularizer to use for the kernel matrix. + bias_regularizer: string or `keras.regularizer` regularizer. + Regularizer to use for the bias vector. + + Example: + + ```python + # after embedding layer in a functional model + input = keras.Input(shape=(), name='indices', dtype="int64") + x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0) + x1 = FeatureCross()(x0, x0) + x2 = FeatureCross()(x0, x1) + logits = keras.layers.Dense(units=10)(x2) + model = keras.Model(input, logits) + ``` + + References: + - [R. Wang et al.](https://arxiv.org/abs/2008.13535) + - [R. Wang et al.](https://arxiv.org/abs/1708.05123) + """ + + def __init__( + self, + projection_dim: Optional[int] = None, + diag_scale: Optional[float] = 0.0, + use_bias: bool = True, + pre_activation: Optional[Union[str, keras.layers.Activation]] = None, + kernel_initializer: Union[ + Text, keras.initializers.Initializer + ] = "glorot_uniform", + bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros", + kernel_regularizer: Union[ + Text, None, keras.regularizers.Regularizer + ] = None, + bias_regularizer: Union[ + Text, None, keras.regularizers.Regularizer + ] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + # Passed args. + self.projection_dim = projection_dim + self.diag_scale = diag_scale + self.use_bias = use_bias + self.pre_activation = keras.activations.get(pre_activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + + # Other args. + self.supports_masking = True + + if self.diag_scale is not None and self.diag_scale < 0.0: + raise ValueError( + "`diag_scale` should be non-negative. Received: " + f"`diag_scale={self.diag_scale}`" + ) + + def build(self, input_shape: types.TensorShape) -> None: + last_dim = input_shape[-1] + + if self.projection_dim is not None: + self.down_proj_dense = keras.layers.Dense( + units=self.projection_dim, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + kernel_regularizer=self.kernel_regularizer, + dtype=self.dtype_policy, + ) + + self.dense = keras.layers.Dense( + units=last_dim, + activation=self.pre_activation, + use_bias=self.use_bias, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + ) + + self.built = True + + def call( + self, x0: types.Tensor, x: Optional[types.Tensor] = None + ) -> types.Tensor: + """Forward pass of the cross layer. + + Args: + x0: a Tensor. The input to the cross layer. N-rank tensor + with shape `(batch_size, ..., input_dim)`. + x: a Tensor. Optional. If provided, the layer will compute + crosses between x0 and x. Otherwise, the layer will + compute crosses between x0 and itself. Should have the same + shape as `x0`. + + Returns: + Tensor of crosses, with the same shape as `x0`. + """ + + if x is None: + x = x0 + + if x0.shape != x.shape: + raise ValueError( + "`x0` and `x` should have the same shape. Received: " + f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}" + ) + + # Project to a lower dimension. + if self.projection_dim is None: + output = x + else: + output = self.down_proj_dense(x) + + output = self.dense(output) + + output = ops.cast(output, self.compute_dtype) + + if self.diag_scale: + output = ops.add(output, ops.multiply(self.diag_scale, x)) + + return ops.add(ops.multiply(x0, output), x) + + def get_config(self) -> dict[str, Any]: + config: dict[str, Any] = super().get_config() + + config.update( + { + "projection_dim": self.projection_dim, + "diag_scale": self.diag_scale, + "use_bias": self.use_bias, + "pre_activation": keras.activations.serialize( + self.pre_activation + ), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "kernel_regularizer": keras.regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": keras.regularizers.serialize( + self.bias_regularizer + ), + } + ) + + return config diff --git a/keras_rs/src/layers/modeling/feature_cross_test.py b/keras_rs/src/layers/modeling/feature_cross_test.py new file mode 100644 index 0000000..5825d3c --- /dev/null +++ b/keras_rs/src/layers/modeling/feature_cross_test.py @@ -0,0 +1,96 @@ +import keras +from absl.testing import parameterized +from keras import ops +from keras.layers import deserialize +from keras.layers import serialize + +from keras_rs.src import testing +from keras_rs.src.layers.modeling.feature_cross import FeatureCross + + +class FeatureCrossTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") + self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") + self.exp_output = ops.array([[0.55, 0.8, 1.05]]) + + self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) + + def test_full_layer(self): + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + output = layer(self.x0, self.x) + + # Test output. + self.assertAllClose(self.exp_output, output) + + # Test which layers have been initialised and their shapes. + # Kernel, bias terms corresponding to dense layer. + self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`") + self.assertEqual(layer.weights[0].shape, (3, 3)) + self.assertEqual(layer.weights[1].shape, (3,)) + + def test_low_rank_layer(self): + layer = FeatureCross(projection_dim=1, kernel_initializer="ones") + output = layer(self.x0, self.x) + + # Test output. + self.assertAllClose(self.exp_output, output) + + # Test which layers have been initialised and their shapes. + # Kernel term corresponding to down projection layer, and kernel, + # bias terms corresponding to dense layer. + self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`") + self.assertEqual(layer.weights[0].shape, (3, 1)) + self.assertEqual(layer.weights[1].shape, (1, 3)) + self.assertEqual(layer.weights[2].shape, (3,)) + + def test_one_input(self): + layer = FeatureCross(projection_dim=None, kernel_initializer="ones") + output = layer(self.x0) + self.assertAllClose(self.one_inp_exp_output, output) + + def test_invalid_input_shapes(self): + x0 = ops.ones((12, 5)) + x = ops.ones((12, 7)) + + layer = FeatureCross() + + with self.assertRaises(ValueError): + layer(x0, x) + + def test_invalid_diag_scale(self): + with self.assertRaises(ValueError): + FeatureCross(diag_scale=-1.0) + + def test_serialization(self): + sampler = FeatureCross(projection_dim=None, pre_activation="swish") + restored = deserialize(serialize(sampler)) + self.assertDictEqual(sampler.get_config(), restored.get_config()) + + def test_diag_scale(self): + layer = FeatureCross( + projection_dim=None, diag_scale=1.0, kernel_initializer="ones" + ) + output = layer(self.x0, self.x) + + self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) + + def test_pre_activation(self): + layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) + output = layer(self.x0, self.x) + + self.assertAllClose(self.x, output) + + def test_model_saving(self): + def get_model(): + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) + return model + + self.run_model_saving_test( + model=get_model(), + input_data=self.x0, + ) diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index d398538..7021c5b 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -1,4 +1,7 @@ +import os +import tempfile import unittest +from typing import Any import keras import numpy as np @@ -54,3 +57,22 @@ def assertAllEqual( if not isinstance(desired, np.ndarray): desired = keras.ops.convert_to_numpy(desired) np.testing.assert_array_equal(actual, desired, err_msg=msg) + + def run_model_saving_test( + self, + model: Any, + input_data: Any, + atol: float = 1e-6, + rtol: float = 1e-6, + ) -> None: + """Save and load a model from disk and assert output is unchanged.""" + model_output = model(input_data) + + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "model.keras") + model.save(path, save_format="keras_v3") + restored_model = keras.models.load_model(path) + + # # Check that output matches. + restored_output = restored_model(input_data) + self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) diff --git a/keras_rs/src/utils/__init__.py b/keras_rs/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/keras_rs/src/utils/keras_utils.py b/keras_rs/src/utils/keras_utils.py new file mode 100644 index 0000000..fd6d104 --- /dev/null +++ b/keras_rs/src/utils/keras_utils.py @@ -0,0 +1,18 @@ +from typing import Text, Union + +import keras + + +def clone_initializer( + initializer: Union[Text, keras.initializers.Initializer], +) -> keras.initializers.Initializer: + """Clones an initializer to ensure a new seed. + + As of tensorflow 2.10, we need to clone user passed initializers when + invoking them twice to avoid creating the same randomized initialization. + """ + # If we get a string or dict, just return as we cannot and should not clone. + if not isinstance(initializer, keras.initializers.Initializer): + return initializer + config = initializer.get_config() + return initializer.__class__.from_config(config)