diff --git a/tensorflow/core/framework/embedding/hbm_dram_storage.h b/tensorflow/core/framework/embedding/hbm_dram_storage.h index 15f6271fb4f..5106266e82c 100644 --- a/tensorflow/core/framework/embedding/hbm_dram_storage.h +++ b/tensorflow/core/framework/embedding/hbm_dram_storage.h @@ -49,6 +49,7 @@ class HbmDramStorage : public MultiTierStorage { ~HbmDramStorage() override { MultiTierStorage::DeleteFromEvictionManager(); + //delete restore_cache_; delete hbm_; delete dram_; delete dram_feat_desc_; @@ -227,7 +228,7 @@ class HbmDramStorage : public MultiTierStorage { } void BatchEviction() override { - constexpr int EvictionSize = 10000; + constexpr int EvictionSize = 5000; K evic_ids[EvictionSize]; if (!MultiTierStorage::ready_eviction_) { return; @@ -287,16 +288,18 @@ class HbmDramStorage : public MultiTierStorage { partition_id, partition_num, is_incr, reset_version, reader); + restore_cache_.reset(CacheFactory::Create(CacheStrategy::LFU, "ads")); restorer.RestoreCkpt(emb_config, device); int64 num_of_hbm_ids = std::min(MultiTierStorage::cache_capacity_, - (int64)MultiTierStorage::cache_->size()); + (int64)restore_cache_->size()); + if (num_of_hbm_ids > 0) { K* hbm_ids = new K[num_of_hbm_ids]; int64* hbm_freqs = new int64[num_of_hbm_ids]; int64* hbm_versions = nullptr; - MultiTierStorage::cache_->get_cached_ids(hbm_ids, num_of_hbm_ids, + restore_cache_->get_cached_ids(hbm_ids, num_of_hbm_ids, hbm_versions, hbm_freqs); ImportToHbm(hbm_ids, num_of_hbm_ids, value_len, emb_config.emb_index); MultiTierStorage::cache_thread_pool_->Schedule( @@ -329,10 +332,10 @@ class HbmDramStorage : public MultiTierStorage { Status s = filter->Restore(key_num, bucket_num, partition_id, partition_num, value_len, is_filter, true/*to_dram*/, is_incr, restore_buff); - - MultiTierStorage::cache_->update((K*)restore_buff.key_buffer, key_num, - (int64*)restore_buff.version_buffer, - (int64*)restore_buff.freq_buffer); + + restore_cache_->update((K*)restore_buff.key_buffer, key_num, + (int64*)restore_buff.version_buffer, + (int64*)restore_buff.freq_buffer); return s; } @@ -574,6 +577,7 @@ class HbmDramStorage : public MultiTierStorage { DramStorage* dram_ = nullptr; FeatureDescriptor* hbm_feat_desc_ = nullptr; FeatureDescriptor* dram_feat_desc_ = nullptr; + std::unique_ptr> restore_cache_ = nullptr; Allocator* gpu_alloc_; const int copyback_flag_offset_bits_ = 60; };