Skip to content

Commit

Permalink
Add FeatureCross Layer (#13)
Browse files Browse the repository at this point in the history
* Add cross feature interaction layer

* Add unit tests

* Add unit tests

* Small change

* Fix doc-strings

* Clean up doc-string example

* Address comments

* Restore init cloning

* Add missing __init__.py file
  • Loading branch information
abheesht17 authored Jan 27, 2025
1 parent 71810b8 commit 163707b
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
since your modifications would be overwritten.
"""

from keras_rs.src.layers.modeling.feature_cross import FeatureCross
from keras_rs.src.layers.retrieval.brute_force_retrieval import (
BruteForceRetrieval,
)
Empty file.
204 changes: 204 additions & 0 deletions keras_rs/src/layers/modeling/feature_cross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import Any, Optional, Text, Union

import keras
from keras import ops

from keras_rs.src import types
from keras_rs.src.api_export import keras_rs_export
from keras_rs.src.utils.keras_utils import clone_initializer


@keras_rs_export("keras_rs.layers.FeatureCross")
class FeatureCross(keras.layers.Layer):
"""FeatureCross layer in Deep & Cross Network (DCN).
A layer that creates explicit and bounded-degree feature interactions
efficiently. The `call` method accepts two inputs: `x0` contains the
original features; the second input `xi` is the output of the previous
`FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer.
For the first `FeatureCross` layer in the stack, `x0 = xi`.
The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`,
where .* denotes element-wise multiplication. W could be a full-rank
matrix, or a low-rank matrix `U*V` to reduce the computational cost, and
`diag_scale` increases the diagonal of W to improve training stability (
especially for the low-rank case).
Args:
projection_dim: int. Dimension for down-projecting the input to reduce
computational cost. If `None` (default), the full matrix, `W`
(with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank
matrix `W = U*V` will be used, where `U` is of shape
`(input_dim, projection_dim)` and `V` is of shape
`(projection_dim, input_dim)`. `projection_dim` need to be smaller
than `input_dim//2` to improve the model efficiency. In practice,
we've observed that `projection_dim = input_dim//4` consistently
preserved the accuracy of a full-rank version.
diag_scale: non-negative float. Used to increase the diagonal of the
kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the
identity matrix. Defaults to `None`.
use_bias: bool. Whether to add a bias term for this layer. Defaults to
`True`.
pre_activation: string or `keras.activations`. Activation applied to
output matrix of the layer, before multiplication with the input.
Can be used to control the scale of the layer's outputs and
improve stability. Defaults to `None`.
kernel_initializer: string or `keras.initializers` initializer.
Initializer to use for the kernel matrix. Defaults to
`"glorot_uniform"`.
bias_initializer: string or `keras.initializers` initializer.
Initializer to use for the bias vector. Defaults to `"ones"`.
kernel_regularizer: string or `keras.regularizer` regularizer.
Regularizer to use for the kernel matrix.
bias_regularizer: string or `keras.regularizer` regularizer.
Regularizer to use for the bias vector.
Example:
```python
# after embedding layer in a functional model
input = keras.Input(shape=(), name='indices', dtype="int64")
x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0)
x1 = FeatureCross()(x0, x0)
x2 = FeatureCross()(x0, x1)
logits = keras.layers.Dense(units=10)(x2)
model = keras.Model(input, logits)
```
References:
- [R. Wang et al.](https://arxiv.org/abs/2008.13535)
- [R. Wang et al.](https://arxiv.org/abs/1708.05123)
"""

def __init__(
self,
projection_dim: Optional[int] = None,
diag_scale: Optional[float] = 0.0,
use_bias: bool = True,
pre_activation: Optional[Union[str, keras.layers.Activation]] = None,
kernel_initializer: Union[
Text, keras.initializers.Initializer
] = "glorot_uniform",
bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros",
kernel_regularizer: Union[
Text, None, keras.regularizers.Regularizer
] = None,
bias_regularizer: Union[
Text, None, keras.regularizers.Regularizer
] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

# Passed args.
self.projection_dim = projection_dim
self.diag_scale = diag_scale
self.use_bias = use_bias
self.pre_activation = keras.activations.get(pre_activation)
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.bias_regularizer = keras.regularizers.get(bias_regularizer)

# Other args.
self.supports_masking = True

if self.diag_scale is not None and self.diag_scale < 0.0:
raise ValueError(
"`diag_scale` should be non-negative. Received: "
f"`diag_scale={self.diag_scale}`"
)

def build(self, input_shape: types.TensorShape) -> None:
last_dim = input_shape[-1]

if self.projection_dim is not None:
self.down_proj_dense = keras.layers.Dense(
units=self.projection_dim,
use_bias=False,
kernel_initializer=clone_initializer(self.kernel_initializer),
kernel_regularizer=self.kernel_regularizer,
dtype=self.dtype_policy,
)

self.dense = keras.layers.Dense(
units=last_dim,
activation=self.pre_activation,
use_bias=self.use_bias,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
kernel_regularizer=self.kernel_regularizer,
bias_regularizer=self.bias_regularizer,
dtype=self.dtype_policy,
)

self.built = True

def call(
self, x0: types.Tensor, x: Optional[types.Tensor] = None
) -> types.Tensor:
"""Forward pass of the cross layer.
Args:
x0: a Tensor. The input to the cross layer. N-rank tensor
with shape `(batch_size, ..., input_dim)`.
x: a Tensor. Optional. If provided, the layer will compute
crosses between x0 and x. Otherwise, the layer will
compute crosses between x0 and itself. Should have the same
shape as `x0`.
Returns:
Tensor of crosses, with the same shape as `x0`.
"""

if x is None:
x = x0

if x0.shape != x.shape:
raise ValueError(
"`x0` and `x` should have the same shape. Received: "
f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}"
)

# Project to a lower dimension.
if self.projection_dim is None:
output = x
else:
output = self.down_proj_dense(x)

output = self.dense(output)

output = ops.cast(output, self.compute_dtype)

if self.diag_scale:
output = ops.add(output, ops.multiply(self.diag_scale, x))

return ops.add(ops.multiply(x0, output), x)

def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()

config.update(
{
"projection_dim": self.projection_dim,
"diag_scale": self.diag_scale,
"use_bias": self.use_bias,
"pre_activation": keras.activations.serialize(
self.pre_activation
),
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": keras.regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": keras.regularizers.serialize(
self.bias_regularizer
),
}
)

return config
96 changes: 96 additions & 0 deletions keras_rs/src/layers/modeling/feature_cross_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.layers import deserialize
from keras.layers import serialize

from keras_rs.src import testing
from keras_rs.src.layers.modeling.feature_cross import FeatureCross


class FeatureCrossTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32")
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32")
self.exp_output = ops.array([[0.55, 0.8, 1.05]])

self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]])

def test_full_layer(self):
layer = FeatureCross(projection_dim=None, kernel_initializer="ones")
output = layer(self.x0, self.x)

# Test output.
self.assertAllClose(self.exp_output, output)

# Test which layers have been initialised and their shapes.
# Kernel, bias terms corresponding to dense layer.
self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`")
self.assertEqual(layer.weights[0].shape, (3, 3))
self.assertEqual(layer.weights[1].shape, (3,))

def test_low_rank_layer(self):
layer = FeatureCross(projection_dim=1, kernel_initializer="ones")
output = layer(self.x0, self.x)

# Test output.
self.assertAllClose(self.exp_output, output)

# Test which layers have been initialised and their shapes.
# Kernel term corresponding to down projection layer, and kernel,
# bias terms corresponding to dense layer.
self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`")
self.assertEqual(layer.weights[0].shape, (3, 1))
self.assertEqual(layer.weights[1].shape, (1, 3))
self.assertEqual(layer.weights[2].shape, (3,))

def test_one_input(self):
layer = FeatureCross(projection_dim=None, kernel_initializer="ones")
output = layer(self.x0)
self.assertAllClose(self.one_inp_exp_output, output)

def test_invalid_input_shapes(self):
x0 = ops.ones((12, 5))
x = ops.ones((12, 7))

layer = FeatureCross()

with self.assertRaises(ValueError):
layer(x0, x)

def test_invalid_diag_scale(self):
with self.assertRaises(ValueError):
FeatureCross(diag_scale=-1.0)

def test_serialization(self):
sampler = FeatureCross(projection_dim=None, pre_activation="swish")
restored = deserialize(serialize(sampler))
self.assertDictEqual(sampler.get_config(), restored.get_config())

def test_diag_scale(self):
layer = FeatureCross(
projection_dim=None, diag_scale=1.0, kernel_initializer="ones"
)
output = layer(self.x0, self.x)

self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output)

def test_pre_activation(self):
layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like)
output = layer(self.x0, self.x)

self.assertAllClose(self.x, output)

def test_model_saving(self):
def get_model():
x0 = keras.layers.Input(shape=(3,))
x1 = FeatureCross(projection_dim=None)(x0, x0)
x2 = FeatureCross(projection_dim=None)(x0, x1)
logits = keras.layers.Dense(units=1)(x2)
model = keras.Model(x0, logits)
return model

self.run_model_saving_test(
model=get_model(),
input_data=self.x0,
)
22 changes: 22 additions & 0 deletions keras_rs/src/testing/test_case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import tempfile
import unittest
from typing import Any

import keras
import numpy as np
Expand Down Expand Up @@ -54,3 +57,22 @@ def assertAllEqual(
if not isinstance(desired, np.ndarray):
desired = keras.ops.convert_to_numpy(desired)
np.testing.assert_array_equal(actual, desired, err_msg=msg)

def run_model_saving_test(
self,
model: Any,
input_data: Any,
atol: float = 1e-6,
rtol: float = 1e-6,
) -> None:
"""Save and load a model from disk and assert output is unchanged."""
model_output = model(input_data)

with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "model.keras")
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)

# # Check that output matches.
restored_output = restored_model(input_data)
self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol)
Empty file added keras_rs/src/utils/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions keras_rs/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Text, Union

import keras


def clone_initializer(
initializer: Union[Text, keras.initializers.Initializer],
) -> keras.initializers.Initializer:
"""Clones an initializer to ensure a new seed.
As of tensorflow 2.10, we need to clone user passed initializers when
invoking them twice to avoid creating the same randomized initialization.
"""
# If we get a string or dict, just return as we cannot and should not clone.
if not isinstance(initializer, keras.initializers.Initializer):
return initializer
config = initializer.get_config()
return initializer.__class__.from_config(config)

0 comments on commit 163707b

Please sign in to comment.