|
| 1 | +#include "../../utils.hpp" |
| 2 | +#include "infinicore/common/hash.hpp" |
| 3 | +#include "infinicore/ops/common/cache.hpp" |
| 4 | +#include "infinicore/ops/random_sample.hpp" |
| 5 | +#include <infiniop.h> |
| 6 | + |
| 7 | +namespace infinicore::op::random_sample_impl::infiniop_backend { |
| 8 | + |
| 9 | +thread_local common::OpCache<size_t, infiniopRandomSampleDescriptor_t> caches( |
| 10 | + 100, // capacity |
| 11 | + [](infiniopRandomSampleDescriptor_t &desc) { |
| 12 | + if (desc != nullptr) { |
| 13 | + INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleDescriptor(desc)); |
| 14 | + desc = nullptr; |
| 15 | + } |
| 16 | + }); |
| 17 | + |
| 18 | +static void calculate( |
| 19 | + Tensor indices, |
| 20 | + Tensor logits, |
| 21 | + float random_val, |
| 22 | + float topp, |
| 23 | + int topk, |
| 24 | + float temperature) { |
| 25 | + // cache per (result desc + logits desc) on device |
| 26 | + size_t seed = hash_combine(indices, logits); |
| 27 | + |
| 28 | + auto device_type = context::getDevice().getType(); |
| 29 | + auto device_index = context::getDevice().getIndex(); |
| 30 | + |
| 31 | + auto &cache = caches.getCache(device_type, device_index); |
| 32 | + |
| 33 | + auto desc_opt = cache.get(seed); |
| 34 | + infiniopRandomSampleDescriptor_t desc = nullptr; |
| 35 | + |
| 36 | + if (!desc_opt) { |
| 37 | + INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor( |
| 38 | + context::getInfiniopHandle(), &desc, |
| 39 | + indices->desc(), logits->desc())); |
| 40 | + cache.put(seed, desc); |
| 41 | + } else { |
| 42 | + desc = *desc_opt; |
| 43 | + } |
| 44 | + |
| 45 | + size_t workspace_size = 0; |
| 46 | + INFINICORE_CHECK_ERROR(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size)); |
| 47 | + std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); |
| 48 | + |
| 49 | + INFINICORE_CHECK_ERROR(infiniopRandomSample( |
| 50 | + desc, |
| 51 | + workspace->data(), workspace_size, |
| 52 | + indices->data(), logits->data(), |
| 53 | + random_val, topp, topk, temperature, |
| 54 | + context::getStream())); |
| 55 | +} |
| 56 | + |
| 57 | +} // namespace infinicore::op::random_sample_impl::infiniop_backend |
| 58 | + |
| 59 | +namespace infinicore::op { |
| 60 | + |
| 61 | +static bool registered = []() { |
| 62 | + RandomSample::dispatcher().registerAll(&random_sample_impl::infiniop_backend::calculate, false); |
| 63 | + return true; |
| 64 | +}(); |
| 65 | + |
| 66 | +} // namespace infinicore::op |
0 commit comments