Commit 11a1fba
feat: improve how device switch is handled between the metric device and the input tensors device (#3043)
* refactor: remove outdated code and issue a warning if two tensors are on separate devices.
* feat: prioritize computation on GPU devices over CPUs
If either one of the metric device or the update input device
is a GPU, this commit will put the other one on GPU.
* fix: use a temp var that will be moved with y_pred
The comparison with self._device was not possible because it
can be created with `torch.device("cuda")` which is not equal
to `torch.device("cuda:0")` which is the device of a tensor
created with `torch.device("cuda")`. This change will have
a bigger performance hit when self._kernel is not on the same
device as y_pred as it will need to be moved onto y_pred's
device every time update() is called.
* test: add metric and y_pred with different devices test
* feat: move self._kernel directly and issue a warning only when not all y_pred tensors are on the same device
* feat: adapt test to new behaviour
* feat: keep the accumulation on the same device as self._kernel
* feat: move accumulation along side self._kernel
* feat: allow different channel number
* style: format using the run_code_style script
* style: add line brak to conform to E501
* fix: use torch.empty to avoid type incompatibility between None and Tensor with mypy
* feat: only operate on self._kernel, keep the accumulation on user's selected device
* test: add variable channel test and factorize the code
* refactor: remove redundant line between init and reset
* refactor: elif comparison and replace RuntimeWarning by UserWarning
Co-authored-by: vfdev <[email protected]>
* refactor: set _kernel in __init__ and manually format to pass E501
* test: adapt test to new UserWarning
* test: remove skips
* refactor: use None instead of torch.empty
* style: reorder imports
* refactor: rename channel to nb_channel
* Fixed failing test_distrib_accumulator_device
---------
Co-authored-by: vfdev <[email protected]>1 parent 86c2a1d commit 11a1fba
2 files changed
+119
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
102 | 103 | | |
103 | 104 | | |
104 | 105 | | |
105 | | - | |
| 106 | + | |
| 107 | + | |
106 | 108 | | |
107 | 109 | | |
108 | 110 | | |
| |||
155 | 157 | | |
156 | 158 | | |
157 | 159 | | |
158 | | - | |
159 | | - | |
160 | | - | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
161 | 176 | | |
162 | 177 | | |
163 | 178 | | |
| |||
166 | 181 | | |
167 | 182 | | |
168 | 183 | | |
169 | | - | |
| 184 | + | |
170 | 185 | | |
171 | 186 | | |
172 | 187 | | |
| |||
184 | 199 | | |
185 | 200 | | |
186 | 201 | | |
187 | | - | |
| 202 | + | |
188 | 203 | | |
189 | 204 | | |
190 | 205 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
1 | 3 | | |
2 | 4 | | |
3 | 5 | | |
| |||
70 | 72 | | |
71 | 73 | | |
72 | 74 | | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
| 75 | + | |
| 76 | + | |
77 | 77 | | |
78 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
79 | 102 | | |
80 | | - | |
81 | | - | |
| 103 | + | |
| 104 | + | |
82 | 105 | | |
83 | 106 | | |
84 | 107 | | |
85 | 108 | | |
86 | 109 | | |
87 | 110 | | |
88 | | - | |
89 | | - | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
90 | 116 | | |
91 | | - | |
| 117 | + | |
92 | 118 | | |
93 | 119 | | |
94 | 120 | | |
| |||
102 | 128 | | |
103 | 129 | | |
104 | 130 | | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
105 | 168 | | |
106 | 169 | | |
107 | 170 | | |
| |||
128 | 191 | | |
129 | 192 | | |
130 | 193 | | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
131 | 209 | | |
132 | 210 | | |
133 | 211 | | |
| |||
136 | 214 | | |
137 | 215 | | |
138 | 216 | | |
139 | | - | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
140 | 223 | | |
141 | 224 | | |
142 | 225 | | |
| |||
213 | 296 | | |
214 | 297 | | |
215 | 298 | | |
216 | | - | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
217 | 303 | | |
218 | 304 | | |
219 | 305 | | |
| |||
0 commit comments