This repository was archived by the owner on Jun 28, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsliders.py
85 lines (64 loc) · 2.46 KB
/
sliders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider, Button
from CustomScaler import CustomScaler
from interpolators.rbf import RbfInterpolator
from simulation_list import SimulationList
simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
resolution = 100
data = simlist.X
values = simlist.Y_mass_fraction
scaler = CustomScaler()
scaler.fit(data)
scaled_data = scaler.transform_data(data)
interpolator = RbfInterpolator(scaled_data, values)
alpharange = np.linspace(-0.5, 60.5, resolution)
vrange = np.linspace(0.5, 5.5, resolution)
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
datagrid = np.zeros_like(grid_alpha)
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="nearest") # type:QuadMesh
plt.colorbar()
# axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03])
buttonax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(buttonax, 'Update', hovercolor='0.975')
# thetext = ax.text(-10, 0, "hello", fontsize=12) #type:Text
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
def update(val):
print("start updating")
# thetext.set_text("updating")
# fig.canvas.draw()
mcode = s_mcode.val
gamma = s_gamma.val
wt = s_wt.val
wp = s_wp.val
parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt, wp]
scaled_parameters = list(scaler.transform_parameters(parameters))
datagrid = interpolator.interpolate(*scaled_parameters)
print(datagrid)
print(np.isnan(datagrid).sum() / (resolution * resolution))
if not isinstance(datagrid, np.ndarray):
return False
mesh.set_array(datagrid.ravel())
print("finished updating")
# thetext.set_text("finished")
fig.canvas.draw_idle()
update(None)
# s_gamma.on_changed(update)
# s_mcode.on_changed(update)
# s_wp.on_changed(update)
# s_wt.on_changed(update)
button.on_clicked(update)
plt.show()