-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_lhs_scatterplot_matrix.py
More file actions
191 lines (155 loc) · 6.54 KB
/
generate_lhs_scatterplot_matrix.py
File metadata and controls
191 lines (155 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/env python3
"""
Generate scatterplot matrix showing pairwise ingredient interactions for v6 LHS design.
Creates a figure with all pairwise combinations of the 6 varied components.
Author: Claude Code
Date: 2026-01-21
"""
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
# Paths
BASE_DIR = Path("/Users/marcin/Documents/VIMSS/ontology/KG-Hub/KG-Microbe/MicroGrowAgents/MicroGrowAgents")
INPUT_FILE = BASE_DIR / "data/designs/MP_latinhypercube/plate_designs_v6/complete_formulations_v6.json"
OUTPUT_DIR = BASE_DIR / "data/designs/MP_latinhypercube/plate_designs_v6"
def main():
print("=" * 80)
print("Generating LHS Scatterplot Matrix for v6")
print("=" * 80)
# Load formulations
print("\n[1/3] Loading v6 formulations...")
with open(INPUT_FILE, 'r') as f:
formulations = json.load(f)
print(f"✓ Loaded {len(formulations)} formulations")
# Extract unique conditions (remove replicates)
unique_conditions = {}
for f in formulations:
cond_num = f['v6_condition_number']
if cond_num not in unique_conditions:
unique_conditions[cond_num] = f
unique_formulations = list(unique_conditions.values())
print(f"✓ Extracted {len(unique_formulations)} unique conditions (excluding replicates)")
# Extract concentrations
print("\n[2/3] Extracting component concentrations...")
data = []
for f in unique_formulations:
varied = f['varied_components']
row = {
'Total_Phosphate_mM': varied['Total_Phosphate']['concentration'],
'NH4SO4_mM': varied['(NH₄)₂SO₄']['concentration'],
'CoCl2_uM': varied['CoCl₂·6H₂O']['concentration'],
'Succinate_mM': varied['Succinate']['concentration'],
'Methanol_mM': varied['Methanol']['concentration'],
'PQQ_uM': varied['PQQ']['concentration'] / 1000.0 # Convert nM to µM
}
data.append(row)
df = pd.DataFrame(data)
print(f"✓ Extracted 6 components × {len(df)} conditions")
# Create scatterplot matrix
print("\n[3/3] Creating scatterplot matrix...")
# Set style
sns.set_style("white")
sns.set_context("paper", font_scale=0.9)
# Create figure
fig = plt.figure(figsize=(16, 16))
# Component labels with units
labels = {
'Total_Phosphate_mM': 'Total Phosphate\n(mM)',
'NH4SO4_mM': '(NH₄)₂SO₄\n(mM)',
'CoCl2_uM': 'CoCl₂\n(µM)',
'Succinate_mM': 'Succinate\n(mM)',
'Methanol_mM': 'Methanol\n(mM)',
'PQQ_uM': 'PQQ\n(µM)'
}
# Get columns
columns = list(df.columns)
n_components = len(columns)
# Create grid of subplots
for i, col_y in enumerate(columns):
for j, col_x in enumerate(columns):
ax = fig.add_subplot(n_components, n_components, i * n_components + j + 1)
if i == j:
# Diagonal: histogram
ax.hist(df[col_x], bins=15, color='steelblue', alpha=0.7, edgecolor='black')
ax.set_ylabel('Count', fontsize=8)
# Add mean and range
mean_val = df[col_x].mean()
min_val = df[col_x].min()
max_val = df[col_x].max()
ax.axvline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7)
ax.text(0.98, 0.95, f'Mean: {mean_val:.2f}\nRange: {min_val:.2f}-{max_val:.2f}',
transform=ax.transAxes, ha='right', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
else:
# Off-diagonal: scatter plot
ax.scatter(df[col_x], df[col_y], alpha=0.6, s=20, color='steelblue', edgecolor='black', linewidth=0.5)
# Add correlation coefficient
corr = df[col_x].corr(df[col_y])
color = 'red' if abs(corr) > 0.3 else 'gray'
ax.text(0.05, 0.95, f'r = {corr:.3f}', transform=ax.transAxes,
fontsize=8, va='top', color=color, weight='bold' if abs(corr) > 0.3 else 'normal')
# Labels
if i == n_components - 1: # Bottom row
ax.set_xlabel(labels[col_x], fontsize=9)
else:
ax.set_xticklabels([])
if j == 0: # Left column
if i != j: # Not diagonal
ax.set_ylabel(labels[col_y], fontsize=9)
else:
ax.set_yticklabels([])
# Tick parameters
ax.tick_params(labelsize=7)
# Grid
ax.grid(True, alpha=0.3, linestyle=':', linewidth=0.5)
# Super title
fig.suptitle('Latin Hypercube Sampling: Pairwise Component Interactions (v6)\n' +
f'{len(unique_formulations)} Unique Conditions, 6 Varied Components',
fontsize=14, fontweight='bold', y=0.995)
# Adjust layout
plt.tight_layout(rect=[0, 0, 1, 0.99])
# Save figure
output_file = OUTPUT_DIR / "lhs_scatterplot_matrix_v6.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"✓ Saved scatterplot matrix to {output_file.name}")
# Also save as PDF for high quality
output_pdf = OUTPUT_DIR / "lhs_scatterplot_matrix_v6.pdf"
plt.savefig(output_pdf, bbox_inches='tight')
print(f"✓ Saved PDF version to {output_pdf.name}")
plt.close()
# Print summary statistics
print("\n" + "=" * 80)
print("SUMMARY STATISTICS")
print("=" * 80)
print("\nComponent Ranges:")
for col in columns:
print(f"\n{labels[col].replace(chr(10), ' ')}:")
print(f" Min: {df[col].min():.3f}")
print(f" Max: {df[col].max():.3f}")
print(f" Mean: {df[col].mean():.3f}")
print(f" Std: {df[col].std():.3f}")
print("\n" + "=" * 80)
print("CORRELATION MATRIX")
print("=" * 80)
corr_matrix = df.corr()
print("\n", corr_matrix.round(3))
# Check for high correlations
print("\n" + "=" * 80)
print("HIGH CORRELATIONS (|r| > 0.3):")
print("=" * 80)
high_corr = False
for i, col1 in enumerate(columns):
for j, col2 in enumerate(columns):
if i < j: # Upper triangle only
r = corr_matrix.loc[col1, col2]
if abs(r) > 0.3:
high_corr = True
print(f" {col1} vs {col2}: r = {r:.3f}")
if not high_corr:
print(" ✓ No high correlations found - good space-filling design!")
print("\n✓ Scatterplot matrix generation complete")
if __name__ == "__main__":
main()