-
Notifications
You must be signed in to change notification settings - Fork 110
Description
Because TensorFlow.jl with_device() expects 1-based device numberings, we cannot natively use the output of DeviceList() to be fed into with_device(), as DeviceList() gives zero-indexed device names.
My current workaround is to do something like the following:
function get_device(sess, device_type)
# Find first device with the given device type (e.g. `"XLA_GPU"`)
devices = collect(TensorFlow.DeviceList(sess))
device = first(filter(x -> x.device_type == device_type, devices))
# Fixup this device name so that it is 1-indexed, as TensorFlow.jl requires
inc_number(x::Number) = x + 1
inc_number(x) = x
function fixup_device!(d::TensorFlow.Device)
d.parts[:] .= [TensorFlow.DevicePart(p.kind, inc_number(p.index)) for p in d.parts]
return d
end
fixup_device!(d) = d
return fixup_device!(TensorFlow.Device(device.name))
endPersonally, I would prefer that with_device used zero-indexed device names, as on systems with multiple devices (e.g. multiple GPUs), it adds an unnecessary extra mental burden to always remember that /job:job/replica:1/task:1/device:CPU:1 in TensorFlow.jl is not the same thing as /job:job/replica:1/task:1/device:CPU:1 when dealing with anything else in the TensorFlow ecosystem. Regardless, we should be consistent so that the output of one function can be fed to another within TensorFlow.jl.