forked from g-walley/cegpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBMA_example.py
More file actions
105 lines (88 loc) · 3.18 KB
/
Copy pathBMA_example.py
File metadata and controls
105 lines (88 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# %%
from random import seed
from Bayesian_model_averaging_CEG import *
import pandas as pd
from src.cegpy.trees import event, staged
from src.cegpy.graphs import ceg
import copy
import matplotlib.pyplot as plt
# create event tree
data = pd.read_csv("datasets/Falls_Data.csv")
#to test on smaller dataset try
# data = pd.read_csv("datasets/Falls_Data_10000.csv")
show_event_tree(data)
# %%
# create stage tree
st = staged.StagedTree(data)
#for any stuctured missing data use the optional argument: struct_missing_label= ""
hyperstage = st._create_default_hyperstage()
HAC_st = copy.deepcopy(st)
#%%
# run HAC and show CEG of output
HAC_st.calculate_AHC_transitions(alpha=4, hyperstage = hyperstage)
HAC_CEG = ceg.ChainEventGraph(HAC_st)
HAC_CEG.create_figure()
#%%
# run BMA with w-hac for each hyperset
staging_output, model_weights,loglikelihoods = Bayesian_model_averaging_CEGs(st, K_max=100, prior_weight=4, hyperstage = hyperstage)
#%%
#plot each hyperset
for h in range(len(staging_output)):
if len(model_weights[h])>1:
labels=[str(x) for x in staging_output[h]]
data = model_weights[h].flatten()
plt.pie(data, labels=labels)
plt.title( "Normalised Bayes factor ratios for hyperset\n"+ str(hyperstage[h]))
plt.show()
#%%
for h in range(len(staging_output)):
if len(model_weights[h])>1:
labels=[str(x) for x in staging_output[h]]
data = model_weights[h].flatten()
order=sorted(range(len(data)), key=data.__getitem__,reverse=True)
ordered_labs=[labels[i] for i in order]
ordered_data=[data[i] for i in order]
plt.bar(ordered_labs,ordered_data , align='center', alpha=0.5)
plt.title( "Normalised Bayes factor ratios for hyperset\n"+ str(hyperstage[h]))
plt.ylabel('Model weights')
plt.xticks(rotation=90)
plt.show()
#%%
# get the staging for the whole hyperstage
full_staging_output, full_model_weights = get_full_model_weights(model_weights, staging_output)
# plot model weights for the well performing models
y_pos = np.arange(len(full_staging_output))
if len(full_staging_output)>1:
sorted_model_weights = sorted(full_model_weights,reverse=True)
else:
sorted_model_weights = full_model_weights
labels=["M"+str(i+1) for i in range(0,len(full_staging_output))]
plt.bar(labels,sorted_model_weights , align='center', alpha=0.5)
plt.ylabel('Model weights')
plt.xticks(rotation=90)
plt.show()
#%%
# calculate well performing unions and intersections
intersection_staging = get_staging_intersection(staging_output)
union_staging = get_staging_union(staging_output)
intersection_set = []
union_set = []
for sets in list(intersection_staging.values()):
intersection_set += sets
for sets in list(union_staging.values()):
union_set += sets
# %%
# show ceg of intersection set
coarsest=copy.deepcopy(st)
coarsest.calculate_full_transitions(hyperstage=intersection_set)
coarsest.create_figure()
coarsest_ceg=ceg.ChainEventGraph(coarsest)
coarsest_ceg.create_figure()
#%%
# show ceg of union set
coarsest=copy.deepcopy(st)
coarsest.calculate_full_transitions(hyperstage=union_set)
coarsest.create_figure()
coarsest_ceg=ceg.ChainEventGraph(coarsest)
coarsest_ceg.create_figure()
# %%