Skip to content

Commit 1eb155e

Browse files
authored
enable hot-word boosting (#3297)
* enable hot-word boosting * more consistent ordering of CLI arguments * progress on review * use map instead of set for hot-words, move string logic to client.cc * typo bug * pointer things? * use map for hotwords, better string splitting * add the boost, not multiply * cleaning up * cleaning whitespace * remove <set> inclusion * change typo set-->map * rename boost_coefficient to boost X-DeepSpeech: NOBUILD * add hot_words to python bindings * missing hot_words * include map in swigwrapper.i * add Map template to swigwrapper.i * emacs intermediate file * map things * map-->unordered_map * typu * typu * use dict() not None * error out if hot_words without scorer * two new functions: remove hot-word and clear all hot-words * starting to work on better error messages X-DeepSpeech: NOBUILD * better error handling + .Net ERR codes * allow for negative boosts:) * adding TC test for hot-words * add hot-words to python client, make TC test hot-words everywhere * only run TC tests for C++ and Python * fully expose API in python bindings * expose API in Java (thanks spectie!) * expose API in dotnet (thanks spectie!) * expose API in javascript (thanks spectie!) * java lol * typo in javascript * commenting * java error codes from swig * java docs from SWIG * java and dotnet issues * add hotword test to android tests * dotnet fixes from carlos * add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command * make sure lm is on android for hotword test * path to android model + nit * path * path
1 parent d466fb0 commit 1eb155e

25 files changed

+400
-11
lines changed

native_client/args.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ int json_candidate_transcripts = 3;
3838

3939
int stream_size = 0;
4040

41+
char* hot_words = NULL;
42+
4143
void PrintHelp(const char* bin)
4244
{
4345
std::cout <<
@@ -56,6 +58,7 @@ void PrintHelp(const char* bin)
5658
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
5759
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
5860
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
61+
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
5962
"\t--help\t\t\t\tShow help\n"
6063
"\t--version\t\t\tPrint version and exits\n";
6164
char* version = DS_Version();
@@ -66,7 +69,7 @@ void PrintHelp(const char* bin)
6669

6770
bool ProcessArgs(int argc, char** argv)
6871
{
69-
const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
72+
const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh";
7073
const option long_opts[] = {
7174
{"model", required_argument, nullptr, 'm'},
7275
{"scorer", required_argument, nullptr, 'l'},
@@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv)
7982
{"json", no_argument, nullptr, 'j'},
8083
{"candidate_transcripts", required_argument, nullptr, 150},
8184
{"stream", required_argument, nullptr, 's'},
85+
{"hot_words", required_argument, nullptr, 'w'},
8286
{"version", no_argument, nullptr, 'v'},
8387
{"help", no_argument, nullptr, 'h'},
8488
{nullptr, no_argument, nullptr, 0}
@@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv)
144148
has_versions = true;
145149
break;
146150

151+
case 'w':
152+
hot_words = optarg;
153+
break;
154+
147155
case 'h': // -h or --help
148156
case '?': // Unrecognized option
149157
default:

native_client/client.cc

+33
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
390390
}
391391
}
392392

393+
std::vector<std::string>
394+
SplitStringOnDelim(std::string in_string, std::string delim)
395+
{
396+
std::vector<std::string> out_vector;
397+
char * tmp_str = new char[in_string.size() + 1];
398+
std::copy(in_string.begin(), in_string.end(), tmp_str);
399+
tmp_str[in_string.size()] = '\0';
400+
const char* token = strtok(tmp_str, delim.c_str());
401+
while( token != NULL ) {
402+
out_vector.push_back(token);
403+
token = strtok(NULL, delim.c_str());
404+
}
405+
delete[] tmp_str;
406+
return out_vector;
407+
}
408+
393409
int
394410
main(int argc, char **argv)
395411
{
@@ -432,6 +448,23 @@ main(int argc, char **argv)
432448
}
433449
// sphinx-doc: c_ref_model_stop
434450

451+
if (hot_words) {
452+
std::vector<std::string> hot_words_ = SplitStringOnDelim(hot_words, ",");
453+
for ( std::string hot_word_ : hot_words_ ) {
454+
std::vector<std::string> pair_ = SplitStringOnDelim(hot_word_, ":");
455+
const char* word = (pair_[0]).c_str();
456+
// the strtof function will return 0 in case of non numeric characters
457+
// so, check the boost string before we turn it into a float
458+
bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos);
459+
float boost = strtof((pair_[1]).c_str(),0);
460+
status = DS_AddHotWord(ctx, word, boost);
461+
if (status != 0 || !boost_is_valid) {
462+
fprintf(stderr, "Could not enable hot-word.\n");
463+
return 1;
464+
}
465+
}
466+
}
467+
435468
#ifndef NO_SOX
436469
// Initialise SOX
437470
assert(sox_init() == SOX_SUCCESS);

