Skip to content

Commit ecd6f7a

Browse files
Luke Iwanskibenoitsteiner
authored andcommitted
OpenCL improvements
- Bumps Eigen Version - Refactors Ops registration - Introduces workaround for Const Op related to the difference between CUDA which uses pointers and OpenCL that uses buffers/accessors - Extends memory types to cover DEVICE_SYCL as well - Introduces GetSYCLDevice() method that returns list of supported devices with GPU device having the highest priority ( doesn't include blacklisted devices ) - ::internal::Transpose -> tensorflow::internal::Transpose in order to avoid compilation reported error - re-introduces fix for bugged string replacement causing a lot of compilation warnings -c -> --include - Adds sycl_runtime to bazels ARRAY_DEPS - Replicates TF_CALL_GPU_PROXY_TYPES for SYCL
1 parent 184dfd9 commit ecd6f7a

File tree

92 files changed

+808
-837
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+808
-837
lines changed

tensorflow/core/common_runtime/direct_session_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,6 @@ class BlockingOp : public OpKernel {
871871
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
872872
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
873873

874-
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_SYCL), BlockingOp);
875-
876874
static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
877875
FunctionDefLibrary library_graph_def;
878876
if (use_function_lib) {
@@ -910,6 +908,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
910908
->set_opt_level(OptimizerOptions_Level_L0);
911909
(*options.config.mutable_device_count())["CPU"] = 2;
912910
(*options.config.mutable_device_count())["GPU"] = 0;
911+
(*options.config.mutable_device_count())["SYCL"] = 0;
913912

914913
options.config.add_session_inter_op_thread_pool();
915914
auto* p = options.config.add_session_inter_op_thread_pool();

tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelWarmup) {
138138
DirectSession* ds = static_cast<DirectSession*>(session.get());
139139
CostModelManager::CostModelMap cost_models;
140140
ds->ExportCostModels(&cost_models);
141-
CHECK_EQ(cost_models.size(), 1);
141+
ASSERT_GE(2, cost_models.size());
142+
ASSERT_LE(1, cost_models.size());
142143
const CostModel* cm = (*cost_models.begin()).second;
143144
EXPECT_EQ(measure_steps, cm->GetUpdateTimes());
144145
}
@@ -155,10 +156,16 @@ static void TestHWAccelerator(bool enableHWTrace) {
155156
test::FillValues<float>(&x_tensor, {1, 1});
156157
Node* x = test::graph::Constant(&graph, x_tensor);
157158
x->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
159+
#ifdef TENSORFLOW_USE_SYCL
160+
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
161+
#endif // TENSORFLOW_USE_SYCL
158162

159163
// y = A * x
160164
Node* y = test::graph::Matmul(&graph, a, x, false, false);
161165
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
166+
#ifdef TENSORFLOW_USE_SYCL
167+
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
168+
#endif // TENSORFLOW_USE_SYCL
162169

163170
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
164171
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
@@ -169,6 +176,9 @@ static void TestHWAccelerator(bool enableHWTrace) {
169176
SessionOptions options;
170177
(*options.config.mutable_device_count())["CPU"] = 1;
171178
(*options.config.mutable_device_count())["GPU"] = 1;
179+
#ifdef TENSORFLOW_USE_SYCL
180+
(*options.config.mutable_device_count())["SYCL"] = 1;
181+
#endif // TENSORFLOW_USE_SYCL
172182
options.config.set_allow_soft_placement(true);
173183
options.config.mutable_graph_options()->set_build_cost_model(1);
174184
std::unique_ptr<Session> session(NewSession(options));

tensorflow/core/common_runtime/memory_types.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ struct EndpointEq {
4545
static Status ProcessMemoryTypes(
4646
DeviceType device_type, const Graph* g,
4747
std::function<Status(const Edge*, MemoryType, MemoryType)> fn) {
48-
if (device_type != DEVICE_GPU) {
49-
// On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always
48+
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) {
49+
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
5050
// compatible.
5151
return Status::OK();
5252
}
53-
// For GPU device, HOST_MEMORY and DEVICE_MEMORY is not
53+
// For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not
5454
// compatible. I.e., a conversion/transfer must be done.
5555
//
5656
// {node id, slot id} -> memory type.

tensorflow/core/common_runtime/memory_types_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TEST(MemoryTypeChecker, Int32OK) {
3434
// There is a kernel for adding two int32s on host memory.
3535
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
3636
#endif // GOOGLE_CUDA
37+
#ifdef TENSORFLOW_USE_SYCL
38+
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
39+
#endif // TENSORFLOW_USE_SYCL
3740
delete g;
3841
}
3942

@@ -53,6 +56,15 @@ TEST(MemoryTypeChecker, Int32NotOk) {
5356
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/gpu:0", g));
5457
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
5558
#endif // GOOGLE_CUDA
59+
#ifdef TENSORFLOW_USE_SYCL
60+
// There is no kernel for casting int32/host memory to float/device
61+
// memory.
62+
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_SYCL, g)));
63+
64+
// But we can insert _HostSend/_HostRecv to ensure the invariant.
65+
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g));
66+
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
67+
#endif // TENSORFLOW_USE_SYCL
5668
delete g;
5769
}
5870

