Skip to content

Commit 9449f2e

Browse files
zhuyuegongchensu
authored andcommitted
Add random_sample python interface and tests.
1 parent 943b38e commit 9449f2e

File tree

7 files changed

+412
-0
lines changed

7 files changed

+412
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
#include "infinicore/tensor.hpp"
7+
8+
namespace infinicore::op {
9+
10+
class RandomSample {
11+
public:
12+
using schema = void (*)(Tensor, Tensor, float, float, int, float);
13+
static void execute(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
14+
static common::OpDispatcher<schema> &dispatcher();
15+
};
16+
17+
// Out-of-place API
18+
Tensor random_sample(Tensor logits, float random_val, float topp, int topk, float temperature);
19+
// In-place API
20+
void random_sample_(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
21+
22+
} // namespace infinicore::op
23+
24+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def random_sample(logits, random_val, topp, topk, temperature, *, out=None):
6+
if out is None:
7+
return Tensor(
8+
_infinicore.random_sample(
9+
logits._underlying, random_val, topp, topk, temperature
10+
)
11+
)
12+
13+
_infinicore.random_sample_(
14+
out._underlying,
15+
logits._underlying,
16+
random_val,
17+
topp,
18+
topk,
19+
temperature,
20+
)
21+
22+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "infinicore/ops/random_sample.hpp"
2+
3+
namespace infinicore::op {
4+
5+
common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
6+
static common::OpDispatcher<RandomSample::schema> dispatcher_;
7+
return dispatcher_;
8+
};
9+
10+
void RandomSample::execute(
11+
Tensor indices, Tensor logits,
12+
float random_val, float topp, int topk, float temperature) {
13+
dispatcher().lookup(context::getDevice().getType())(
14+
indices, logits, random_val, topp, topk, temperature);
15+
}
16+
17+
Tensor random_sample(
18+
Tensor logits,
19+
float random_val,
20+
float topp,
21+
int topk,
22+
float temperature) {
23+
auto indices = Tensor::empty({}, DataType::I32, logits->device());
24+
random_sample_(indices, logits, random_val, topp, topk, temperature);
25+
return indices;
26+
}
27+
28+
void random_sample_(
29+
Tensor indices,
30+
Tensor logits,
31+
float random_val,
32+
float topp,
33+
int topk,
34+
float temperature) {
35+
RandomSample::execute(indices, logits, random_val, topp, topk, temperature);
36+
}
37+
38+
} // namespace infinicore::op
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/random_sample.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::random_sample_impl::infiniop_backend {
8+
9+
thread_local common::OpCache<size_t, infiniopRandomSampleDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopRandomSampleDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
static void calculate(
19+
Tensor indices,
20+
Tensor logits,
21+
float random_val,
22+
float topp,
23+
int topk,
24+
float temperature) {
25+
// cache per (result desc + logits desc) on device
26+
size_t seed = hash_combine(indices, logits);
27+
28+
auto device_type = context::getDevice().getType();
29+
auto device_index = context::getDevice().getIndex();
30+
31+
auto &cache = caches.getCache(device_type, device_index);
32+
33+
auto desc_opt = cache.get(seed);
34+
infiniopRandomSampleDescriptor_t desc = nullptr;
35+
36+
if (!desc_opt) {
37+
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
38+
context::getInfiniopHandle(), &desc,
39+
indices->desc(), logits->desc()));
40+
cache.put(seed, desc);
41+
} else {
42+
desc = *desc_opt;
43+
}
44+
45+
size_t workspace_size = 0;
46+
INFINICORE_CHECK_ERROR(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
47+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
48+
49+
INFINICORE_CHECK_ERROR(infiniopRandomSample(
50+
desc,
51+
workspace->data(), workspace_size,
52+
indices->data(), logits->data(),
53+
random_val, topp, topk, temperature,
54+
context::getStream()));
55+
}
56+
57+
} // namespace infinicore::op::random_sample_impl::infiniop_backend
58+
59+
namespace infinicore::op {
60+
61+
static bool registered = []() {
62+
RandomSample::dispatcher().registerAll(&random_sample_impl::infiniop_backend::calculate, false);
63+
return true;
64+
}();
65+
66+
} // namespace infinicore::op

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ops/causal_softmax.hpp"
88
#include "ops/gemm.hpp"
99
#include "ops/matmul.hpp"
10+
#include "ops/random_sample.hpp"
1011
#include "ops/rearrange.hpp"
1112
#include "ops/rms_norm.hpp"
1213
#include "ops/silu.hpp"
@@ -21,6 +22,7 @@ inline void bind(py::module &m) {
2122
bind_attention(m);
2223
bind_causal_softmax(m);
2324
bind_gemm(m);
25+
bind_random_sample(m);
2426
bind_matmul(m);
2527
bind_rearrange(m);
2628
bind_rms_norm(m);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/random_sample.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_random_sample(py::module &m) {
12+
m.def("random_sample",
13+
&op::random_sample,
14+
py::arg("logits"),
15+
py::arg("random_val"),
16+
py::arg("topp"),
17+
py::arg("topk"),
18+
py::arg("temperature"),
19+
R"doc(Random sampling: returns an int32 scalar index.)doc");
20+
21+
m.def("random_sample_",
22+
&op::random_sample_,
23+
py::arg("indices"),
24+
py::arg("logits"),
25+
py::arg("random_val"),
26+
py::arg("topp"),
27+
py::arg("topk"),
28+
py::arg("temperature"),
29+
R"doc(In-place random sampling into provided int32 scalar tensor.)doc");
30+
}
31+
32+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)