Skip to content

Commit 90deba8

Browse files
authored
Merge pull request #55 from Axiomatic-AI/better-plots
Better plots
2 parents 57f2e4d + 5396edd commit 90deba8

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

src/axiomatic/pic_helpers.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,19 @@ def plot_interactive_spectra(
112112
A list of spectra, where each spectrum is a list of lists of float values, each
113113
corresponding to the transmission of a single wavelength.
114114
wavelengths : list of float
115-
A list of wavelength values corresponding to the x-axis of the plot.
115+
A list of wavelength values corresponding to the x-axis of the plot, in um.
116116
vlines : list of float, optional
117-
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
117+
A list of x-values where vertical lines should be drawn, in um. Defaults to an empty list.
118118
hlines : list of float, optional
119119
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
120120
"""
121+
122+
hlines = hlines or []
123+
124+
# Convert wavelengths to nm
125+
wavelengths = [wl*1e3 for wl in wavelengths]
126+
vlines = [wl*1e3 for wl in vlines] if vlines else []
127+
121128
if isinstance(spectra, dict):
122129
port_keys = []
123130
for key in spectra:
@@ -137,10 +144,6 @@ def plot_interactive_spectra(
137144

138145
elif spectrum_labels is None:
139146
spectrum_labels = [f"Spectrum {i}" for i in range(len(spectra))]
140-
if vlines is None:
141-
vlines = []
142-
if hlines is None:
143-
hlines = []
144147

145148
if isinstance(spectra, dict):
146149
spectra = list(spectra.values())
@@ -149,9 +152,16 @@ def plot_interactive_spectra(
149152
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
150153
y_min = min(all_vals)
151154
y_max = max(all_vals)
152-
if hlines:
153-
y_min = min(hlines + [y_min]) * 0.95
154-
y_max = max(hlines + [y_max]) * 1.05
155+
156+
# dB scale
157+
if y_max <= 0:
158+
y_max = 0
159+
db = True
160+
else:
161+
db = False
162+
if hlines:
163+
y_min = min(hlines + [y_min]) * 0.95
164+
y_max = max(hlines + [y_max]) * 1.05
155165

156166
# Create hlines and vlines
157167
shapes = []
@@ -187,8 +197,8 @@ def plot_interactive_spectra(
187197

188198
# Create the layout
189199
fig.update_layout(
190-
xaxis_title="Wavelength",
191-
yaxis_title="Transmission",
200+
xaxis_title="Wavelength (nm)",
201+
yaxis_title="Transmission " + "(dB)" if db else "(linear)",
192202
shapes=shapes,
193203
sliders=sliders,
194204
yaxis=dict(range=[y_min, y_max]),
@@ -453,6 +463,8 @@ def print_statements(
453463

454464

455465
def _str_units_to_float(str_units: str) -> Optional[float]:
466+
"""Returns the numeric value of a string with units in micrometers, e.g. '1550 nm' -> 1.55"""
467+
"""Return None if the string is not a valid unit."""
456468
unit_conversions = {
457469
"nm": 1e-3,
458470
"um": 1,
@@ -465,11 +477,11 @@ def _str_units_to_float(str_units: str) -> Optional[float]:
465477
return float(numeric_value * unit_conversions[unit]) if unit in unit_conversions and numeric_value is not None else None
466478

467479

468-
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]:
480+
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 1000) -> Tuple[List[float], List[float]]:
469481
"""
470482
Get the wavelengths to plot based on the statements.
471483
472-
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
484+
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra, in um.
473485
"""
474486

475487
min_wl = float("inf")
@@ -511,8 +523,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
511523
min_wl = min(min_wl, min(vlines))
512524
max_wl = max(max_wl, max(vlines))
513525
if min_wl >= max_wl:
514-
avg_wl = sum(vlines) / len(vlines) if vlines else 1.55
515-
min_wl, max_wl = avg_wl - 0.01, avg_wl + 0.01
526+
avg_wl = sum(vlines) / len(vlines) if vlines else _str_units_to_float("1550 nm")
527+
min_wl, max_wl = avg_wl - _str_units_to_float("10 nm"), avg_wl + _str_units_to_float("10 nm")
516528
else:
517529
range_size = max_wl - min_wl
518530
min_wl -= 0.2 * range_size

0 commit comments

Comments
 (0)