@@ -74,6 +86,12 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) {
7486
// int Switch's output on GPU has HOST_MEMORY constraint.
7587
EXPECT_EQ(memory_type, HOST_MEMORY);
7688
#endif // GOOGLE_CUDA
89+
#ifdef TENSORFLOW_USE_SYCL
90+
auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred);
91+
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));
92+
// int Switch's output on GPU has HOST_MEMORY constraint.
93+
EXPECT_EQ(memory_type, HOST_MEMORY);
94+
#endif // TENSORFLOW_USE_SYCL
7795
delete g;
7896
}
7997

tensorflow/core/common_runtime/sycl/sycl_device_factory.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818
#include "tensorflow/core/common_runtime/device_factory.h"
1919
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
2020

21+
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
22+
2123
namespace tensorflow {
2224

2325
class SYCLDeviceFactory : public DeviceFactory {
@@ -34,7 +36,7 @@ class SYCLDeviceFactory : public DeviceFactory {
3436
devices->push_back(
3537
new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality(),
3638
SYCLDevice::GetShortDeviceDescription(),
37-
cl::sycl::gpu_selector(), cpu_allocator()));
39+
GetSYCLDevice(), cpu_allocator()));
3840
}
3941
return Status::OK();
4042
}

tensorflow/core/common_runtime/sycl/sycl_util.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ namespace tensorflow {
3030
}
3131

3232
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
33+
34+
inline cl::sycl::device GetSYCLDevice() {
35+
// Obtain list of supported devices from Eigen
36+
for (const auto& device :Eigen::get_sycl_supported_devices()) {
37+
if(device.is_gpu()) {
38+
// returns first found GPU
39+
return device;
40+
}
41+
}
42+
43+
// Currently Intel GPU is not supported
44+
LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, trying OpenCL CPU";
45+
46+
for (const auto& device :Eigen::get_sycl_supported_devices()) {
47+
if(device.is_cpu()) {
48+
// returns first found CPU
49+
return device;
50+
}
51+
}
52+
// Currently Intel GPU is not supported
53+
LOG(FATAL) << "No OpenCL GPU nor CPU found that is supported by ComputeCpp";
54+
}
3355
}
3456

3557
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_