native_client/ctcdecode/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq,
9696
cutoff_prob=1.0,
9797
cutoff_top_n=40,
9898
scorer=None,
99+
hot_words=dict(),
99100
num_results=1):
100101
"""Wrapper for the CTC Beam Search Decoder.
101102
@@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq,
116117
:param scorer: External scorer for partially decoded sentence, e.g. word
117118
count or language model.
118119
:type scorer: Scorer
120+
:param hot_words: Map of words (keys) to their assigned boosts (values)
121+
:type hot_words: map{string:float}
119122
:param num_results: Number of beams to return.
120123
:type num_results: int
121124
:return: List of tuples of confidence and sentence as decoding
@@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq,
124127
"""
125128
beam_results = swigwrapper.ctc_beam_search_decoder(
126129
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
127-
scorer, num_results)
130+
scorer, hot_words, num_results)
128131
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
129132
return beam_results
130133

@@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
137140
cutoff_prob=1.0,
138141
cutoff_top_n=40,
139142
scorer=None,
143+
hot_words=dict(),
140144
num_results=1):
141145
"""Wrapper for the batched CTC beam search decoder.
142146
@@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq,
161165
:param scorer: External scorer for partially decoded sentence, e.g. word
162166
count or language model.
163167
:type scorer: Scorer
168+
:param hot_words: Map of words (keys) to their assigned boosts (values)
169+
:type hot_words: map{string:float}
164170
:param num_results: Number of beams to return.
165171
:type num_results: int
166172
:return: List of tuples of confidence and sentence as decoding
167173
results, in descending order of the confidence.
168174
:rtype: list
169175
"""
170-
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
176+
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results)
171177
batch_beam_results = [
172178
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
173179
for beam_results in batch_beam_results

native_client/ctcdecode/ctc_beam_search_decoder.cpp

+24-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cmath>
55
#include <iostream>
66
#include <limits>
7-
#include <map>
7+
#include <unordered_map>
88
#include <utility>
99

1010
#include "decoder_utils.h"
@@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet,
1818
size_t beam_size,
1919
double cutoff_prob,
2020
size_t cutoff_top_n,
21-
std::shared_ptr<Scorer> ext_scorer)
21+
std::shared_ptr<Scorer> ext_scorer,
22+
std::unordered_map<std::string, float> hot_words)
2223
{
2324
// assign special ids
2425
abs_time_step_ = 0;
@@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet,
2930
cutoff_prob_ = cutoff_prob;
3031
cutoff_top_n_ = cutoff_top_n;
3132
ext_scorer_ = ext_scorer;
33+
hot_words_ = hot_words;
3234
start_expanding_ = false;
3335

3436
// init prefixes' root
@@ -160,8 +162,23 @@ DecoderState::next(const double *probs,
160162
float score = 0.0;
161163
std::vector<std::string> ngram;
162164
ngram = ext_scorer_->make_ngram(prefix_to_score);
165+
166+
float hot_boost = 0.0;
167+
if (!hot_words_.empty()) {
168+
std::unordered_map<std::string, float>::iterator iter;
169+
// increase prob of prefix for every word
170+
// that matches a word in the hot-words list
171+
for (std::string word : ngram) {
172+
iter = hot_words_.find(word);
173+
if ( iter != hot_words_.end() ) {
174+
// increase the log_cond_prob(prefix|LM)
175+
hot_boost += iter->second;
176+
}
177+
}
178+
}
179+
163180
bool bos = ngram.size() < ext_scorer_->get_max_order();
164-
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
181+
score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha;
165182
log_p += score;
166183
log_p += ext_scorer_->beta;
167184
}
@@ -256,11 +273,12 @@ std::vector<Output> ctc_beam_search_decoder(
256273
double cutoff_prob,
257274
size_t cutoff_top_n,
258275
std::shared_ptr<Scorer> ext_scorer,
276+
std::unordered_map<std::string, float> hot_words,
259277
size_t num_results)
260278
{
261279
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
262280
DecoderState state;
263-
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
281+
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words);
264282
state.next(probs, time_dim, class_dim);
265283
return state.decode(num_results);
266284
}
@@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch(
279297
double cutoff_prob,
280298
size_t cutoff_top_n,
281299
std::shared_ptr<Scorer> ext_scorer,
300+
std::unordered_map<std::string, float> hot_words,
282301
size_t num_results)
283302
{
284303
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
@@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch(
298317
cutoff_prob,
299318
cutoff_top_n,
300319
ext_scorer,
320+
hot_words,
301321
num_results));
302322
}
303323

native_client/ctcdecode/ctc_beam_search_decoder.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DecoderState {
2222
std::vector<PathTrie*> prefixes_;
2323
std::unique_ptr<PathTrie> prefix_root_;
2424
TimestepTreeNode timestep_tree_root_{nullptr, 0};
25+
std::unordered_map<std::string, float> hot_words_;
2526

2627
public:
2728
DecoderState() = default;
@@ -48,7 +49,8 @@ class DecoderState {
4849
size_t beam_size,
4950
double cutoff_prob,
5051
size_t cutoff_top_n,
51-
std::shared_ptr<Scorer> ext_scorer);
52+
std::shared_ptr<Scorer> ext_scorer,
53+
std::unordered_map<std::string, float> hot_words);
5254

