Skip to content

Commit 724606a

Browse files
committed
Merge branch 'forced_align_backptr' into forced_align_accessors
2 parents 847b726 + 5fa467d commit 724606a

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@ void forced_align_impl(
3232
const auto L = targets.size(1);
3333
const auto S = 2 * L + 1;
3434

35-
auto alphas_a = new scalar_t[S][2]; // scalar_t is just logProbs.dtype()
36-
for (int i = 0; i < S; i++) {
37-
alphas_a[i][0] = kNegInfinity;
38-
alphas_a[i][1] = kNegInfinity;
35+
auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype()
36+
for (int i = 0; i < 2 * S; i++) {
37+
alphas_a[i] = kNegInfinity;
3938
}
4039

4140
auto backPtr_a = new int8_t[T * S];
@@ -64,7 +63,8 @@ void forced_align_impl(
6463
auto end = (S == 1) ? 1 : 2;
6564
for (auto i = start; i < end; i++) {
6665
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
67-
alphas_a[i][0] = logProbs_a.index(batchIndex,0,labelIdx);
66+
alphas_a[i] = logProbs_a.index(batchIndex,0,labelIdx);
67+
6868
}
6969
for (auto t = 1; t < T; t++) {
7070
if (T - t <= L + R) {
@@ -87,18 +87,18 @@ void forced_align_impl(
8787
auto curIdxOffset = t % 2;
8888
auto prevIdxOffset = (t - 1) % 2;
8989
for (auto j = 0; j < S; ++j) {
90-
alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t>::infinity();
90+
alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t>::infinity(); // alphas_a[curIdxOffset][j]
9191
}
9292
if (start == 0) {
93-
alphas_a[0][curIdxOffset] =
94-
alphas_a[0][prevIdxOffset] + logProbs_a.index(batchIndex, t, blank);
93+
alphas_a[curIdxOffset * S] =
94+
alphas_a[prevIdxOffset * S] + logProbs_a.index(batchIndex, t, blank);
9595
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
9696
startloop += 1;
9797
}
9898

9999
for (auto i = startloop; i < end; i++) {
100-
auto x0 = alphas_a[i][prevIdxOffset];
101-
auto x1 = alphas_a[i - 1][prevIdxOffset];
100+
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
101+
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
102102
auto x2 = -std::numeric_limits<scalar_t>::infinity();
103103

104104
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
@@ -109,7 +109,7 @@ void forced_align_impl(
109109
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
110110
if (i % 2 != 0 && i != 1 &&
111111
targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) {
112-
x2 = alphas_a[i - 2][prevIdxOffset];
112+
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
113113
}
114114
scalar_t result = 0.0;
115115
if (x2 > x1 && x2 > x0) {
@@ -122,11 +122,13 @@ void forced_align_impl(
122122
result = x0;
123123
backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0
124124
}
125-
alphas_a[i][curIdxOffset] = result + logProbs_a.index(batchIndex, t, labelIdx);
125+
126+
alphas_a[curIdxOffset * S + i] = result + logProbs_a.index(batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i]
126127
}
127128
}
128129
auto idx1 = (T - 1) % 2;
129-
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
130+
auto ltrIdx = alphas_a[S * idx1 + S - 1] >
131+
alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
130132
delete[] alphas_a;
131133
// path stores the token index for each time step after force alignment.
132134
for (auto t = T - 1; t > -1; t--) {

0 commit comments

Comments
 (0)