Skip to content

Better plots #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions src/axiomatic/pic_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,19 @@ def plot_interactive_spectra(
A list of spectra, where each spectrum is a list of lists of float values, each
corresponding to the transmission of a single wavelength.
wavelengths : list of float
A list of wavelength values corresponding to the x-axis of the plot.
A list of wavelength values corresponding to the x-axis of the plot, in um.
vlines : list of float, optional
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
A list of x-values where vertical lines should be drawn, in um. Defaults to an empty list.
hlines : list of float, optional
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
"""

hlines = hlines or []

# Convert wavelengths to nm
wavelengths = [wl*1e3 for wl in wavelengths]
vlines = [wl*1e3 for wl in vlines] if vlines else []

if isinstance(spectra, dict):
port_keys = []
for key in spectra:
Expand All @@ -137,10 +144,6 @@ def plot_interactive_spectra(

elif spectrum_labels is None:
spectrum_labels = [f"Spectrum {i}" for i in range(len(spectra))]
if vlines is None:
vlines = []
if hlines is None:
hlines = []

if isinstance(spectra, dict):
spectra = list(spectra.values())
Expand All @@ -149,9 +152,16 @@ def plot_interactive_spectra(
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
y_min = min(all_vals)
y_max = max(all_vals)
if hlines:
y_min = min(hlines + [y_min]) * 0.95
y_max = max(hlines + [y_max]) * 1.05

# dB scale
if y_max <= 0:
y_max = 0
db = True
else:
db = False
if hlines:
y_min = min(hlines + [y_min]) * 0.95
y_max = max(hlines + [y_max]) * 1.05

# Create hlines and vlines
shapes = []
Expand Down Expand Up @@ -187,8 +197,8 @@ def plot_interactive_spectra(

# Create the layout
fig.update_layout(
xaxis_title="Wavelength",
yaxis_title="Transmission",
xaxis_title="Wavelength (nm)",
yaxis_title="Transmission " + "(dB)" if db else "(linear)",
shapes=shapes,
sliders=sliders,
yaxis=dict(range=[y_min, y_max]),
Expand Down Expand Up @@ -453,6 +463,8 @@ def print_statements(


def _str_units_to_float(str_units: str) -> Optional[float]:
"""Returns the numeric value of a string with units in micrometers, e.g. '1550 nm' -> 1.55"""
"""Return None if the string is not a valid unit."""
unit_conversions = {
"nm": 1e-3,
"um": 1,
Expand All @@ -465,11 +477,11 @@ def _str_units_to_float(str_units: str) -> Optional[float]:
return float(numeric_value * unit_conversions[unit]) if unit in unit_conversions and numeric_value is not None else None


def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]:
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 1000) -> Tuple[List[float], List[float]]:
"""
Get the wavelengths to plot based on the statements.

Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra, in um.
"""

min_wl = float("inf")
Expand Down Expand Up @@ -511,8 +523,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
min_wl = min(min_wl, min(vlines))
max_wl = max(max_wl, max(vlines))
if min_wl >= max_wl:
avg_wl = sum(vlines) / len(vlines) if vlines else 1.55
min_wl, max_wl = avg_wl - 0.01, avg_wl + 0.01
avg_wl = sum(vlines) / len(vlines) if vlines else _str_units_to_float("1550 nm")
min_wl, max_wl = avg_wl - _str_units_to_float("10 nm"), avg_wl + _str_units_to_float("10 nm")
else:
range_size = max_wl - min_wl
min_wl -= 0.2 * range_size
Expand Down