Skip to content

Commit

Permalink
removed unet name
Browse files Browse the repository at this point in the history
  • Loading branch information
rhewett committed Aug 2, 2021
1 parent 598bd7d commit 7d63d84
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/distdl_unet/unet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def __init__(self, levels, in_channels, base_channels, out_channels, **level_kwa
self.level_kwargs = level_kwargs

self.input_map = self.assemble_input_map()
self.unet = self.assemble_munet()
self.cycle = self.assemble_cycle()
self.output_map = self.assemble_output_map()

def assemble_input_map(self):
raise NotImplementedError()

def assemble_munet(self):
def assemble_cycle(self):
raise NotImplementedError()

def assemble_output_map(self):
Expand All @@ -30,7 +30,7 @@ def assemble_output_map(self):
def forward(self, input):

x_f = self.input_map(input)
y_f = self.unet(x_f)
y_f = self.cycle(x_f)
output = self.output_map(y_f)

return output
Expand Down
2 changes: 1 addition & 1 deletion src/distdl_unet/unet_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def assemble_input_map(self):
acti = torch.nn.ReLU(inplace=True)
return torch.nn.Sequential(conv, norm, acti)

def assemble_munet(self):
def assemble_cycle(self):
return ClassicalUNetLevel(self.feature_dimension,
self.levels, 0, 0, self.base_channels,
**self.level_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/distdl_unet/unet_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def assemble_input_map(self):
acti = torch.nn.ReLU(inplace=_relu_inplace)
return torch.nn.Sequential(conv, norm, acti)

def assemble_munet(self):
def assemble_cycle(self):
return DistributedUNetLevel(self.P,
self.levels, 0, 0, self.base_channels, **self.level_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/distdl_unet/unet_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def assemble_input_map(self):

return DummyLayer(self.feature_dimension, self.in_channels, self.base_channels, 0, "S")

def assemble_munet(self):
def assemble_cycle(self):
return DummyMuNetLevel(self.feature_dimension,
self.levels, 0, 0, self.base_channels, **self.level_kwargs)

Expand Down

0 comments on commit 7d63d84

Please sign in to comment.