1
+ import rsatoolbox as rsa
2
+ import numpy as np
3
+ import PcmPy as pcm
4
+ import scipy .spatial .distance as sd
5
+ import pandas as pd
6
+ import seaborn as sb
7
+ import matplotlib .pyplot as plt
8
+
9
+ # Simulate data
10
+
11
+ def crossval_sim (n_part = 2 ,n_cond = 4 ,n_sim = 10 ,n_channel = 200 ,sigma = 'iid' ):
12
+ """ Simulate data from a model with 4 conditions
13
+ Use different noise covriances across trials for each
14
+ Use square eucledian or cross-validated square Eucldian distancen
15
+ Low numbers of partitions to emphasize increase in variance.
16
+ """
17
+
18
+ cond_vec ,part_vec = pcm .sim .make_design (n_cond ,n_part )
19
+ true_dist = np .array ([2 ,1 ,0 ,3 ,2 ,1 ])
20
+ dist_type = np .array ([1 ,2 ,3 ,4 ,5 ,6 ])
21
+ D = sd .squareform (true_dist )
22
+ H = pcm .matrix .centering (n_cond )
23
+ G = - 0.5 * H @ D @ H
24
+ M = pcm .model .FixedModel ('fixed' ,G )
25
+ if (sigma == 'iid' ):
26
+ Sigma = np .kron (np .eye (n_part ),np .eye (n_cond ))
27
+ elif (sigma == 'neigh' ):
28
+ A = [[1 ,0.8 ,0 ,0 ],[0.8 ,1 ,0 ,0.5 ],[0 ,0 ,1 ,0 ],[0 ,0.5 ,0 ,1 ]]
29
+ Sigma = np .kron (np .eye (n_part ),A )
30
+ data = pcm .sim .make_dataset (M ,[],cond_vec ,
31
+ n_sim = n_sim ,
32
+ noise = 4 ,
33
+ n_channel = n_channel ,
34
+ noise_cov_trial = Sigma )
35
+ Z = pcm .matrix .indicator (cond_vec )
36
+
37
+ D_simp = np .zeros ((n_sim ,n_cond * (n_cond - 1 )// 2 ))
38
+ D_cross = np .zeros ((n_sim ,n_cond * (n_cond - 1 )// 2 ))
39
+ for i in range (n_sim ):
40
+ mean_act = np .linalg .pinv (Z ) @ data [i ].measurements
41
+ D_simp [i ] = sd .pdist (mean_act )** 2 / n_channel
42
+ G_cross ,_ = pcm .est_G_crossval (data [i ].measurements ,cond_vec ,part_vec )
43
+ D_cross [i ,:] = sd .squareform (pcm .G_to_dist (G_cross ))
44
+
45
+
46
+ # model, theta, cond_vec, n_channel=30, n_sim=1,
47
+ # signal=1, noise=1, signal_cov_channel=None,
48
+ # noise_cov_channel=None, noise_cov_trial=None,
49
+ # use_exact_signal=False, use_same_signal=False)
50
+ T = pd .DataFrame ({'Simp' :D_simp .flatten (),
51
+ 'Cross' :D_cross .flatten (),
52
+ 'True' :np .tile (true_dist ,n_sim ),
53
+ 'dist_type' :np .tile (dist_type ,n_sim )})
54
+ return (T )
55
+
56
+ def plot_panel (T ):
57
+ sb .violinplot (data = T ,x = 'True' ,y = 'Simp' )
58
+ sb .violinplot (data = T ,x = 'True' ,y = 'Cross' )
59
+ sb .despine ()
60
+ plt .plot ([0 ,3 ],[0 ,3 ],'k--' )
61
+ plt .xlabel ('True distance' )
62
+ plt .ylabel ('Estimated distance' )
63
+ ax = plt .gca ()
64
+ ax .set_ylim ([- 1 ,10 ])
65
+
66
+
67
+ if __name__ == "__main__" :
68
+ T1 = crossval_sim (sigma = 'iid' ,n_sim = 100 )
69
+ T2 = crossval_sim (sigma = 'neigh' ,n_sim = 100 )
70
+ plt .figure ()
71
+ plt .subplot (1 ,2 ,1 )
72
+ plot_panel (T1 )
73
+ plt .subplot (1 ,2 ,2 )
74
+ plot_panel (T2 )
75
+
76
+ pass
0 commit comments