-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add cross feature interaction layer * Add unit tests * Add unit tests * Small change * Fix doc-strings * Clean up doc-string example * Address comments * Restore init cloning * Add missing __init__.py file
- Loading branch information
1 parent
71810b8
commit 163707b
Showing
7 changed files
with
341 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from typing import Any, Optional, Text, Union | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_rs.src import types | ||
from keras_rs.src.api_export import keras_rs_export | ||
from keras_rs.src.utils.keras_utils import clone_initializer | ||
|
||
|
||
@keras_rs_export("keras_rs.layers.FeatureCross") | ||
class FeatureCross(keras.layers.Layer): | ||
"""FeatureCross layer in Deep & Cross Network (DCN). | ||
A layer that creates explicit and bounded-degree feature interactions | ||
efficiently. The `call` method accepts two inputs: `x0` contains the | ||
original features; the second input `xi` is the output of the previous | ||
`FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer. | ||
For the first `FeatureCross` layer in the stack, `x0 = xi`. | ||
The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`, | ||
where .* denotes element-wise multiplication. W could be a full-rank | ||
matrix, or a low-rank matrix `U*V` to reduce the computational cost, and | ||
`diag_scale` increases the diagonal of W to improve training stability ( | ||
especially for the low-rank case). | ||
Args: | ||
projection_dim: int. Dimension for down-projecting the input to reduce | ||
computational cost. If `None` (default), the full matrix, `W` | ||
(with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank | ||
matrix `W = U*V` will be used, where `U` is of shape | ||
`(input_dim, projection_dim)` and `V` is of shape | ||
`(projection_dim, input_dim)`. `projection_dim` need to be smaller | ||
than `input_dim//2` to improve the model efficiency. In practice, | ||
we've observed that `projection_dim = input_dim//4` consistently | ||
preserved the accuracy of a full-rank version. | ||
diag_scale: non-negative float. Used to increase the diagonal of the | ||
kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the | ||
identity matrix. Defaults to `None`. | ||
use_bias: bool. Whether to add a bias term for this layer. Defaults to | ||
`True`. | ||
pre_activation: string or `keras.activations`. Activation applied to | ||
output matrix of the layer, before multiplication with the input. | ||
Can be used to control the scale of the layer's outputs and | ||
improve stability. Defaults to `None`. | ||
kernel_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the kernel matrix. Defaults to | ||
`"glorot_uniform"`. | ||
bias_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the bias vector. Defaults to `"ones"`. | ||
kernel_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the kernel matrix. | ||
bias_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the bias vector. | ||
Example: | ||
```python | ||
# after embedding layer in a functional model | ||
input = keras.Input(shape=(), name='indices', dtype="int64") | ||
x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0) | ||
x1 = FeatureCross()(x0, x0) | ||
x2 = FeatureCross()(x0, x1) | ||
logits = keras.layers.Dense(units=10)(x2) | ||
model = keras.Model(input, logits) | ||
``` | ||
References: | ||
- [R. Wang et al.](https://arxiv.org/abs/2008.13535) | ||
- [R. Wang et al.](https://arxiv.org/abs/1708.05123) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
projection_dim: Optional[int] = None, | ||
diag_scale: Optional[float] = 0.0, | ||
use_bias: bool = True, | ||
pre_activation: Optional[Union[str, keras.layers.Activation]] = None, | ||
kernel_initializer: Union[ | ||
Text, keras.initializers.Initializer | ||
] = "glorot_uniform", | ||
bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros", | ||
kernel_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
bias_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
||
# Passed args. | ||
self.projection_dim = projection_dim | ||
self.diag_scale = diag_scale | ||
self.use_bias = use_bias | ||
self.pre_activation = keras.activations.get(pre_activation) | ||
self.kernel_initializer = keras.initializers.get(kernel_initializer) | ||
self.bias_initializer = keras.initializers.get(bias_initializer) | ||
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) | ||
self.bias_regularizer = keras.regularizers.get(bias_regularizer) | ||
|
||
# Other args. | ||
self.supports_masking = True | ||
|
||
if self.diag_scale is not None and self.diag_scale < 0.0: | ||
raise ValueError( | ||
"`diag_scale` should be non-negative. Received: " | ||
f"`diag_scale={self.diag_scale}`" | ||
) | ||
|
||
def build(self, input_shape: types.TensorShape) -> None: | ||
last_dim = input_shape[-1] | ||
|
||
if self.projection_dim is not None: | ||
self.down_proj_dense = keras.layers.Dense( | ||
units=self.projection_dim, | ||
use_bias=False, | ||
kernel_initializer=clone_initializer(self.kernel_initializer), | ||
kernel_regularizer=self.kernel_regularizer, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.dense = keras.layers.Dense( | ||
units=last_dim, | ||
activation=self.pre_activation, | ||
use_bias=self.use_bias, | ||
kernel_initializer=clone_initializer(self.kernel_initializer), | ||
bias_initializer=clone_initializer(self.bias_initializer), | ||
kernel_regularizer=self.kernel_regularizer, | ||
bias_regularizer=self.bias_regularizer, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, x0: types.Tensor, x: Optional[types.Tensor] = None | ||
) -> types.Tensor: | ||
"""Forward pass of the cross layer. | ||
Args: | ||
x0: a Tensor. The input to the cross layer. N-rank tensor | ||
with shape `(batch_size, ..., input_dim)`. | ||
x: a Tensor. Optional. If provided, the layer will compute | ||
crosses between x0 and x. Otherwise, the layer will | ||
compute crosses between x0 and itself. Should have the same | ||
shape as `x0`. | ||
Returns: | ||
Tensor of crosses, with the same shape as `x0`. | ||
""" | ||
|
||
if x is None: | ||
x = x0 | ||
|
||
if x0.shape != x.shape: | ||
raise ValueError( | ||
"`x0` and `x` should have the same shape. Received: " | ||
f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}" | ||
) | ||
|
||
# Project to a lower dimension. | ||
if self.projection_dim is None: | ||
output = x | ||
else: | ||
output = self.down_proj_dense(x) | ||
|
||
output = self.dense(output) | ||
|
||
output = ops.cast(output, self.compute_dtype) | ||
|
||
if self.diag_scale: | ||
output = ops.add(output, ops.multiply(self.diag_scale, x)) | ||
|
||
return ops.add(ops.multiply(x0, output), x) | ||
|
||
def get_config(self) -> dict[str, Any]: | ||
config: dict[str, Any] = super().get_config() | ||
|
||
config.update( | ||
{ | ||
"projection_dim": self.projection_dim, | ||
"diag_scale": self.diag_scale, | ||
"use_bias": self.use_bias, | ||
"pre_activation": keras.activations.serialize( | ||
self.pre_activation | ||
), | ||
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"bias_initializer": keras.initializers.serialize( | ||
self.bias_initializer | ||
), | ||
"kernel_regularizer": keras.regularizers.serialize( | ||
self.kernel_regularizer | ||
), | ||
"bias_regularizer": keras.regularizers.serialize( | ||
self.bias_regularizer | ||
), | ||
} | ||
) | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import keras | ||
from absl.testing import parameterized | ||
from keras import ops | ||
from keras.layers import deserialize | ||
from keras.layers import serialize | ||
|
||
from keras_rs.src import testing | ||
from keras_rs.src.layers.modeling.feature_cross import FeatureCross | ||
|
||
|
||
class FeatureCrossTest(testing.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") | ||
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") | ||
self.exp_output = ops.array([[0.55, 0.8, 1.05]]) | ||
|
||
self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) | ||
|
||
def test_full_layer(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel, bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 3)) | ||
self.assertEqual(layer.weights[1].shape, (3,)) | ||
|
||
def test_low_rank_layer(self): | ||
layer = FeatureCross(projection_dim=1, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel term corresponding to down projection layer, and kernel, | ||
# bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 1)) | ||
self.assertEqual(layer.weights[1].shape, (1, 3)) | ||
self.assertEqual(layer.weights[2].shape, (3,)) | ||
|
||
def test_one_input(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0) | ||
self.assertAllClose(self.one_inp_exp_output, output) | ||
|
||
def test_invalid_input_shapes(self): | ||
x0 = ops.ones((12, 5)) | ||
x = ops.ones((12, 7)) | ||
|
||
layer = FeatureCross() | ||
|
||
with self.assertRaises(ValueError): | ||
layer(x0, x) | ||
|
||
def test_invalid_diag_scale(self): | ||
with self.assertRaises(ValueError): | ||
FeatureCross(diag_scale=-1.0) | ||
|
||
def test_serialization(self): | ||
sampler = FeatureCross(projection_dim=None, pre_activation="swish") | ||
restored = deserialize(serialize(sampler)) | ||
self.assertDictEqual(sampler.get_config(), restored.get_config()) | ||
|
||
def test_diag_scale(self): | ||
layer = FeatureCross( | ||
projection_dim=None, diag_scale=1.0, kernel_initializer="ones" | ||
) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) | ||
|
||
def test_pre_activation(self): | ||
layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(self.x, output) | ||
|
||
def test_model_saving(self): | ||
def get_model(): | ||
x0 = keras.layers.Input(shape=(3,)) | ||
x1 = FeatureCross(projection_dim=None)(x0, x0) | ||
x2 = FeatureCross(projection_dim=None)(x0, x1) | ||
logits = keras.layers.Dense(units=1)(x2) | ||
model = keras.Model(x0, logits) | ||
return model | ||
|
||
self.run_model_saving_test( | ||
model=get_model(), | ||
input_data=self.x0, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Text, Union | ||
|
||
import keras | ||
|
||
|
||
def clone_initializer( | ||
initializer: Union[Text, keras.initializers.Initializer], | ||
) -> keras.initializers.Initializer: | ||
"""Clones an initializer to ensure a new seed. | ||
As of tensorflow 2.10, we need to clone user passed initializers when | ||
invoking them twice to avoid creating the same randomized initialization. | ||
""" | ||
# If we get a string or dict, just return as we cannot and should not clone. | ||
if not isinstance(initializer, keras.initializers.Initializer): | ||
return initializer | ||
config = initializer.get_config() | ||
return initializer.__class__.from_config(config) |