-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemimgp_sine.m
161 lines (136 loc) · 4.15 KB
/
demimgp_sine.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
addpath misc_toolbox/;
addpath misc_toolbox/gpml/;
addpath misc_toolbox/netlab/;
randn('state', 1724);
rand('state', 1724);
n = 150;
s2 = 0.01;
alpha = 1;
beta = 2.5;
X = linspace(0,2,n)';
Y = sin(alpha*pi*X.^beta) + sqrt(s2)*randn(n,1);
n_test = 500;
Xtest = linspace(0,2,n_test)';
Ytest = sin(alpha*pi*Xtest.^beta);
options = [];
options(1) = 1; % display lower bound
options(2) = 1; % learn kernel hyperparameters
options(4) = 1; % learn target noise
options(6) = 1; % learn delta
options(8) = 1; % learn nu0
options(9) = 1; % learn W0
options(10) = 1; % label re-ordering
options(11) = 30; % no. of iterations
options(15) = 1; % use Kmeans for the initialization
options(16) = 1; % non-zero mean GPs
C = 20; % threshold
% im-gp
ctime = cputime;
[ model vardist lb ] = imgpTrain(X,Y,{'covSEard'},C,s2,0.,options);
fprintf(1, 'training of the IM-GP completed in %f s.\n', cputime-ctime);
fprintf(1,'im-gp delta = %g\n', vardist.delta);
fprintf(1,'im-gp noise = %g\n', model.Likelihood.sigma2);
[ yp sig2 omega ypc ] = imgpPredict(model, vardist, Xtest);
sig = sqrt(sig2);
%disp('Assignments:');
%vardist.gamma
disp('nu0:');
for c = 1:C
fprintf(1,'%f %f\n', vardist.nu0(c), vardist.W0(:,:,c));
end
% vanilla gp
ycovfunc = {'covSum', {'covSEard', 'covNoise'}};
logtheta = [ log(0.5), log(1), 0.5*log(s2) ];
% checkgrad('gpr_fn', logtheta(:), 1e-4, ycovfunc, X, Y);
[logtheta fX] = minimize(logtheta(:), 'gpr', 5, ycovfunc, X, Y);
sig2vanilla = exp(2*logtheta(3));
fprintf(1,'vanilla-gp noise = %g\n', sig2vanilla);
logtheta = logtheta(1:2);
exp(logtheta)
K = feval('covSEard', logtheta, model.X);
[Kss, Kstar] = feval('covSEard', logtheta, model.X, Xtest);
Lc = chol(K+s2*eye(n),'lower');
V = (Lc'\(Lc\(Kstar)))';
yvanilla = V*model.Y;
sig2vanilla = Kss - sum(V.*Kstar',2) + sig2vanilla;
sigvanilla = sqrt(sig2vanilla);
fh1 = figure(1,"position",[0,0,900,1000]); % w,h
dots_size = 10; % 14
mean_size = 3; % 10
sig_size = 2; % 3
subplot(4,1,1);
hold on
plot(X, Y, '.', 'markersize', dots_size, 'color','black');
xlabel('Input')
ylabel('Target')
plot(Xtest,Ytest, 'color', 'black','linewidth', mean_size);
plot(Xtest, yp, '-b','linewidth', mean_size);
plot(Xtest, yp+(2*sig), '-b','linewidth', sig_size);
plot(Xtest, yp-(2*sig), '-b','linewidth', sig_size);
legend('data', 'function', 'PYP-GP', 'PYP-GP lower bound', 'PYP-GP upper bound');
hold off
axis([0 2 -1.5 2.5]);
t = ['N=' num2str(n) ', C=' num2str(C) ', target noise=' num2str(s2) ];
if options(4)
t = [ t ' (learned)' ];
else
t = [ t ' (frozen, i.e. not learned)' ];
end
%title(t);
subplot(4,1,2);
hold on
plot(X, Y, '.', 'markersize', dots_size, 'color','black');
xlabel('Input')
ylabel('Target')
plot(Xtest, Ytest, 'color', 'black','linewidth', mean_size);
plot(Xtest, yvanilla, '-r','linewidth', mean_size);
plot(Xtest, yvanilla+(2*sigvanilla), '-r','linewidth', sig_size);
plot(Xtest, yvanilla-(2*sigvanilla), '-r','linewidth', sig_size);
legend('data', 'function', 'Vanilla-GP', 'Vanilla-GP lower bound', 'Vanilla-GP upper bound');
hold off
axis([0 2 -1.5 2.5]);
subplot(4,1,3);
hold all
plot(X, Y, '.', 'markersize', dots_size, 'color','black');
xlabel('Input')
ylabel('Target')
plot(Xtest, Ytest, 'color', 'black','linewidth', mean_size);
t = { 'samples' 'true curve' };
k = 3;
col = { 'r', 'g', 'b', 'r', 'b' };
for c = 1:C
if omega(c) > 0.1
plot(Xtest, ypc(:,c), 'linewidth', mean_size, 'color', col{mod(c,4)+1});
t{k} = num2str(omega(c));
k = k + 1;
end
end
for c = 1:C
if omega(c) > 0.1
fprintf(1,'plotting component c = %g: 2*sig = %f, W = %g, nu = %f, G = %f\n', ...
c, 2*sqrt(inv(vardist.W(:,:,c))/vardist.nu(c)), vardist.W(:,:,c), vardist.nu(c), vardist.G(:,:,c));
end
end
legend(t);
axis([0 2 -1.5 4]);
hold off
subplot(4,1,4);
hold all
k = 3;
for c = 1:C
if omega(c) > 0.1
plot(Xtest, gauss(vardist.g(:,c), inv(vardist.W(:,:,c))/vardist.nu(c), Xtest), 'linewidth', mean_size, 'color', col{mod(c,4)+1});
t{k} = num2str(omega(c));
k = k + 1;
end
end
axis([0 2 -1.5 4]);
hold off
fprintf(1,'rms on im-gp: %f\n', norm(Ytest-yp)/sqrt(n_test));
fprintf(1,'rms on vanilla gp: %f\n', norm(Ytest-yvanilla)/sqrt(n_test));
% print -dpng a.png -S2304,640
disp(' ')
disp('Press any key to end.')
pause
close(fh1);
clear all;