Skip to content

Commit

Permalink
Improved documentation of autoencoder example
Browse files Browse the repository at this point in the history
  • Loading branch information
Pseudomanifold committed Feb 17, 2022
1 parent dd2cf51 commit 0e6c220
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion torch_topological/examples/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def forward(self, x):


class TopologicalAutoencoder(torch.nn.Module):
"""Wrapper for a topologically-regularised autoencoder."""
"""Wrapper for a topologically-regularised autoencoder.
This class uses another autoencoder model and imbues it with an
additional topology-based loss term.
"""
def __init__(self, model, lam=1.0):
super().__init__()

Expand All @@ -94,6 +98,9 @@ def forward(self, x):


if __name__ == '__main__':
# We first have to create a data set. This follows the original
# publication by Moor et al. by introducing a simple 'manifold'
# data set consisting of multiple spheres.
n_spheres = 11
data_set = Spheres(n_spheres=n_spheres)

Expand All @@ -104,6 +111,10 @@ def forward(self, x):
drop_last=True
)

# Let's set up the two models that we are training. Note that in
# a real application, you would have a more complicated training
# setup, potentially with early stopping etc. This training loop
# is merely to be seen as a proof of concept.
model = LinearAutoencoder(input_dim=data_set.dimension)
topo_model = TopologicalAutoencoder(model, lam=10)

Expand All @@ -125,6 +136,7 @@ def forward(self, x):

progress.set_postfix(loss=loss.item())

# Evaluate the autoencoder on a new instance of the data set.
data_set = Spheres(
train=False,
n_samples=2000,
Expand Down

0 comments on commit 0e6c220

Please sign in to comment.