Skip to content

Commit

Permalink
source file
Browse files Browse the repository at this point in the history
  • Loading branch information
KibromBerihu committed May 26, 2022
1 parent 34559c3 commit b122808
Show file tree
Hide file tree
Showing 19 changed files with 3,485 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/LFBNet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import absolute_import
from src import *
from . import *

230 changes: 230 additions & 0 deletions src/LFBNet/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""
"""
import os
import glob
import numpy as np
from numpy import ndarray
import matplotlib.pyplot as plt
import nibabel as nib
from typing import List, Tuple
from numpy.random import seed

# seed random number generator
seed(1)


class DataLoader:
"""
read preprocessed pet and gt MIP data for training
"""

def __init__(self, data_dir: str, ids_to_read: ndarray = None, shuffle=True, training: bool = True):
self.data_dir = data_dir
self.ids_to_read = ids_to_read
self.shuffle = shuffle
self.training = training

def get_batch_of_data(self):
"""
data structure:
-- main directory
------case Name:
-- pet.nii.gz
-- gt.nii.gz
--Given list of training and testing on .text files
-- train.text
-- valid.text
"""

# check directory
self.directory_exist(self.data_dir)

# get all names of the directories under data_dir
case_ids = os.listdir(self.data_dir)

# store batch data
image_batch, ground_truth_batch = [], []

# if there are file in data dir
if not len(case_ids):
raise Exception("No files found in %s" % self.data_dir)

# else continue getting.reading the files
for get_id in list(case_ids):
if str(get_id) in list(self.ids_to_read):
try:
# consider there four images in each folder name get_id:
# e.g. : coronal (gt_1, pet_1) and sagittal (gt_0, pet_0)
current_dir = os.path.join(self.data_dir, str(get_id))
# read sagittal and coronal as independent images
pet_sagittla_coronal, gt_sagittal_coronal = self.get_nii_files_path(current_dir)

# pet, normalization, standardization
if len(pet_sagittla_coronal): # if image is read
pet_sagittla_coronal = self.data_normalization_standardization(pet_sagittla_coronal,
z_score=True,
z_score_include_zeros=False)

gt_sagittal_coronal = self.data_normalization_standardization(gt_sagittal_coronal, threshold=True)

# display or save samples
# self.mip_show(pet=pet_sagittla_coronal, gt=gt_sagittal_coronal, identifier=str(get_id))

# collect all images with case_id
if not bool(len(image_batch)): # if it is empty; first time
image_batch = pet_sagittla_coronal
ground_truth_batch = gt_sagittal_coronal
else:
image_batch = np.concatenate((image_batch, pet_sagittla_coronal), axis=0)
ground_truth_batch = np.concatenate((ground_truth_batch, gt_sagittal_coronal), axis=0)
except:
print('Not read %s' %(str(get_id)))

return [image_batch, ground_truth_batch]

@staticmethod
def directory_exist(dir_check: str = None) -> None:
"""
:param dir_check:
"""
if os.path.exists(dir_check):
# print("The directory %s does exist \n" % dir_check)
pass
else:
raise Exception(
"Please provide the correct path to the processed data ! \n %s not found \n" % (dir_check))

@staticmethod
def mip_show(pet: ndarray = None, gt: ndarray = None, identifier: str = None) -> None:
"""
:param pet:
:param gt:
:param identifier:
:return:
"""
# consider axis 0 for sagittal and axis 1 for coronal views
fig, axs = plt.subplots(1, 4, figsize=(15, 15))
plt.title(str(identifier))
try:
pet = np.squeeze(pet)
gt = np.squeeze(gt)
except:
pass

axs[0].imshow(np.rot90(np.log(pet[0] + 1)))
axs[0].set_title('pet_project_on_axis_0')
axs[1].imshow(np.rot90(np.log(gt[0] + 1)))
axs[1].set_title('gt_project_on_axis_0')
axs[2].imshow(np.rot90(np.log(pet[1] + 1)))
axs[2].set_title('project_on_axis_1')
axs[3].imshow(np.rot90(np.log(gt[1] + 1)))
axs[3].set_title('gt_project_on_axis_1')
plt.show()

@staticmethod
def get_nii_files_path(data_directory: str) -> List[ndarray]:
"""
read .nii or .nii.gz files from a given folder of path data_directory
:param data_directory:
:return:
"""
# more than one .nii or .nii.gz is found in the folder the first will be returned
types = ('/*.nii', '/*.nii.gz') # the tuple of file types
nii_paths = []
for files in types:
nii_paths.extend([i for i in glob.glob(str(data_directory) + files)])

pet, gt = [], []
if not len(nii_paths): # if no file exists that ends wtih .nii.gz or .nii
# raise Exception("No .nii or .nii.gz found in %s dirctory" % data_directory)
pass
else:
# assuming the folder contains coronal mips: pet_1, gt_1, and sagittal mips: pet_0, gt_0,
pet_saggital, pet_coronal, gt_saggital, gt_coronal = [], [], [], []
for path in list(nii_paths):
# get the base name: means the file name
identifier_base_name = str(os.path.basename(path)).split('.')[0]
if "pet_0" == str(identifier_base_name):
pet_saggital = np.asanyarray(nib.load(path).dataobj)
pet_saggital = np.expand_dims(pet_saggital, axis=0)

elif "pet_1" == str(identifier_base_name):
pet_coronal = np.asanyarray(nib.load(path).dataobj)
pet_coronal = np.expand_dims(pet_coronal, axis=0)

if "gt_0" == str(identifier_base_name):
gt_saggital = np.asanyarray(nib.load(path).dataobj)
gt_saggital = np.expand_dims(gt_saggital, axis=0)

elif "gt_1" == str(identifier_base_name):
gt_coronal = np.asanyarray(nib.load(path).dataobj)
gt_coronal = np.expand_dims(gt_coronal, axis=0)

# concatenate coronal and sagita images
# show
pet = np.concatenate((pet_saggital, pet_coronal), axis=0)
gt = np.concatenate((gt_saggital, gt_coronal), axis=0)
return [pet, gt]

@staticmethod
def z_score(image: ndarray, include_zeros: bool = False):
"""
:param image:
:param include_zeros:
:return:
"""
# include zeros
if include_zeros:
image = (image - np.mean(image)) / (np.std(image) + 1e-8)
else:
# Don't include zeros
means = np.true_divide(image.sum(), (image != 0).sum())
stds = np.nanstd(np.where(np.isclose(image, 0), np.nan, image))
image = (image - means) / (stds + 1e-8)
return image

def data_normalization_standardization(self, data: ndarray, threshold: bool = False, z_score: bool = False,
z_score_include_zeros: bool = False,
min_max_scale: bool = False, log_transform: bool = False) -> List[ndarray]:
"""
Data normalization and standardization function
:param data:
:param threshold:
:param z_score:
:param z_score_include_zeros:
:param min_max_scale:
:param log_transform:
:return:
"""

if not isinstance(data, List):
data = np.array(data)

# groundtruh > 0 is 1 and <=0 is 0
if threshold:
data[data > 0] = 1

if z_score:
data = self.z_score(data, include_zeros=z_score_include_zeros)

if min_max_scale:
data = (data - min(data)) / (max(data) - min(data))

if log_transform:
data = np.log(data + 1)

return data


if __name__ == '__main__':
# for Example
print("data_loader for preprocessed coronal and sagittal MIPs, pet, and gt")
data_dir = "../data/vienna_default_MIP_dir/"
ids_to_read = os.listdir(data_dir)

data_loader = DataLoader(data_dir=data_dir, ids_to_read=ids_to_read)
loaded_data = data_loader.get_batch_of_data()
print(np.array(loaded_data).shape)
3 changes: 3 additions & 0 deletions src/LFBNet/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import
from . import *
from .losses import *
Loading

0 comments on commit b122808

Please sign in to comment.