ElenaRyumina
commited on
Commit
·
f4944b0
1
Parent(s):
005c0ed
Summary
Browse files- .gitignore +171 -0
- FER_dinamic_LSTM_Aff-Wild2.pth → FER_dinamic_LSTM_Aff-Wild2.pt +2 -2
- FER_dinamic_LSTM_CREMA-D.pth → FER_dinamic_LSTM_CREMA-D.pt +2 -2
- FER_dinamic_LSTM_IEMOCAP.pth → FER_dinamic_LSTM_IEMOCAP.pt +2 -2
- FER_dinamic_LSTM_RAMAS.pth → FER_dinamic_LSTM_RAMAS.pt +2 -2
- FER_dinamic_LSTM_RAVDESS.pt +3 -0
- FER_dinamic_LSTM_RAVDESS.pth +0 -3
- FER_dinamic_LSTM_SAVEE.pt +3 -0
- FER_dinamic_LSTM_SAVEE.pth +0 -3
- FER_static_ResNet50_AffectNet.pt +3 -0
- FER_static_ResNet50_AffectNet.pth +0 -3
- run_webcam.ipynb +165 -4
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled source #
|
2 |
+
###################
|
3 |
+
*.com
|
4 |
+
*.class
|
5 |
+
*.dll
|
6 |
+
*.exe
|
7 |
+
*.o
|
8 |
+
*.so
|
9 |
+
*.pyc
|
10 |
+
|
11 |
+
# Packages #
|
12 |
+
############
|
13 |
+
# it's better to unpack these files and commit the raw source
|
14 |
+
# git has its own built in compression methods
|
15 |
+
*.7z
|
16 |
+
*.dmg
|
17 |
+
*.gz
|
18 |
+
*.iso
|
19 |
+
*.rar
|
20 |
+
#*.tar
|
21 |
+
*.zip
|
22 |
+
|
23 |
+
# Logs and databases #
|
24 |
+
######################
|
25 |
+
*.log
|
26 |
+
*.sqlite
|
27 |
+
|
28 |
+
# OS generated files #
|
29 |
+
######################
|
30 |
+
.DS_Store
|
31 |
+
ehthumbs.db
|
32 |
+
Icon
|
33 |
+
Thumbs.db
|
34 |
+
.tmtags
|
35 |
+
.idea
|
36 |
+
.vscode
|
37 |
+
tags
|
38 |
+
vendor.tags
|
39 |
+
tmtagsHistory
|
40 |
+
*.sublime-project
|
41 |
+
*.sublime-workspace
|
42 |
+
.bundle
|
43 |
+
|
44 |
+
# Byte-compiled / optimized / DLL files
|
45 |
+
__pycache__/
|
46 |
+
*.py[cod]
|
47 |
+
*$py.class
|
48 |
+
|
49 |
+
# C extensions
|
50 |
+
*.so
|
51 |
+
|
52 |
+
# Distribution / packaging
|
53 |
+
.Python
|
54 |
+
build/
|
55 |
+
develop-eggs/
|
56 |
+
dist/
|
57 |
+
downloads/
|
58 |
+
eggs/
|
59 |
+
.eggs/
|
60 |
+
lib/
|
61 |
+
lib64/
|
62 |
+
parts/
|
63 |
+
sdist/
|
64 |
+
var/
|
65 |
+
wheels/
|
66 |
+
pip-wheel-metadata/
|
67 |
+
share/python-wheels/
|
68 |
+
*.egg-info/
|
69 |
+
.installed.cfg
|
70 |
+
*.egg
|
71 |
+
MANIFEST
|
72 |
+
node_modules/
|
73 |
+
|
74 |
+
# PyInstaller
|
75 |
+
# Usually these files are written by a python script from a template
|
76 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
77 |
+
*.manifest
|
78 |
+
*.spec
|
79 |
+
|
80 |
+
# Installer logs
|
81 |
+
pip-log.txt
|
82 |
+
pip-delete-this-directory.txt
|
83 |
+
|
84 |
+
# Unit test / coverage reports
|
85 |
+
htmlcov/
|
86 |
+
.tox/
|
87 |
+
.nox/
|
88 |
+
.coverage
|
89 |
+
.coverage.*
|
90 |
+
.cache
|
91 |
+
nosetests.xml
|
92 |
+
coverage.xml
|
93 |
+
*.cover
|
94 |
+
.hypothesis/
|
95 |
+
.pytest_cache/
|
96 |
+
|
97 |
+
# Translations
|
98 |
+
*.mo
|
99 |
+
*.pot
|
100 |
+
|
101 |
+
# Django stuff:
|
102 |
+
*.log
|
103 |
+
local_settings.py
|
104 |
+
db.sqlite3
|
105 |
+
db.sqlite3-journal
|
106 |
+
|
107 |
+
# Flask stuff:
|
108 |
+
instance/
|
109 |
+
.webassets-cache
|
110 |
+
|
111 |
+
# Scrapy stuff:
|
112 |
+
.scrapy
|
113 |
+
|
114 |
+
# Sphinx documentation
|
115 |
+
docs/_build/
|
116 |
+
|
117 |
+
# PyBuilder
|
118 |
+
target/
|
119 |
+
|
120 |
+
# Jupyter Notebook
|
121 |
+
.ipynb_checkpoints
|
122 |
+
|
123 |
+
# IPython
|
124 |
+
profile_default/
|
125 |
+
ipython_config.py
|
126 |
+
|
127 |
+
# pyenv
|
128 |
+
.python-version
|
129 |
+
|
130 |
+
# pipenv
|
131 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
132 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
133 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
134 |
+
# install all needed dependencies.
|
135 |
+
#Pipfile.lock
|
136 |
+
|
137 |
+
# celery beat schedule file
|
138 |
+
celerybeat-schedule
|
139 |
+
|
140 |
+
# SageMath parsed files
|
141 |
+
*.sage.py
|
142 |
+
|
143 |
+
# Environments
|
144 |
+
.env
|
145 |
+
.venv
|
146 |
+
env/
|
147 |
+
venv/
|
148 |
+
ENV/
|
149 |
+
env.bak/
|
150 |
+
venv.bak/
|
151 |
+
|
152 |
+
# Spyder project settings
|
153 |
+
.spyderproject
|
154 |
+
.spyproject
|
155 |
+
|
156 |
+
# Rope project settings
|
157 |
+
.ropeproject
|
158 |
+
|
159 |
+
# mkdocs documentation
|
160 |
+
/site
|
161 |
+
|
162 |
+
# mypy
|
163 |
+
.mypy_cache/
|
164 |
+
.dmypy.json
|
165 |
+
dmypy.json
|
166 |
+
|
167 |
+
# Pyre type checker
|
168 |
+
.pyre/
|
169 |
+
|
170 |
+
# Custom
|
171 |
+
*.mp4
|
FER_dinamic_LSTM_Aff-Wild2.pth → FER_dinamic_LSTM_Aff-Wild2.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21b0cabebe7bd0257ca8aaa991efc7546c6f46fd4d17f759d33abbb859abdacc
|
3 |
+
size 11569812
|
FER_dinamic_LSTM_CREMA-D.pth → FER_dinamic_LSTM_CREMA-D.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5486a9b4816fb86c4fcbbbc7e8b6506c9e66fc6db25404ed492da119330b86ee
|
3 |
+
size 11569208
|
FER_dinamic_LSTM_IEMOCAP.pth → FER_dinamic_LSTM_IEMOCAP.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cd1561a72f9de26c315bb857f03e8946635db047e0dbea52bb0276610f19751
|
3 |
+
size 11569208
|
FER_dinamic_LSTM_RAMAS.pth → FER_dinamic_LSTM_RAMAS.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba1a49032311f91580eff67732bbb0a7077f1382c8a65e5d0fca01b1ad09ba37
|
3 |
+
size 11569180
|
FER_dinamic_LSTM_RAVDESS.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b8eb7e702d4be62bba48dd54addd53698c95fd94ff8293fb53fd8d59ab22248
|
3 |
+
size 11569208
|
FER_dinamic_LSTM_RAVDESS.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:43bc117590334f5f64465d3dd80c894baafe80b83911959cf403cba41d2bbf54
|
3 |
-
size 11590417
|
|
|
|
|
|
|
|
FER_dinamic_LSTM_SAVEE.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa651fe5a937837610dea81fdf4e0079e1ebda07f28657007bcbc985faf25fc5
|
3 |
+
size 11569180
|
FER_dinamic_LSTM_SAVEE.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c7b89232ffa9fefaeaca64f3d8dc6271065f8b6cd56fe5a32e76bc93a8138669
|
3 |
-
size 11590359
|
|
|
|
|
|
|
|
FER_static_ResNet50_AffectNet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8274190b5be4355bd2f07b59f593fcdb294f9d7c563bfa9ac9e5ea06c10692d2
|
3 |
+
size 98562934
|
FER_static_ResNet50_AffectNet.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:790f76fe4d443953b3b71a6899bdd981742e8e1b954da48483d3eea6c3c717a1
|
3 |
-
size 98631726
|
|
|
|
|
|
|
|
run_webcam.ipynb
CHANGED
@@ -17,10 +17,167 @@
|
|
17 |
"warnings.simplefilter(\"ignore\", UserWarning)\n",
|
18 |
"\n",
|
19 |
"import torch\n",
|
|
|
|
|
20 |
"from PIL import Image\n",
|
21 |
"from torchvision import transforms"
|
22 |
]
|
23 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
{
|
25 |
"cell_type": "markdown",
|
26 |
"id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b",
|
@@ -177,7 +334,7 @@
|
|
177 |
"source": [
|
178 |
"mp_face_mesh = mp.solutions.face_mesh\n",
|
179 |
"\n",
|
180 |
-
"name_backbone_model = 'FER_static_ResNet50_AffectNet.
|
181 |
"# name_LSTM_model = 'IEMOCAP'\n",
|
182 |
"# name_LSTM_model = 'CREMA-D'\n",
|
183 |
"# name_LSTM_model = 'RAMAS'\n",
|
@@ -186,12 +343,16 @@
|
|
186 |
"name_LSTM_model = 'Aff-Wild2'\n",
|
187 |
"\n",
|
188 |
"# torch\n",
|
189 |
-
"
|
|
|
|
|
190 |
"pth_backbone_model.eval()\n",
|
191 |
"\n",
|
192 |
-
"pth_LSTM_model =
|
|
|
193 |
"pth_LSTM_model.eval()\n",
|
194 |
"\n",
|
|
|
195 |
"DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n",
|
196 |
"\n",
|
197 |
"cap = cv2.VideoCapture(0)\n",
|
@@ -220,7 +381,7 @@
|
|
220 |
" frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n",
|
221 |
" results = face_mesh.process(frame_copy)\n",
|
222 |
" frame_copy.flags.writeable = True\n",
|
223 |
-
"
|
224 |
" if results.multi_face_landmarks:\n",
|
225 |
" for fl in results.multi_face_landmarks:\n",
|
226 |
" startX, startY, endX, endY = get_box(fl, w, h)\n",
|
|
|
17 |
"warnings.simplefilter(\"ignore\", UserWarning)\n",
|
18 |
"\n",
|
19 |
"import torch\n",
|
20 |
+
"import torch.nn as nn\n",
|
21 |
+
"import torch.nn.functional as F\n",
|
22 |
"from PIL import Image\n",
|
23 |
"from torchvision import transforms"
|
24 |
]
|
25 |
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"id": "a0907155",
|
29 |
+
"metadata": {},
|
30 |
+
"source": [
|
31 |
+
"#### Model architectures"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"id": "f67038e3",
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"class Bottleneck(nn.Module):\n",
|
42 |
+
" expansion = 4\n",
|
43 |
+
" def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):\n",
|
44 |
+
" super(Bottleneck, self).__init__()\n",
|
45 |
+
" \n",
|
46 |
+
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)\n",
|
47 |
+
" self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)\n",
|
48 |
+
" \n",
|
49 |
+
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same', bias=False)\n",
|
50 |
+
" self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)\n",
|
51 |
+
" \n",
|
52 |
+
" self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False)\n",
|
53 |
+
" self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion, eps=0.001, momentum=0.99)\n",
|
54 |
+
" \n",
|
55 |
+
" self.i_downsample = i_downsample\n",
|
56 |
+
" self.stride = stride\n",
|
57 |
+
" self.relu = nn.ReLU()\n",
|
58 |
+
" \n",
|
59 |
+
" def forward(self, x):\n",
|
60 |
+
" identity = x.clone()\n",
|
61 |
+
" x = self.relu(self.batch_norm1(self.conv1(x)))\n",
|
62 |
+
" \n",
|
63 |
+
" x = self.relu(self.batch_norm2(self.conv2(x)))\n",
|
64 |
+
" \n",
|
65 |
+
" x = self.conv3(x)\n",
|
66 |
+
" x = self.batch_norm3(x)\n",
|
67 |
+
" \n",
|
68 |
+
" #downsample if needed\n",
|
69 |
+
" if self.i_downsample is not None:\n",
|
70 |
+
" identity = self.i_downsample(identity)\n",
|
71 |
+
" #add identity\n",
|
72 |
+
" x+=identity\n",
|
73 |
+
" x=self.relu(x)\n",
|
74 |
+
" \n",
|
75 |
+
" return x\n",
|
76 |
+
"\n",
|
77 |
+
"class Conv2dSame(torch.nn.Conv2d):\n",
|
78 |
+
"\n",
|
79 |
+
" def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:\n",
|
80 |
+
" return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)\n",
|
81 |
+
"\n",
|
82 |
+
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
83 |
+
" ih, iw = x.size()[-2:]\n",
|
84 |
+
"\n",
|
85 |
+
" pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])\n",
|
86 |
+
" pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])\n",
|
87 |
+
"\n",
|
88 |
+
" if pad_h > 0 or pad_w > 0:\n",
|
89 |
+
" x = F.pad(\n",
|
90 |
+
" x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]\n",
|
91 |
+
" )\n",
|
92 |
+
" return F.conv2d(\n",
|
93 |
+
" x,\n",
|
94 |
+
" self.weight,\n",
|
95 |
+
" self.bias,\n",
|
96 |
+
" self.stride,\n",
|
97 |
+
" self.padding,\n",
|
98 |
+
" self.dilation,\n",
|
99 |
+
" self.groups,\n",
|
100 |
+
" )\n",
|
101 |
+
"\n",
|
102 |
+
"class ResNet(nn.Module):\n",
|
103 |
+
" def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):\n",
|
104 |
+
" super(ResNet, self).__init__()\n",
|
105 |
+
" self.in_channels = 64\n",
|
106 |
+
"\n",
|
107 |
+
" self.conv_layer_s2_same = Conv2dSame(num_channels, 64, 7, stride=2, groups=1, bias=False)\n",
|
108 |
+
" self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)\n",
|
109 |
+
" self.relu = nn.ReLU()\n",
|
110 |
+
" self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2)\n",
|
111 |
+
" \n",
|
112 |
+
" self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)\n",
|
113 |
+
" self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)\n",
|
114 |
+
" self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)\n",
|
115 |
+
" self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)\n",
|
116 |
+
" \n",
|
117 |
+
" self.avgpool = nn.AdaptiveAvgPool2d((1,1))\n",
|
118 |
+
" self.fc1 = nn.Linear(512*ResBlock.expansion, 512)\n",
|
119 |
+
" self.relu1 = nn.ReLU()\n",
|
120 |
+
" self.fc2 = nn.Linear(512, num_classes)\n",
|
121 |
+
"\n",
|
122 |
+
" def extract_features(self, x):\n",
|
123 |
+
" x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))\n",
|
124 |
+
" x = self.max_pool(x)\n",
|
125 |
+
" # print(x.shape)\n",
|
126 |
+
" x = self.layer1(x)\n",
|
127 |
+
" x = self.layer2(x)\n",
|
128 |
+
" x = self.layer3(x)\n",
|
129 |
+
" x = self.layer4(x)\n",
|
130 |
+
" \n",
|
131 |
+
" x = self.avgpool(x)\n",
|
132 |
+
" x = x.reshape(x.shape[0], -1)\n",
|
133 |
+
" x = self.fc1(x)\n",
|
134 |
+
" return x\n",
|
135 |
+
" \n",
|
136 |
+
" def forward(self, x):\n",
|
137 |
+
" x = self.extract_features(x)\n",
|
138 |
+
" x = self.relu1(x)\n",
|
139 |
+
" x = self.fc2(x)\n",
|
140 |
+
" return x\n",
|
141 |
+
" \n",
|
142 |
+
" def _make_layer(self, ResBlock, blocks, planes, stride=1):\n",
|
143 |
+
" ii_downsample = None\n",
|
144 |
+
" layers = []\n",
|
145 |
+
" \n",
|
146 |
+
" if stride != 1 or self.in_channels != planes*ResBlock.expansion:\n",
|
147 |
+
" ii_downsample = nn.Sequential(\n",
|
148 |
+
" nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0),\n",
|
149 |
+
" nn.BatchNorm2d(planes*ResBlock.expansion, eps=0.001, momentum=0.99)\n",
|
150 |
+
" )\n",
|
151 |
+
" \n",
|
152 |
+
" layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))\n",
|
153 |
+
" self.in_channels = planes*ResBlock.expansion\n",
|
154 |
+
" \n",
|
155 |
+
" for i in range(blocks-1):\n",
|
156 |
+
" layers.append(ResBlock(self.in_channels, planes))\n",
|
157 |
+
" \n",
|
158 |
+
" return nn.Sequential(*layers)\n",
|
159 |
+
" \n",
|
160 |
+
"def ResNet50(num_classes, channels=3):\n",
|
161 |
+
" return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)\n",
|
162 |
+
"\n",
|
163 |
+
"\n",
|
164 |
+
"class LSTMPyTorch(nn.Module):\n",
|
165 |
+
" def __init__(self):\n",
|
166 |
+
" super(LSTMPyTorch, self).__init__()\n",
|
167 |
+
" \n",
|
168 |
+
" self.lstm1 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=False)\n",
|
169 |
+
" self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=False)\n",
|
170 |
+
" self.fc = nn.Linear(256, 7)\n",
|
171 |
+
" self.softmax = nn.Softmax(dim=1)\n",
|
172 |
+
"\n",
|
173 |
+
" def forward(self, x):\n",
|
174 |
+
" x, _ = self.lstm1(x)\n",
|
175 |
+
" x, _ = self.lstm2(x) \n",
|
176 |
+
" x = self.fc(x[:, -1, :])\n",
|
177 |
+
" x = self.softmax(x)\n",
|
178 |
+
" return x"
|
179 |
+
]
|
180 |
+
},
|
181 |
{
|
182 |
"cell_type": "markdown",
|
183 |
"id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b",
|
|
|
334 |
"source": [
|
335 |
"mp_face_mesh = mp.solutions.face_mesh\n",
|
336 |
"\n",
|
337 |
+
"name_backbone_model = 'FER_static_ResNet50_AffectNet.pt'\n",
|
338 |
"# name_LSTM_model = 'IEMOCAP'\n",
|
339 |
"# name_LSTM_model = 'CREMA-D'\n",
|
340 |
"# name_LSTM_model = 'RAMAS'\n",
|
|
|
343 |
"name_LSTM_model = 'Aff-Wild2'\n",
|
344 |
"\n",
|
345 |
"# torch\n",
|
346 |
+
"\n",
|
347 |
+
"pth_backbone_model = ResNet50(7, channels=3)\n",
|
348 |
+
"pth_backbone_model.load_state_dict(torch.load(name_backbone_model))\n",
|
349 |
"pth_backbone_model.eval()\n",
|
350 |
"\n",
|
351 |
+
"pth_LSTM_model = LSTMPyTorch()\n",
|
352 |
+
"pth_LSTM_model.load_state_dict(torch.load('FER_dinamic_LSTM_{0}.pt'.format(name_LSTM_model)))\n",
|
353 |
"pth_LSTM_model.eval()\n",
|
354 |
"\n",
|
355 |
+
"\n",
|
356 |
"DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n",
|
357 |
"\n",
|
358 |
"cap = cv2.VideoCapture(0)\n",
|
|
|
381 |
" frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n",
|
382 |
" results = face_mesh.process(frame_copy)\n",
|
383 |
" frame_copy.flags.writeable = True\n",
|
384 |
+
"\n",
|
385 |
" if results.multi_face_landmarks:\n",
|
386 |
" for fl in results.multi_face_landmarks:\n",
|
387 |
" startX, startY, endX, endY = get_box(fl, w, h)\n",
|