Skip to content

Commit

Permalink
Custom NNs must take feature & embed dims as args
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Sep 6, 2021
1 parent add0a41 commit d6c33e2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
8 changes: 5 additions & 3 deletions atomai/models/dklgp/dklgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def fit(self, X: Union[torch.Tensor, np.ndarray],
Keyword Args:
feature_extractor:
(Optional) Custom neural network for feature extractor
(Optional) Custom neural network for feature extractor.
Must take input/feature dims and embedding dims as its arguments.
freeze_weights:
Freezes weights of feature extractor, that is, they are not
passed to the optimizer. Used for a transfer learning.
Expand All @@ -98,7 +99,7 @@ def fit_ensemble(self, X: Union[torch.Tensor, np.ndarray],
**kwargs: Union[Type[torch.nn.Module], bool, float]
) -> None:
"""
Initializes and trains a deep kernel GP model
Initializes and trains an ensemble of deep kernel GP model
Args:
X: Input training data (aka features) of N x input_dim dimensions
Expand All @@ -108,7 +109,8 @@ def fit_ensemble(self, X: Union[torch.Tensor, np.ndarray],
Keyword Args:
feature_extractor:
(Optional) Custom neural network for feature extractor
(Optional) Custom neural network for feature extractor.
Must take input/feature dims and embedding dims as its arguments.
freeze_weights:
Freezes weights of feature extractor, that is, they are not
passed to the optimizer. Used for a transfer learning.
Expand Down
10 changes: 6 additions & 4 deletions atomai/trainers/gptrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def compile_multi_model_trainer(self,
the number of neural networks is equal to the number of Gaussian
processes. For example, if the outputs are spectra of length 128,
one will have 128 neural networks and 128 GPs trained in parallel.
It can be also used for training an ensembles of models for the same
scalar output.
"""
if self.correlated_output:
raise NotImplementedError(
Expand Down Expand Up @@ -160,7 +162,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
Keyword Args:
feature_extractor:
(Optional) Custom neural network for feature extractor
(Optional) Custom neural network for feature extractor.
Must take input/feature dims and embedding dims as its arguments.
grid_size:
Grid size for structured kernel interpolation (Default: 50)
freeze_weights:
Expand All @@ -174,9 +177,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
"use compile_multi_model_trainer(*args, **kwargs)")
X, y = self.set_data(X, y)
input_dim, embedim = self.dimdict["input_dim"], self.dimdict["embedim"]
feature_extractor = kwargs.get("feature_extractor")
if feature_extractor is None:
feature_extractor = fcFeatureExtractor(input_dim, embedim)
feature_net = kwargs.get("feature_extractor", fcFeatureExtractor)
feature_extractor = feature_net(input_dim, embedim)
freeze_weights = kwargs.get("freeze_weights", False)
if freeze_weights:
for p in feature_extractor.parameters():
Expand Down

0 comments on commit d6c33e2

Please sign in to comment.