tensorflow/core/debug/debug_gateway.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
8484
// Determine if the tensor is on device (GPU) or host (CPU).
8585
// The second part of the check is necessary because even an OpKernel on
8686
// may have output tensors allocated on CPU.
87-
if (device->name().find("gpu:") != string::npos &&
87+
if ((device->name().find("gpu:") != string::npos || device->name().find("SYCL:") != string::npos) &&
8888
!ctx->output_alloc_attr(output_slot).on_host()) {
8989
// GPU tensors: Copy it to host (CPU).
9090
DeviceContext* device_ctxt = ctx->op_device_context();

tensorflow/core/debug/debug_gateway_test.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class SessionDebugMinusAXTest : public ::testing::Test {
4545

4646
#if GOOGLE_CUDA
4747
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
48+
#elif defined(TENSORFLOW_USE_SYCL)
49+
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
4850
#else
4951
const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0";
5052
#endif
@@ -302,6 +304,8 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
302304
// through RunMetadata, given whether GPU is involved.
303305
#if GOOGLE_CUDA
304306
ASSERT_EQ(2, run_metadata.partition_graphs().size());
307+
#elif defined(TENSORFLOW_USE_SYCL)
308+
ASSERT_EQ(2, run_metadata.partition_graphs().size());
305309
#else
306310
ASSERT_EQ(1, run_metadata.partition_graphs().size());
307311
#endif
@@ -336,7 +340,7 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
336340
ASSERT_EQ(1, debug_nan_count_tensor_vals[0].scalar<int64>()());
337341
}
338342

339-
#ifndef GOOGLE_CUDA
343+
#if !defined(GOOGLE_CUDA) && !defined(TENSORFLOW_USE_SYCL)
340344
// TODO(cais): Reinstate the following test for concurrent debugged runs on
341345
// a GPU once the root cause of the ~0.5% flakiness has been addressed.
342346
// (b/34081273)
@@ -499,6 +503,8 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
499503

500504
#if GOOGLE_CUDA
501505
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
506+
#elif defined(TENSORFLOW_USE_SYCL)
507+
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
502508
#else
503509
const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0";
504510
#endif
@@ -599,6 +605,8 @@ class SessionDebugVariableTest : public ::testing::Test {
599605

600606
#if GOOGLE_CUDA
601607
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
608+
#elif defined(TENSORFLOW_USE_SYCL)
609+
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
602610
#else
603611
const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0";
604612
#endif
@@ -818,6 +826,8 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
818826

819827
#if GOOGLE_CUDA
820828
ASSERT_EQ(2, run_metadata.partition_graphs().size());
829+
#elif defined(TENSORFLOW_USE_SYCL)
830+
ASSERT_EQ(2, run_metadata.partition_graphs().size());
821831
#else
822832
ASSERT_EQ(1, run_metadata.partition_graphs().size());
823833
#endif
@@ -855,13 +865,17 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
855865
ASSERT_EQ(2, debug_nan_count_tensor_vals[0].scalar<int64>()());
856866
}
857867

858-
#if GOOGLE_CUDA
868+
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_SYCL)
859869
class SessionDebugGPUSwitchTest : public ::testing::Test {
860870
public:
861871
void Initialize() {
862872
Graph graph(OpRegistry::Global());
863873

874+
#ifdef GOOGLE_CUDA
864875
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
876+
#elif TENSORFLOW_USE_SYCL
877+
const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
878+
#endif
865879

866880
Tensor vb(DT_BOOL, TensorShape({}));
867881
vb.scalar<bool>()() = true;

tensorflow/core/framework/op_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ OpKernel::OpKernel(OpKernelConstruction* context)
9494
OP_REQUIRES_OK(context, CheckOpDeprecation(context->op_def(),
9595
context->graph_def_version()));
9696

97-
// Kernels executing on GPU tie very few resources on the CPU where the
97+
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
9898
// scheduler runs: we consider them as inexpensive.
99-
expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
99+
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && context->device_type() != DeviceType(DEVICE_SYCL);
100100
}
101101

102102
OpKernel::~OpKernel() {}

tensorflow/core/graph/testlib.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ namespace tensorflow {
3636
REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
3737
REGISTER_KERNEL_BUILDER(
3838
Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
39+
#ifdef TENSORFLOW_USE_SYCL
40+
REGISTER_KERNEL_BUILDER(
41+
Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
42+
#endif // TENSORFLOW_USE_SYCL
3943

4044
// Register the HostConst Op
4145
// Returns a constant tensor on the host. Useful for writing C++ tests

0 commit comments

Comments
 (0)