Skip to content

Commit eecb408

Browse files
committedOct 12, 2023
try to make tests work by removing plotly from env
1 parent d93e0fd commit eecb408

File tree

3 files changed

+83
-79
lines changed

3 files changed

+83
-79
lines changed
 

‎devtools/conda-envs/test_env.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies:
1919
- pandas
2020
- pytables
2121
- matplotlib
22-
- plotly
2322

2423

2524
# Pip-only installs

‎reeds/function_libs/visualization/sampling_plots.py

+2-78
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
from typing import Union, List
1+
from typing import List
22

33
import numpy as np
44
from matplotlib import pyplot as plt
5-
from matplotlib.colors import Colormap, to_rgba
6-
7-
import plotly.graph_objects as go
8-
from plotly.colors import convert_to_RGB_255
95

106
from reeds.function_libs.visualization import plots_style as ps
117
from reeds.function_libs.visualization.utils import nice_s_vals
@@ -380,76 +376,4 @@ def plot_stateOccurence_matrix(data: dict,
380376

381377
if (not out_dir is None):
382378
fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight')
383-
plt.close()
384-
385-
def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None):
386-
"""
387-
Make a Sankey plot showing the flows between states.
388-
389-
Parameters
390-
----------
391-
state_transitions : np.ndarray
392-
num_states * num_states 2D array containing the number of transitions between states
393-
title: str, optional
394-
printed title of the plot
395-
colors: Union[List[str], Colormap], optional
396-
if you don't like the default colors
397-
out_path: str, optional
398-
path to save the image to. if none, the image is returned as a plotly figure
399-
Returns
400-
-------
401-
None or fig
402-
plotly figure if if was not saved
403-
"""
404-
num_states = len(state_transitions)
405-
406-
if isinstance(colors, Colormap):
407-
colors = [colors(i) for i in np.linspace(0, 1, num_states)]
408-
elif len(colors) < num_states:
409-
raise Exception("Insufficient colors to plot all states")
410-
411-
def v_distribute(total_transitions):
412-
# Vertically distribute nodes in plot based on total number of transitions per state
413-
box_sizes = total_transitions / total_transitions.sum()
414-
box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))]
415-
return box_vplace
416-
417-
y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0))
418-
419-
# Convert colors to plotly format and make them transparent
420-
rgba_colors = []
421-
for color in colors:
422-
rgba = to_rgba(color)
423-
rgba_plotly = convert_to_RGB_255(rgba[:-1])
424-
# Add opacity
425-
rgba_plotly = rgba_plotly + (0.8,)
426-
# Make string
427-
rgba_colors.append("rgba" + str(rgba_plotly))
428-
429-
# Indices 0..n-1 are the source and n..2n-1 are the target.
430-
fig = go.Figure(data=[go.Sankey(
431-
node = dict(
432-
pad = 5,
433-
thickness = 20,
434-
line = dict(color = "black", width = 2),
435-
label = [f"state {i+1}" for i in range(num_states)]*2,
436-
color = rgba_colors[:num_states]*2,
437-
x = [0.1]*num_states + [1]*num_states,
438-
y = y_placements
439-
),
440-
link = dict(
441-
arrowlen = 30,
442-
source = np.array([[i]*num_states for i in range(num_states)]).flatten(),
443-
target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(),
444-
value = state_transitions.flatten(),
445-
color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten()
446-
),
447-
arrangement="fixed",
448-
)])
449-
fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100))
450-
451-
if out_path:
452-
fig.write_image(out_path)
453-
return None
454-
else:
455-
return fig
379+
plt.close()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Union, List
2+
import numpy as np
3+
4+
from matplotlib.colors import Colormap, to_rgba
5+
import plotly.graph_objects as go
6+
from plotly.colors import convert_to_RGB_255
7+
8+
from reeds.function_libs.visualization import plots_style as ps
9+
10+
11+
def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None):
12+
"""
13+
Make a Sankey plot showing the flows between states.
14+
15+
Parameters
16+
----------
17+
state_transitions : np.ndarray
18+
num_states * num_states 2D array containing the number of transitions between states
19+
title: str, optional
20+
printed title of the plot
21+
colors: Union[List[str], Colormap], optional
22+
if you don't like the default colors
23+
out_path: str, optional
24+
path to save the image to. if none, the image is returned as a plotly figure
25+
Returns
26+
-------
27+
None or fig
28+
plotly figure if if was not saved
29+
"""
30+
num_states = len(state_transitions)
31+
32+
if isinstance(colors, Colormap):
33+
colors = [colors(i) for i in np.linspace(0, 1, num_states)]
34+
elif len(colors) < num_states:
35+
raise Exception("Insufficient colors to plot all states")
36+
37+
def v_distribute(total_transitions):
38+
# Vertically distribute nodes in plot based on total number of transitions per state
39+
box_sizes = total_transitions / total_transitions.sum()
40+
box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))]
41+
return box_vplace
42+
43+
y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0))
44+
45+
# Convert colors to plotly format and make them transparent
46+
rgba_colors = []
47+
for color in colors:
48+
rgba = to_rgba(color)
49+
rgba_plotly = convert_to_RGB_255(rgba[:-1])
50+
# Add opacity
51+
rgba_plotly = rgba_plotly + (0.8,)
52+
# Make string
53+
rgba_colors.append("rgba" + str(rgba_plotly))
54+
55+
# Indices 0..n-1 are the source and n..2n-1 are the target.
56+
fig = go.Figure(data=[go.Sankey(
57+
node = dict(
58+
pad = 5,
59+
thickness = 20,
60+
line = dict(color = "black", width = 2),
61+
label = [f"state {i+1}" for i in range(num_states)]*2,
62+
color = rgba_colors[:num_states]*2,
63+
x = [0.1]*num_states + [1]*num_states,
64+
y = y_placements
65+
),
66+
link = dict(
67+
arrowlen = 30,
68+
source = np.array([[i]*num_states for i in range(num_states)]).flatten(),
69+
target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(),
70+
value = state_transitions.flatten(),
71+
color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten()
72+
),
73+
arrangement="fixed",
74+
)])
75+
fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100))
76+
77+
if out_path:
78+
fig.write_image(out_path)
79+
return None
80+
else:
81+
return fig

0 commit comments

Comments
 (0)
Please sign in to comment.