5355
/* Send data to the decoder
5456
*
@@ -88,6 +90,8 @@ class DecoderState {
8890
* ext_scorer: External scorer to evaluate a prefix, which consists of
8991
* n-gram language model scoring and word insertion term.
9092
* Default null, decoding the input sample without scorer.
93+
* hot_words: A map of hot-words and their corresponding boosts
94+
* The hot-word is a string and the boost is a float.
9195
* num_results: Number of beams to return.
9296
* Return:
9397
* A vector where each element is a pair of score and decoding result,
@@ -103,6 +107,7 @@ std::vector<Output> ctc_beam_search_decoder(
103107
double cutoff_prob,
104108
size_t cutoff_top_n,
105109
std::shared_ptr<Scorer> ext_scorer,
110+
std::unordered_map<std::string, float> hot_words,
106111
size_t num_results=1);
107112

108113
/* CTC Beam Search Decoder for batch data
@@ -117,6 +122,8 @@ std::vector<Output> ctc_beam_search_decoder(
117122
* ext_scorer: External scorer to evaluate a prefix, which consists of
118123
* n-gram language model scoring and word insertion term.
119124
* Default null, decoding the input sample without scorer.
125+
* hot_words: A map of hot-words and their corresponding boosts
126+
* The hot-word is a string and the boost is a float.
120127
* num_results: Number of beams to return.
121128
* Return:
122129
* A 2-D vector where each element is a vector of beam search decoding
@@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch(
136143
double cutoff_prob,
137144
size_t cutoff_top_n,
138145
std::shared_ptr<Scorer> ext_scorer,
146+
std::unordered_map<std::string, float> hot_words,
139147
size_t num_results=1);
140148

141149
#endif // CTC_BEAM_SEARCH_DECODER_H_

native_client/ctcdecode/swigwrapper.i

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
%include <std_string.i>
1212
%include <std_vector.i>
1313
%include <std_shared_ptr.i>
14+
%include <std_unordered_map.i>
1415
%include "numpy.i"
1516

1617
%init %{
@@ -22,6 +23,7 @@ namespace std {
2223
%template(UnsignedIntVector) vector<unsigned int>;
2324
%template(OutputVector) vector<Output>;
2425
%template(OutputVectorVector) vector<vector<Output>>;
26+
%template(Map) unordered_map<string, float>;
2527
}
2628

2729
%shared_ptr(Scorer);

native_client/deepspeech.cc

+49-1
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx,
342342
return DS_ERR_OK;
343343
}
344344

345+
int
346+
DS_AddHotWord(ModelState* aCtx,
347+
const char* word,
348+
float boost)
349+
{
350+
if (aCtx->scorer_) {
351+
const int size_before = aCtx->hot_words_.size();
352+
aCtx->hot_words_.insert( std::pair<std::string,float> (word, boost) );
353+
const int size_after = aCtx->hot_words_.size();
354+
if (size_before == size_after) {
355+
return DS_ERR_FAIL_INSERT_HOTWORD;
356+
}
357+
return DS_ERR_OK;
358+
}
359+
return DS_ERR_SCORER_NOT_ENABLED;
360+
}
361+
362+
int
363+
DS_EraseHotWord(ModelState* aCtx,
364+
const char* word)
365+
{
366+
if (aCtx->scorer_) {
367+
const int size_before = aCtx->hot_words_.size();
368+
int err = aCtx->hot_words_.erase(word);
369+
const int size_after = aCtx->hot_words_.size();
370+
if (size_before == size_after) {
371+
return DS_ERR_FAIL_ERASE_HOTWORD;
372+
}
373+
return DS_ERR_OK;
374+
}
375+
return DS_ERR_SCORER_NOT_ENABLED;
376+
}
377+
378+
int
379+
DS_ClearHotWords(ModelState* aCtx)
380+
{
381+
if (aCtx->scorer_) {
382+
aCtx->hot_words_.clear();
383+
const int size_after = aCtx->hot_words_.size();
384+
if (size_after != 0) {
385+
return DS_ERR_FAIL_CLEAR_HOTWORD;
386+
}
387+
return DS_ERR_OK;
388+
}
389+
return DS_ERR_SCORER_NOT_ENABLED;
390+
}
391+
345392
int
346393
DS_DisableExternalScorer(ModelState* aCtx)
347394
{
@@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx,
390437
aCtx->beam_width_,
391438
cutoff_prob,
392439
cutoff_top_n,
393-
aCtx->scorer_);
440+
aCtx->scorer_,
441+
aCtx->hot_words_);
394442

395443
*retval = ctx.release();
396444
return DS_ERR_OK;

0 commit comments

Comments
 (0)