Skip to content

Commit 67e73ba

Browse files
authored
[Offload] Refactor device/platform info queries (#146345)
This makes several small changes to how the platform and device info queries are handled: * ReturnHelper has been replaced with InfoWriter which is more explicit in how it is invoked. * InfoWriter consumes `llvm::Expected` rather than values directly, and will early exit if it returns an error. * As a result of the above, `GetInfoString` now correctly returns errors rather than empty strings. * The host device now has its own dedicated "getInfo" function rather than being checked in multiple places.
1 parent 619f7af commit 67e73ba

File tree

2 files changed

+90
-50
lines changed

2 files changed

+90
-50
lines changed

offload/liboffload/src/Helpers.hpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,39 +61,41 @@ llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize,
6161
array_length * sizeof(T), memcpy);
6262
}
6363

64-
template <>
65-
inline llvm::Error
66-
getInfo<const char *>(size_t ParamValueSize, void *ParamValue,
67-
size_t *ParamValueSizeRet, const char *Value) {
68-
return getInfoArray(strlen(Value) + 1, ParamValueSize, ParamValue,
69-
ParamValueSizeRet, Value);
64+
llvm::Error getInfoString(size_t ParamValueSize, void *ParamValue,
65+
size_t *ParamValueSizeRet, llvm::StringRef Value) {
66+
return getInfoArray(Value.size() + 1, ParamValueSize, ParamValue,
67+
ParamValueSizeRet, Value.data());
7068
}
7169

72-
class ReturnHelper {
70+
class InfoWriter {
7371
public:
74-
ReturnHelper(size_t ParamValueSize, void *ParamValue,
75-
size_t *ParamValueSizeRet)
76-
: ParamValueSize(ParamValueSize), ParamValue(ParamValue),
77-
ParamValueSizeRet(ParamValueSizeRet) {}
72+
InfoWriter(size_t Size, void *Target, size_t *SizeRet)
73+
: Size(Size), Target(Target), SizeRet(SizeRet) {};
74+
InfoWriter() = delete;
75+
InfoWriter(InfoWriter &) = delete;
76+
~InfoWriter() = default;
7877

79-
// A version where in/out info size is represented by a single pointer
80-
// to a value which is updated on return
81-
ReturnHelper(size_t *ParamValueSize, void *ParamValue)
82-
: ParamValueSize(*ParamValueSize), ParamValue(ParamValue),
83-
ParamValueSizeRet(ParamValueSize) {}
78+
template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
79+
if (Val)
80+
return getInfo(Size, Target, SizeRet, *Val);
81+
return Val.takeError();
82+
}
8483

85-
// Scalar return Value
86-
template <class T> llvm::Error operator()(const T &t) {
87-
return getInfo(ParamValueSize, ParamValue, ParamValueSizeRet, t);
84+
template <typename T>
85+
llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
86+
if (Val)
87+
return getInfoArray(Elems, Size, Target, SizeRet, *Val);
88+
return Val.takeError();
8889
}
8990

90-
// Array return Value
91-
template <class T> llvm::Error operator()(const T *t, size_t s) {
92-
return getInfoArray(s, ParamValueSize, ParamValue, ParamValueSizeRet, t);
91+
llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
92+
if (Val)
93+
return getInfoString(Size, Target, SizeRet, *Val);
94+
return Val.takeError();
9395
}
9496

95-
protected:
96-
size_t ParamValueSize;
97-
void *ParamValue;
98-
size_t *ParamValueSizeRet;
97+
private:
98+
size_t Size;
99+
void *Target;
100+
size_t *SizeRet;
99101
};

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "OffloadImpl.hpp"
1515
#include "Helpers.hpp"
16+
#include "OffloadPrint.hpp"
1617
#include "PluginManager.h"
1718
#include "llvm/Support/FormatVariadic.h"
1819
#include <OffloadAPI.h>
@@ -234,23 +235,22 @@ Error olShutDown_impl() {
234235
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
235236
ol_platform_info_t PropName, size_t PropSize,
236237
void *PropValue, size_t *PropSizeRet) {
237-
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
238+
InfoWriter Info(PropSize, PropValue, PropSizeRet);
238239
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
239240

240241
switch (PropName) {
241242
case OL_PLATFORM_INFO_NAME:
242-
return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName());
243+
return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
243244
case OL_PLATFORM_INFO_VENDOR_NAME:
244245
// TODO: Implement this
245-
return ReturnValue("Unknown platform vendor");
246+
return Info.writeString("Unknown platform vendor");
246247
case OL_PLATFORM_INFO_VERSION: {
247-
return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
248-
OL_VERSION_MINOR, OL_VERSION_PATCH)
249-
.str()
250-
.c_str());
248+
return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
249+
OL_VERSION_MINOR, OL_VERSION_PATCH)
250+
.str());
251251
}
252252
case OL_PLATFORM_INFO_BACKEND: {
253-
return ReturnValue(Platform->BackendType);
253+
return Info.write<ol_platform_backend_t>(Platform->BackendType);
254254
}
255255
default:
256256
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
@@ -277,36 +277,68 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
277277
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
278278
ol_device_info_t PropName, size_t PropSize,
279279
void *PropValue, size_t *PropSizeRet) {
280+
assert(Device != OffloadContext::get().HostDevice());
281+
InfoWriter Info(PropSize, PropValue, PropSizeRet);
280282

281-
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
283+
auto makeError = [&](ErrorCode Code, StringRef Err) {
284+
std::string ErrBuffer;
285+
llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
286+
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
287+
};
282288

283289
// Find the info if it exists under any of the given names
284-
auto GetInfoString = [&](std::vector<std::string> Names) {
285-
if (Device == OffloadContext::get().HostDevice())
286-
return "Host";
287-
288-
for (auto Name : Names) {
289-
if (auto Entry = Device->Info.get(Name))
290+
auto getInfoString =
291+
[&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
292+
for (auto &Name : Names) {
293+
if (auto Entry = Device->Info.get(Name)) {
294+
if (!std::holds_alternative<std::string>((*Entry)->Value))
295+
return makeError(ErrorCode::BACKEND_FAILURE,
296+
"plugin returned incorrect type");
290297
return std::get<std::string>((*Entry)->Value).c_str();
298+
}
291299
}
292300

293-
return "";
301+
return makeError(ErrorCode::UNIMPLEMENTED,
302+
"plugin did not provide a response for this information");
294303
};
295304

296305
switch (PropName) {
297306
case OL_DEVICE_INFO_PLATFORM:
298-
return ReturnValue(Device->Platform);
307+
return Info.write<void *>(Device->Platform);
308+
case OL_DEVICE_INFO_TYPE:
309+
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
310+
case OL_DEVICE_INFO_NAME:
311+
return Info.writeString(getInfoString({"Device Name"}));
312+
case OL_DEVICE_INFO_VENDOR:
313+
return Info.writeString(getInfoString({"Vendor Name"}));
314+
case OL_DEVICE_INFO_DRIVER_VERSION:
315+
return Info.writeString(
316+
getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
317+
default:
318+
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
319+
"getDeviceInfo enum '%i' is invalid", PropName);
320+
}
321+
322+
return Error::success();
323+
}
324+
325+
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
326+
ol_device_info_t PropName, size_t PropSize,
327+
void *PropValue, size_t *PropSizeRet) {
328+
assert(Device == OffloadContext::get().HostDevice());
329+
InfoWriter Info(PropSize, PropValue, PropSizeRet);
330+
331+
switch (PropName) {
332+
case OL_DEVICE_INFO_PLATFORM:
333+
return Info.write<void *>(Device->Platform);
299334
case OL_DEVICE_INFO_TYPE:
300-
return Device == OffloadContext::get().HostDevice()
301-
? ReturnValue(OL_DEVICE_TYPE_HOST)
302-
: ReturnValue(OL_DEVICE_TYPE_GPU);
335+
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
303336
case OL_DEVICE_INFO_NAME:
304-
return ReturnValue(GetInfoString({"Device Name"}));
337+
return Info.writeString("Virtual Host Device");
305338
case OL_DEVICE_INFO_VENDOR:
306-
return ReturnValue(GetInfoString({"Vendor Name"}));
339+
return Info.writeString("Liboffload");
307340
case OL_DEVICE_INFO_DRIVER_VERSION:
308-
return ReturnValue(
309-
GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
341+
return Info.writeString(LLVM_VERSION_STRING);
310342
default:
311343
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
312344
"getDeviceInfo enum '%i' is invalid", PropName);
@@ -317,12 +349,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
317349

318350
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
319351
size_t PropSize, void *PropValue) {
352+
if (Device == OffloadContext::get().HostDevice())
353+
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
354+
nullptr);
320355
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
321356
nullptr);
322357
}
323358

324359
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
325360
ol_device_info_t PropName, size_t *PropSizeRet) {
361+
if (Device == OffloadContext::get().HostDevice())
362+
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
363+
PropSizeRet);
326364
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
327365
}
328366

0 commit comments

Comments
 (0)