diff --git a/.gitignore b/.gitignore index ef79da47..8efbf696 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ dist/ *.egg-info/ site/ venv/ +.vscode +.devcontainer .ipynb_checkpoints examples/notebooks/dataset examples/notebooks/CIFAR10_Dataset diff --git a/src/pytorch_metric_learning/losses/margin_loss.py b/src/pytorch_metric_learning/losses/margin_loss.py index af834846..be203e39 100644 --- a/src/pytorch_metric_learning/losses/margin_loss.py +++ b/src/pytorch_metric_learning/losses/margin_loss.py @@ -36,8 +36,11 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): if len(anchor_idx) == 0: return self.zero_losses() + # Gives error on my computer if self.beta is on cpu and labels are on cuda + self.beta.data = c_f.to_device( + self.beta.data, device=embeddings.device, dtype=embeddings.dtype + ) beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx]] - beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype) mat = self.distance(embeddings, ref_emb) diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index 77c792ff..4ddd93e6 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -470,7 +470,10 @@ def get_accuracy( ) knn_distances, knn_indices = self.knn_func( - query, num_k, reference, ref_includes_query + query, + reference, + num_k, + ref_includes_query, # modified to follow the same signature of faiss ) knn_labels = reference_labels[knn_indices] diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index cb95ebf7..1366f817 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -3,10 +3,12 @@ import logging import os import re +from typing import List, Tuple, Union import numpy as np import scipy.stats import torch +from torch import nn LOGGER_NAME = "PML" LOGGER = logging.getLogger(LOGGER_NAME) @@ -394,13 +396,13 @@ def check_shapes(embeddings, labels): def assert_distance_type(obj, distance_type=None, **kwargs): + obj_name = obj.__class__.__name__ if distance_type is not None: if is_list_or_tuple(distance_type): distance_type_str = ", ".join(x.__name__ for x in distance_type) distance_type_str = "one of " + distance_type_str else: distance_type_str = distance_type.__name__ - obj_name = obj.__class__.__name__ assert isinstance( obj.distance, distance_type ), "{} requires the distance metric to be {}".format( @@ -459,13 +461,42 @@ def to_dtype(x, tensor=None, dtype=None): return x -def to_device(x, tensor=None, device=None, dtype=None): +def to_device( + x: Union[torch.Tensor, nn.Parameter, List, Tuple], + tensor=None, + device=None, + dtype: Union[torch.dtype, List, Tuple] = None, +): dv = device if device is not None else tensor.device - if x.device != dv: - x = x.to(dv) - if dtype is not None: - x = to_dtype(x, dtype=dtype) - return x + dt = ( + dtype if dtype is not None else x.dtype + ) # Specify if by default cast to x.dtype or tensor.dtype + if not is_list_or_tuple(x): + x = [x] + + if is_list_or_tuple(dt): + if len(dt) == len(x): + xd = [ + to_dtype(x[i].to(dv), tensor=tensor, dtype=dt[i]) for i in range(len(x)) + ] + else: + raise RuntimeError( + f"The size of dtype was {len(dt)}. It is only available 1 or the same of x" + ) + else: + xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dt) for xt in x] + + if len(xd) == 1: + xd = xd[0] + return xd + + +def check_multiple_gpus(gpus): + if gpus is not None: + if not isinstance(gpus, (list, tuple)): + raise TypeError("gpus must be a list") + if len(gpus) < 1: + raise ValueError("gpus must have length greater than 0") def set_ref_emb(embeddings, labels, ref_emb, ref_labels): diff --git a/src/pytorch_metric_learning/utils/inference.py b/src/pytorch_metric_learning/utils/inference.py index 2adeb6ae..8b182a95 100644 --- a/src/pytorch_metric_learning/utils/inference.py +++ b/src/pytorch_metric_learning/utils/inference.py @@ -1,7 +1,7 @@ import numpy as np import torch -from ..distances import BatchedDistance, CosineSimilarity +from ..distances import BatchedDistance from . import common_functions as c_f try: @@ -11,32 +11,60 @@ pass -class MatchFinder: - def __init__(self, distance, threshold=None): - self.distance = distance - self.threshold = threshold +def mask_reshape_knn_idx(x, matches_self_idx): + return x[~matches_self_idx].view(x.shape[0], -1) + + +def return_results(D, I, ref_includes_query): + if ref_includes_query: + self_idx = torch.arange(len(I), device=I.device) + matches_self_idx = I == self_idx.unsqueeze(1) + row_has_match = torch.any(matches_self_idx, dim=1) + # If every row has a match, then masking will work + if not torch.all(row_has_match): + # For rows that don't contain the self index + # Remove the Nth value by setting matches_self_idx[N] to True + matches_self_idx[~row_has_match, -1] = True + I = mask_reshape_knn_idx(I, matches_self_idx) + D = mask_reshape_knn_idx(D, matches_self_idx) + return D, I - def operate_on_emb(self, input_func, query_emb, ref_emb=None, *args, **kwargs): - if ref_emb is None: - ref_emb = query_emb - return input_func(query_emb, ref_emb, *args, **kwargs) + +def get_topk(distances, indices, k, get_largest): + def fn(mat, s, e): + D, I = torch.topk(mat, k, largest=get_largest, dim=1) + distances[s:e] = D + indices[s:e] = I + + return fn + + +class CustomKNN: + def __init__(self, distance, batch_size=None, threshold=None): + if batch_size: + self.distance = BatchedDistance(distance, batch_size=batch_size) + else: + self.distance = distance + self.threshold = threshold # for a batch of queries + + @torch.no_grad() def get_matching_pairs( self, query_emb, ref_emb=None, threshold=None, return_tuples=False ): - with torch.no_grad(): - threshold = threshold if threshold is not None else self.threshold - return self.operate_on_emb( - self._get_matching_pairs, query_emb, ref_emb, threshold, return_tuples - ) + threshold = threshold if threshold is not None else self.threshold + ref_emb = ref_emb if ref_emb is not None else query_emb + return self._get_matching_pairs(query_emb, ref_emb, threshold, return_tuples) def _get_matching_pairs(self, query_emb, ref_emb, threshold, return_tuples): mat = self.distance(query_emb, ref_emb) matches = mat >= threshold if self.distance.is_inverted else mat <= threshold - matches = matches.cpu().numpy() + matches = matches if return_tuples: - return list(zip(*np.where(matches))) + return list( + zip(*torch.where(matches)) + ) # Why transforming to numpy? torch.where gives same result return matches # where x and y are already matched pairs @@ -51,29 +79,41 @@ def is_match(self, x, y, threshold=None): return output.detach().item() return output.cpu().numpy() + def __call__( + self, query, reference, k, ref_includes_query=False + ): # modified to follow the same signature of faiss + if ref_includes_query: + k = k + 1 + get_largest = self.distance.is_inverted + if isinstance(self.distance, BatchedDistance): + device = query.device + distances = torch.zeros(len(query), k, device=device) + indices = torch.zeros(len(query), k, device=device, dtype=torch.long) + self.distance.iter_fn = get_topk(distances, indices, k, get_largest) + self.distance(query, reference) + else: + mat = self.distance(query, reference) + distances, indices = torch.topk(mat, k, largest=get_largest, dim=1) + return return_results(distances, indices, ref_includes_query) + class InferenceModel: def __init__( self, trunk, embedder=None, - match_finder=None, normalize_embeddings=True, - knn_func=None, + knn_func: CustomKNN = None, data_device=None, dtype=None, ): self.trunk = trunk self.embedder = torch.nn.Identity() if embedder is None else embedder - self.match_finder = ( - MatchFinder(distance=CosineSimilarity(), threshold=0.9) - if match_finder is None - else match_finder - ) - self.knn_func = ( - FaissKNN(reset_before=False, reset_after=False) - if knn_func is None - else knn_func + + if knn_func is not None: + knn_func.threshold = 0.9 + self.knn_func = FaissKNN( + reset_before=False, reset_after=False, knn_func=knn_func ) self.normalize_embeddings = normalize_embeddings self.data_device = ( @@ -107,7 +147,9 @@ def call_knn(self, func, inputs, batch_size): def get_nearest_neighbors(self, query, k): query_emb = self.get_embeddings(query) - return self.knn_func(query_emb, k) + return self.knn_func( + query_emb, k=k + ) # modified to follow the same signature of faiss def get_embeddings(self, x): x = self.process_if_list(x) @@ -127,7 +169,7 @@ def get_matches(self, query, ref=None, threshold=None, return_tuples=False): ref_emb = query_emb if ref is not None: ref_emb = self.get_embeddings(ref) - return self.match_finder.get_matching_pairs( + return self.knn_func.get_matching_pairs( query_emb, ref_emb, threshold, return_tuples ) @@ -135,7 +177,7 @@ def get_matches(self, query, ref=None, threshold=None, return_tuples=False): def is_match(self, x, y, threshold=None): x = self.get_embeddings(x) y = self.get_embeddings(y) - return self.match_finder.is_match(x, y, threshold) + return self.knn_func.is_match(x, y, threshold) def save_knn_func(self, filename): self.knn_func.save(filename) @@ -151,7 +193,12 @@ def process_if_list(self, x): class FaissKNN: def __init__( - self, reset_before=True, reset_after=True, index_init_fn=None, gpus=None + self, + reset_before=True, + reset_after=True, + index_init_fn=None, + knn_func=None, + gpus=None, ): self.reset() self.reset_before = reset_before @@ -159,18 +206,15 @@ def __init__( self.index_init_fn = ( faiss.IndexFlatL2 if index_init_fn is None else index_init_fn ) - if gpus is not None: - if not isinstance(gpus, (list, tuple)): - raise TypeError("gpus must be a list") - if len(gpus) < 1: - raise ValueError("gpus must have length greater than 0") + self.knn_func = knn_func + c_f.check_multiple_gpus(gpus) self.gpus = gpus def __call__( self, query, - k, reference=None, + k=1, # modified to follow the same signature of faiss ref_includes_query=False, ): if ref_includes_query: @@ -186,13 +230,17 @@ def __call__( raise ValueError( "self.index is None. It needs to be initialized before being used." ) - distances, indices = try_gpu( - self.index, - query, - reference, - k, - is_cuda, - self.gpus, + distances, indices = ( + try_gpu( + self.index, + query, + reference, + k, + is_cuda, + self.gpus, + ) + if self.knn_func is None + else self.knn_func(query, reference, k) ) distances = c_f.to_device(distances, device=device) indices = c_f.to_device(indices, device=device) @@ -216,6 +264,26 @@ def load(self, filename): def reset(self): self.index = None + def get_matching_pairs( + self, query_emb, ref_emb=None, threshold=None, return_tuples=False + ): + try: + return self.knn_func.get_matching_pairs( + query_emb, ref_emb, threshold, return_tuples + ) + except RuntimeError: + raise RuntimeWarning( + "No suitable match finder provided. It must implement the get_matching_pairs method" + ) + + def is_match(self, x, y, threshold=None): + try: + return self.knn_func.is_match(x, y, threshold) + except RuntimeError: + raise RuntimeWarning( + "No suitable match finder provided. It must implement the is_match method" + ) + class FaissKMeans: def __init__(self, **kwargs): @@ -224,7 +292,7 @@ def __init__(self, **kwargs): def __call__(self, x, nmb_clusters): device = x.device x = c_f.to_numpy(x).astype(np.float32) - n_data, d = x.shape + _, d = x.shape c_f.LOGGER.info("running k-means clustering with k=%d" % nmb_clusters) c_f.LOGGER.info("embedding dimensionality is %d" % d) @@ -232,13 +300,22 @@ def __call__(self, x, nmb_clusters): kmeans = faiss.Kmeans(d, nmb_clusters, **self.kwargs) kmeans.train(x) _, idxs = kmeans.index.search(x, 1) - return torch.tensor([int(n[0]) for n in idxs], dtype=int, device=device) + return torch.tensor([int(n[0]) for n in idxs], dtype=torch.int, device=device) def add_to_index_and_search(index, query, reference, k): - if reference is not None: - index.add(reference.float().cpu()) - return index.search(query.float().cpu(), k) + indexOnOnlyOneGPU = faiss.get_num_gpus() == 1 and isinstance( + index, faiss.GpuIndex + ) # Issue #491 + device_query = query.float() + device_ref = reference.float() if reference is not None else None + if not indexOnOnlyOneGPU: + device_query = device_query.cpu() + device_ref = device_ref.cpu() if device_ref is not None else None + + if device_ref is not None: + index.add(device_ref) + return index.search(device_query, k) def convert_to_gpu_index(index, gpus): @@ -260,8 +337,8 @@ def try_gpu(index, query, reference, k, is_cuda, gpus): gpu_index = None gpus_are_available = faiss.get_num_gpus() > 0 gpu_condition = (is_cuda or (gpus is not None)) and gpus_are_available + max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048 if gpu_condition: - max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048 if k <= max_k_for_gpu: gpu_index = convert_to_gpu_index(index, gpus) try: @@ -283,54 +360,3 @@ def run_pca(x, output_dimensionality): mat.train(x) assert mat.is_trained return c_f.to_device(torch.from_numpy(mat.apply_py(x)), device=device) - - -def mask_reshape_knn_idx(x, matches_self_idx): - return x[~matches_self_idx].view(x.shape[0], -1) - - -def return_results(D, I, ref_includes_query): - if ref_includes_query: - self_idx = torch.arange(len(I), device=I.device) - matches_self_idx = I == self_idx.unsqueeze(1) - row_has_match = torch.any(matches_self_idx, dim=1) - # If every row has a match, then masking will work - if not torch.all(row_has_match): - # For rows that don't contain the self index - # Remove the Nth value by setting matches_self_idx[N] to True - matches_self_idx[~row_has_match, -1] = True - I = mask_reshape_knn_idx(I, matches_self_idx) - D = mask_reshape_knn_idx(D, matches_self_idx) - return D, I - - -def get_topk(distances, indices, k, get_largest): - def fn(mat, s, e): - D, I = torch.topk(mat, k, largest=get_largest, dim=1) - distances[s:e] = D - indices[s:e] = I - - return fn - - -class CustomKNN: - def __init__(self, distance, batch_size=None): - if batch_size: - self.distance = BatchedDistance(distance, batch_size=batch_size) - else: - self.distance = distance - - def __call__(self, query, k, reference, ref_includes_query=False): - if ref_includes_query: - k = k + 1 - get_largest = self.distance.is_inverted - if isinstance(self.distance, BatchedDistance): - device = query.device - distances = torch.zeros(len(query), k, device=device) - indices = torch.zeros(len(query), k, device=device, dtype=torch.long) - self.distance.iter_fn = get_topk(distances, indices, k, get_largest) - self.distance(query, reference) - else: - mat = self.distance(query, reference) - distances, indices = torch.topk(mat, k, largest=get_largest, dim=1) - return return_results(distances, indices, ref_includes_query) diff --git a/tests/utils/test_calculate_accuracies_large_k.py b/tests/utils/test_calculate_accuracies_large_k.py index b90d24cd..826aaa45 100644 --- a/tests/utils/test_calculate_accuracies_large_k.py +++ b/tests/utils/test_calculate_accuracies_large_k.py @@ -83,14 +83,18 @@ def evaluate(self, encs, labels, max_k=None, ecfss=False): torch_encs = torch.from_numpy(encs) k = len(encs) - 1 if ecfss else len(encs) knn_func = inference.FaissKNN() - _, all_indices = knn_func(torch_encs, k, torch_encs, ecfss) + _, all_indices = knn_func( + torch_encs, torch_encs, k, ecfss + ) # modified to follow the same signature of faiss if max_k is None: max_k = k indices = all_indices else: if max_k == "max_bin_count": max_k = int(max(np.bincount(labels))) - int(ecfss) - _, indices = knn_func(torch_encs, max_k, torch_encs, ecfss) + _, indices = knn_func( + torch_encs, torch_encs, max_k, ecfss + ) # modified to follow the same signature of faiss # let's use the most simple mAP implementation # of course this can be computed much faster using cumsum, etc.