-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
executable file
·292 lines (257 loc) · 11.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Mingshuang Luo,
# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import kaldialign
Pathlike = Union[str, Path]
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]]
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
Returns:
Return None.
"""
with open(filename, "w") as f:
for cut_id, ref, hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write(
"1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n"
)
summary_f.write(
"2. https://github.com/triton-inference-server/server/issues/5374 \n\n"
)
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = (
int(model_inference_stats["compute_output"]["ns"]) / 1e9
)
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert (
compute_infer["count"]
== compute_output["count"]
== compute_input["count"]
)
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa
)
# summary_f.write(
# f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa
# )
# summary_f.write(
# f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa
# )
def download_and_extract(
target_path: str,
url: str = "https://huggingface.co/csukuangfj/aishell-test-dev-manifests/resolve/main/data_aishell.tar.gz",
):
filename = url.split("/")[-1]
# Download file using wget if it doesn't exist
command = f"wget -nc {url}"
os.system(command)
# Extract file using tar
command = f"tar -xf {filename} -C {target_path}"
os.system(command)
# Delete downloaded file
os.remove(filename)