Skip to content

Commit 8762f54

Browse files
committed
many important updates and fixes
1 parent 0089ac5 commit 8762f54

35 files changed

+2652
-1105
lines changed

Create Animations.ipynb

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import matplotlib.pyplot as plt\n",
12+
"import h5py\n",
13+
"import pickle\n",
14+
"from matplotlib.colors import hsv_to_rgb\n",
15+
"import numpy as np\n",
16+
"from dae import ex\n",
17+
"import os.path"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"metadata": {
24+
"collapsed": false,
25+
"scrolled": true
26+
},
27+
"outputs": [],
28+
"source": [
29+
"# run a longer (30 iterations) evaluation for visualization\n",
30+
"datasets = {\n",
31+
" 'bars': 12, \n",
32+
" 'corners': 5,\n",
33+
" 'shapes': 3,\n",
34+
" 'multi_mnist': 3,\n",
35+
" 'mnist_shape': 2,\n",
36+
" 'simple_superpos':2\n",
37+
"}\n",
38+
"nr_iters = 30\n",
39+
"nrows = 10\n",
40+
"ncols = 12\n",
41+
"\n",
42+
"\n",
43+
"\n",
44+
"for ds, k in datasets.items():\n",
45+
" results_filename = 'Results/{}_{}_{}.pickle'.format(ds, nr_iters, k)\n",
46+
" animation_dir = 'animations/{}'.format(ds)\n",
47+
" if not os.path.exists(animation_dir):\n",
48+
" os.makedirs(animation_dir)\n",
49+
" \n",
50+
" ex.run_command('evaluate', config_updates={\n",
51+
" 'dataset.name': ds,\n",
52+
" 'net_filename': 'Networks/best_{}_dae.h5'.format(ds),\n",
53+
" 'em.k': k,\n",
54+
" 'em.nr_iters': 30,\n",
55+
" 'em.dump_results': results_filename,\n",
56+
" 'seed': 42}) \n",
57+
" \n",
58+
" with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:\n",
59+
" true_groups = f['test']['groups'][:]\n",
60+
" with open(results_filename, 'rb') as f:\n",
61+
" scores, likelihoods, results = pickle.load(f)\n",
62+
" \n",
63+
" if results.shape[-1] != 3:\n",
64+
" nr_colors = results.shape[-1]\n",
65+
" hsv_colors = np.ones((nr_colors, 3))\n",
66+
" hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0\n",
67+
" color_conv = hsv_to_rgb(hsv_colors)\n",
68+
" results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))\n",
69+
" \n",
70+
" for it in range(nr_iters+1):\n",
71+
" fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))\n",
72+
" for r in range(nrows):\n",
73+
" for c in range(ncols):\n",
74+
" axes[r, c].imshow(results[ncols*r + c, it, 0, :, :, 0, :], interpolation='nearest')\n",
75+
" axes[r, c].set_xticks([])\n",
76+
" axes[r, c].set_yticks([])\n",
77+
" plt.subplots_adjust(wspace=0, hspace=0)\n",
78+
" fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 6,
84+
"metadata": {
85+
"collapsed": false
86+
},
87+
"outputs": [],
88+
"source": [
89+
"from subprocess import call\n",
90+
"\n",
91+
"for ds in datasets:\n",
92+
" call(['convert', '-delay', '10', '-loop', '0', 'animations/{}/*.png'.format(ds), 'animations/{}.gif'.format(ds)])\n"
93+
]
94+
}
95+
],
96+
"metadata": {
97+
"kernelspec": {
98+
"display_name": "Python 3",
99+
"language": "python",
100+
"name": "python3"
101+
},
102+
"language_info": {
103+
"codemirror_mode": {
104+
"name": "ipython",
105+
"version": 3
106+
},
107+
"file_extension": ".py",
108+
"mimetype": "text/x-python",
109+
"name": "python",
110+
"nbconvert_exporter": "python",
111+
"pygments_lexer": "ipython3",
112+
"version": "3.4.3"
113+
}
114+
},
115+
"nbformat": 4,
116+
"nbformat_minor": 0
117+
}

0 commit comments

Comments
 (0)