diff --git a/khmer/_cpy_smallcountgraph.hh b/khmer/_cpy_smallcountgraph.hh index 40c18f339c..acfa2f5551 100644 --- a/khmer/_cpy_smallcountgraph.hh +++ b/khmer/_cpy_smallcountgraph.hh @@ -46,40 +46,7 @@ typedef struct { static void khmer_smallcountgraph_dealloc(khmer_KSmallCountgraph_Object * obj); -static -PyObject * -smallcount_get_raw_tables(khmer_KSmallCountgraph_Object * self, PyObject * args) -{ - SmallCountgraph * countgraph = self->countgraph; - - khmer::Byte ** table_ptrs = countgraph->get_raw_tables(); - std::vector sizes = countgraph->get_tablesizes(); - - PyObject * raw_tables = PyList_New(sizes.size()); - for (unsigned int i=0; inodegraph; - - khmer::Byte ** table_ptrs = countgraph->get_raw_tables(); - std::vector sizes = countgraph->get_tablesizes(); - - PyObject * raw_tables = PyList_New(sizes.size()); - for (unsigned int i=0; i sizes) : Hashgraph(ksize, new ByteStorage(sizes)) { } ; + + // get access to raw tables. + Byte ** get_raw_tables() + { + return ((ByteStorage*)store)->get_raw_tables(); + } }; // Hashgraph-derived class with NibbleStorage. diff --git a/lib/hashtable.hh b/lib/hashtable.hh index e5cc4a1d5e..e35b6f6df6 100644 --- a/lib/hashtable.hh +++ b/lib/hashtable.hh @@ -277,12 +277,6 @@ public: void get_kmer_counts(const std::string &s, std::vector &counts) const; - // get access to raw tables. - Byte ** get_raw_tables() - { - return store->get_raw_tables(); - } - // find the minimum k-mer count in the given sequence BoundedCounterType get_min_count(const std::string &s); diff --git a/lib/khmer.hh b/lib/khmer.hh index 0dcefe3bd3..943fb7eb3a 100644 --- a/lib/khmer.hh +++ b/lib/khmer.hh @@ -69,6 +69,7 @@ private:\ className(const className&);\ const className& operator=(const className&) +#include #include #include #include @@ -121,6 +122,8 @@ typedef unsigned short int BoundedCounterType; // A single-byte type. typedef unsigned char Byte; +using AtomicByte = std::atomic; + typedef void (*CallbackFn)(const char * info, void * callback_data, unsigned long long n_reads, diff --git a/lib/storage.cc b/lib/storage.cc index 0523796980..5fdc9cbe6a 100644 --- a/lib/storage.cc +++ b/lib/storage.cc @@ -872,7 +872,7 @@ void NibbleStorage::load(std::string infilename, WordLength& ksize) _n_tables = (unsigned int) save_n_tables; _occupied_bins = save_occupied_bins; - _counts = new Byte*[_n_tables]; + _counts = new AtomicByte*[_n_tables]; for (unsigned int i = 0; i < _n_tables; i++) { _counts[i] = NULL; } @@ -887,7 +887,7 @@ void NibbleStorage::load(std::string infilename, WordLength& ksize) tablesize = save_tablesize; _tablesizes.push_back(tablesize); - _counts[i] = new Byte[tablebytes]; + _counts[i] = new AtomicByte[tablebytes]; unsigned long long loaded = 0; while (loaded != tablebytes) { diff --git a/lib/storage.hh b/lib/storage.hh index 1a1bc4787b..6a67f96a38 100644 --- a/lib/storage.hh +++ b/lib/storage.hh @@ -38,10 +38,6 @@ Contact: khmer-project@idyll.org #ifndef STORAGE_HH #define STORAGE_HH -#include -#include -#include -using MuxGuard = std::lock_guard; namespace khmer { @@ -67,7 +63,6 @@ public: virtual BoundedCounterType test_and_set_bits( HashIntoType khash ) = 0; virtual void add(HashIntoType khash) = 0; virtual const BoundedCounterType get_count(HashIntoType khash) const = 0; - virtual Byte ** get_raw_tables() = 0; void set_use_bigcount(bool b); bool get_use_bigcount(); @@ -214,13 +209,6 @@ public: return 1; } - // Writing to the tables outside of defined methods has undefined behavior! - // As such, this should only be used to return read-only interfaces - Byte ** get_raw_tables() - { - return _counts; - } - void update_from(const BitStorage&); }; @@ -246,9 +234,8 @@ protected: size_t _n_tables; uint64_t _occupied_bins; uint64_t _n_unique_kmers; - std::array mutexes; static constexpr uint8_t _max_count{15}; - Byte ** _counts; + AtomicByte ** _counts; // Compute index into the table, this retrieves the correct byte // which you then need to select the correct nibble from @@ -271,8 +258,6 @@ public: NibbleStorage(std::vector& tablesizes) : _tablesizes{tablesizes}, _occupied_bins{0}, _n_unique_kmers{0} { - // to allow more than 32 tables increase the size of mutex pool - assert(_n_tables <= 32); _allocate_counters(); } @@ -293,13 +278,13 @@ public: { _n_tables = _tablesizes.size(); - _counts = new Byte*[_n_tables]; + _counts = new AtomicByte*[_n_tables]; for (size_t i = 0; i < _n_tables; i++) { const uint64_t tablesize = _tablesizes[i]; const uint64_t tablebytes = tablesize / 2 + 1; - _counts[i] = new Byte[tablebytes]; + _counts[i] = new AtomicByte[tablebytes]; memset(_counts[i], 0, tablebytes); } } @@ -317,12 +302,12 @@ public: bool is_new_kmer = false; for (unsigned int i = 0; i < _n_tables; i++) { - MuxGuard g(mutexes[i]); - Byte* const table(_counts[i]); + AtomicByte* const table(_counts[i]); const uint64_t idx = _table_index(khash, _tablesizes[i]); const uint8_t mask = _mask(khash, _tablesizes[i]); const uint8_t shift = _shift(khash, _tablesizes[i]); - const uint8_t current_count = (table[idx] & mask) >> shift; + uint8_t current_tbl = table[idx]; + uint8_t current_count = (current_tbl & mask) >> shift; if (!is_new_kmer) { if (current_count == 0) { @@ -342,8 +327,25 @@ public: } // increase count, no checking for overflow - const uint8_t new_count = (current_count + 1) << shift; - table[idx] = (table[idx] & ~mask) | (new_count & mask); + // current_tbl and new_tbl are the current and new bit packed values + // for the idx'th byte of the table. + // compare_exchange_weak will update the value of table[idx] if + // current_tbl is the current value (hasn't been changed by a + // different thread) if they differ the value actually stored + // in table[idx] is written to current_tbl so this is a + // compare-and-swap loop + uint8_t new_count = (current_count + 1) << shift; + uint8_t new_tbl = (current_tbl & ~mask) | (new_count & mask); + + while(!table[idx].compare_exchange_weak(current_tbl, new_tbl)) { + current_count = (current_tbl & mask) >> shift; + new_count = (current_count + 1); + if (new_count > _max_count) { + break; + } + new_count <<= shift; + new_tbl = (current_tbl & ~mask) | (new_count & mask); + } } if (is_new_kmer) { @@ -358,7 +360,7 @@ public: // get the minimum count across all tables for (unsigned int i = 0; i < _n_tables; i++) { - const Byte* table(_counts[i]); + const AtomicByte* table(_counts[i]); const uint64_t idx = _table_index(khash, _tablesizes[i]); const uint8_t mask = _mask(khash, _tablesizes[i]); const uint8_t shift = _shift(khash, _tablesizes[i]); @@ -391,10 +393,6 @@ public: void save(std::string outfilename, WordLength ksize); void load(std::string infilename, WordLength& ksize); - Byte ** get_raw_tables() - { - return _counts; - } }; diff --git a/tests/test_countgraph.py b/tests/test_countgraph.py index 87a0dee1db..a73f306b04 100644 --- a/tests/test_countgraph.py +++ b/tests/test_countgraph.py @@ -209,14 +209,11 @@ def test_get_raw_tables(): def test_get_raw_tables_smallcountgraph(): - # for the same number of entries a SmallCountgraph uses ~half the memory - # of a normal Countgraph + # smallcountgraphs store individual counts packed into a byte, the raw + # tables probably do not give users what they expect (something that can be + # given to numpy.frombuffer) ht = khmer.SmallCountgraph(20, 1e5, 4) - tables = ht.get_raw_tables() - - for size, table in zip(ht.hashsizes(), tables): - assert isinstance(table, memoryview) - assert size // 2 + 1 == len(table) + assert not hasattr(ht, 'get_raw_tables') def test_get_raw_tables_view(): @@ -229,18 +226,6 @@ def test_get_raw_tables_view(): assert sum(tab.tolist()) == 1 -def test_get_raw_tables_view_smallcountgraph(): - ht = khmer.SmallCountgraph(4, 1e5, 4) - tables = ht.get_raw_tables() - for tab in tables: - assert sum(tab.tolist()) == 0 - ht.consume('AAAA') - # the actual count is 1 but stored in the first 4bits of a Byte - # and so becomes 16 - for tab in tables: - assert sum(tab.tolist()) == int('00010000', 2) - - @pytest.mark.huge def test_toobig(): try: diff --git a/tests/test_nodegraph.py b/tests/test_nodegraph.py index 78a3fef410..b27acf8f5f 100644 --- a/tests/test_nodegraph.py +++ b/tests/test_nodegraph.py @@ -553,15 +553,11 @@ def test_extract_unique_paths_4(): def test_get_raw_tables(): + # nodegraphs store individual bits packed into a byte, the raw tables + # probably do not give users what they expect (something that can be + # given to numpy.frombuffer) kh = khmer.Nodegraph(10, 1e6, 4) - kh.consume('ATGGAGAGAC') - kh.consume('AGTGGCGATG') - kh.consume('ATAGACAGGA') - tables = kh.get_raw_tables() - - for size, table in zip(kh.hashsizes(), tables): - assert isinstance(table, memoryview) - assert size == len(table) + assert not hasattr(kh, 'get_raw_tables') def test_simple_median():