diff --git a/common/common.h b/common/common.h index a564b3b8c2b..4e43ccac193 100644 --- a/common/common.h +++ b/common/common.h @@ -342,9 +342,12 @@ struct common_params_speculative_ngram_map { uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed }; -struct common_params_speculative_ngram_cache { +struct common_params_speculative_ngram_cache : common_params_speculative_ngram_map { std::string lookup_cache_static; // path of static ngram cache file for lookup decoding std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding + + bool save_lookup_cache_static = false; // whether or not we should save the static ngram cache file // NOLINT + bool save_lookup_cache_dynamic = false; // whether or not we should save the dynamic ngram cache file // NOLINT }; struct common_params_speculative { diff --git a/common/speculative.cpp b/common/speculative.cpp index bbf88fa6e71..74a59a562e2 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -809,6 +809,9 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { bool save_dynamic; bool save_static; + const std::string path_static; + const std::string path_dynamic; + common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_static; @@ -817,15 +820,17 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { common_speculative_state_ngram_cache( const enum common_speculative_type type, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, bool save_dynamic, bool save_static) : common_speculative_state(type) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) + , path_static(path_static) + , path_dynamic(path_dynamic) { if (!path_static.empty()) { try { @@ -846,6 +851,15 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } + ~common_speculative_state_ngram_cache() override { + if (save_static) { + common_ngram_cache_save(ngram_cache_static, path_static); + } + if (save_dynamic) { + common_ngram_cache_save(ngram_cache_dynamic, path_dynamic); + } + } + void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); } @@ -922,16 +936,15 @@ static common_ngram_map get_common_ngram_map( return common_ngram_map(size_key, size_value, key_only, min_hits); } -static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { - uint16_t n_draft = 8; // TODO get from config? - - // TODO bool param in common/common.h to set save_static/save_dynamic? - bool save_static = false; - bool save_dynamic = false; +static common_speculative_state_ngram_cache create_state_ngram_cache(const common_speculative_config & config) { - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state( + config.type, + config.params.ngram_cache.size_n, + config.params.ngram_cache.lookup_cache_static, + config.params.ngram_cache.lookup_cache_dynamic, + config.params.ngram_cache.save_lookup_cache_static, + config.params.ngram_cache.save_lookup_cache_dynamic); return state; } @@ -1089,7 +1102,7 @@ common_speculative * common_speculative_init( break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache(config); impls.push_back(std::make_unique(state)); break; }