|
| 1 | +# |
| 2 | +# This file is part of jetflows. |
| 3 | +# |
| 4 | +# Copyright (C) 2014, Henry O. Jacobs ([email protected]), Stefan Sommer ([email protected]) |
| 5 | +# https://github.com/nefan/jetflows.git |
| 6 | +# |
| 7 | +# jetflows is free software: you can redistribute it and/or modify |
| 8 | +# it under the terms of the GNU General Public License as published by |
| 9 | +# the Free Software Foundation, either version 3 of the License, or |
| 10 | +# (at your option) any later version. |
| 11 | +# |
| 12 | +# jetflows is distributed in the hope that it will be useful, |
| 13 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 14 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 15 | +# GNU General Public License for more details. |
| 16 | +# |
| 17 | +# You should have received a copy of the GNU General Public License |
| 18 | +# along with jetflows. If not, see <http://www.gnu.org/licenses/>. |
| 19 | +# |
| 20 | + |
| 21 | +""" |
| 22 | +Perform mean and covariance estimation using supplied similarity measure using the mpps. |
| 23 | +""" |
| 24 | + |
| 25 | +import numpy as np |
| 26 | +import mpp |
| 27 | +from scipy.optimize import minimize,fmin_bfgs,fmin_cg,fmin_l_bfgs_b,root |
| 28 | +from scipy.optimize import check_grad |
| 29 | +from scipy.optimize import approx_fprime |
| 30 | +from scipy import linalg |
| 31 | +# from scipy.optimize import fmin_bfgs |
| 32 | +import matplotlib.pyplot as plt |
| 33 | +import itertools |
| 34 | +import logging |
| 35 | +from functools import partial |
| 36 | + |
| 37 | +import dill |
| 38 | +from pathos.multiprocessing import ProcessingPool |
| 39 | +from pathos.multiprocessing import cpu_count |
| 40 | +P = ProcessingPool(cpu_count()/2) |
| 41 | + |
| 42 | +N_t = 100 |
| 43 | +t_span = np.linspace(0. ,2. , N_t ) |
| 44 | +N = None |
| 45 | +DIM = None |
| 46 | +rank = None |
| 47 | +weights = None |
| 48 | +ps = None |
| 49 | + |
| 50 | + |
| 51 | +def getf(): |
| 52 | + |
| 53 | + _N = N |
| 54 | + _DIM = DIM |
| 55 | + _rank = rank |
| 56 | + _SIGMA = mpp.SIGMA |
| 57 | + _ps = ps |
| 58 | + _weights = weights |
| 59 | + |
| 60 | + def f(m, full=False, onlyidx=None): |
| 61 | + |
| 62 | + # setup |
| 63 | + N = _N |
| 64 | + DIM = _DIM |
| 65 | + rank = _rank |
| 66 | + SIGMA = _SIGMA |
| 67 | + mpp.DIM = DIM |
| 68 | + mpp.N = N |
| 69 | + mpp.SIGMA = SIGMA |
| 70 | + mpp.rank = rank |
| 71 | + mpp.gaussian.N = N |
| 72 | + mpp.gaussian.DIM = DIM |
| 73 | + mpp.gaussian.SIGMA = mpp.SIGMA |
| 74 | + ps = _ps |
| 75 | + weights = _weights |
| 76 | + |
| 77 | + def MPPLogf((idx,m,lambdag)): |
| 78 | + #print "computing Log for sample %d" % idx |
| 79 | + |
| 80 | + # setup |
| 81 | + N = _N |
| 82 | + DIM = _DIM |
| 83 | + rank = _rank |
| 84 | + SIGMA = _SIGMA |
| 85 | + mpp.DIM = DIM |
| 86 | + mpp.N = N |
| 87 | + mpp.SIGMA = SIGMA |
| 88 | + mpp.rank = rank |
| 89 | + mpp.gaussian.N = N |
| 90 | + mpp.gaussian.DIM = DIM |
| 91 | + mpp.gaussian.SIGMA = mpp.SIGMA |
| 92 | + ps = _ps |
| 93 | + weights = _weights |
| 94 | + |
| 95 | + # input |
| 96 | + mpp.lambdag = lambdag |
| 97 | + Nsamples = (m.shape[0]-(N*DIM+N*DIM*rank))/(N*DIM+N*DIM*rank) |
| 98 | + x0 = (1./Nsamples)*m[0:N*DIM] |
| 99 | + Xa0 = (1./Nsamples)*m[N*DIM:N*DIM+N*DIM*rank] |
| 100 | + Logsamples = m[N*DIM+N*DIM*rank:].reshape([-1,N*DIM+N*DIM*rank]) |
| 101 | + xi0 = Logsamples[idx,0:N*DIM] |
| 102 | + xia0 = Logsamples[idx,N*DIM:N*DIM+N*DIM*rank] |
| 103 | + |
| 104 | + # flow |
| 105 | + state0 = mpp.weinstein_darboux_to_state( x0, Xa0, xi0, xia0 ) |
| 106 | + x0,Xa0,xi0,xia0 = mpp.state_to_weinstein_darboux( state0 ) |
| 107 | + (t_span, y_span) = mpp.integrate(state0) |
| 108 | + stateT = y_span[-1] |
| 109 | + xT,XaT,xiT,xiaT = mpp.state_to_weinstein_darboux( stateT ) |
| 110 | + |
| 111 | + v0 = ps[idx,:,:]-xT.reshape([N,DIM]) |
| 112 | + res = np.einsum('ia,ia',v0,v0) # 1./N ?? |
| 113 | + #logging.info('match term after flow: ' + str(res)) |
| 114 | + |
| 115 | + EH = mpp.Hamiltonian(x0,Xa0,xi0,xia0) # path energy from Hamiltonian |
| 116 | + #logging.info('Hamiltonian: ' + str(EH)) |
| 117 | + |
| 118 | + #print "computed Log for sample %d (lambdag %g)" % (idx,mpp.lambdag) |
| 119 | + |
| 120 | + return weights[0]*EH+weights[1]*res |
| 121 | + |
| 122 | + |
| 123 | + # parallel compute distances |
| 124 | + Nsamples = (m.shape[0]-(N*DIM+N*DIM*rank))/(N*DIM+N*DIM*rank) |
| 125 | + |
| 126 | + # determine lambdag |
| 127 | + #logging.info("determining lambdag...") |
| 128 | + x0 = (1./Nsamples)*m[0:N*DIM] |
| 129 | + Xa0 = (1./Nsamples)*m[N*DIM:N*DIM+N*DIM*rank].reshape([N*DIM,rank]) |
| 130 | + (gsharp,_,_,g,_) = mpp.gs(x0) |
| 131 | + def detlambdag(lg): |
| 132 | + delta1 = np.eye(rank) |
| 133 | + W = np.einsum('ab,ka,lb->kl',delta1,Xa0,Xa0)+lg*gsharp |
| 134 | + W2 = np.einsum('ba,bi,ij->aj',W,g,W) |
| 135 | + detgsharp = np.linalg.det(gsharp) |
| 136 | + detW2 = np.linalg.det(W2) |
| 137 | + #print "detlambdag: lg %g, detW2 %g, detgsharp %g" % (lg,detW2,detgsharp) |
| 138 | + return detW2-detgsharp |
| 139 | + |
| 140 | + if rank > 0: |
| 141 | + reslambdag = root(detlambdag,1.) |
| 142 | + assert(reslambdag.success) |
| 143 | + lambdag = reslambdag.x |
| 144 | + else: |
| 145 | + lambdag = 1. |
| 146 | + |
| 147 | + # run logs |
| 148 | + #logging.info("performing shots...") |
| 149 | + input_args = zip(*(xrange(Nsamples), itertools.cycle((m,)), itertools.cycle((lambdag,)),)) |
| 150 | + if onlyidx == None: |
| 151 | + sol = P.imap(MPPLogf, input_args) |
| 152 | + Logs = np.array(list(sol)) |
| 153 | + #Logs = np.empty(Nsamples) |
| 154 | + #for i in range(Nsamples): |
| 155 | + # Logs[i] = MPPLogf(input_args[i]) |
| 156 | + else: |
| 157 | + sampleid = onlyidx/(N*DIM+N*DIM*rank) |
| 158 | + logging.info("only idx %d, sample %d...",onlyidx,sampleid) |
| 159 | + Logs = np.zeros(Nsamples) |
| 160 | + Logs[sampleid] = MPPLogf(input_args[sampleid]) |
| 161 | + |
| 162 | + res = (1./Nsamples)*np.sum(Logs) |
| 163 | + |
| 164 | + ## debug output |
| 165 | + if not full and onlyidx == None: |
| 166 | + #print "f x0: %s, Xa: %s, res %g" % (x0,Xa0,res,) |
| 167 | + print "f res %g" % (res,) |
| 168 | + |
| 169 | + if not full: |
| 170 | + return res |
| 171 | + else: |
| 172 | + return (res,(1./Nsamples)*Logs) |
| 173 | + |
| 174 | + return f |
| 175 | + |
| 176 | +#def constr(m): |
| 177 | +# # constraints on frame |
| 178 | +# Nsamples = (m.shape[0]-(N*DIM+N*DIM*rank))/(N*DIM+N*DIM*rank) |
| 179 | +# x0 = (1./Nsamples)*m[0:N*DIM] |
| 180 | +# Xa0 = (1./Nsamples)*m[N*DIM:N*DIM+N*DIM*rank].reshape([N*DIM,N*DIM]) |
| 181 | +# (_,_,_,gx0,_) = mpp.gs(x0) |
| 182 | +# Xa02inner = np.einsum('ba,bi,ij->aj',Xa0,gx0,Xa0) |
| 183 | +# detXa02 = np.linalg.det(Xa02inner) |
| 184 | +# |
| 185 | +# res = -np.sum(np.abs([1-detXa02])) |
| 186 | +# print "constr res: %s" % res |
| 187 | +# |
| 188 | +# return res |
| 189 | + |
| 190 | +def err_func_gradient(p): |
| 191 | + |
| 192 | + logging.info("gradient...") |
| 193 | + |
| 194 | + f = getf() |
| 195 | + |
| 196 | + (fp,Logs) = f(p,full=True) |
| 197 | + #lsingle_grad_point = partial(single_grad_point, fp) |
| 198 | + |
| 199 | + _N = N |
| 200 | + _DIM = DIM |
| 201 | + _rank = rank |
| 202 | + _SIGMA = mpp.SIGMA |
| 203 | + _ps = ps |
| 204 | + _weights = weights |
| 205 | + lambdag = None |
| 206 | + |
| 207 | + def single_grad_point((idx,px,Logs)): |
| 208 | + # setup |
| 209 | + N = _N |
| 210 | + DIM = _DIM |
| 211 | + rank = _rank |
| 212 | + |
| 213 | + p = px.copy() |
| 214 | + epsilon = 1e-6 |
| 215 | + p[idx] += epsilon |
| 216 | + if idx < N*DIM+N*DIM*rank: |
| 217 | + d1 = f(p) |
| 218 | + return (d1-fp)/(epsilon) |
| 219 | + else: |
| 220 | + onlyidx = idx-(N*DIM+N*DIM*rank) |
| 221 | + sampleid = onlyidx/(N*DIM+N*DIM*rank) |
| 222 | + d1 = f(p, onlyidx=onlyidx) |
| 223 | + return (d1-Logs[sampleid])/(epsilon) |
| 224 | + #p[idx] -= 2*epsilon |
| 225 | + #d2 = err_func(p) |
| 226 | + #return (d1-d2)/(2*epsilon) |
| 227 | + |
| 228 | + Nsamples = (p.shape[0]-(N*DIM+N*DIM*rank))/(N*DIM+N*DIM*rank) |
| 229 | + res = np.zeros(p.shape) |
| 230 | + # divide into two cases, x,Xa and remaining shots |
| 231 | + r0 = (0,N*DIM+N*DIM*rank) |
| 232 | + for i in range(r0[0],r0[1]): # run this serially |
| 233 | + res[i] = single_grad_point((i,p,None), ) |
| 234 | + |
| 235 | + r1 = (N*DIM+N*DIM*rank,p.size) |
| 236 | + assert((r1[1]-r1[0])==Nsamples*(N*DIM+N*DIM*rank)) |
| 237 | + input_args = zip(*(xrange(r1[0],r1[1]), itertools.cycle((p,)), itertools.cycle((Logs,)))) |
| 238 | + sol = P.imap(single_grad_point, input_args) |
| 239 | + res[r1[0]:r1[1]] = np.array(list(sol)) |
| 240 | + #for i in range(res2.size): |
| 241 | + # res[r1[0]+i] = single_grad_point(input_args[i]) |
| 242 | + |
| 243 | + return res |
| 244 | + |
| 245 | + |
| 246 | +def est(_ps,SIGMA,_weights,_rank,maxIter=150, visualize=False, visualizeIterations=False, x0=None, Xa0=None): |
| 247 | + """ |
| 248 | + Perform mean/cov estimation using the supplied similarity measure, mpps |
| 249 | + and scipy's optimizer. Currently no no derivative information is used based. |
| 250 | +
|
| 251 | + Weights determines the split between energy (weights[0]) and match term (weights[1]) |
| 252 | + """ |
| 253 | + |
| 254 | + # number of samples |
| 255 | + global ps |
| 256 | + ps = _ps |
| 257 | + Nsamples = ps.shape[0] |
| 258 | + |
| 259 | + # set flow parameters |
| 260 | + global DIM,N,rank |
| 261 | + mpp.DIM = DIM = ps.shape[2] |
| 262 | + mpp.N = N = ps.shape[1] |
| 263 | + mpp.SIGMA = SIGMA |
| 264 | + mpp.rank = rank = _rank |
| 265 | + mpp.lambdag = 1. |
| 266 | + mpp.init() |
| 267 | + global weights |
| 268 | + weights = _weights |
| 269 | + |
| 270 | + logging.info("Estimation parameters: rank %d, N %d, Nsamples %d, weights %s, SIGMA %g, maxIter %d, visualize %s, visualizeIterations %s",mpp.rank,N,Nsamples,weights,SIGMA,maxIter,visualize,visualizeIterations) |
| 271 | + |
| 272 | + |
| 273 | + if x0 == None: |
| 274 | + # initial point |
| 275 | + x0 = np.mean(ps,0).flatten() |
| 276 | + if Xa0 == None: |
| 277 | + # initial frame |
| 278 | + pss = ps.reshape([Nsamples,N*DIM]) |
| 279 | + (eigv,eigV) = np.linalg.eig(1./(Nsamples-1)*np.dot(pss.T,pss)) |
| 280 | + inds = eigv>1e-4 |
| 281 | + assert(np.sum(inds) >= rank) |
| 282 | + FrPCA = np.einsum('ij,j->ij',eigV[:,inds],np.sqrt(eigv[inds])) |
| 283 | + Xa0 = FrPCA.reshape([N*DIM,np.sum(inds)])[:,0:rank] |
| 284 | + |
| 285 | + logging.info("initial point/frame, x0: %s, Xa0: %s",x0,Xa0) |
| 286 | + |
| 287 | + initval = np.hstack( (Nsamples*x0,Nsamples*Xa0.flatten(),np.zeros((Nsamples,N*DIM+N*DIM*rank)).flatten(),) ).astype('double') |
| 288 | + tol = 1e-4 |
| 289 | + # use COBYLA for constrainted optimization |
| 290 | + if maxIter > 0: |
| 291 | + f = getf() |
| 292 | + #logging.info("checking gradient...") |
| 293 | + #from scipy.optimize import approx_fprime |
| 294 | + #findiffgrad1 = approx_fprime(initval,f,1e-7) |
| 295 | + #findiffgrad2 = err_func_gradient(initval) |
| 296 | + #logging.info("gradient difference: %g",np.linalg.norm(findiffgrad1-findiffgrad2,np.inf)) |
| 297 | + logging.info("running optimizer...") |
| 298 | + res = minimize(f, initval, method='CG',\ |
| 299 | + tol=tol,\ |
| 300 | +# constraints={'type': 'ineq', 'fun': constr},\ |
| 301 | + options={'disp': True, 'maxiter': maxIter},\ |
| 302 | + jac=err_func_gradient |
| 303 | + ) |
| 304 | + |
| 305 | + if not res.success: |
| 306 | + print "mean/covar optimization failed:\n%s" % res |
| 307 | + |
| 308 | + mu = (1./Nsamples)*res.x[0:N*DIM] |
| 309 | + SigmaSQRT = (1./Nsamples)*res.x[N*DIM:N*DIM+N*DIM*rank] |
| 310 | + Logyis = res.x[N*DIM+N*DIM*rank:].reshape([Nsamples,N*DIM+N*DIM*rank]) |
| 311 | + else: |
| 312 | + logging.info("not running optimizer (maxIter 0)") |
| 313 | + |
| 314 | + # determine lambdag |
| 315 | + logging.info("determining lambdag...") |
| 316 | + m = initval |
| 317 | + x0 = (1./Nsamples)*m[0:N*DIM] |
| 318 | + Xa0 = (1./Nsamples)*m[N*DIM:N*DIM+N*DIM*rank].reshape([N*DIM,rank]) |
| 319 | + (gsharp,_,_,g,_) = mpp.gs(x0) |
| 320 | + def detlambdag(lg): |
| 321 | + delta1 = np.eye(rank) |
| 322 | + W = np.einsum('ab,ka,lb->kl',delta1,Xa0,Xa0)+lg*gsharp |
| 323 | + W2 = np.einsum('ba,bi,ij->aj',W,g,W) |
| 324 | + detgsharp = np.linalg.det(gsharp) |
| 325 | + detW2 = np.linalg.det(W2) |
| 326 | + #print "detlambdag: lg %g, detW2 %g, detgsharp %g" % (lg,detW2,detgsharp) |
| 327 | + return detW2-detgsharp |
| 328 | + |
| 329 | + if rank > 0: |
| 330 | + reslambdag = root(detlambdag,1.) |
| 331 | + assert(reslambdag.success) |
| 332 | + mpp.lambdag = reslambdag.x |
| 333 | + else: |
| 334 | + mpp.lambdag = 1. |
| 335 | + |
| 336 | + mu = (1./Nsamples)*initval[0:N*DIM] |
| 337 | + SigmaSQRT = (1./Nsamples)*initval[N*DIM:N*DIM+N*DIM*rank] |
| 338 | + Logyis = initval[N*DIM+N*DIM*rank:].reshape([Nsamples,N*DIM+N*DIM*rank]) |
| 339 | + |
| 340 | + mu = mu.reshape([N,DIM]) |
| 341 | + SigmaSQRT = SigmaSQRT.reshape([N*DIM,rank]) |
| 342 | + print "mu: %s,\nSigmaSQRT: %s" % (mu,SigmaSQRT) |
| 343 | + print "diff %s" % np.linalg.norm(initval-res.x,np.inf) |
| 344 | + |
| 345 | + return (mu,SigmaSQRT,mpp.lambdag,Logyis) |
| 346 | + |
| 347 | + |
| 348 | +def genStateData(fstate, sim): |
| 349 | + logging.info("generating state data for optimization result") |
| 350 | + |
| 351 | + mpp.DIM = DIM = sim['DIM'] |
| 352 | + mpp.N = N = sim['N'] |
| 353 | + mpp.SIGMA = SIGMA = sim['SIGMA'] |
| 354 | + mpp.init() |
| 355 | + |
| 356 | + (t_span, y_span) = mpp.integrate( fstate ) |
| 357 | + |
| 358 | + # save result |
| 359 | + np.save('output/est_final_fstate',fstate) |
| 360 | + np.save('output/est_setup',[N,DIM,SIGMA]) |
0 commit comments