-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
34559c3
commit b122808
Showing
19 changed files
with
3,485 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from __future__ import absolute_import | ||
from src import * | ||
from . import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
""" | ||
""" | ||
import os | ||
import glob | ||
import numpy as np | ||
from numpy import ndarray | ||
import matplotlib.pyplot as plt | ||
import nibabel as nib | ||
from typing import List, Tuple | ||
from numpy.random import seed | ||
|
||
# seed random number generator | ||
seed(1) | ||
|
||
|
||
class DataLoader: | ||
""" | ||
read preprocessed pet and gt MIP data for training | ||
""" | ||
|
||
def __init__(self, data_dir: str, ids_to_read: ndarray = None, shuffle=True, training: bool = True): | ||
self.data_dir = data_dir | ||
self.ids_to_read = ids_to_read | ||
self.shuffle = shuffle | ||
self.training = training | ||
|
||
def get_batch_of_data(self): | ||
""" | ||
data structure: | ||
-- main directory | ||
------case Name: | ||
-- pet.nii.gz | ||
-- gt.nii.gz | ||
--Given list of training and testing on .text files | ||
-- train.text | ||
-- valid.text | ||
""" | ||
|
||
# check directory | ||
self.directory_exist(self.data_dir) | ||
|
||
# get all names of the directories under data_dir | ||
case_ids = os.listdir(self.data_dir) | ||
|
||
# store batch data | ||
image_batch, ground_truth_batch = [], [] | ||
|
||
# if there are file in data dir | ||
if not len(case_ids): | ||
raise Exception("No files found in %s" % self.data_dir) | ||
|
||
# else continue getting.reading the files | ||
for get_id in list(case_ids): | ||
if str(get_id) in list(self.ids_to_read): | ||
try: | ||
# consider there four images in each folder name get_id: | ||
# e.g. : coronal (gt_1, pet_1) and sagittal (gt_0, pet_0) | ||
current_dir = os.path.join(self.data_dir, str(get_id)) | ||
# read sagittal and coronal as independent images | ||
pet_sagittla_coronal, gt_sagittal_coronal = self.get_nii_files_path(current_dir) | ||
|
||
# pet, normalization, standardization | ||
if len(pet_sagittla_coronal): # if image is read | ||
pet_sagittla_coronal = self.data_normalization_standardization(pet_sagittla_coronal, | ||
z_score=True, | ||
z_score_include_zeros=False) | ||
|
||
gt_sagittal_coronal = self.data_normalization_standardization(gt_sagittal_coronal, threshold=True) | ||
|
||
# display or save samples | ||
# self.mip_show(pet=pet_sagittla_coronal, gt=gt_sagittal_coronal, identifier=str(get_id)) | ||
|
||
# collect all images with case_id | ||
if not bool(len(image_batch)): # if it is empty; first time | ||
image_batch = pet_sagittla_coronal | ||
ground_truth_batch = gt_sagittal_coronal | ||
else: | ||
image_batch = np.concatenate((image_batch, pet_sagittla_coronal), axis=0) | ||
ground_truth_batch = np.concatenate((ground_truth_batch, gt_sagittal_coronal), axis=0) | ||
except: | ||
print('Not read %s' %(str(get_id))) | ||
|
||
return [image_batch, ground_truth_batch] | ||
|
||
@staticmethod | ||
def directory_exist(dir_check: str = None) -> None: | ||
""" | ||
:param dir_check: | ||
""" | ||
if os.path.exists(dir_check): | ||
# print("The directory %s does exist \n" % dir_check) | ||
pass | ||
else: | ||
raise Exception( | ||
"Please provide the correct path to the processed data ! \n %s not found \n" % (dir_check)) | ||
|
||
@staticmethod | ||
def mip_show(pet: ndarray = None, gt: ndarray = None, identifier: str = None) -> None: | ||
""" | ||
:param pet: | ||
:param gt: | ||
:param identifier: | ||
:return: | ||
""" | ||
# consider axis 0 for sagittal and axis 1 for coronal views | ||
fig, axs = plt.subplots(1, 4, figsize=(15, 15)) | ||
plt.title(str(identifier)) | ||
try: | ||
pet = np.squeeze(pet) | ||
gt = np.squeeze(gt) | ||
except: | ||
pass | ||
|
||
axs[0].imshow(np.rot90(np.log(pet[0] + 1))) | ||
axs[0].set_title('pet_project_on_axis_0') | ||
axs[1].imshow(np.rot90(np.log(gt[0] + 1))) | ||
axs[1].set_title('gt_project_on_axis_0') | ||
axs[2].imshow(np.rot90(np.log(pet[1] + 1))) | ||
axs[2].set_title('project_on_axis_1') | ||
axs[3].imshow(np.rot90(np.log(gt[1] + 1))) | ||
axs[3].set_title('gt_project_on_axis_1') | ||
plt.show() | ||
|
||
@staticmethod | ||
def get_nii_files_path(data_directory: str) -> List[ndarray]: | ||
""" | ||
read .nii or .nii.gz files from a given folder of path data_directory | ||
:param data_directory: | ||
:return: | ||
""" | ||
# more than one .nii or .nii.gz is found in the folder the first will be returned | ||
types = ('/*.nii', '/*.nii.gz') # the tuple of file types | ||
nii_paths = [] | ||
for files in types: | ||
nii_paths.extend([i for i in glob.glob(str(data_directory) + files)]) | ||
|
||
pet, gt = [], [] | ||
if not len(nii_paths): # if no file exists that ends wtih .nii.gz or .nii | ||
# raise Exception("No .nii or .nii.gz found in %s dirctory" % data_directory) | ||
pass | ||
else: | ||
# assuming the folder contains coronal mips: pet_1, gt_1, and sagittal mips: pet_0, gt_0, | ||
pet_saggital, pet_coronal, gt_saggital, gt_coronal = [], [], [], [] | ||
for path in list(nii_paths): | ||
# get the base name: means the file name | ||
identifier_base_name = str(os.path.basename(path)).split('.')[0] | ||
if "pet_0" == str(identifier_base_name): | ||
pet_saggital = np.asanyarray(nib.load(path).dataobj) | ||
pet_saggital = np.expand_dims(pet_saggital, axis=0) | ||
|
||
elif "pet_1" == str(identifier_base_name): | ||
pet_coronal = np.asanyarray(nib.load(path).dataobj) | ||
pet_coronal = np.expand_dims(pet_coronal, axis=0) | ||
|
||
if "gt_0" == str(identifier_base_name): | ||
gt_saggital = np.asanyarray(nib.load(path).dataobj) | ||
gt_saggital = np.expand_dims(gt_saggital, axis=0) | ||
|
||
elif "gt_1" == str(identifier_base_name): | ||
gt_coronal = np.asanyarray(nib.load(path).dataobj) | ||
gt_coronal = np.expand_dims(gt_coronal, axis=0) | ||
|
||
# concatenate coronal and sagita images | ||
# show | ||
pet = np.concatenate((pet_saggital, pet_coronal), axis=0) | ||
gt = np.concatenate((gt_saggital, gt_coronal), axis=0) | ||
return [pet, gt] | ||
|
||
@staticmethod | ||
def z_score(image: ndarray, include_zeros: bool = False): | ||
""" | ||
:param image: | ||
:param include_zeros: | ||
:return: | ||
""" | ||
# include zeros | ||
if include_zeros: | ||
image = (image - np.mean(image)) / (np.std(image) + 1e-8) | ||
else: | ||
# Don't include zeros | ||
means = np.true_divide(image.sum(), (image != 0).sum()) | ||
stds = np.nanstd(np.where(np.isclose(image, 0), np.nan, image)) | ||
image = (image - means) / (stds + 1e-8) | ||
return image | ||
|
||
def data_normalization_standardization(self, data: ndarray, threshold: bool = False, z_score: bool = False, | ||
z_score_include_zeros: bool = False, | ||
min_max_scale: bool = False, log_transform: bool = False) -> List[ndarray]: | ||
""" | ||
Data normalization and standardization function | ||
:param data: | ||
:param threshold: | ||
:param z_score: | ||
:param z_score_include_zeros: | ||
:param min_max_scale: | ||
:param log_transform: | ||
:return: | ||
""" | ||
|
||
if not isinstance(data, List): | ||
data = np.array(data) | ||
|
||
# groundtruh > 0 is 1 and <=0 is 0 | ||
if threshold: | ||
data[data > 0] = 1 | ||
|
||
if z_score: | ||
data = self.z_score(data, include_zeros=z_score_include_zeros) | ||
|
||
if min_max_scale: | ||
data = (data - min(data)) / (max(data) - min(data)) | ||
|
||
if log_transform: | ||
data = np.log(data + 1) | ||
|
||
return data | ||
|
||
|
||
if __name__ == '__main__': | ||
# for Example | ||
print("data_loader for preprocessed coronal and sagittal MIPs, pet, and gt") | ||
data_dir = "../data/vienna_default_MIP_dir/" | ||
ids_to_read = os.listdir(data_dir) | ||
|
||
data_loader = DataLoader(data_dir=data_dir, ids_to_read=ids_to_read) | ||
loaded_data = data_loader.get_batch_of_data() | ||
print(np.array(loaded_data).shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
from . import * | ||
from .losses import * |
Oops, something went wrong.