Skip to content

Commit

Permalink
use a class for sketch index
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmudhera committed Nov 16, 2024
1 parent 36a255a commit b8239a1
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 21 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ SRC_DIR = src/cpp
BIN_DIR = src/yacht

# Source files
SRC_FILES = $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp $(SRC_DIR)/compute_similarity.cpp
SRC_FILES = $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp $(SRC_DIR)/compute_similarity.cpp $(SRC_DIR)/MultiSketchIndex.cpp

# Object files
OBJ_FILES = $(SRC_FILES:.cpp=.o)
Expand All @@ -20,10 +20,10 @@ TARGET2 = $(BIN_DIR)/run_compute_similarity
all: $(TARGET1) $(TARGET2)

$(TARGET1): $(OBJ_FILES)
$(CXX) $(CXXFLAGS) $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp -o $(TARGET1)
$(CXX) $(CXXFLAGS) $(SRC_DIR)/yacht_train_core.cpp $(SRC_DIR)/utils.cpp $(SRC_DIR)/MultiSketchIndex.cpp -o $(TARGET1)

$(TARGET2): $(OBJ_FILES)
$(CXX) $(CXXFLAGS) $(SRC_DIR)/compute_similarity.cpp $(SRC_DIR)/utils.cpp -o $(TARGET2)
$(CXX) $(CXXFLAGS) $(SRC_DIR)/compute_similarity.cpp $(SRC_DIR)/utils.cpp $(SRC_DIR)/MultiSketchIndex.cpp -o $(TARGET2)

%.o: %.cpp
$(CXX) $(CXXFLAGS) -c $< -o $@
Expand Down
47 changes: 47 additions & 0 deletions src/cpp/MultiSketchIndex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "MultiSketchIndex.h"

MultiSketchIndex::MultiSketchIndex() {
// Constructor
}

MultiSketchIndex::~MultiSketchIndex() {
// Destructor
}


void MultiSketchIndex::add_hash(hash_t hash_value, std::vector<int> sketch_indices) {
// Add the hash value to the index
if (multi_sketch_index.find(hash_value) == multi_sketch_index.end()) {
multi_sketch_index[hash_value] = sketch_indices;
return;
}

for (int i = 0; i < sketch_indices.size(); i++) {
add_hash(hash_value, sketch_indices[i]);
}
}


void MultiSketchIndex::add_hash(hash_t hash_value, int sketch_index) {
// Add the hash value to the index
if (multi_sketch_index.find(hash_value) == multi_sketch_index.end()) {
multi_sketch_index[hash_value] = std::vector<int>();
}
multi_sketch_index[hash_value].push_back(sketch_index);
}






const std::vector<int>& MultiSketchIndex::get_sketch_indices(hash_t hash_value) {
// Get the sketch indices for the hash value
if (multi_sketch_index.find(hash_value) == multi_sketch_index.end()) {
return std::vector<int>();
}
return multi_sketch_index[hash_value];
}



69 changes: 69 additions & 0 deletions src/cpp/MultiSketchIndex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef SKETCH_H
#define SKETCH_H

#include <iostream>
#include <vector>


#ifndef HASH_T
#define HASH_T
typedef unsigned long long int hash_t;
#endif


/**
* @brief MultiSketchIndex class, which is used to store the index of many sketches.
*
*/
class MultiSketchIndex {
public:
MultiSketchIndex();
~MultiSketchIndex();

/**
* @brief Add a hash value to the index.
*
* @param hash_value The hash value to add.
* @param sketch_index The index of the sketch in which this hash value appears.
*/
void add_hash(hash_t hash_value, int sketch_index);



/**
* @brief Add a hash value to the index.
*
* @param hash_value The hash value to add.
* @param sketch_indices Indices of the sketches in which this hash value appears.
*/
void add_hash(hash_t hash_value, std::vector<int> sketch_indices);



/**
* @brief Get the sketch indices for a hash value.
*
* @param hash_value The hash value to get the sketch indices for.
* @return const std::vector<int>& The sketch indices in which the hash value appears.
*/
const std::vector<int>& get_sketch_indices(hash_t hash_value);


/**
* @brief Check if a hash value exists in the index.
*
* @param hash_value The hash value to check.
* @return true If the hash value exists in the index.
* @return false If the hash value does not exist in the index.
*/
bool hash_exists(hash_t hash_value) {
return multi_sketch_index.find(hash_value) != multi_sketch_index.end();
}


private:
std::unordered_map<hash_t, std::vector<int>> multi_sketch_index;

};

