Skip to content

Commit 6dae552

Browse files
authored
[Embedding] Refine header file of embedding variable. (#978)
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 186afd0 commit 6dae552

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ limitations under the License.
3434
#include "tensorflow/core/framework/embedding/gpu_hash_map_kv.h"
3535
#include "tensorflow/core/framework/embedding/embedding_config.h"
3636
#include "tensorflow/core/framework/embedding/storage.h"
37-
#include "tensorflow/core/framework/embedding/storage_factory.h"
3837
#include "tensorflow/core/framework/typed_allocator.h"
3938

4039
namespace tensorflow {

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "tensorflow/core/framework/embedding/cache.h"
2424
#include "tensorflow/core/framework/embedding/config.pb.h"
2525
#include "tensorflow/core/framework/embedding/embedding_var.h"
26+
#include "tensorflow/core/framework/embedding/storage_factory.h"
2627
#include "tensorflow/core/framework/op_kernel.h"
2728
#include "tensorflow/core/framework/register_types.h"
2829
#include "tensorflow/core/framework/resource_mgr.h"

tensorflow/core/kernels/kv_variable_restore_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/core/framework/embedding/cache.h"
2626
#include "tensorflow/core/framework/embedding/config.pb.h"
2727
#include "tensorflow/core/framework/embedding/embedding_var.h"
28+
#include "tensorflow/core/framework/embedding/storage_factory.h"
2829
#include "tensorflow/core/framework/op_kernel.h"
2930
#include "tensorflow/core/framework/register_types.h"
3031
#include "tensorflow/core/framework/resource_mgr.h"

tensorflow/core/kernels/training_ali_ops.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class KvSparseApplyAdagradGPUOp : public OpKernel {
236236
T** dev_a = dev_v + task_size;
237237
CHECK(dev_a);
238238
CHECK(dev_v);
239-
DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2);
239+
se::DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2);
240240
stream->ThenMemcpy(&dev_v_ptr, v, sizeof(T*) * task_size * 2);
241241

242242
int block_size = 128;
@@ -1606,7 +1606,7 @@ class KvSparseApplyAdamGPUOp : public OpKernel {
16061606
CHECK(dev_m_ptr);
16071607
CHECK(dev_v_ptr);
16081608

1609-
DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
1609+
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
16101610
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);
16111611

16121612
int block_size = 128;
@@ -2579,7 +2579,7 @@ class KvSparseApplyAdamAsyncGPUOp : public OpKernel {
25792579
CHECK(dev_m_ptr);
25802580
CHECK(dev_v_ptr);
25812581

2582-
DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
2582+
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
25832583
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);
25842584

25852585
int block_size = 128;
@@ -3236,7 +3236,7 @@ class KvSparseApplyAdamWGPUOp : public OpKernel {
32363236
CHECK(dev_m_ptr);
32373237
CHECK(dev_v_ptr);
32383238

3239-
DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
3239+
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
32403240
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);
32413241

32423242
int block_size = 128;

0 commit comments

Comments
 (0)