Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/infinicore/ops/random_sample.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

#include "infinicore/tensor.hpp"

namespace infinicore::op {

class RandomSample {
public:
using schema = void (*)(Tensor, Tensor, float, float, int, float);
static void execute(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
static common::OpDispatcher<schema> &dispatcher();
};

// Out-of-place API
Tensor random_sample(Tensor logits, float random_val, float topp, int topk, float temperature);
// In-place API
void random_sample_(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);

} // namespace infinicore::op
36 changes: 35 additions & 1 deletion python/infinicore/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"]
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]


def causal_softmax(input: Tensor, out=None) -> Tensor:
Expand Down Expand Up @@ -65,3 +65,37 @@ def swiglu(input: Tensor, other: Tensor, *, out=None):
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)

return out


def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""

if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)

_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)

return out
38 changes: 38 additions & 0 deletions src/infinicore/ops/random_sample/random_sample.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "infinicore/ops/random_sample.hpp"

namespace infinicore::op {

common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
static common::OpDispatcher<RandomSample::schema> dispatcher_;
return dispatcher_;
};

void RandomSample::execute(
Tensor indices, Tensor logits,
float random_val, float topp, int topk, float temperature) {
dispatcher().lookup(context::getDevice().getType())(
indices, logits, random_val, topp, topk, temperature);
}

Tensor random_sample(
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
auto indices = Tensor::empty({}, DataType::I32, logits->device());
random_sample_(indices, logits, random_val, topp, topk, temperature);
return indices;
}

void random_sample_(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
RandomSample::execute(indices, logits, random_val, topp, topk, temperature);
}

} // namespace infinicore::op
66 changes: 66 additions & 0 deletions src/infinicore/ops/random_sample/random_sample_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/random_sample.hpp"
#include <infiniop.h>

namespace infinicore::op::random_sample_impl::infiniop_backend {

thread_local common::OpCache<size_t, infiniopRandomSampleDescriptor_t> caches(
100, // capacity
[](infiniopRandomSampleDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleDescriptor(desc));
desc = nullptr;
}
});

static void calculate(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
// cache per (result desc + logits desc) on device
size_t seed = hash_combine(indices, logits);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopRandomSampleDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopRandomSample(
desc,
workspace->data(), workspace_size,
indices->data(), logits->data(),
random_val, topp, topk, temperature,
context::getStream()));
}

} // namespace infinicore::op::random_sample_impl::infiniop_backend

namespace infinicore::op {

static bool registered = []() {
RandomSample::dispatcher().registerAll(&random_sample_impl::infiniop_backend::calculate, false);
return true;
}();

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
Expand All @@ -20,6 +21,7 @@ inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_matmul(m);
bind_mul(m);
bind_rearrange(m);
Expand Down
32 changes: 32 additions & 0 deletions src/infinicore/pybind11/ops/random_sample.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/random_sample.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_random_sample(py::module &m) {
m.def("random_sample",
&op::random_sample,
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(Random sampling: returns an int32 scalar index.)doc");

m.def("random_sample_",
&op::random_sample_,
py::arg("indices"),
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(In-place random sampling into provided int32 scalar tensor.)doc");
}

} // namespace infinicore::ops
Loading