Skip to content

Commit 7541f2b

Browse files
committed
finalized
1 parent 66a0c53 commit 7541f2b

14 files changed

+346
-1452
lines changed

config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
epochs: 4
1+
epochs: 10
22
batch_size: 64
33
learning_rate: 0.0001
44
logs: logs

dataset.ipynb

+1-147
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"outputs": [],
88
"source": [
99
"from torch.utils.data import Dataset\n",
10-
"from Datasets import load\n",
10+
"from src.Datasets import load\n",
1111
"import matplotlib.pyplot as plt\n",
1212
"import numpy as np\n",
1313
"from torch.utils.data import DataLoader\n",
@@ -16,20 +16,6 @@
1616
"import torch"
1717
]
1818
},
19-
{
20-
"cell_type": "code",
21-
"execution_count": 30,
22-
"metadata": {},
23-
"outputs": [],
24-
"source": [
25-
"dataset = datasets.MNIST(\n",
26-
" root=\".\",\n",
27-
" train=True,\n",
28-
" download=True,\n",
29-
" transform=transforms.ToTensor(),\n",
30-
" )"
31-
]
32-
},
3319
{
3420
"cell_type": "code",
3521
"execution_count": 2,
@@ -241,138 +227,6 @@
241227
"mean, std"
242228
]
243229
},
244-
{
245-
"cell_type": "code",
246-
"execution_count": 14,
247-
"metadata": {},
248-
"outputs": [
249-
{
250-
"data": {
251-
"text/plain": [
252-
"(tensor([[[0.2000, 0.1490, 0.1804, 0.2510, 0.3412, 0.4157, 0.4431, 0.4627,\n",
253-
" 0.4510, 0.4000, 0.4196, 0.4392, 0.4235, 0.3961, 0.4196, 0.4039,\n",
254-
" 0.4039, 0.4431, 0.4275, 0.4235, 0.4314, 0.4706, 0.5137, 0.5216,\n",
255-
" 0.5098, 0.5098, 0.4941, 0.4824],\n",
256-
" [0.0667, 0.0235, 0.0980, 0.2039, 0.3176, 0.3765, 0.3647, 0.3569,\n",
257-
" 0.3294, 0.3059, 0.3020, 0.3216, 0.3255, 0.3333, 0.3451, 0.3176,\n",
258-
" 0.3294, 0.3765, 0.3765, 0.3843, 0.3647, 0.3569, 0.3804, 0.3961,\n",
259-
" 0.4039, 0.3804, 0.3569, 0.3529],\n",
260-
" [0.0902, 0.0863, 0.1961, 0.2784, 0.3412, 0.3608, 0.3294, 0.3176,\n",
261-
" 0.3412, 0.3373, 0.3176, 0.3686, 0.3843, 0.3765, 0.3608, 0.3490,\n",
262-
" 0.3529, 0.3765, 0.4000, 0.4157, 0.3922, 0.3725, 0.3804, 0.4000,\n",
263-
" 0.4078, 0.3725, 0.3569, 0.3098],\n",
264-
" [0.1216, 0.1608, 0.2824, 0.3216, 0.3294, 0.3294, 0.3059, 0.3137,\n",
265-
" 0.3686, 0.3529, 0.3137, 0.3647, 0.3725, 0.3333, 0.3098, 0.3333,\n",
266-
" 0.3569, 0.3686, 0.4157, 0.4118, 0.4039, 0.3804, 0.3725, 0.3647,\n",
267-
" 0.3529, 0.3176, 0.2902, 0.2392],\n",
268-
" [0.1843, 0.2275, 0.3216, 0.3608, 0.3608, 0.3529, 0.3333, 0.3137,\n",
269-
" 0.3608, 0.3333, 0.2980, 0.3216, 0.3255, 0.2863, 0.2706, 0.2863,\n",
270-
" 0.2863, 0.3098, 0.3412, 0.3176, 0.3647, 0.4039, 0.3804, 0.3333,\n",
271-
" 0.2863, 0.2235, 0.1412, 0.1176],\n",
272-
" [0.2706, 0.3098, 0.3529, 0.3725, 0.3725, 0.3451, 0.3294, 0.3490,\n",
273-
" 0.3765, 0.3373, 0.3137, 0.2667, 0.2275, 0.2157, 0.2510, 0.2471,\n",
274-
" 0.1882, 0.1961, 0.2275, 0.3294, 0.3451, 0.4039, 0.4196, 0.3412,\n",
275-
" 0.2784, 0.2510, 0.2196, 0.1686],\n",
276-
" [0.3373, 0.3412, 0.3882, 0.3922, 0.3647, 0.3725, 0.3765, 0.3725,\n",
277-
" 0.3843, 0.3529, 0.2510, 0.1608, 0.1255, 0.1765, 0.4392, 0.4667,\n",
278-
" 0.2667, 0.1765, 0.2078, 0.3137, 0.3020, 0.3765, 0.4314, 0.3725,\n",
279-
" 0.3098, 0.2902, 0.3451, 0.3098],\n",
280-
" [0.4196, 0.3804, 0.3922, 0.3686, 0.3529, 0.3647, 0.3686, 0.3647,\n",
281-
" 0.3961, 0.3020, 0.1647, 0.1373, 0.1412, 0.2078, 0.4706, 0.4784,\n",
282-
" 0.3608, 0.2431, 0.2078, 0.2039, 0.1922, 0.3098, 0.4039, 0.3961,\n",
283-
" 0.3804, 0.3608, 0.3608, 0.3882],\n",
284-
" [0.4980, 0.4784, 0.4627, 0.4078, 0.3843, 0.3843, 0.3922, 0.4157,\n",
285-
" 0.3529, 0.2196, 0.1765, 0.1922, 0.2510, 0.2902, 0.3569, 0.3490,\n",
286-
" 0.3647, 0.3294, 0.3255, 0.2941, 0.2235, 0.2941, 0.3882, 0.3882,\n",
287-
" 0.3961, 0.3961, 0.3765, 0.3882],\n",
288-
" [0.4902, 0.4745, 0.4588, 0.4196, 0.4078, 0.4039, 0.3843, 0.4275,\n",
289-
" 0.3294, 0.2549, 0.2824, 0.3176, 0.4078, 0.4039, 0.4196, 0.3882,\n",
290-
" 0.4235, 0.4549, 0.5608, 0.5176, 0.3294, 0.3059, 0.3922, 0.3882,\n",
291-
" 0.4000, 0.4000, 0.3804, 0.3922],\n",
292-
" [0.4667, 0.4510, 0.4392, 0.4275, 0.4000, 0.3216, 0.2902, 0.3608,\n",
293-
" 0.3098, 0.3882, 0.4667, 0.5059, 0.5804, 0.5333, 0.5686, 0.5647,\n",
294-
" 0.5765, 0.5843, 0.6392, 0.5804, 0.3843, 0.3255, 0.3961, 0.3961,\n",
295-
" 0.4000, 0.3961, 0.3843, 0.3922],\n",
296-
" [0.5020, 0.4627, 0.4392, 0.4275, 0.3490, 0.2392, 0.2353, 0.2588,\n",
297-
" 0.3059, 0.5176, 0.6078, 0.5922, 0.7451, 0.7176, 0.7176, 0.7137,\n",
298-
" 0.6824, 0.6667, 0.6549, 0.6039, 0.5020, 0.4314, 0.3961, 0.3961,\n",
299-
" 0.3922, 0.4039, 0.3922, 0.4000],\n",
300-
" [0.4824, 0.4627, 0.4471, 0.4196, 0.2941, 0.2314, 0.2667, 0.2549,\n",
301-
" 0.4000, 0.6000, 0.5490, 0.5804, 0.7686, 0.8549, 0.8510, 0.8235,\n",
302-
" 0.7961, 0.7529, 0.7686, 0.7647, 0.7490, 0.6588, 0.4431, 0.3882,\n",
303-
" 0.3961, 0.4039, 0.3882, 0.3922],\n",
304-
" [0.4667, 0.4275, 0.4275, 0.4157, 0.2667, 0.1961, 0.2431, 0.3412,\n",
305-
" 0.6196, 0.6275, 0.4510, 0.5765, 0.6863, 0.8863, 0.9647, 0.9412,\n",
306-
" 0.8824, 0.8275, 0.7882, 0.8275, 0.9137, 0.8196, 0.4863, 0.3961,\n",
307-
" 0.3961, 0.3843, 0.3765, 0.3725],\n",
308-
" [0.4627, 0.4078, 0.4235, 0.4078, 0.2667, 0.2392, 0.3255, 0.5608,\n",
309-
" 0.7882, 0.5961, 0.5059, 0.6431, 0.7373, 0.8745, 0.9137, 0.8784,\n",
310-
" 0.8353, 0.7686, 0.6196, 0.6941, 0.8980, 0.8000, 0.5294, 0.4784,\n",
311-
" 0.4745, 0.4549, 0.4039, 0.3569],\n",
312-
" [0.4471, 0.3725, 0.4078, 0.4078, 0.3373, 0.2980, 0.4745, 0.7882,\n",
313-
" 0.8431, 0.5608, 0.6431, 0.7882, 0.7922, 0.7961, 0.7529, 0.6510,\n",
314-
" 0.6000, 0.6314, 0.6549, 0.6706, 0.8275, 0.7725, 0.4980, 0.4275,\n",
315-
" 0.4196, 0.4314, 0.4431, 0.4314],\n",
316-
" [0.4431, 0.3412, 0.3647, 0.3843, 0.3765, 0.3804, 0.6863, 0.9098,\n",
317-
" 0.8118, 0.5490, 0.6745, 0.7961, 0.6196, 0.5490, 0.5020, 0.4824,\n",
318-
" 0.4824, 0.5255, 0.6824, 0.8000, 0.8706, 0.7765, 0.4588, 0.3412,\n",
319-
" 0.3333, 0.3216, 0.3569, 0.4667],\n",
320-
" [0.4275, 0.3608, 0.3804, 0.3804, 0.4000, 0.6078, 0.8667, 0.8627,\n",
321-
" 0.7098, 0.5216, 0.5765, 0.6471, 0.4471, 0.4157, 0.4275, 0.4980,\n",
322-
" 0.4510, 0.4392, 0.5922, 0.7725, 0.7922, 0.7059, 0.4431, 0.3137,\n",
323-
" 0.3255, 0.3255, 0.3176, 0.4353],\n",
324-
" [0.4196, 0.3490, 0.3686, 0.3608, 0.4196, 0.5608, 0.6902, 0.7059,\n",
325-
" 0.6235, 0.5176, 0.5412, 0.6118, 0.4745, 0.4235, 0.4510, 0.4745,\n",
326-
" 0.4235, 0.4235, 0.4667, 0.5451, 0.5373, 0.5373, 0.4353, 0.3569,\n",
327-
" 0.3882, 0.4118, 0.4078, 0.5294],\n",
328-
" [0.4157, 0.3569, 0.3804, 0.3804, 0.4353, 0.4275, 0.4549, 0.5843,\n",
329-
" 0.5882, 0.6078, 0.6275, 0.5922, 0.4863, 0.4235, 0.4353, 0.4471,\n",
330-
" 0.4471, 0.4392, 0.4314, 0.4392, 0.4314, 0.4118, 0.4157, 0.4314,\n",
331-
" 0.4353, 0.4275, 0.4706, 0.6471],\n",
332-
" [0.4275, 0.3804, 0.3961, 0.4392, 0.4510, 0.4392, 0.4431, 0.5176,\n",
333-
" 0.5451, 0.6314, 0.6941, 0.6824, 0.5333, 0.4275, 0.4275, 0.4157,\n",
334-
" 0.4549, 0.4588, 0.4314, 0.3529, 0.2902, 0.2902, 0.4039, 0.4824,\n",
335-
" 0.4314, 0.3961, 0.5098, 0.6471],\n",
336-
" [0.4471, 0.3686, 0.3961, 0.4275, 0.4431, 0.4627, 0.4902, 0.5451,\n",
337-
" 0.5725, 0.5686, 0.7059, 0.7020, 0.6392, 0.5765, 0.4392, 0.3725,\n",
338-
" 0.4275, 0.4510, 0.3725, 0.2627, 0.2549, 0.3294, 0.4471, 0.4510,\n",
339-
" 0.3961, 0.4431, 0.6039, 0.5373],\n",
340-
" [0.5373, 0.4118, 0.3804, 0.3725, 0.3882, 0.4353, 0.4784, 0.4980,\n",
341-
" 0.5137, 0.4588, 0.5020, 0.5294, 0.4863, 0.4863, 0.4314, 0.4039,\n",
342-
" 0.3882, 0.3765, 0.3804, 0.4078, 0.4392, 0.4471, 0.4275, 0.3922,\n",
343-
" 0.4118, 0.5725, 0.6784, 0.4549],\n",
344-
" [0.6353, 0.5098, 0.4235, 0.3686, 0.3608, 0.3961, 0.4353, 0.4314,\n",
345-
" 0.4039, 0.3647, 0.4039, 0.4627, 0.4314, 0.3804, 0.3804, 0.4431,\n",
346-
" 0.4353, 0.4078, 0.4235, 0.4510, 0.4510, 0.4588, 0.4314, 0.4000,\n",
347-
" 0.4980, 0.6627, 0.5765, 0.3373],\n",
348-
" [0.6863, 0.5569, 0.5176, 0.4588, 0.4000, 0.4039, 0.4353, 0.4275,\n",
349-
" 0.3922, 0.3647, 0.4039, 0.4275, 0.4000, 0.3647, 0.3647, 0.4118,\n",
350-
" 0.4471, 0.4353, 0.3882, 0.3922, 0.4275, 0.4471, 0.4471, 0.4078,\n",
351-
" 0.5216, 0.6353, 0.3608, 0.1686],\n",
352-
" [0.6784, 0.6000, 0.5882, 0.5490, 0.5020, 0.4667, 0.4510, 0.4157,\n",
353-
" 0.3882, 0.3961, 0.4157, 0.4235, 0.4000, 0.3804, 0.3843, 0.4235,\n",
354-
" 0.4431, 0.4510, 0.4078, 0.3529, 0.3804, 0.4314, 0.4078, 0.2863,\n",
355-
" 0.3569, 0.5255, 0.2510, 0.1255],\n",
356-
" [0.5922, 0.5451, 0.5804, 0.5961, 0.6000, 0.5686, 0.5137, 0.4471,\n",
357-
" 0.4039, 0.4118, 0.4157, 0.4078, 0.4000, 0.3922, 0.4078, 0.4235,\n",
358-
" 0.4235, 0.4235, 0.3922, 0.3490, 0.3294, 0.3608, 0.2863, 0.1412,\n",
359-
" 0.2392, 0.5098, 0.3333, 0.2157],\n",
360-
" [0.5765, 0.5412, 0.5804, 0.6196, 0.6863, 0.7059, 0.6667, 0.6000,\n",
361-
" 0.5373, 0.5098, 0.5059, 0.4980, 0.4706, 0.4824, 0.4941, 0.4627,\n",
362-
" 0.4392, 0.4431, 0.4235, 0.3725, 0.3647, 0.3843, 0.3255, 0.2667,\n",
363-
" 0.4118, 0.6471, 0.5255, 0.3765]]]),\n",
364-
" 6)"
365-
]
366-
},
367-
"execution_count": 14,
368-
"metadata": {},
369-
"output_type": "execute_result"
370-
}
371-
],
372-
"source": [
373-
"dataset[0]"
374-
]
375-
},
376230
{
377231
"cell_type": "code",
378232
"execution_count": 9,

energy.ipynb

+28-8
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 2,
66
"metadata": {},
77
"outputs": [],
88
"source": [
99
"import torch\n",
1010
"import torch.nn as nn\n",
1111
"import torch.nn.functional as F\n",
12-
"from Datasets import load\n",
12+
"from src.Datasets import load\n",
1313
"from torch.utils.data import DataLoader\n",
14-
"from model import Conv_Net_Dropout, Conv_Net_FC_Dropout\n",
14+
"from src.model import Conv_Net_Dropout, Conv_Net_FC_Dropout\n",
1515
"import matplotlib.pyplot as plt\n",
1616
"from tqdm import tqdm\n",
1717
"import seaborn as sns\n",
@@ -23,7 +23,7 @@
2323
},
2424
{
2525
"cell_type": "code",
26-
"execution_count": 2,
26+
"execution_count": 3,
2727
"metadata": {},
2828
"outputs": [
2929
{
@@ -50,7 +50,7 @@
5050
")"
5151
]
5252
},
53-
"execution_count": 2,
53+
"execution_count": 3,
5454
"metadata": {},
5555
"output_type": "execute_result"
5656
}
@@ -103,7 +103,7 @@
103103
"name": "stderr",
104104
"output_type": "stream",
105105
"text": [
106-
"100%|██████████| 79/79 [00:00<00:00, 121.75it/s]\n"
106+
"100%|██████████| 79/79 [00:00<00:00, 126.62it/s]\n"
107107
]
108108
}
109109
],
@@ -126,7 +126,7 @@
126126
"name": "stderr",
127127
"output_type": "stream",
128128
"text": [
129-
"100%|██████████| 79/79 [00:00<00:00, 92.24it/s]\n"
129+
"100%|██████████| 79/79 [00:00<00:00, 90.06it/s]\n"
130130
]
131131
}
132132
],
@@ -144,6 +144,26 @@
144144
"cell_type": "code",
145145
"execution_count": 8,
146146
"metadata": {},
147+
"outputs": [
148+
{
149+
"data": {
150+
"text/plain": [
151+
"((10000, 10), (10000, 10))"
152+
]
153+
},
154+
"execution_count": 8,
155+
"metadata": {},
156+
"output_type": "execute_result"
157+
}
158+
],
159+
"source": [
160+
"cifar_logits.shape, mnist_logits.shape"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 9,
166+
"metadata": {},
147167
"outputs": [
148168
{
149169
"data": {
@@ -182,7 +202,7 @@
182202
},
183203
{
184204
"cell_type": "code",
185-
"execution_count": 9,
205+
"execution_count": 10,
186206
"metadata": {},
187207
"outputs": [
188208
{

0 commit comments

Comments
 (0)