Skip to content

Commit 0b6ff52

Browse files
committed
Extract accessor definition into separate PR
1 parent bdd9c72 commit 0b6ff52

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

src/libtorchaudio/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(
66
lfilter.cpp
77
overdrive.cpp
88
utils.cpp
9+
accessor_tests.cpp
910
)
1011

1112
set(

src/libtorchaudio/accessor.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include <torch/csrc/stable/tensor.h>
4+
#include <type_traits>
5+
#include <cstdarg>
6+
7+
using torch::stable::Tensor;
8+
9+
template<unsigned int k, typename T, bool IsConst = true>
10+
class Accessor {
11+
int64_t strides[k];
12+
T *data;
13+
14+
public:
15+
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;
16+
17+
Accessor(tensor_type tensor) {
18+
data = (T*)tensor.template data_ptr();
19+
for (unsigned int i = 0; i < k; i++) {
20+
strides[i] = tensor.stride(i);
21+
}
22+
}
23+
24+
T index(...) {
25+
va_list args;
26+
va_start(args, k);
27+
int64_t ix = 0;
28+
for (unsigned int i = 0; i < k; i++) {
29+
ix += strides[i] * va_arg(args, int);
30+
}
31+
va_end(args);
32+
return data[ix];
33+
}
34+
35+
template<bool C = IsConst>
36+
typename std::enable_if<!C, void>::type set_index(T value, ...) {
37+
va_list args;
38+
va_start(args, value);
39+
int64_t ix = 0;
40+
for (unsigned int i = 0; i < k; i++) {
41+
ix += strides[i] * va_arg(args, int);
42+
}
43+
va_end(args);
44+
data[ix] = value;
45+
}
46+
};
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <libtorchaudio/accessor.h>
2+
#include <cstdint>
3+
#include <torch/torch.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/csrc/stable/library.h>
6+
7+
namespace torchaudio {
8+
9+
namespace accessor_tests {
10+
11+
using namespace std;
12+
using torch::stable::Tensor;
13+
14+
bool test_accessor(const Tensor tensor) {
15+
int64_t* data_ptr = (int64_t*)tensor.data_ptr();
16+
auto accessor = Accessor<3, int64_t>(tensor);
17+
for (unsigned int i = 0; i < tensor.size(0); i++) {
18+
for (unsigned int j = 0; j < tensor.size(1); j++) {
19+
for (unsigned int k = 0; k < tensor.size(2); k++) {
20+
auto check = *(data_ptr++) == accessor.index(i, j, k);
21+
if (!check) {
22+
return false;
23+
}
24+
}
25+
}
26+
}
27+
return true;
28+
}
29+
30+
void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
31+
Tensor t1(to<AtenTensorHandle>(stack[0]));
32+
auto result = test_accessor(std::move(t1));
33+
stack[0] = from(result);
34+
}
35+
36+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
37+
m.def(
38+
"_test_accessor(Tensor log_probs) -> bool");
39+
}
40+
41+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
42+
m.impl("torchaudio::_test_accessor", &boxed_test_accessor);
43+
}
44+
45+
}
46+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torch
2+
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
3+
4+
if _IS_TORCHAUDIO_EXT_AVAILABLE:
5+
def test_accessor():
6+
tensor = torch.randint(1000, (5,4,3))
7+
assert torch.ops.torchaudio._test_accessor(tensor)

0 commit comments

Comments
 (0)