Skip to content

Commit

Permalink
Add get_config to BruteForceRetrieval layer. (#15)
Browse files Browse the repository at this point in the history
For consistency with other layers. Note that this does not serialize candidates.

Also removed explicit `name` argument in `__init__`.
  • Loading branch information
hertschuh authored Jan 28, 2025
1 parent 1b5c0e4 commit 5164294
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions keras_rs/src/layers/retrieval/brute_force_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class BruteForceRetrieval(keras.layers.Layer):
The identifiers for the candidates can be specified as a tensor. If not
provided, the IDs used are simply the candidate indices.
Note that the serialization of this layer does not preserve the candidates
and only saves the `k` and `return_scores` arguments. One has to call
`update_candidates` after deserializing the layers.
Args:
candidate_embeddings: The candidate embeddings. If `None`,
candidates must be provided using `update_candidates` before
Expand All @@ -32,7 +36,6 @@ class BruteForceRetrieval(keras.layers.Layer):
return_scores: When `True`, this layer returns a tuple with the top
scores and the top identifiers. When `False`, this layer returns
a single tensor with the top identifiers.
name: Name of the layer.
**kwargs: Args to pass to the base class.
Example:
Expand All @@ -55,10 +58,9 @@ def __init__(
candidate_ids: Optional[types.Tensor] = None,
k: int = 10,
return_scores: bool = True,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(name=name, **kwargs)
super().__init__(**kwargs)
self.candidate_embeddings = None
self.candidate_ids = None
self.k = k
Expand Down Expand Up @@ -183,3 +185,13 @@ def compute_score(
return keras.ops.matmul(
query_embedding, keras.ops.transpose(candidate_embedding)
)

def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update(
{
"k": self.k,
"return_scores": self.compute_score,
}
)
return config

0 comments on commit 5164294

Please sign in to comment.