Skip to content

Commit e93e6dd

Browse files
authored
Merge pull request #54 from Axiomatic-AI/default-range
Improve get wls to plot
2 parents ecba8f3 + ba6d090 commit e93e6dd

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/axiomatic/pic_helpers.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -452,17 +452,17 @@ def print_statements(
452452
print("\n-----------------------------------\n")
453453

454454

455-
def _str_units_to_float(str_units: str) -> float:
455+
def _str_units_to_float(str_units: str) -> Optional[float]:
456456
unit_conversions = {
457457
"nm": 1e-3,
458458
"um": 1,
459459
"mm": 1e3,
460460
"m": 1e6,
461461
}
462462
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units)
463-
numeric_value = float(match.group(1) if match else 1.55)
464-
unit = match.group(2) if match else "um"
465-
return float(numeric_value * unit_conversions[unit])
463+
numeric_value = float(match.group(1)) if match else None
464+
unit = match.group(2) if match else None
465+
return float(numeric_value * unit_conversions[unit]) if unit in unit_conversions and numeric_value is not None else None
466466

467467

468468
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]:
@@ -484,16 +484,19 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
484484
vlines = vlines | {
485485
_str_units_to_float(wl)
486486
for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else [])
487-
if isinstance(wl, str)
487+
if isinstance(wl, str) and _str_units_to_float(wl) is not None
488488
}
489489
if "wavelength_range" in comp.arguments:
490490
if (
491491
isinstance(comp.arguments["wavelength_range"], list)
492492
and len(comp.arguments["wavelength_range"]) == 2
493493
and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"])
494494
):
495-
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
496-
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
495+
mi = _str_units_to_float(comp.arguments["wavelength_range"][0])
496+
ma = _str_units_to_float(comp.arguments["wavelength_range"][1])
497+
if mi is not None and ma is not None:
498+
min_wl = min(min_wl, mi)
499+
max_wl = max(max_wl, ma)
497500
return min_wl, max_wl, vlines
498501

499502
for cost_stmt in statements.cost_functions or []:
@@ -508,8 +511,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
508511
min_wl = min(min_wl, min(vlines))
509512
max_wl = max(max_wl, max(vlines))
510513
if min_wl >= max_wl:
511-
avg_wl = sum(vlines) / len(vlines) if vlines else 1550
512-
min_wl, max_wl = avg_wl - 10, avg_wl + 10
514+
avg_wl = sum(vlines) / len(vlines) if vlines else 1.55
515+
min_wl, max_wl = avg_wl - 0.01, avg_wl + 0.01
513516
else:
514517
range_size = max_wl - min_wl
515518
min_wl -= 0.2 * range_size

0 commit comments

Comments
 (0)