1
- from typing import Union , List
1
+ from typing import List
2
2
3
3
import numpy as np
4
4
from matplotlib import pyplot as plt
5
- from matplotlib .colors import Colormap , to_rgba
6
-
7
- import plotly .graph_objects as go
8
- from plotly .colors import convert_to_RGB_255
9
5
10
6
from reeds .function_libs .visualization import plots_style as ps
11
7
from reeds .function_libs .visualization .utils import nice_s_vals
@@ -380,76 +376,4 @@ def plot_stateOccurence_matrix(data: dict,
380
376
381
377
if (not out_dir is None ):
382
378
fig .savefig (out_dir + '/sampling_maxContrib_matrix.png' , bbox_inches = 'tight' )
383
- plt .close ()
384
-
385
- def plot_state_transitions (state_transitions : np .ndarray , title : str = None , colors : Union [List [str ], Colormap ] = ps .qualitative_tab_map , out_path : str = None ):
386
- """
387
- Make a Sankey plot showing the flows between states.
388
-
389
- Parameters
390
- ----------
391
- state_transitions : np.ndarray
392
- num_states * num_states 2D array containing the number of transitions between states
393
- title: str, optional
394
- printed title of the plot
395
- colors: Union[List[str], Colormap], optional
396
- if you don't like the default colors
397
- out_path: str, optional
398
- path to save the image to. if none, the image is returned as a plotly figure
399
- Returns
400
- -------
401
- None or fig
402
- plotly figure if if was not saved
403
- """
404
- num_states = len (state_transitions )
405
-
406
- if isinstance (colors , Colormap ):
407
- colors = [colors (i ) for i in np .linspace (0 , 1 , num_states )]
408
- elif len (colors ) < num_states :
409
- raise Exception ("Insufficient colors to plot all states" )
410
-
411
- def v_distribute (total_transitions ):
412
- # Vertically distribute nodes in plot based on total number of transitions per state
413
- box_sizes = total_transitions / total_transitions .sum ()
414
- box_vplace = [np .sum (box_sizes [:i ]) + box_sizes [i ]/ 2 for i in range (len (box_sizes ))]
415
- return box_vplace
416
-
417
- y_placements = v_distribute (np .sum (state_transitions , axis = 1 )) + v_distribute (np .sum (state_transitions , axis = 0 ))
418
-
419
- # Convert colors to plotly format and make them transparent
420
- rgba_colors = []
421
- for color in colors :
422
- rgba = to_rgba (color )
423
- rgba_plotly = convert_to_RGB_255 (rgba [:- 1 ])
424
- # Add opacity
425
- rgba_plotly = rgba_plotly + (0.8 ,)
426
- # Make string
427
- rgba_colors .append ("rgba" + str (rgba_plotly ))
428
-
429
- # Indices 0..n-1 are the source and n..2n-1 are the target.
430
- fig = go .Figure (data = [go .Sankey (
431
- node = dict (
432
- pad = 5 ,
433
- thickness = 20 ,
434
- line = dict (color = "black" , width = 2 ),
435
- label = [f"state { i + 1 } " for i in range (num_states )]* 2 ,
436
- color = rgba_colors [:num_states ]* 2 ,
437
- x = [0.1 ]* num_states + [1 ]* num_states ,
438
- y = y_placements
439
- ),
440
- link = dict (
441
- arrowlen = 30 ,
442
- source = np .array ([[i ]* num_states for i in range (num_states )]).flatten (),
443
- target = np .array ([[i for i in range (num_states , 2 * num_states )] for _ in range (num_states )]).flatten (),
444
- value = state_transitions .flatten (),
445
- color = np .array ([[c ]* num_states for c in rgba_colors [:num_states ]]).flatten ()
446
- ),
447
- arrangement = "fixed" ,
448
- )])
449
- fig .update_layout (title_text = title , font_size = 20 , title_x = 0.5 , height = max (600 , num_states * 100 ))
450
-
451
- if out_path :
452
- fig .write_image (out_path )
453
- return None
454
- else :
455
- return fig
379
+ plt .close ()
0 commit comments