@@ -32,10 +32,9 @@ void forced_align_impl(
32
32
const auto L = targets.size (1 );
33
33
const auto S = 2 * L + 1 ;
34
34
35
- auto alphas_a = new scalar_t [S][2 ]; // scalar_t is just logProbs.dtype()
36
- for (int i = 0 ; i < S; i++) {
37
- alphas_a[i][0 ] = kNegInfinity ;
38
- alphas_a[i][1 ] = kNegInfinity ;
35
+ auto alphas_a = new scalar_t [2 * S]; // scalar_t is just logProbs.dtype()
36
+ for (int i = 0 ; i < 2 * S; i++) {
37
+ alphas_a[i] = kNegInfinity ;
39
38
}
40
39
41
40
auto backPtr_a = new int8_t [T * S];
@@ -64,7 +63,8 @@ void forced_align_impl(
64
63
auto end = (S == 1 ) ? 1 : 2 ;
65
64
for (auto i = start; i < end; i++) {
66
65
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a.index (batchIndex, i / 2 );
67
- alphas_a[i][0 ] = logProbs_a.index (batchIndex,0 ,labelIdx);
66
+ alphas_a[i] = logProbs_a.index (batchIndex,0 ,labelIdx);
67
+
68
68
}
69
69
for (auto t = 1 ; t < T; t++) {
70
70
if (T - t <= L + R) {
@@ -87,18 +87,18 @@ void forced_align_impl(
87
87
auto curIdxOffset = t % 2 ;
88
88
auto prevIdxOffset = (t - 1 ) % 2 ;
89
89
for (auto j = 0 ; j < S; ++j) {
90
- alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t >::infinity ();
90
+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
91
91
}
92
92
if (start == 0 ) {
93
- alphas_a[0 ][ curIdxOffset] =
94
- alphas_a[0 ][ prevIdxOffset] + logProbs_a.index (batchIndex, t, blank);
93
+ alphas_a[curIdxOffset * S ] =
94
+ alphas_a[prevIdxOffset * S ] + logProbs_a.index (batchIndex, t, blank);
95
95
backPtr_a[S * t] = 0 ; // backPtr_a[t][0] = 0
96
96
startloop += 1 ;
97
97
}
98
98
99
99
for (auto i = startloop; i < end; i++) {
100
- auto x0 = alphas_a[i] [prevIdxOffset];
101
- auto x1 = alphas_a[i - 1 ][prevIdxOffset];
100
+ auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a [prevIdxOffset][i ];
101
+ auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a [prevIdxOffset][i - 1 ];
102
102
auto x2 = -std::numeric_limits<scalar_t >::infinity ();
103
103
104
104
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a.index (batchIndex, i / 2 );
@@ -109,7 +109,7 @@ void forced_align_impl(
109
109
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
110
110
if (i % 2 != 0 && i != 1 &&
111
111
targets_a.index (batchIndex, i / 2 ) != targets_a.index (batchIndex, i / 2 - 1 )) {
112
- x2 = alphas_a[i - 2 ][prevIdxOffset];
112
+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a [prevIdxOffset][i - 2 ];
113
113
}
114
114
scalar_t result = 0.0 ;
115
115
if (x2 > x1 && x2 > x0) {
@@ -122,11 +122,13 @@ void forced_align_impl(
122
122
result = x0;
123
123
backPtr_a[t * S + i] = 0 ; // backPtr_a[t][i] = 0
124
124
}
125
- alphas_a[i][curIdxOffset] = result + logProbs_a.index (batchIndex, t, labelIdx);
125
+
126
+ alphas_a[curIdxOffset * S + i] = result + logProbs_a.index (batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i]
126
127
}
127
128
}
128
129
auto idx1 = (T - 1 ) % 2 ;
129
- auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1] ? S - 1 : S - 2 ;
130
+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
131
+ alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
130
132
delete[] alphas_a;
131
133
// path stores the token index for each time step after force alignment.
132
134
for (auto t = T - 1 ; t > -1 ; t--) {
0 commit comments