Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,12 @@ struct common_params_speculative_ngram_map {
uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
};

struct common_params_speculative_ngram_cache {
struct common_params_speculative_ngram_cache : common_params_speculative_ngram_map {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the wrong way of going about this, but I am curious if the same concept of m-gram speculative tokens can be applied in the ngram-cache implemetantion

std::string lookup_cache_static; // path of static ngram cache file for lookup decoding
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding

bool save_lookup_cache_static = false; // whether or not we should save the static ngram cache file // NOLINT
bool save_lookup_cache_dynamic = false; // whether or not we should save the dynamic ngram cache file // NOLINT
};

struct common_params_speculative {
Expand Down
35 changes: 24 additions & 11 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
bool save_dynamic;
bool save_static;

const std::string path_static;
const std::string path_dynamic;

common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
common_ngram_cache ngram_cache_static;
Expand All @@ -817,15 +820,17 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {

common_speculative_state_ngram_cache(
const enum common_speculative_type type,
uint16_t n_draft,
const std::string & path_static,
const std::string & path_dynamic,
uint16_t n_draft,
bool save_dynamic,
bool save_static)
: common_speculative_state(type)
, n_draft(n_draft)
, save_dynamic(save_dynamic)
, save_static(save_static)
, path_static(path_static)
, path_dynamic(path_dynamic)
{
if (!path_static.empty()) {
try {
Expand All @@ -846,6 +851,15 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
}
}

~common_speculative_state_ngram_cache() override {
if (save_static) {
common_ngram_cache_save(ngram_cache_static, path_static);
}
if (save_dynamic) {
common_ngram_cache_save(ngram_cache_dynamic, path_dynamic);
}
}

void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
Expand Down Expand Up @@ -922,16 +936,15 @@ static common_ngram_map get_common_ngram_map(
return common_ngram_map(size_key, size_value, key_only, min_hits);
}

static common_speculative_state_ngram_cache create_state_ngram_cache(
const std::string & path_static, const std::string & path_dynamic,
const common_speculative_config & config) {
uint16_t n_draft = 8; // TODO get from config?

// TODO bool param in common/common.h to set save_static/save_dynamic?
bool save_static = false;
bool save_dynamic = false;
static common_speculative_state_ngram_cache create_state_ngram_cache(const common_speculative_config & config) {

common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
common_speculative_state_ngram_cache state(
config.type,
config.params.ngram_cache.size_n,
config.params.ngram_cache.lookup_cache_static,
config.params.ngram_cache.lookup_cache_dynamic,
config.params.ngram_cache.save_lookup_cache_static,
config.params.ngram_cache.save_lookup_cache_dynamic);

return state;
}
Expand Down Expand Up @@ -1089,7 +1102,7 @@ common_speculative * common_speculative_init(
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config);
auto state = create_state_ngram_cache(config);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
break;
}
Expand Down