@@ -452,17 +452,17 @@ def print_statements(
452
452
print ("\n -----------------------------------\n " )
453
453
454
454
455
- def _str_units_to_float (str_units : str ) -> float :
455
+ def _str_units_to_float (str_units : str ) -> Optional [ float ] :
456
456
unit_conversions = {
457
457
"nm" : 1e-3 ,
458
458
"um" : 1 ,
459
459
"mm" : 1e3 ,
460
460
"m" : 1e6 ,
461
461
}
462
462
match = re .match (r"([\d\.]+)\s*([a-zA-Z]+)" , str_units )
463
- numeric_value = float (match .group (1 ) if match else 1.55 )
464
- unit = match .group (2 ) if match else "um"
465
- return float (numeric_value * unit_conversions [unit ])
463
+ numeric_value = float (match .group (1 )) if match else None
464
+ unit = match .group (2 ) if match else None
465
+ return float (numeric_value * unit_conversions [unit ]) if unit in unit_conversions and numeric_value is not None else None
466
466
467
467
468
468
def get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int = 100 ) -> Tuple [List [float ], List [float ]]:
@@ -484,16 +484,19 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
484
484
vlines = vlines | {
485
485
_str_units_to_float (wl )
486
486
for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else [])
487
- if isinstance (wl , str )
487
+ if isinstance (wl , str ) and _str_units_to_float ( wl ) is not None
488
488
}
489
489
if "wavelength_range" in comp .arguments :
490
490
if (
491
491
isinstance (comp .arguments ["wavelength_range" ], list )
492
492
and len (comp .arguments ["wavelength_range" ]) == 2
493
493
and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ])
494
494
):
495
- min_wl = min (min_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][0 ]))
496
- max_wl = max (max_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][1 ]))
495
+ mi = _str_units_to_float (comp .arguments ["wavelength_range" ][0 ])
496
+ ma = _str_units_to_float (comp .arguments ["wavelength_range" ][1 ])
497
+ if mi is not None and ma is not None :
498
+ min_wl = min (min_wl , mi )
499
+ max_wl = max (max_wl , ma )
497
500
return min_wl , max_wl , vlines
498
501
499
502
for cost_stmt in statements .cost_functions or []:
@@ -508,8 +511,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
508
511
min_wl = min (min_wl , min (vlines ))
509
512
max_wl = max (max_wl , max (vlines ))
510
513
if min_wl >= max_wl :
511
- avg_wl = sum (vlines ) / len (vlines ) if vlines else 1550
512
- min_wl , max_wl = avg_wl - 10 , avg_wl + 10
514
+ avg_wl = sum (vlines ) / len (vlines ) if vlines else 1.55
515
+ min_wl , max_wl = avg_wl - 0.01 , avg_wl + 0.01
513
516
else :
514
517
range_size = max_wl - min_wl
515
518
min_wl -= 0.2 * range_size
0 commit comments