Skip to content

Commit

Permalink
improved handleing
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolaitennant committed Oct 26, 2024
1 parent 2375c03 commit 295e505
Show file tree
Hide file tree
Showing 28 changed files with 2,003 additions and 184 deletions.
Binary file modified .DS_Store
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
{
"Test": {
"Accuracy": "94.35%",
"Precision": "93.94%",
"Recall": "93.73%",
"F1": "93.82%",
"AUC": "99.48%"
"Accuracy": "48.62%",
"Precision": "45.35%",
"Recall": "42.41%",
"F1": "39.16%",
"AUC": "75.90%"
},
"Baseline": {
"Accuracy": "33.78%",
"Precision": "8.44%",
"Accuracy": "34.83%",
"Precision": "8.71%",
"Recall": "25.00%",
"F1": "12.62%"
"F1": "12.92%"
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
{
"Test": {
"Accuracy": "94.35%",
"Precision": "93.94%",
"Recall": "93.73%",
"F1": "93.82%",
"AUC": "99.48%"
"Accuracy": "48.62%",
"Precision": "45.35%",
"Recall": "42.41%",
"F1": "39.16%",
"AUC": "75.90%"
},
"Baseline": {
"Accuracy": "33.78%",
"Precision": "8.44%",
"Accuracy": "34.83%",
"Precision": "8.71%",
"Recall": "25.00%",
"F1": "12.62%"
"F1": "12.92%"
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified Code/__pycache__/Config.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/Model.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/Utilities.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/eda.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/interpreter.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/pipeline_manager.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/preprocess.cpython-39.pyc
Binary file not shown.
Binary file modified Code/__pycache__/visuals.cpython-39.pyc
Binary file not shown.
12 changes: 6 additions & 6 deletions Code/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# config.py
config_dict = {
"Device": {
"processor": "Other", # Options: 'Other' or 'M' (M is for Mac M1/M2/M3 processors)
"processor": "M", # Options: 'Other' or 'M' (M is for Mac M1/M2/M3 processors)
},
"FileLocations": {
"training_file": "fly_train.h5ad", # Name of the training data file
Expand All @@ -22,8 +22,8 @@
"sex_type": "all", # Options: 'all', 'male', 'female'
},
"Sampling": {
"num_samples":289981, # Number of samples (cells) for training (total = 289981)
"num_variables": 15992, # Number of variables (genes) for training (total = 15992)
"num_samples":2899, # Number of samples (cells) for training (total = 289981)
"num_variables":1590, # Number of variables (genes) for training (total = 15992)
},
"Filtering": {
"include_mixed_sex": False, # Options: True, False
Expand All @@ -50,7 +50,7 @@
"enabled": False, # Options: True, False
},
"ModelManagement": {
"load_model": True, # Options: True, False
"load_model": False, # Options: True, False
},
"Preprocessing": {
"required": True, # Options: True, False
Expand Down Expand Up @@ -82,8 +82,8 @@
"FeatureImportanceAndVisualizations": {
"run_visualization": True, # Options: True, False
"run_interpreter": True, # Options: True, False (SHAP)
"load_SHAP": True, # Options: True to load SHAP values, False to compute them, only works if run_interpreter is True
"reference_size": 5000, # Reference data size for SHAP
"load_SHAP": False, # Options: True to load SHAP values, False to compute them, only works if run_interpreter is True
"reference_size": 100, # Reference data size for SHAP
"save_predictions": False, # Options: True, False; (Model predictions csv file)
},
"DataSplit": {
Expand Down
115 changes: 100 additions & 15 deletions Code/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,54 @@ def compute_or_load_shap_values(self):

return shap_values, squeezed_test_data

# def compute_shap_values(self):
# """
# Compute SHAP values for model interpretation.

# Returns:
# - tuple: A tuple containing the SHAP values and the corresponding SHAP test data.
# """

# # Access the model type from the configuration
# model_type = self.config.DataParameters.GeneralSettings.model_type.lower()

# # Determine the explainer to use based on the model type
# if model_type in ["mlp", "cnn"]:
# # For neural network models, use GradientExplainer
# explainer = shap.GradientExplainer(self.model, self.reference_data)
# elif model_type in ["xgboost", "randomforest"]:
# # For tree-based models, use TreeExplainer
# explainer = shap.TreeExplainer(self.model)
# else:
# # For linear models, use LinearExplainer
# explainer = shap.LinearExplainer(self.model, self.reference_data)

# # Compute SHAP values
# shap_values = explainer.shap_values(self.test_data)

# # Adjust SHAP values and test data shapes if necessary
# if isinstance(shap_values, list):
# squeezed_shap_values = [
# np.squeeze(val, axis=1) if val.ndim > 3 else val for val in shap_values
# ]
# else:
# squeezed_shap_values = (
# np.squeeze(shap_values, axis=1) if shap_values.ndim > 3 else shap_values
# )

# # Convert the SHAP values to a list of arrays for compatibility with the rest of the code
# squeezed_shap_values = [
# squeezed_shap_values[:, :, i] for i in range(squeezed_shap_values.shape[2])
# ]

# squeezed_test_data = (
# self.test_data
# if self.test_data.ndim <= 2
# else np.squeeze(self.test_data, axis=1)
# )

# return squeezed_shap_values, squeezed_test_data

def compute_shap_values(self):
"""
Compute SHAP values for model interpretation.
Expand All @@ -177,43 +225,80 @@ def compute_shap_values(self):
- tuple: A tuple containing the SHAP values and the corresponding SHAP test data.
"""

print("=== Starting compute_shap_values ===")

# Access the model type from the configuration
model_type = self.config.DataParameters.GeneralSettings.model_type.lower()
print(f"[DEBUG] Model type: {model_type} (type: {type(model_type)})")

# Determine the explainer to use based on the model type
if model_type in ["mlp", "cnn"]:
# For neural network models, use GradientExplainer
print("[DEBUG] Using GradientExplainer for neural network models (MLP/CNN).")
explainer = shap.GradientExplainer(self.model, self.reference_data)
print(f"[DEBUG] GradientExplainer created: {explainer} (type: {type(explainer)})")
elif model_type in ["xgboost", "randomforest"]:
# For tree-based models, use TreeExplainer
print("[DEBUG] Using TreeExplainer for tree-based models (XGBoost/RandomForest).")
explainer = shap.TreeExplainer(self.model)
print(f"[DEBUG] TreeExplainer created: {explainer} (type: {type(explainer)})")
else:
# For linear models, use LinearExplainer
print("[DEBUG] Using LinearExplainer for linear models.")
explainer = shap.LinearExplainer(self.model, self.reference_data)
print(f"[DEBUG] LinearExplainer created: {explainer} (type: {type(explainer)})")

# Compute SHAP values
print("[DEBUG] Computing SHAP values...")
shap_values = explainer.shap_values(self.test_data)
print(f"[DEBUG] SHAP values computed: {type(shap_values)}")

# Inspect SHAP values
if isinstance(shap_values, list):
print(f"[DEBUG] SHAP values is a list with {len(shap_values)} elements.")
for idx, val in enumerate(shap_values):
print(f" [DEBUG] shap_values[{idx}] shape: {val.shape}, type: {type(val)}")
else:
print(f"[DEBUG] SHAP values shape: {shap_values.shape}, type: {type(shap_values)}")

# Adjust SHAP values and test data shapes if necessary
print("[DEBUG] Adjusting SHAP values shapes if necessary...")
if isinstance(shap_values, list):
squeezed_shap_values = [
np.squeeze(val, axis=1) if val.ndim > 3 else val for val in shap_values
]
for idx, val in enumerate(squeezed_shap_values):
print(f" [DEBUG] Squeezed shap_values[{idx}] shape: {val.shape}")
else:
squeezed_shap_values = (
np.squeeze(shap_values, axis=1) if shap_values.ndim > 3 else shap_values
)
if shap_values.ndim > 3:
squeezed_shap_values = np.squeeze(shap_values, axis=1)
print(f"[DEBUG] Squeezed shap_values shape: {squeezed_shap_values.shape}")
else:
squeezed_shap_values = shap_values
print("[DEBUG] No squeezing applied to shap_values.")

# Convert the SHAP values to a list of arrays for compatibility with the rest of the code
squeezed_shap_values = [
squeezed_shap_values[:, :, i] for i in range(squeezed_shap_values.shape[2])
]

squeezed_test_data = (
self.test_data
if self.test_data.ndim <= 2
else np.squeeze(self.test_data, axis=1)
)
print("[DEBUG] Converting squeezed SHAP values to a list of arrays...")
if isinstance(squeezed_shap_values, list):
squeezed_shap_values = [
squeezed_shap_values[:, :, i] for i in range(squeezed_shap_values.shape[2])
]
print(f"[DEBUG] Converted SHAP values to list with {len(squeezed_shap_values)} elements.")
else:
squeezed_shap_values = [
squeezed_shap_values[:, :, i] for i in range(squeezed_shap_values.shape[2])
]
print(f"[DEBUG] Converted SHAP values to list with {len(squeezed_shap_values)} elements.")

# Adjust test data shape if necessary
print("[DEBUG] Adjusting test data shape if necessary...")
if self.test_data.ndim <= 2:
squeezed_test_data = self.test_data
print(f"[DEBUG] Test data has {self.test_data.ndim} dimensions. No squeezing applied.")
else:
squeezed_test_data = np.squeeze(self.test_data, axis=1)
print(f"[DEBUG] Squeezed test data shape: {squeezed_test_data.shape}")

print("=== Finished compute_shap_values ===")
print(f"[DEBUG] Returning SHAP values: List of {len(squeezed_shap_values)} arrays.")
print(f"[DEBUG] Returning test data: shape {squeezed_test_data.shape}, type {type(squeezed_test_data)}")

return squeezed_shap_values, squeezed_test_data

Expand Down
Loading

0 comments on commit 295e505

Please sign in to comment.