From 9e448dc519985c8b538da355e2d04640e2bc5e39 Mon Sep 17 00:00:00 2001 From: h-westmacott Date: Thu, 19 Oct 2023 10:46:00 +0100 Subject: [PATCH 1/3] initial commit of semi-supervised code --- .../semi_supervised_experiment_config.py | 48 ++++ .../configs/semi_supervised_train_config.py | 111 ++++++++ cellulus/criterions/__init__.py | 117 +++++++++ cellulus/criterions/stardist_loss.py | 62 +++++ cellulus/datasets/__init__.py | 6 + cellulus/datasets/zarr_dataset.py | 90 ++++++- cellulus/train_semi_supervised.py | 247 ++++++++++++++++++ 7 files changed, 667 insertions(+), 14 deletions(-) create mode 100644 cellulus/configs/semi_supervised_experiment_config.py create mode 100644 cellulus/configs/semi_supervised_train_config.py create mode 100644 cellulus/criterions/stardist_loss.py create mode 100644 cellulus/train_semi_supervised.py diff --git a/cellulus/configs/semi_supervised_experiment_config.py b/cellulus/configs/semi_supervised_experiment_config.py new file mode 100644 index 0000000..ad881e6 --- /dev/null +++ b/cellulus/configs/semi_supervised_experiment_config.py @@ -0,0 +1,48 @@ +import attrs +from attrs.validators import instance_of + +from .inference_config import InferenceConfig +from .model_config import ModelConfig +from .semi_supervised_train_config import SemiSupervisedTrainConfig +from .utils import to_config + + +@attrs.define +class SemiSupervisedExperimentConfig: + """Top-level config for a semi-supervised experiment (containing training and prediction). + + Parameters: + + experiment_name: + + A unique name for the experiment. + + object_size: + + A rough estimate of the size of objects in the image, given in + world units. The "patch size" of the network will be chosen based + on this estimate. + + model_config: + + The model configuration. + + semi_sup_train_config: + + Configuration object for training the semi-supervised model. + + inference_config: + + Configuration object for prediction. + """ + + experiment_name: str = attrs.field(validator=instance_of(str)) + object_size: float = attrs.field(validator=instance_of(float)) + + model_config: ModelConfig = attrs.field(converter=to_config(ModelConfig)) + semi_sup_train_config: SemiSupervisedTrainConfig = attrs.field( + default=None, converter=to_config(SemiSupervisedTrainConfig) + ) + inference_config: InferenceConfig = attrs.field( + default=None, converter=to_config(InferenceConfig) + ) \ No newline at end of file diff --git a/cellulus/configs/semi_supervised_train_config.py b/cellulus/configs/semi_supervised_train_config.py new file mode 100644 index 0000000..8934434 --- /dev/null +++ b/cellulus/configs/semi_supervised_train_config.py @@ -0,0 +1,111 @@ +from typing import List + +import attrs +from attrs.validators import instance_of + +from .dataset_config import DatasetConfig +from .utils import to_config + + +@attrs.define +class SemiSupervisedTrainConfig: + """Train configuration. + + Parameters: + + raw_data_config: + + Configuration object for the raw training data. + + pseudo_data_config: + + Configuration object for the pseudo-ground-truth labels. + + supervised_data_config: + + Configuration object for the ground-truth labels/annotations. + + crop_size: + + The size of the crops - specified as a tuple of pixels - + extracted from the raw images, used during training. + + batch_size: + + The number of samples to use per batch. + + max_iterations: + + The maximum number of iterations to train for. + + initial_learning_rate (default = 4e-5): + + Initial learning rate of the optimizer. + + temperature (default = 10): + + Factor used to scale the gaussian function and control the rate of damping. + + regularizer_weight (default = 1e-5): + + The weight of the L2 regularizer on the object-centric embeddings. + + reduce_mean (default = True): + + If True, the loss contribution is averaged across all pairs of patches. + + density (default = 0.2) + + Determines the fraction of patches to sample per crop, during training. + + kappa (default = 10.0): + + Neighborhood radius to extract patches from + + save_model_every (default = 1e3): + + The model weights are saved every few iterations. + + save_snapshot_every (default = 1e3): + + The zarr snapshot is saved every few iterations. + + num_workers (default = 8): + + The number of sub-processes to use for data-loading. + + control_point_spacing (default = 64): + + The distance in pixels between control points used for elastic + deformation of the raw data during training. + + control_point_jitter (default = 2.0): + + How much to jitter the control points for elastic deformation + of the raw data during training, given as the standard deviation of + a normal distribution with zero mean. + + + """ + + raw_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) + pseudo_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) + supervised_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) + + crop_size: List = attrs.field(default=[252, 252], validator=instance_of(List)) + batch_size: int = attrs.field(default=8, validator=instance_of(int)) + max_iterations: int = attrs.field(default=100_000, validator=instance_of(int)) + initial_learning_rate: float = attrs.field( + default=4e-5, validator=instance_of(float) + ) + density: float = attrs.field(default=0.2, validator=instance_of(float)) + kappa: float = attrs.field(default=10.0, validator=instance_of(float)) + temperature: float = attrs.field(default=10.0, validator=instance_of(float)) + regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float)) + reduce_mean: bool = attrs.field(default=True, validator=instance_of(bool)) + save_model_every: int = attrs.field(default=1_000, validator=instance_of(int)) + save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int)) + num_workers: int = attrs.field(default=8, validator=instance_of(int)) + + control_point_spacing: int = attrs.field(default=64, validator=instance_of(int)) + control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float)) diff --git a/cellulus/criterions/__init__.py b/cellulus/criterions/__init__.py index ea57d4f..0cb6088 100644 --- a/cellulus/criterions/__init__.py +++ b/cellulus/criterions/__init__.py @@ -1,4 +1,8 @@ from cellulus.criterions.oce_loss import OCELoss +import numpy as np +import stardist +from inferno.io.transform import Transform +import gunpowder as gp def get_loss( @@ -19,3 +23,116 @@ def get_loss( reduce_mean, device, ) + +class TransformStardist(gp.BatchFilter): + def __init__(self,array): + self.array = array + def prepare(self, request): + + # the requested ROI for array + # expects (17,x,y) + roi = request[self.array].roi + + self.stardist_shape = roi.get_shape() + self.stardist_roi = roi + print('roi = ',roi) + + # 1. compute the context + # context = gp.Coordinate((self.truncate,)*roi.dims()) * self.sigma + + # 2. enlarge the requested ROI by the context + # roi.__offset = [0,0,0,0] + # context_roi = roi.set_shape([1,1,stardist_shape[1],stardist_shape[2]]) + # context_roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) + roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) + print('roi =',roi) + + # create a new request with our dependencies + deps = gp.BatchRequest() + deps[self.array] = roi + print('deps created') + # return the request + return deps + + def process(self, batch, request): + self.data_shape = data.shape + data = batch[self.array].data + # import numpy as np + print(self.array, data.shape, np.unique(data)) + temp = stardist_transform(data) + print(temp.shape, np.unique(temp)) + batch[self.array].data = temp + + +def stardist_transform(gt, n_rays=16, fill_label_holes=False): + + if len(gt.shape)>2: + gt = np.squeeze(gt) + + if np.any(gt - gt.astype(np.uint16)): + mapping={v:k for k,v in enumerate(np.unique(gt))} + u,inv = np.unique(gt,return_inverse = True) + Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape) + gt = Y1.astype(np.uint16) + + + if fill_label_holes: + gt = stardist.fill_label_holes(gt) + + dist = stardist.geometry.star_dist(gt, n_rays = n_rays) + dist_mask = stardist.utils.edt_prob(gt.astype(int)) + + if gt.min() < 0: + # ignore label found + ignore_mask = gt < 0 + print(gt.shape, dist.shape) + dist[ignore_mask] = 0 + dist_mask[ignore_mask] = -1 + + dist_mask = dist_mask[None] + dist = np.transpose(dist, (2, 0, 1)) + + # dist_mask = torch.tensor(dist_mask) + # dist = torch.tensor(dist) + mask_and_dist = np.concatenate([dist_mask, dist], axis=0) + + # mask_and_dist = torch.cat([dist_mask, dist], axis=0) + return mask_and_dist + + +class StardistTf(Transform): + """Convert segmentation to stardist""" + + def __init__(self, n_rays=16, fill_label_holes=False): + super().__init__() + self.n_rays = n_rays + self.fill_label_holes = fill_label_holes + + def tensor_function(self, gt): + + if np.any(gt-gt.astype(np.uint16)): + mapping={v:k for k,v in enumerate(np.unique(gt))} + u,inv = np.unique(gt,return_inverse = True) + Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape) + gt = Y1.astype(np.uint16) + # gt = measure.label(gt) + if self.fill_label_holes: + gt = stardist.fill_label_holes(gt) + # import pdb + # pdb.set_trace() + # print('gt.type',gt.type()) + dist = stardist.geometry.star_dist(gt, n_rays=self.n_rays) + dist_mask = stardist.utils.edt_prob(gt) + + if gt.min() < 0: + # ignore label found + ignore_mask = gt < 0 + print(gt.shape, dist.shape) + dist[ignore_mask] = 0 + dist_mask[ignore_mask] = -1 + + dist_mask = dist_mask[None] + dist = np.transpose(dist, (2, 0, 1)) + + mask_and_dist = np.concatenate([dist_mask, dist], axis=0) + return mask_and_dist \ No newline at end of file diff --git a/cellulus/criterions/stardist_loss.py b/cellulus/criterions/stardist_loss.py new file mode 100644 index 0000000..7fb1f88 --- /dev/null +++ b/cellulus/criterions/stardist_loss.py @@ -0,0 +1,62 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +class StardistLoss(nn.Module): + """Loss for stardist predictions combines BCE loss for probabilities + with MAE (L1) loss for distances + + Args: + weight: Distance loss weight. Total loss will be bce_loss + weight * l1_loss + """ + + def __init__(self, weight=1.): + + super().__init__() + self.weight = weight + + def forward(self, prediction, target, mask=None): + # Predicted distances errors are weighted by object prob + if target.shape!=prediction.shape: + prediction = prediction.squeeze(1) + + target_prob = target[:, :1] + predicted_prob = prediction[:, :1] + target_dist = target[:, 1:] + predicted_dist = prediction[:, 1:] + + if mask is not None: + target_prob = mask * target_prob + # do not train foreground prediction when mask is supplied + predicted_prob = predicted_prob.detach() + + l1loss_pp = F.l1_loss(predicted_dist, + target_dist, + reduction='none') + + ignore_mask_provided = target_prob.min() < 0 + if ignore_mask_provided: + # ignore label was supplied + ignore_mask = target_prob >= 0. + # add one to avoid division by zero + imsum = ignore_mask.sum() + if imsum == 0: + print("WARNING: Batch with only ignorelabel encountered!") + return 0*l1loss_pp.sum() + + l1loss = ((target_prob * ignore_mask) * l1loss_pp).sum() / imsum + + bceloss = F.binary_cross_entropy_with_logits(predicted_prob[ignore_mask], + target_prob[ignore_mask].float(), + reduction='sum') / imsum + return self.weight * l1loss + bceloss + + # weight predictions by target probs + l1loss = (target_prob * l1loss_pp).mean() + + bceloss = F.binary_cross_entropy_with_logits(predicted_prob, + target_prob.float(), + reduction='mean') + + return (self.weight * l1loss) + bceloss \ No newline at end of file diff --git a/cellulus/datasets/__init__.py b/cellulus/datasets/__init__.py index 2b1f464..599fc72 100644 --- a/cellulus/datasets/__init__.py +++ b/cellulus/datasets/__init__.py @@ -10,10 +10,16 @@ def get_dataset( crop_size: Tuple[int, ...], control_point_spacing: int, control_point_jitter: float, + semi_supervised: bool = False, + supervised_dataset_config: DatasetConfig = None, + pseudo_dataset_config: DatasetConfig = None, ) -> ZarrDataset: return ZarrDataset( dataset_config=dataset_config, crop_size=crop_size, control_point_spacing=control_point_spacing, control_point_jitter=control_point_jitter, + semi_supervised = semi_supervised, + supervised_dataset_config = supervised_dataset_config, + pseudo_dataset_config = pseudo_dataset_config, ) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index 5453076..a7e162b 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -8,6 +8,9 @@ from .meta_data import DatasetMetaData +from cellulus.criterions import stardist_transform +import numpy as np + class ZarrDataset(IterableDataset): # type: ignore def __init__( @@ -16,7 +19,10 @@ def __init__( crop_size: Tuple[int, ...], control_point_spacing: int, control_point_jitter: float, - ): + semi_supervised: bool = False, + supervised_dataset_config: DatasetConfig = None, + pseudo_dataset_config: DatasetConfig = None, + ): """A dataset that serves random samples from a zarr container. Args: @@ -55,6 +61,14 @@ def __init__( self.crop_size = crop_size self.control_point_spacing = control_point_spacing self.control_point_jitter = control_point_jitter + self.semi_supervised = semi_supervised + if supervised_dataset_config != None: + self.supervised_dataset_config = supervised_dataset_config + self.pseudo_dataset_config = pseudo_dataset_config + else: + self.supervised_dataset_config = None + self.pseudo_dataset_config = None + self.__read_meta_data() assert len(crop_size) == self.num_spatial_dims, ( @@ -70,6 +84,8 @@ def __iter__(self): def __setup_pipeline(self): self.raw = gp.ArrayKey("RAW") + self.pseudo = gp.ArrayKey("PSEUDO") + self.supervised = gp.ArrayKey("SUPERVISED") # treat all dimensions as spatial, with a voxel size of 1 raw_spec = gp.ArraySpec(voxel_size=(1,) * self.num_dims, interpolatable=True) @@ -77,22 +93,39 @@ def __setup_pipeline(self): # spatial_dims = tuple(range(self.num_dims - self.num_spatial_dims, # self.num_dims)) - self.pipeline = ( - gp.ZarrSource( + if self.supervised_dataset_config != None: + source_node = gp.ZarrSource( + self.dataset_config.container_path, + {self.raw: self.dataset_config.dataset_name, + self.pseudo: self.pseudo_dataset_config.dataset_name, + self.supervised: self.supervised_dataset_config.dataset_name}, + array_specs={self.raw: raw_spec, + self.pseudo: raw_spec, + self.supervised: raw_spec}, + ) + else: + source_node = gp.ZarrSource( self.dataset_config.container_path, {self.raw: self.dataset_config.dataset_name}, array_specs={self.raw: raw_spec}, ) + + # Elastic augmentation is incompatible with labels, because the images get + # interpolated. If Elastic augmentation is required, the self-supervised + # training type needs to be switched from combined labels to separate. + # This is because stardist representations survive the Elastic Augment. + self.pipeline = ( + source_node + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, - ) + # + gp.ElasticAugment( + # control_point_spacing=(self.control_point_spacing,) + # * self.num_spatial_dims, + # jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, + # rotation_interval=(0, math.pi / 2), + # scale_interval=(0.9, 1.1), + # subsample=4, + # spatial_dims=self.num_spatial_dims, + # ) # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) @@ -108,9 +141,38 @@ def __yield_sample(self): (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) ) ) + if self.supervised_dataset_config != None: + # if we have a supervised dataset config, we must be training a semi-supervised + # model. Therefore we need to add requests to our gp pipeline for pseudo- and + # GT-annotations + request[self.pseudo] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + ) + ) + request[self.supervised] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + ) + ) sample = self.pipeline.request_batch(request) - yield sample[self.raw].data[0] + + if self.semi_supervised: + # If we are training a semi-supervised model, our dataset class needs to + # return all of the below datasets + transformed_pseudo = stardist_transform(sample[self.pseudo].data[0]) + transformed_supervised = stardist_transform(sample[self.supervised].data[0]) + yield {'raw':sample[self.raw].data[0], + 'pseudo_stardist':transformed_pseudo, + 'supervised_stardist':transformed_supervised, + 'pseudo_labels':sample[self.pseudo].data[0], + 'supervised_labels':sample[self.supervised].data[0]} + + else: + # if we are training a self-supervised model (i.e. standard cellulus), + # we need to return just the raw image data. + yield sample[self.raw].data[0] def __read_meta_data(self): meta_data = DatasetMetaData.from_dataset_config(self.dataset_config) @@ -127,4 +189,4 @@ def get_num_channels(self): return self.num_channels def get_num_spatial_dims(self): - return self.num_spatial_dims + return self.num_spatial_dims \ No newline at end of file diff --git a/cellulus/train_semi_supervised.py b/cellulus/train_semi_supervised.py new file mode 100644 index 0000000..5aadaab --- /dev/null +++ b/cellulus/train_semi_supervised.py @@ -0,0 +1,247 @@ +import os + +import torch +import zarr +from tqdm import tqdm + +from cellulus.datasets import get_dataset +from cellulus.models import get_model +from cellulus.criterions.stardist_loss import StardistLoss +from cellulus.utils import get_logger +import torch.nn as nn +import numpy as np + +def semisupervised_train(semi_sup_exp_config): + print(semi_sup_exp_config) + + if not os.path.exists("semi_supervised_models"): + os.makedirs("semi_supervised_models") + + model_config = semi_sup_exp_config.model_config + + semi_sup_train_config = semi_sup_exp_config.semi_sup_train_config + + raw_dataset = get_dataset( + dataset_config = semi_sup_train_config.raw_data_config, + crop_size = tuple(semi_sup_train_config.crop_size), + control_point_spacing = semi_sup_train_config.control_point_spacing, + control_point_jitter = semi_sup_train_config.control_point_jitter, + pseudo_dataset_config = semi_sup_train_config.pseudo_data_config, + supervised_dataset_config = semi_sup_train_config.supervised_data_config, + semi_supervised = True + ) + + raw_dataloader = torch.utils.data.DataLoader( + dataset=raw_dataset, + batch_size=semi_sup_train_config.batch_size, + drop_last=True, + num_workers=semi_sup_train_config.num_workers, + pin_memory=True, + ) + + model = get_model( + in_channels=raw_dataset.get_num_channels(), + out_channels=17, + num_fmaps=model_config.num_fmaps, + fmap_inc_factor=model_config.fmap_inc_factor, + features_in_last_layer=model_config.features_in_last_layer, + downsampling_factors=[ + tuple(factor) for factor in model_config.downsampling_factors + ], + num_spatial_dims=raw_dataset.get_num_spatial_dims(), + ) + + if torch.cuda.is_available(): + model = model.cuda() + + + + + criterion = StardistLoss() + + # set optimizer + optimizer = torch.optim.Adam( + model.parameters(), + lr=semi_sup_train_config.initial_learning_rate, + ) + + def lambda_(iteration): + return pow((1 - ((iteration) / semi_sup_train_config.max_iterations)), 0.9) + + # set logger + logger = get_logger(keys=["train"], title="loss") + + # resume training + start_iteration = 0 + lowest_loss = 1e7 + + if model_config.checkpoint is None: + pass + else: + print(f"Resuming model from {model_config.checkpoint}") + state = torch.load(model_config.checkpoint) + start_iteration = state["iteration"] + 1 + lowest_loss = state["lowest_loss"] + model.load_state_dict(state["model_state_dict"], strict=True) + optimizer.load_state_dict(state["optim_state_dict"]) + logger.data = state["logger_data"] + + # call `train_iteration` + for iteration, batch in tqdm( + zip( + range(start_iteration, semi_sup_train_config.max_iterations), + raw_dataloader + ) + ): + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 + ) + + train_loss, prediction = train_iteration( + batch, + model=model, + criterion=criterion, + optimizer=optimizer, + ) + scheduler.step() + logger.add(key="train", value=train_loss.cpu().detach().numpy()) + logger.write() + logger.plot() + + if (iteration + 1) % semi_sup_train_config.save_model_every == 0: + is_lowest = train_loss < lowest_loss + lowest_loss = min(train_loss, lowest_loss) + state = { + "iteration": iteration, + "lowest_loss": lowest_loss, + "model_state_dict": model.state_dict(), + "optim_state_dict": optimizer.state_dict(), + "logger_data": logger.data, + } + save_model(state, iteration, is_lowest) + + if (iteration + 1) % semi_sup_train_config.save_snapshot_every == 0: + save_snapshot( + batch, + prediction, + iteration, + ) + +def train_iteration( + batch, + model, + criterion, + optimizer, + alpha = 0.5 +): + model.train() + + def unqiue_values_to_unique_ints(array): + if np.any(array - array.astype(np.int16)): + mapping={v:k for k,v in enumerate(np.unique(array))} + u,inv = np.unique(array,return_inverse = True) + Y1 = np.array([mapping[x] for x in u])[inv].reshape(array.shape) + array = Y1.astype(np.int16) + return array + + if torch.cuda.is_available(): + batch['raw'] = batch['raw'].to("cuda") + batch['pseudo_stardist'] = batch['pseudo_stardist'].to("cuda") + batch['supervised_stardist'] = batch['supervised_stardist'].to("cuda") + batch['pseudo_labels'] = batch['pseudo_labels'].to("cuda") + batch['supervised_labels'] = batch['supervised_labels'].to("cuda") + batch['supervised_labels'] = torch.tensor(unqiue_values_to_unique_ints(batch['supervised_labels'].cpu().detach().numpy())).to("cuda") + + prediction = model(batch['raw']) + + + def combine_GT_and_pseudo_labels(gt_labels, pseudo_labels): + if gt_labels.shape != pseudo_labels.shape: + print('labelled images are different sizes') + + combined_labels = np.zeros(gt_labels.shape) + combined_labels[:] = gt_labels.cpu().detach().numpy()[:] + + for pseudo_label_value in np.unique(pseudo_labels.cpu().detach().numpy()): + if pseudo_label_value == 0.0: + pass + # for each pseudo label, check it does not intersect with a GT label. + # if it doesn't intersect, add it to the combined labels. + # if it does intersect, leave it or add to ignore mask? + + # all of the indicies that have this given label value + # this_pseudo_label = np.where(np.any(pseudo_labels==pseudo_label_value)) + + if np.any(gt_labels.cpu().detach().numpy()[pseudo_labels.cpu().detach().numpy()==pseudo_label_value]): + pass + else: + combined_labels = combined_labels + pseudo_labels.cpu().detach().numpy()*(pseudo_labels.cpu().detach().numpy()==pseudo_label_value) + return combined_labels + + gt_labels = batch['supervised_labels'][:,:,:prediction.shape[2],:prediction.shape[3]] + pseudo_labels = batch['pseudo_labels'][:,:,:prediction.shape[2],:prediction.shape[3]] + combined_labels = combine_GT_and_pseudo_labels(gt_labels, pseudo_labels) + + use_combined_loss = True + + if use_combined_loss: + from cellulus.criterions import stardist_transform + + combined_stardist = torch.tensor(stardist_transform(combined_labels)).cuda().unsqueeze(0) + + loss = criterion(prediction, combined_stardist) + else: + supervised_loss = criterion(prediction, batch['supervised_stardist'][:,:,:prediction.shape[2],:prediction.shape[3]]) + semisupervised_loss = criterion(prediction, batch['pseudo_stardist'][:,:,:prediction.shape[2],:prediction.shape[3]]) + loss = (alpha*supervised_loss) + ((1-alpha)*semisupervised_loss) + + + optimizer.zero_grad() + loss.backward() + optimizer.step() + return loss, prediction + + +def save_model(state, iteration, is_lowest=False): + file_name = os.path.join("semi_supervised_models", str(iteration).zfill(6) + ".pth") + torch.save(state, file_name) + print(f"Checkpoint saved at iteration {iteration}") + if is_lowest: + file_name = os.path.join("semi_supervised_models", "best_loss.pth") + torch.save(state, file_name) + + +def save_snapshot(batch, prediction, iteration): + num_spatial_dims = len(batch['raw'].shape) - 2 + + axis_names = ["s", "c"] + ["t", "z", "y", "x"][-num_spatial_dims:] + prediction_offset = tuple( + (a - b) / 2 + for a, b in zip( + batch['raw'].shape[-num_spatial_dims:], prediction.shape[-num_spatial_dims:] + ) + ) + f = zarr.open("semi_supervised_snapshots.zarr", "a") + f[f"{iteration}/raw"] = batch['raw'].detach().cpu().numpy() + f[f"{iteration}/raw"].attrs["axis_names"] = axis_names + + f[f"{iteration}/pseudo_stardist"] = batch['pseudo_stardist'].detach().cpu().numpy() + f[f"{iteration}/pseudo_stardist"].attrs["axis_names"] = axis_names + + f[f"{iteration}/pseudo_labels"] = batch['pseudo_labels'].detach().cpu().numpy() + f[f"{iteration}/pseudo_labels"].attrs["axis_names"] = axis_names + + f[f"{iteration}/supervised_stardist"] = batch['supervised_stardist'].detach().cpu().numpy() + f[f"{iteration}/supervised_stardist"].attrs["supervised"] = axis_names + + f[f"{iteration}/supervised_labels"] = batch['supervised_labels'].detach().cpu().numpy() + f[f"{iteration}/supervised_labels"].attrs["supervised"] = axis_names + + f[f"{iteration}/prediction"] = prediction.detach().cpu().numpy() + f[f"{iteration}/prediction"].attrs["axis_names"] = axis_names + f[f"{iteration}/prediction"].attrs["offset"] = prediction_offset + + print(f"Snapshot saved at iteration {iteration}") + + + From ec8158f6915c4cf1840fb32f2c67e5aa6357ba3a Mon Sep 17 00:00:00 2001 From: h-westmacott Date: Thu, 19 Oct 2023 11:37:32 +0100 Subject: [PATCH 2/3] Move Elastic augmentation to apply to unsupervised cellulus --- cellulus/datasets/zarr_dataset.py | 41 ++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index a7e162b..79d168d 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -103,18 +103,13 @@ def __setup_pipeline(self): self.pseudo: raw_spec, self.supervised: raw_spec}, ) - else: - source_node = gp.ZarrSource( - self.dataset_config.container_path, - {self.raw: self.dataset_config.dataset_name}, - array_specs={self.raw: raw_spec}, - ) - - # Elastic augmentation is incompatible with labels, because the images get - # interpolated. If Elastic augmentation is required, the self-supervised - # training type needs to be switched from combined labels to separate. - # This is because stardist representations survive the Elastic Augment. - self.pipeline = ( + + # Elastic augmentation is incompatible with labels, because the images get + # interpolated. If Elastic augmentation is required, the self-supervised + # training type needs to be switched from combined labels to separate. + # This is because stardist representations survive the Elastic Augment. + + self.pipeline = ( source_node + gp.RandomLocation() # + gp.ElasticAugment( @@ -129,6 +124,28 @@ def __setup_pipeline(self): # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) + else: + source_node = gp.ZarrSource( + self.dataset_config.container_path, + {self.raw: self.dataset_config.dataset_name}, + array_specs={self.raw: raw_spec}, + ) + + self.pipeline = ( + source_node + + gp.RandomLocation() + + gp.ElasticAugment( + control_point_spacing=(self.control_point_spacing,) + * self.num_spatial_dims, + jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, + rotation_interval=(0, math.pi / 2), + scale_interval=(0.9, 1.1), + subsample=4, + spatial_dims=self.num_spatial_dims, + ) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) + ) + def __yield_sample(self): """An infinite generator of crops.""" From a9901db4bff05ad01cbb0bc34078b1050debe9c5 Mon Sep 17 00:00:00 2001 From: h-westmacott Date: Thu, 19 Oct 2023 13:09:58 +0100 Subject: [PATCH 3/3] checked with ruuf, mypy and balck --- .../semi_supervised_experiment_config.py | 5 +- .../configs/semi_supervised_train_config.py | 6 +- cellulus/criterions/__init__.py | 111 +++++++-------- cellulus/criterions/stardist_loss.py | 37 ++--- cellulus/datasets/__init__.py | 6 +- cellulus/datasets/zarr_dataset.py | 104 +++++++------- cellulus/train_semi_supervised.py | 129 ++++++++++-------- 7 files changed, 215 insertions(+), 183 deletions(-) diff --git a/cellulus/configs/semi_supervised_experiment_config.py b/cellulus/configs/semi_supervised_experiment_config.py index ad881e6..29df761 100644 --- a/cellulus/configs/semi_supervised_experiment_config.py +++ b/cellulus/configs/semi_supervised_experiment_config.py @@ -9,7 +9,8 @@ @attrs.define class SemiSupervisedExperimentConfig: - """Top-level config for a semi-supervised experiment (containing training and prediction). + """Top-level config for a semi-supervised experiment + (containing training and prediction). Parameters: @@ -45,4 +46,4 @@ class SemiSupervisedExperimentConfig: ) inference_config: InferenceConfig = attrs.field( default=None, converter=to_config(InferenceConfig) - ) \ No newline at end of file + ) diff --git a/cellulus/configs/semi_supervised_train_config.py b/cellulus/configs/semi_supervised_train_config.py index 8934434..3cf3b14 100644 --- a/cellulus/configs/semi_supervised_train_config.py +++ b/cellulus/configs/semi_supervised_train_config.py @@ -85,12 +85,14 @@ class SemiSupervisedTrainConfig: of the raw data during training, given as the standard deviation of a normal distribution with zero mean. - + """ raw_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) pseudo_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) - supervised_data_config: DatasetConfig = attrs.field(converter=to_config(DatasetConfig)) + supervised_data_config: DatasetConfig = attrs.field( + converter=to_config(DatasetConfig) + ) crop_size: List = attrs.field(default=[252, 252], validator=instance_of(List)) batch_size: int = attrs.field(default=8, validator=instance_of(int)) diff --git a/cellulus/criterions/__init__.py b/cellulus/criterions/__init__.py index 0cb6088..1aa3987 100644 --- a/cellulus/criterions/__init__.py +++ b/cellulus/criterions/__init__.py @@ -1,8 +1,9 @@ -from cellulus.criterions.oce_loss import OCELoss +import gunpowder as gp import numpy as np import stardist from inferno.io.transform import Transform -import gunpowder as gp + +from cellulus.criterions.oce_loss import OCELoss def get_loss( @@ -24,64 +25,65 @@ def get_loss( device, ) + class TransformStardist(gp.BatchFilter): - def __init__(self,array): - self.array = array - def prepare(self, request): - - # the requested ROI for array - # expects (17,x,y) - roi = request[self.array].roi - - self.stardist_shape = roi.get_shape() - self.stardist_roi = roi - print('roi = ',roi) - - # 1. compute the context - # context = gp.Coordinate((self.truncate,)*roi.dims()) * self.sigma - - # 2. enlarge the requested ROI by the context - # roi.__offset = [0,0,0,0] - # context_roi = roi.set_shape([1,1,stardist_shape[1],stardist_shape[2]]) - # context_roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) - roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) - print('roi =',roi) - - # create a new request with our dependencies - deps = gp.BatchRequest() - deps[self.array] = roi - print('deps created') - # return the request - return deps - - def process(self, batch, request): - self.data_shape = data.shape - data = batch[self.array].data - # import numpy as np - print(self.array, data.shape, np.unique(data)) - temp = stardist_transform(data) - print(temp.shape, np.unique(temp)) - batch[self.array].data = temp + def __init__(self, array): + self.array = array + + def prepare(self, request): + # the requested ROI for array + # expects (17,x,y) + roi = request[self.array].roi + + self.stardist_shape = roi.get_shape() + self.stardist_roi = roi + print("roi = ", roi) + + # 1. compute the context + # context = gp.Coordinate((self.truncate,)*roi.dims()) * self.sigma + + # 2. enlarge the requested ROI by the context + # roi.__offset = [0,0,0,0] + # context_roi = roi.set_shape([1,1,stardist_shape[1],stardist_shape[2]]) + # context_roi = gp.Roi((0,0,0,0),(1,1,self.stardist_shape[1],self.stardist_shape[2])) # noqa: E501 + roi = gp.Roi( + (0, 0, 0, 0), (1, 1, self.stardist_shape[1], self.stardist_shape[2]) + ) + print("roi =", roi) + + # create a new request with our dependencies + deps = gp.BatchRequest() + deps[self.array] = roi + print("deps created") + # return the request + return deps + + def process(self, batch, request): + data = batch[self.array].data + self.data_shape = data.shape + # import numpy as np + print(self.array, data.shape, np.unique(data)) + temp = stardist_transform(data) + print(temp.shape, np.unique(temp)) + batch[self.array].data = temp def stardist_transform(gt, n_rays=16, fill_label_holes=False): - - if len(gt.shape)>2: - gt = np.squeeze(gt) + if len(gt.shape) > 2: + gt = np.squeeze(gt) if np.any(gt - gt.astype(np.uint16)): - mapping={v:k for k,v in enumerate(np.unique(gt))} - u,inv = np.unique(gt,return_inverse = True) - Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape) - gt = Y1.astype(np.uint16) - + mapping = {v: k for k, v in enumerate(np.unique(gt))} + u, inv = np.unique(gt, return_inverse=True) + Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape) + gt = Y1.astype(np.uint16) if fill_label_holes: gt = stardist.fill_label_holes(gt) - dist = stardist.geometry.star_dist(gt, n_rays = n_rays) + dist = stardist.geometry.star_dist(gt, n_rays=n_rays) dist_mask = stardist.utils.edt_prob(gt.astype(int)) - + if gt.min() < 0: # ignore label found ignore_mask = gt < 0 @@ -109,10 +111,9 @@ def __init__(self, n_rays=16, fill_label_holes=False): self.fill_label_holes = fill_label_holes def tensor_function(self, gt): - - if np.any(gt-gt.astype(np.uint16)): - mapping={v:k for k,v in enumerate(np.unique(gt))} - u,inv = np.unique(gt,return_inverse = True) + if np.any(gt - gt.astype(np.uint16)): + mapping = {v: k for k, v in enumerate(np.unique(gt))} + u, inv = np.unique(gt, return_inverse=True) Y1 = np.array([mapping[x] for x in u])[inv].reshape(gt.shape) gt = Y1.astype(np.uint16) # gt = measure.label(gt) @@ -123,7 +124,7 @@ def tensor_function(self, gt): # print('gt.type',gt.type()) dist = stardist.geometry.star_dist(gt, n_rays=self.n_rays) dist_mask = stardist.utils.edt_prob(gt) - + if gt.min() < 0: # ignore label found ignore_mask = gt < 0 @@ -135,4 +136,4 @@ def tensor_function(self, gt): dist = np.transpose(dist, (2, 0, 1)) mask_and_dist = np.concatenate([dist_mask, dist], axis=0) - return mask_and_dist \ No newline at end of file + return mask_and_dist diff --git a/cellulus/criterions/stardist_loss.py b/cellulus/criterions/stardist_loss.py index 7fb1f88..33858d1 100644 --- a/cellulus/criterions/stardist_loss.py +++ b/cellulus/criterions/stardist_loss.py @@ -1,8 +1,7 @@ -import numpy as np -import torch import torch.nn as nn from torch.nn import functional as F + class StardistLoss(nn.Module): """Loss for stardist predictions combines BCE loss for probabilities with MAE (L1) loss for distances @@ -11,14 +10,13 @@ class StardistLoss(nn.Module): weight: Distance loss weight. Total loss will be bce_loss + weight * l1_loss """ - def __init__(self, weight=1.): - + def __init__(self, weight=1.0): super().__init__() self.weight = weight def forward(self, prediction, target, mask=None): # Predicted distances errors are weighted by object prob - if target.shape!=prediction.shape: + if target.shape != prediction.shape: prediction = prediction.squeeze(1) target_prob = target[:, :1] @@ -31,32 +29,35 @@ def forward(self, prediction, target, mask=None): # do not train foreground prediction when mask is supplied predicted_prob = predicted_prob.detach() - l1loss_pp = F.l1_loss(predicted_dist, - target_dist, - reduction='none') - + l1loss_pp = F.l1_loss(predicted_dist, target_dist, reduction="none") + ignore_mask_provided = target_prob.min() < 0 if ignore_mask_provided: # ignore label was supplied - ignore_mask = target_prob >= 0. + ignore_mask = target_prob >= 0.0 # add one to avoid division by zero imsum = ignore_mask.sum() if imsum == 0: print("WARNING: Batch with only ignorelabel encountered!") - return 0*l1loss_pp.sum() + return 0 * l1loss_pp.sum() l1loss = ((target_prob * ignore_mask) * l1loss_pp).sum() / imsum - bceloss = F.binary_cross_entropy_with_logits(predicted_prob[ignore_mask], - target_prob[ignore_mask].float(), - reduction='sum') / imsum + bceloss = ( + F.binary_cross_entropy_with_logits( + predicted_prob[ignore_mask], + target_prob[ignore_mask].float(), + reduction="sum", + ) + / imsum + ) return self.weight * l1loss + bceloss # weight predictions by target probs l1loss = (target_prob * l1loss_pp).mean() - bceloss = F.binary_cross_entropy_with_logits(predicted_prob, - target_prob.float(), - reduction='mean') + bceloss = F.binary_cross_entropy_with_logits( + predicted_prob, target_prob.float(), reduction="mean" + ) - return (self.weight * l1loss) + bceloss \ No newline at end of file + return (self.weight * l1loss) + bceloss diff --git a/cellulus/datasets/__init__.py b/cellulus/datasets/__init__.py index 599fc72..c4bcfa5 100644 --- a/cellulus/datasets/__init__.py +++ b/cellulus/datasets/__init__.py @@ -19,7 +19,7 @@ def get_dataset( crop_size=crop_size, control_point_spacing=control_point_spacing, control_point_jitter=control_point_jitter, - semi_supervised = semi_supervised, - supervised_dataset_config = supervised_dataset_config, - pseudo_dataset_config = pseudo_dataset_config, + semi_supervised=semi_supervised, + supervised_dataset_config=supervised_dataset_config, + pseudo_dataset_config=pseudo_dataset_config, ) diff --git a/cellulus/datasets/zarr_dataset.py b/cellulus/datasets/zarr_dataset.py index 79d168d..18a8455 100644 --- a/cellulus/datasets/zarr_dataset.py +++ b/cellulus/datasets/zarr_dataset.py @@ -5,12 +5,10 @@ from torch.utils.data import IterableDataset from cellulus.configs import DatasetConfig +from cellulus.criterions import stardist_transform from .meta_data import DatasetMetaData -from cellulus.criterions import stardist_transform -import numpy as np - class ZarrDataset(IterableDataset): # type: ignore def __init__( @@ -22,7 +20,7 @@ def __init__( semi_supervised: bool = False, supervised_dataset_config: DatasetConfig = None, pseudo_dataset_config: DatasetConfig = None, - ): + ): """A dataset that serves random samples from a zarr container. Args: @@ -62,13 +60,13 @@ def __init__( self.control_point_spacing = control_point_spacing self.control_point_jitter = control_point_jitter self.semi_supervised = semi_supervised - if supervised_dataset_config != None: + if supervised_dataset_config is not None: self.supervised_dataset_config = supervised_dataset_config self.pseudo_dataset_config = pseudo_dataset_config else: self.supervised_dataset_config = None self.pseudo_dataset_config = None - + self.__read_meta_data() assert len(crop_size) == self.num_spatial_dims, ( @@ -93,36 +91,40 @@ def __setup_pipeline(self): # spatial_dims = tuple(range(self.num_dims - self.num_spatial_dims, # self.num_dims)) - if self.supervised_dataset_config != None: + if self.supervised_dataset_config is not None: source_node = gp.ZarrSource( self.dataset_config.container_path, - {self.raw: self.dataset_config.dataset_name, - self.pseudo: self.pseudo_dataset_config.dataset_name, - self.supervised: self.supervised_dataset_config.dataset_name}, - array_specs={self.raw: raw_spec, - self.pseudo: raw_spec, - self.supervised: raw_spec}, + { + self.raw: self.dataset_config.dataset_name, + self.pseudo: self.pseudo_dataset_config.dataset_name, + self.supervised: self.supervised_dataset_config.dataset_name, + }, + array_specs={ + self.raw: raw_spec, + self.pseudo: raw_spec, + self.supervised: raw_spec, + }, ) - # Elastic augmentation is incompatible with labels, because the images get + # Elastic augmentation is incompatible with labels, because the images get # interpolated. If Elastic augmentation is required, the self-supervised # training type needs to be switched from combined labels to separate. # This is because stardist representations survive the Elastic Augment. self.pipeline = ( - source_node - + gp.RandomLocation() - # + gp.ElasticAugment( - # control_point_spacing=(self.control_point_spacing,) - # * self.num_spatial_dims, - # jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, - # rotation_interval=(0, math.pi / 2), - # scale_interval=(0.9, 1.1), - # subsample=4, - # spatial_dims=self.num_spatial_dims, - # ) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) - ) + source_node + + gp.RandomLocation() + # + gp.ElasticAugment( + # control_point_spacing=(self.control_point_spacing,) + # * self.num_spatial_dims, + # jitter_sigma=(self.control_point_jitter,) * self.num_spatial_dims, + # rotation_interval=(0, math.pi / 2), + # scale_interval=(0.9, 1.1), + # subsample=4, + # spatial_dims=self.num_spatial_dims, + # ) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) # noqa: E501 + ) else: source_node = gp.ZarrSource( @@ -130,7 +132,7 @@ def __setup_pipeline(self): {self.raw: self.dataset_config.dataset_name}, array_specs={self.raw: raw_spec}, ) - + self.pipeline = ( source_node + gp.RandomLocation() @@ -143,7 +145,7 @@ def __setup_pipeline(self): subsample=4, spatial_dims=self.num_spatial_dims, ) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) # noqa: E501 ) def __yield_sample(self): @@ -158,37 +160,43 @@ def __yield_sample(self): (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) ) ) - if self.supervised_dataset_config != None: - # if we have a supervised dataset config, we must be training a semi-supervised - # model. Therefore we need to add requests to our gp pipeline for pseudo- and - # GT-annotations + if self.supervised_dataset_config is not None: + # if we have a supervised dataset config, we must be training a semi + # -supervised model. Therefore we need to add requests to our gp + # pipeline for pseudo- and GT-annotations request[self.pseudo] = gp.ArraySpec( roi=gp.Roi( - (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + (0,) * self.num_dims, + (1, self.num_channels, *self.crop_size), ) ) request[self.supervised] = gp.ArraySpec( roi=gp.Roi( - (0,) * self.num_dims, (1, self.num_channels, *self.crop_size) + (0,) * self.num_dims, + (1, self.num_channels, *self.crop_size), ) ) sample = self.pipeline.request_batch(request) - if self.semi_supervised: - # If we are training a semi-supervised model, our dataset class needs to - # return all of the below datasets + if self.semi_supervised: + # If we are training a semi-supervised model, our dataset class + # needs to return all of the below datasets transformed_pseudo = stardist_transform(sample[self.pseudo].data[0]) - transformed_supervised = stardist_transform(sample[self.supervised].data[0]) - yield {'raw':sample[self.raw].data[0], - 'pseudo_stardist':transformed_pseudo, - 'supervised_stardist':transformed_supervised, - 'pseudo_labels':sample[self.pseudo].data[0], - 'supervised_labels':sample[self.supervised].data[0]} - + transformed_supervised = stardist_transform( + sample[self.supervised].data[0] + ) + yield { + "raw": sample[self.raw].data[0], + "pseudo_stardist": transformed_pseudo, + "supervised_stardist": transformed_supervised, + "pseudo_labels": sample[self.pseudo].data[0], + "supervised_labels": sample[self.supervised].data[0], + } + else: - # if we are training a self-supervised model (i.e. standard cellulus), - # we need to return just the raw image data. + # if we are training a self-supervised model (i.e. standard + # cellulus), we need to return just the raw image data. yield sample[self.raw].data[0] def __read_meta_data(self): @@ -206,4 +214,4 @@ def get_num_channels(self): return self.num_channels def get_num_spatial_dims(self): - return self.num_spatial_dims \ No newline at end of file + return self.num_spatial_dims diff --git a/cellulus/train_semi_supervised.py b/cellulus/train_semi_supervised.py index 5aadaab..3185486 100644 --- a/cellulus/train_semi_supervised.py +++ b/cellulus/train_semi_supervised.py @@ -1,15 +1,15 @@ import os +import numpy as np import torch import zarr from tqdm import tqdm +from cellulus.criterions.stardist_loss import StardistLoss from cellulus.datasets import get_dataset from cellulus.models import get_model -from cellulus.criterions.stardist_loss import StardistLoss from cellulus.utils import get_logger -import torch.nn as nn -import numpy as np + def semisupervised_train(semi_sup_exp_config): print(semi_sup_exp_config) @@ -22,13 +22,13 @@ def semisupervised_train(semi_sup_exp_config): semi_sup_train_config = semi_sup_exp_config.semi_sup_train_config raw_dataset = get_dataset( - dataset_config = semi_sup_train_config.raw_data_config, - crop_size = tuple(semi_sup_train_config.crop_size), - control_point_spacing = semi_sup_train_config.control_point_spacing, - control_point_jitter = semi_sup_train_config.control_point_jitter, - pseudo_dataset_config = semi_sup_train_config.pseudo_data_config, - supervised_dataset_config = semi_sup_train_config.supervised_data_config, - semi_supervised = True + dataset_config=semi_sup_train_config.raw_data_config, + crop_size=tuple(semi_sup_train_config.crop_size), + control_point_spacing=semi_sup_train_config.control_point_spacing, + control_point_jitter=semi_sup_train_config.control_point_jitter, + pseudo_dataset_config=semi_sup_train_config.pseudo_data_config, + supervised_dataset_config=semi_sup_train_config.supervised_data_config, + semi_supervised=True, ) raw_dataloader = torch.utils.data.DataLoader( @@ -54,9 +54,6 @@ def semisupervised_train(semi_sup_exp_config): if torch.cuda.is_available(): model = model.cuda() - - - criterion = StardistLoss() # set optimizer @@ -89,8 +86,7 @@ def lambda_(iteration): # call `train_iteration` for iteration, batch in tqdm( zip( - range(start_iteration, semi_sup_train_config.max_iterations), - raw_dataloader + range(start_iteration, semi_sup_train_config.max_iterations), raw_dataloader ) ): scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -127,37 +123,35 @@ def lambda_(iteration): iteration, ) -def train_iteration( - batch, - model, - criterion, - optimizer, - alpha = 0.5 -): + +def train_iteration(batch, model, criterion, optimizer, alpha=0.5): model.train() def unqiue_values_to_unique_ints(array): if np.any(array - array.astype(np.int16)): - mapping={v:k for k,v in enumerate(np.unique(array))} - u,inv = np.unique(array,return_inverse = True) + mapping = {v: k for k, v in enumerate(np.unique(array))} + u, inv = np.unique(array, return_inverse=True) Y1 = np.array([mapping[x] for x in u])[inv].reshape(array.shape) array = Y1.astype(np.int16) return array if torch.cuda.is_available(): - batch['raw'] = batch['raw'].to("cuda") - batch['pseudo_stardist'] = batch['pseudo_stardist'].to("cuda") - batch['supervised_stardist'] = batch['supervised_stardist'].to("cuda") - batch['pseudo_labels'] = batch['pseudo_labels'].to("cuda") - batch['supervised_labels'] = batch['supervised_labels'].to("cuda") - batch['supervised_labels'] = torch.tensor(unqiue_values_to_unique_ints(batch['supervised_labels'].cpu().detach().numpy())).to("cuda") + batch["raw"] = batch["raw"].to("cuda") + batch["pseudo_stardist"] = batch["pseudo_stardist"].to("cuda") + batch["supervised_stardist"] = batch["supervised_stardist"].to("cuda") + batch["pseudo_labels"] = batch["pseudo_labels"].to("cuda") + batch["supervised_labels"] = batch["supervised_labels"].to("cuda") + batch["supervised_labels"] = torch.tensor( + unqiue_values_to_unique_ints( + batch["supervised_labels"].cpu().detach().numpy() + ) + ).to("cuda") - prediction = model(batch['raw']) - + prediction = model(batch["raw"]) def combine_GT_and_pseudo_labels(gt_labels, pseudo_labels): if gt_labels.shape != pseudo_labels.shape: - print('labelled images are different sizes') + print("labelled images are different sizes") combined_labels = np.zeros(gt_labels.shape) combined_labels[:] = gt_labels.cpu().detach().numpy()[:] @@ -168,34 +162,58 @@ def combine_GT_and_pseudo_labels(gt_labels, pseudo_labels): # for each pseudo label, check it does not intersect with a GT label. # if it doesn't intersect, add it to the combined labels. # if it does intersect, leave it or add to ignore mask? - + # all of the indicies that have this given label value # this_pseudo_label = np.where(np.any(pseudo_labels==pseudo_label_value)) - if np.any(gt_labels.cpu().detach().numpy()[pseudo_labels.cpu().detach().numpy()==pseudo_label_value]): + if np.any( + gt_labels.cpu() + .detach() + .numpy()[pseudo_labels.cpu().detach().numpy() == pseudo_label_value] + ): pass else: - combined_labels = combined_labels + pseudo_labels.cpu().detach().numpy()*(pseudo_labels.cpu().detach().numpy()==pseudo_label_value) + combined_labels = ( + combined_labels + + pseudo_labels.cpu().detach().numpy() + * (pseudo_labels.cpu().detach().numpy() == pseudo_label_value) + ) return combined_labels - gt_labels = batch['supervised_labels'][:,:,:prediction.shape[2],:prediction.shape[3]] - pseudo_labels = batch['pseudo_labels'][:,:,:prediction.shape[2],:prediction.shape[3]] - combined_labels = combine_GT_and_pseudo_labels(gt_labels, pseudo_labels) - use_combined_loss = True if use_combined_loss: from cellulus.criterions import stardist_transform - combined_stardist = torch.tensor(stardist_transform(combined_labels)).cuda().unsqueeze(0) + gt_labels = batch["supervised_labels"][ + :, :, : prediction.shape[2], : prediction.shape[3] + ] + pseudo_labels = batch["pseudo_labels"][ + :, :, : prediction.shape[2], : prediction.shape[3] + ] + combined_labels = combine_GT_and_pseudo_labels(gt_labels, pseudo_labels) + + combined_stardist = ( + torch.tensor(stardist_transform(combined_labels)).cuda().unsqueeze(0) + ) loss = criterion(prediction, combined_stardist) + else: - supervised_loss = criterion(prediction, batch['supervised_stardist'][:,:,:prediction.shape[2],:prediction.shape[3]]) - semisupervised_loss = criterion(prediction, batch['pseudo_stardist'][:,:,:prediction.shape[2],:prediction.shape[3]]) - loss = (alpha*supervised_loss) + ((1-alpha)*semisupervised_loss) + supervised_loss = criterion( + prediction, + batch["supervised_stardist"][ + :, :, : prediction.shape[2], : prediction.shape[3] + ], + ) + semisupervised_loss = criterion( + prediction, + batch["pseudo_stardist"][ + :, :, : prediction.shape[2], : prediction.shape[3] + ], + ) + loss = (alpha * supervised_loss) + ((1 - alpha) * semisupervised_loss) - optimizer.zero_grad() loss.backward() optimizer.step() @@ -212,29 +230,33 @@ def save_model(state, iteration, is_lowest=False): def save_snapshot(batch, prediction, iteration): - num_spatial_dims = len(batch['raw'].shape) - 2 + num_spatial_dims = len(batch["raw"].shape) - 2 axis_names = ["s", "c"] + ["t", "z", "y", "x"][-num_spatial_dims:] prediction_offset = tuple( (a - b) / 2 for a, b in zip( - batch['raw'].shape[-num_spatial_dims:], prediction.shape[-num_spatial_dims:] + batch["raw"].shape[-num_spatial_dims:], prediction.shape[-num_spatial_dims:] ) ) f = zarr.open("semi_supervised_snapshots.zarr", "a") - f[f"{iteration}/raw"] = batch['raw'].detach().cpu().numpy() + f[f"{iteration}/raw"] = batch["raw"].detach().cpu().numpy() f[f"{iteration}/raw"].attrs["axis_names"] = axis_names - f[f"{iteration}/pseudo_stardist"] = batch['pseudo_stardist'].detach().cpu().numpy() + f[f"{iteration}/pseudo_stardist"] = batch["pseudo_stardist"].detach().cpu().numpy() f[f"{iteration}/pseudo_stardist"].attrs["axis_names"] = axis_names - f[f"{iteration}/pseudo_labels"] = batch['pseudo_labels'].detach().cpu().numpy() + f[f"{iteration}/pseudo_labels"] = batch["pseudo_labels"].detach().cpu().numpy() f[f"{iteration}/pseudo_labels"].attrs["axis_names"] = axis_names - f[f"{iteration}/supervised_stardist"] = batch['supervised_stardist'].detach().cpu().numpy() + f[f"{iteration}/supervised_stardist"] = ( + batch["supervised_stardist"].detach().cpu().numpy() + ) f[f"{iteration}/supervised_stardist"].attrs["supervised"] = axis_names - f[f"{iteration}/supervised_labels"] = batch['supervised_labels'].detach().cpu().numpy() + f[f"{iteration}/supervised_labels"] = ( + batch["supervised_labels"].detach().cpu().numpy() + ) f[f"{iteration}/supervised_labels"].attrs["supervised"] = axis_names f[f"{iteration}/prediction"] = prediction.detach().cpu().numpy() @@ -242,6 +264,3 @@ def save_snapshot(batch, prediction, iteration): f[f"{iteration}/prediction"].attrs["offset"] = prediction_offset print(f"Snapshot saved at iteration {iteration}") - - -