forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_losses_gpu.py
112 lines (98 loc) · 3.62 KB
/
test_losses_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
import torch.multiprocessing as mp
from vissl.losses.barlow_twins_loss import BarlowTwinsCriterion
from vissl.losses.simclr_info_nce_loss import SimclrInfoNCECriterion
from vissl.utils.test_utils import gpu_test, init_distributed_on_file, with_temp_files
class TestSimClrCriterionOnGpu(unittest.TestCase):
"""
Specific tests on SimCLR going further than just doing a forward pass
"""
@staticmethod
def worker_fn(gpu_id: int, world_size: int, batch_size: int, sync_file: str):
init_distributed_on_file(
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
)
embeddings = torch.full(
size=(batch_size, 3), fill_value=float(gpu_id), requires_grad=True
).cuda(gpu_id)
gathered = SimclrInfoNCECriterion.gather_embeddings(embeddings)
if world_size == 1:
assert gathered.equal(
torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], device=f"cuda:{gpu_id}"
)
)
if world_size == 2:
assert gathered.equal(
torch.tensor(
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
],
device=f"cuda:{gpu_id}",
)
)
assert gathered.requires_grad
@gpu_test(gpu_count=1)
def test_gather_embeddings_word_size_1(self):
with with_temp_files(count=1) as sync_file:
WORLD_SIZE = 1
BATCH_SIZE = 2
mp.spawn(
self.worker_fn,
args=(WORLD_SIZE, BATCH_SIZE, sync_file),
nprocs=WORLD_SIZE,
)
@gpu_test(gpu_count=2)
def test_gather_embeddings_word_size_2(self):
with with_temp_files(count=1) as sync_file:
WORLD_SIZE = 2
BATCH_SIZE = 2
mp.spawn(
self.worker_fn,
args=(WORLD_SIZE, BATCH_SIZE, sync_file),
nprocs=WORLD_SIZE,
)
class TestBarlowTwinsCriterionOnGpu(unittest.TestCase):
"""
Specific tests on Barlow Twins going further than just doing a forward pass
"""
@staticmethod
def worker_fn(gpu_id: int, world_size: int, batch_size: int, sync_file: str):
init_distributed_on_file(
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
)
EMBEDDING_DIM = 128
criterion = BarlowTwinsCriterion(
lambda_=0.0051, scale_loss=0.024, embedding_dim=EMBEDDING_DIM
)
embeddings = torch.randn(
(batch_size, EMBEDDING_DIM), dtype=torch.float32, requires_grad=True
).cuda()
criterion(embeddings).backward()
@gpu_test(gpu_count=1)
def test_backward_world_size_1(self):
with with_temp_files(count=1) as sync_file:
WORLD_SIZE = 1
BATCH_SIZE = 2
mp.spawn(
self.worker_fn,
args=(WORLD_SIZE, BATCH_SIZE, sync_file),
nprocs=WORLD_SIZE,
)
@gpu_test(gpu_count=2)
def test_backward_world_size_2(self):
with with_temp_files(count=1) as sync_file:
WORLD_SIZE = 2
BATCH_SIZE = 2
mp.spawn(
self.worker_fn,
args=(WORLD_SIZE, BATCH_SIZE, sync_file),
nprocs=WORLD_SIZE,
)