@@ -112,12 +112,19 @@ def plot_interactive_spectra(
112112        A list of spectra, where each spectrum is a list of lists of float values, each 
113113        corresponding to the transmission of a single wavelength. 
114114    wavelengths : list of float 
115-         A list of wavelength values corresponding to the x-axis of the plot. 
115+         A list of wavelength values corresponding to the x-axis of the plot, in um . 
116116    vlines : list of float, optional 
117-         A list of x-values where vertical lines should be drawn. Defaults to an empty list. 
117+         A list of x-values where vertical lines should be drawn, in um . Defaults to an empty list. 
118118    hlines : list of float, optional 
119119        A list of y-values where horizontal lines should be drawn. Defaults to an empty list. 
120120    """ 
121+ 
122+     hlines  =  hlines  or  []
123+     
124+     # Convert wavelengths to nm 
125+     wavelengths  =  [wl * 1e3  for  wl  in  wavelengths ]
126+     vlines  =  [wl * 1e3  for  wl  in  vlines ] if  vlines  else  []
127+ 
121128    if  isinstance (spectra , dict ):
122129        port_keys  =  []
123130        for  key  in  spectra :
@@ -137,10 +144,6 @@ def plot_interactive_spectra(
137144
138145    elif  spectrum_labels  is  None :
139146        spectrum_labels  =  [f"Spectrum { i }   for  i  in  range (len (spectra ))]
140-     if  vlines  is  None :
141-         vlines  =  []
142-     if  hlines  is  None :
143-         hlines  =  []
144147
145148    if  isinstance (spectra , dict ):
146149        spectra  =  list (spectra .values ())
@@ -149,9 +152,16 @@ def plot_interactive_spectra(
149152    all_vals  =  [val  for  spec  in  spectra  for  iteration  in  spec  for  val  in  iteration ]
150153    y_min  =  min (all_vals )
151154    y_max  =  max (all_vals )
152-     if  hlines :
153-         y_min  =  min (hlines  +  [y_min ]) *  0.95 
154-         y_max  =  max (hlines  +  [y_max ]) *  1.05 
155+ 
156+     # dB scale 
157+     if  y_max  <=  0 :
158+         y_max  =  0 
159+         db  =  True 
160+     else :
161+         db  =  False 
162+         if  hlines :
163+             y_min  =  min (hlines  +  [y_min ]) *  0.95 
164+             y_max  =  max (hlines  +  [y_max ]) *  1.05 
155165
156166    # Create hlines and vlines 
157167    shapes  =  []
@@ -187,8 +197,8 @@ def plot_interactive_spectra(
187197
188198    # Create the layout 
189199    fig .update_layout (
190-         xaxis_title = "Wavelength" ,
191-         yaxis_title = "Transmission" ,
200+         xaxis_title = "Wavelength (nm) " ,
201+         yaxis_title = "Transmission "    +   "(dB)"   if   db   else   "(linear) "
192202        shapes = shapes ,
193203        sliders = sliders ,
194204        yaxis = dict (range = [y_min , y_max ]),
@@ -453,6 +463,8 @@ def print_statements(
453463
454464
455465def  _str_units_to_float (str_units : str ) ->  Optional [float ]:
466+     """Returns the numeric value of a string with units in micrometers, e.g. '1550 nm' -> 1.55""" 
467+     """Return None if the string is not a valid unit.""" 
456468    unit_conversions  =  {
457469        "nm" : 1e-3 ,
458470        "um" : 1 ,
@@ -465,11 +477,11 @@ def _str_units_to_float(str_units: str) -> Optional[float]:
465477    return  float (numeric_value  *  unit_conversions [unit ]) if  unit  in  unit_conversions  and  numeric_value  is  not None  else  None 
466478
467479
468- def  get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int  =  100 ) ->  Tuple [List [float ], List [float ]]:
480+ def  get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int  =  1000 ) ->  Tuple [List [float ], List [float ]]:
469481    """ 
470482    Get the wavelengths to plot based on the statements. 
471483
472-     Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra. 
484+     Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra, in um . 
473485    """ 
474486
475487    min_wl  =  float ("inf" )
@@ -511,8 +523,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
511523        min_wl  =  min (min_wl , min (vlines ))
512524        max_wl  =  max (max_wl , max (vlines ))
513525    if  min_wl  >=  max_wl :
514-         avg_wl  =  sum (vlines ) /  len (vlines ) if  vlines  else  1.55 
515-         min_wl , max_wl  =  avg_wl  -  0.01 , avg_wl  +  0.01 
526+         avg_wl  =  sum (vlines ) /  len (vlines ) if  vlines  else  _str_units_to_float ( "1550 nm" ) 
527+         min_wl , max_wl  =  avg_wl  -  _str_units_to_float ( "10 nm" ) , avg_wl  +  _str_units_to_float ( "10 nm" ) 
516528    else :
517529        range_size  =  max_wl  -  min_wl 
518530        min_wl  -=  0.2  *  range_size 
0 commit comments