#endif
Binary file added src/cpp/MultiSketchIndex.o
Binary file not shown.
10 changes: 7 additions & 3 deletions src/cpp/compute_similarity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ struct Arguments {


typedef Arguments Arguments;

#ifndef HASH_T
#define HASH_T
typedef unsigned long long int hash_t;
#endif


void parse_arguments(int argc, char *argv[], Arguments &arguments) {
Expand Down Expand Up @@ -186,14 +190,14 @@ int main(int argc, char** argv) {

// compute the index from the target sketches
cout << "Building index from target sketches..." << endl;
unordered_map<hash_t, vector<int>> hash_index_target;
compute_index_from_sketches(sketches_target, hash_index_target, args.number_of_threads);
MultiSketchIndex target_sketches_index;
compute_index_from_sketches(sketches_target, target_sketches_index, args.number_of_threads);

// compute the similarity matrix
cout << "Computing similarity matrix..." << endl;
vector<vector<int>> similars;
compute_intersection_matrix(sketches_query, sketches_target,
hash_index_target,
target_sketches_index,
args.output_directory, similars,
args.containment_threshold,
args.num_of_passes,
Expand Down
24 changes: 14 additions & 10 deletions src/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ void compute_index_from_sketches_one_chunk( int sketch_index_start, int sketch_i



void compute_index_from_sketches(std::vector<std::vector<hash_t>>& sketches, std::unordered_map<hash_t, std::vector<int>>& hash_index, const int num_threads) {
void compute_index_from_sketches(std::vector<std::vector<hash_t>>& sketches,
MultiSketchIndex& multi_sketch_index,
const int num_threads) {

// create mutexes
int num_unordered_maps = 100000;
Expand Down Expand Up @@ -97,7 +99,8 @@ void compute_index_from_sketches(std::vector<std::vector<hash_t>>& sketches, std
for (int i = 0; i < num_unordered_maps; i++) {
for (auto it = hash_index_chunks[i].begin(); it != hash_index_chunks[i].end(); it++) {
hash_t hash_value = it->first;
hash_index[hash_value] = it->second;
std::vector<int> sketch_indices = it->second;
multi_sketch_index.add_hash(hash_value, sketch_indices);
}
}

Expand Down Expand Up @@ -195,7 +198,7 @@ void compute_intersection_matrix_by_sketches(int query_sketch_start_index, int q
int pass_id, int negative_offset,
const std::vector<std::vector<hash_t>>& sketches_query,
const std::vector<std::vector<hash_t>>& sketches_ref,
const std::unordered_map<hash_t, std::vector<int>>& hash_index_ref,
MultiSketchIndex& multi_sketch_index_ref,
int** intersectionMatrix,
double containment_threshold,
std::vector<std::vector<int>>& similars) {
Expand All @@ -207,11 +210,12 @@ void compute_intersection_matrix_by_sketches(int query_sketch_start_index, int q
for (uint i = query_sketch_start_index; i < query_sketch_end_index; i++) {
for (int j = 0; j < sketches_query[i].size(); j++) {
hash_t hash = sketches_query[i][j];
if (hash_index_ref.find(hash) != hash_index_ref.end()) {
std::vector<int> ref_sketch_indices = hash_index_ref.at(hash);
for (uint k = 0; k < ref_sketch_indices.size(); k++) {
intersectionMatrix[i-negative_offset][ref_sketch_indices[k]]++;
}
if (!multi_sketch_index_ref.hash_exists(hash)) {
continue;
}
std::vector<int> ref_sketch_indices = multi_sketch_index_ref.get_sketch_indices(hash);
for (uint k = 0; k < ref_sketch_indices.size(); k++) {
intersectionMatrix[i-negative_offset][ref_sketch_indices[k]]++;
}
}
}
Expand Down Expand Up @@ -265,7 +269,7 @@ void compute_intersection_matrix_by_sketches(int query_sketch_start_index, int q

void compute_intersection_matrix(const std::vector<std::vector<hash_t>>& sketches_query,
const std::vector<std::vector<hash_t>>& sketches_ref,
const std::unordered_map<hash_t, std::vector<int>>& hash_index_ref,
MultiSketchIndex& multi_sketch_index_ref,
const std::string& out_dir,
std::vector<std::vector<int>>& similars,
double containment_threshold,
Expand Down Expand Up @@ -310,7 +314,7 @@ void compute_intersection_matrix(const std::vector<std::vector<hash_t>>& sketche
start_query_index_this_thread, end_query_index_this_thread,
i, out_dir, pass_id, negative_offset,
std::ref(sketches_query), std::ref(sketches_ref),
std::ref(hash_index_ref), intersectionMatrix,
std::ref(multi_sketch_index_ref), intersectionMatrix,
containment_threshold,
std::ref(similars)));
}
Expand Down
15 changes: 13 additions & 2 deletions src/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,18 @@
#include <random>

#include "json.hpp"
#include "MultiSketchIndex.h"

using json = nlohmann::json;

#ifndef HASH_T
#define HASH_T
typedef unsigned long long int hash_t;
#endif





/**
* @brief Read the min-hashes from a FMH sketch file
Expand All @@ -49,7 +57,7 @@ std::vector<hash_t> read_min_hashes(const std::string& sketch_path);
* @param num_threads The number of threads to use
*/
void compute_index_from_sketches(std::vector<std::vector<hash_t>>& sketches,
std::unordered_map<hash_t, std::vector<int>>& hash_index,
MultiSketchIndex& multi_sketch_index,
int num_threads);


Expand Down Expand Up @@ -111,10 +119,13 @@ void show_empty_sketches(const std::vector<int>&);
*/
void compute_intersection_matrix(const std::vector<std::vector<hash_t>>& sketches_query,
const std::vector<std::vector<hash_t>>& sketches_ref,
const std::unordered_map<hash_t, std::vector<int>>& hash_index_ref,
MultiSketchIndex& multi_sketch_index_ref,
const std::string& out_dir,
std::vector<std::vector<int>>& similars,
double containment_threshold,
const int num_passes, const int num_threads);




#endif
7 changes: 4 additions & 3 deletions src/cpp/yacht_train_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "argparse.hpp"
#include "json.hpp"
#include "utils.h"
#include "MultiSketchIndex.h"

#include <iostream>
#include <vector>
Expand Down Expand Up @@ -183,7 +184,7 @@ int main(int argc, char *argv[]) {

std::vector<std::string> sketch_paths;
vector<vector<hash_t>> sketches;
unordered_map<hash_t, vector<int>> hash_index;
MultiSketchIndex ref_sketches_index;
mutex mutex_count_empty_sketch;
vector<int> empty_sketch_ids;
int ** intersectionMatrix;
Expand Down Expand Up @@ -231,7 +232,7 @@ int main(int argc, char *argv[]) {
// ****************************************************************
auto index_build_start = chrono::high_resolution_clock::now();
cout << "Building index from sketches..." << endl;
compute_index_from_sketches(sketches, hash_index, arguments.number_of_threads);
compute_index_from_sketches(sketches, ref_sketches_index, arguments.number_of_threads);
auto index_build_end = chrono::high_resolution_clock::now();
auto index_build_duration = chrono::duration_cast<chrono::milliseconds>(index_build_end - index_build_start);
cout << "Time taken to build index: " << index_build_duration.count() << " milliseconds" << endl;
Expand All @@ -243,7 +244,7 @@ int main(int argc, char *argv[]) {
// **********************************************************************
auto mat_computation_start = chrono::high_resolution_clock::now();
cout << "Computing intersection matrix..." << endl;
compute_intersection_matrix(sketches, sketches, hash_index,
compute_intersection_matrix(sketches, sketches, ref_sketches_index,
arguments.working_directory, similars,
arguments.containment_threshold, arguments.num_of_passes,
arguments.number_of_threads);
Expand Down

0 comments on commit b8239a1

Please sign in to comment.