Skip to content

Commit 1e2e5f4

Browse files
committed
refactor how we initialize for higher scale on implicit-feedback
1 parent b620968 commit 1e2e5f4

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

src/lenskit/flexmf/_base.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,14 @@ def prepare_data(
141141
"""
142142
raise NotImplementedError()
143143

144-
def create_model(self, context: FlexMFTrainingContext, data: FlexMFTrainingData) -> FlexMFModel:
144+
@abstractmethod
145+
def create_model(
146+
self, context: FlexMFTrainingContext, data: FlexMFTrainingData
147+
) -> FlexMFModel: # pragma: nocover
145148
"""
146149
Prepare the model for training.
147150
"""
148-
return FlexMFModel(
149-
self.config.embedding_size,
150-
data.n_users,
151-
data.n_items,
152-
context.torch_rng,
153-
sparse=self.config.reg_method != "AdamW",
154-
)
151+
raise NotImplementedError()
155152

156153
def create_optimizer(self, context: FlexMFTrainingContext) -> torch.optim.Optimizer:
157154
"""

src/lenskit/flexmf/_explicit.py

+14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.nn import functional as F
1313

1414
from lenskit.data import Dataset
15+
from lenskit.flexmf._model import FlexMFModel
1516
from lenskit.training import TrainingOptions
1617

1718
from ._base import FlexMFConfigBase, FlexMFScorerBase
@@ -74,6 +75,19 @@ def prepare_data(
7475
fields={"ratings": rm_values},
7576
).to(context.device)
7677

78+
def create_model(self, context: FlexMFTrainingContext, data: FlexMFTrainingData) -> FlexMFModel:
79+
"""
80+
Prepare the model for training.
81+
"""
82+
return FlexMFModel(
83+
self.config.embedding_size,
84+
data.n_users,
85+
data.n_items,
86+
context.torch_rng,
87+
sparse=self.config.reg_method != "AdamW",
88+
init_scale=0.1,
89+
)
90+
7791
def train_batch(
7892
self, context: FlexMFTrainingContext, batch: FlexMFTrainingBatch, opt: torch.optim.Optimizer
7993
) -> float:

src/lenskit/flexmf/_model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
n_users: int,
4444
n_items: int,
4545
rng: torch.Generator,
46+
init_scale: float = 1.0,
4647
user_bias: bool = True,
4748
item_bias: bool = True,
4849
sparse: bool = False,
@@ -64,12 +65,12 @@ def __init__(
6465

6566
# initialize all values to a small normal
6667
if self.u_bias is not None:
67-
nn.init.normal_(self.u_bias.weight, std=0.05, generator=rng)
68+
nn.init.normal_(self.u_bias.weight, std=init_scale, generator=rng)
6869
if self.i_bias is not None:
69-
nn.init.normal_(self.i_bias.weight, std=0.05, generator=rng)
70+
nn.init.normal_(self.i_bias.weight, std=init_scale, generator=rng)
7071

71-
nn.init.normal_(self.u_embed.weight, std=0.05, generator=rng)
72-
nn.init.normal_(self.i_embed.weight, std=0.05, generator=rng)
72+
nn.init.normal_(self.u_embed.weight, std=init_scale, generator=rng)
73+
nn.init.normal_(self.i_embed.weight, std=init_scale, generator=rng)
7374

7475
@property
7576
def device(self):

0 commit comments

Comments
 (0)