Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 166 additions & 59 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2303,14 +2303,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
}
}

// Sets arguments for a given kernel and device based on the argument type.
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
// extension.
static void SetArgBasedOnType(
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
// Gets UR argument struct for a given kernel and device based on the argument
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
// the graphs extension (LaunchWithArgs for graphs is planned future work).
static void GetUrArgsBasedOnType(
device_image_impl *DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
Expand All @@ -2330,52 +2330,61 @@ static void SetArgBasedOnType(
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
Value});
break;
}
case kernel_param_kind_t::kind_std_layout: {
ur_exp_kernel_arg_type_t Type;
if (Arg.MPtr) {
Adapter.call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
} else {
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
}
ur_exp_kernel_arg_value_t Value = {};
Value.value = {Arg.MPtr};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
Type, static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Arg.MSize), Value});

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
ur_exp_kernel_arg_value_t Value = {};
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(ur_sampler_handle_t), Value});
Comment on lines +2361 to +2364
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, that's too much copy-paste... I think either

auto AddArg = [](<something>) {
};
switch () {
...
case foo: {
  ....
  AddArg(<something>);
}
}

or maybe wrap the whole switch into a lambda returning something would be a way to address that.

break;
}
case kernel_param_kind_t::kind_pointer: {
// We need to de-rerence this to get the actual USM allocation - that's the
ur_exp_kernel_arg_value_t Value = {};
// We need to de-rerence to get the actual USM allocation - that's the
// pointer UR is expecting.
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex), sizeof(Arg.MPtr),
Value});
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex),
sizeof(SpecConstsBuffer), Value});
break;
}
case kernel_param_kind_t::kind_invalid:
Expand Down Expand Up @@ -2404,22 +2413,33 @@ static ur_result_t SetKernelParamsAndLaunch(
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
}

// just a performance optimization - avoid heap allocations
static thread_local std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
UrArgs.reserve(Args.size());
UrArgs.clear();

if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures) {
auto setFunc = [&Adapter, Kernel,
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
auto setFunc = [KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
size_t NextTrueIndex) {
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
switch (ParamDesc.kind) {
case kernel_param_kind_t::kind_std_layout: {
int Size = ParamDesc.info;
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
Size, nullptr, ArgPtr);
ur_exp_kernel_arg_value_t Value = {};
Value.value = ArgPtr;
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_VALUE,
static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Size), Value});
break;
}
case kernel_param_kind_t::kind_pointer: {
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
ur_exp_kernel_arg_value_t Value = {};
Value.pointer = *static_cast<const void *const *>(ArgPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(Value.pointer), Value});
break;
}
default:
Expand All @@ -2429,23 +2449,28 @@ static ur_result_t SetKernelParamsAndLaunch(
applyFuncOnFilteredArgs(EliminatedArgMask, DeviceKernelInfo.NumParams,
DeviceKernelInfo.ParamDescGetter, setFunc);
} else {
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImpl(), Arg, NextTrueIndex);
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs);
};
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
}

const std::optional<int> &ImplicitLocalArg =
DeviceKernelInfo.getImplicitLocalArgPos();
std::optional<int> ImplicitLocalArg =
ProgramManager::getInstance().kernelImplicitLocalArgPos(
DeviceKernelInfo.Name);
// Set the implicit local memory buffer to support
// get_work_group_scratch_memory. This is for backend not supporting
// CUDA-style local memory setting. Note that we may have -1 as a position,
// this indicates the buffer is actually unused and was elided.
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
Adapter.call<UrApiKind::urKernelSetArgLocal>(
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
nullptr,
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
static_cast<uint32_t>(ImplicitLocalArg.value()),
WorkGroupMemorySize,
{nullptr}});
}

adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
Expand All @@ -2468,16 +2493,14 @@ static ur_result_t SetKernelParamsAndLaunch(
/* pPropSizeRet = */ nullptr);

const bool EnforcedLocalSize =
(RequiredWGSize[0] != 0 &&
(NDRDesc.Dims < 2 || RequiredWGSize[1] != 0) &&
(NDRDesc.Dims < 3 || RequiredWGSize[2] != 0));
(RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 ||
RequiredWGSize[2] != 0);
if (EnforcedLocalSize)
LocalSize = RequiredWGSize;
}

const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 &&
(NDRDesc.Dims < 2 || NDRDesc.GlobalOffset[1] != 0) &&
(NDRDesc.Dims < 3 || NDRDesc.GlobalOffset[2] != 0);
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 ||
NDRDesc.GlobalOffset[1] != 0 ||
NDRDesc.GlobalOffset[2] != 0;

std::vector<ur_kernel_launch_property_t> property_list;

Expand Down Expand Up @@ -2505,20 +2528,104 @@ static ur_result_t SetKernelParamsAndLaunch(
{{WorkGroupMemorySize}}});
}
ur_event_handle_t UREvent = nullptr;
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
LocalSize, property_list.size(),
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
ur_result_t Error =
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
property_list.size(),
property_list.empty() ? nullptr : property_list.data(),
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
OutEventImpl->setHandle(UREvent);
}

return Error;
}

// Sets arguments for a given kernel and device based on the argument type.
// This is a legacy path which the graphs extension still uses.
static void SetArgBasedOnType(
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
device_image_impl *DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
case kernel_param_kind_t::kind_work_group_memory:
break;
case kernel_param_kind_t::kind_stream:
break;
case kernel_param_kind_t::kind_dynamic_accessor:
case kernel_param_kind_t::kind_accessor: {
Requirement *Req = (Requirement *)(Arg.MPtr);

// getMemAllocationFunc is nullptr when there are no requirements. However,
// we may pass default constructed accessors to a command, which don't add
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
// valid case, so we need to properly handle it.
ur_mem_handle_t MemArg =
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
break;
}
case kernel_param_kind_t::kind_std_layout: {
if (Arg.MPtr) {
Adapter.call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
} else {
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
}

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
break;
}
case kernel_param_kind_t::kind_pointer: {
// We need to de-rerence this to get the actual USM allocation - that's the
// pointer UR is expecting.
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
break;
}
case kernel_param_kind_t::kind_invalid:
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Invalid kernel param kind " +
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
break;
}
}

static std::tuple<ur_kernel_handle_t, device_image_impl *,
const KernelArgMask *>
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
queue q;

submit_kernel(q); // starts a batch
// CHECK: ---> urEnqueueKernelLaunch
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// Initializing Level Zero driver is required if this test is linked
Expand All @@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
// CHECK-NOT: zeCommandQueueExecuteCommandLists

submit_kernel(q);
// CHECK: ---> urEnqueueKernelLaunch
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// interop should close the batch
Expand Down
Loading
Loading