Skip to content

Commit c82301c

Browse files
committed
control vector support in cli
1 parent 7ec24b4 commit c82301c

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

common/common.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,35 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
562562
break;
563563
}
564564
params.lora_base = argv[i];
565+
} else if (arg == "--control-vector") {
566+
if (++i >= argc) {
567+
invalid_param = true;
568+
break;
569+
}
570+
params.control_vectors.push_back(std::make_tuple(argv[i], 1.0f));
571+
} else if (arg == "--control-vector-scaled") {
572+
if (++i >= argc) {
573+
invalid_param = true;
574+
break;
575+
}
576+
const char * control_vector = argv[i];
577+
if (++i >= argc) {
578+
invalid_param = true;
579+
break;
580+
}
581+
params.control_vectors.push_back(std::make_tuple(control_vector, std::stof(argv[i])));
582+
} else if (arg == "--control-vector-layer-range") {
583+
if (++i >= argc) {
584+
invalid_param = true;
585+
break;
586+
}
587+
int32_t start = std::stoi(argv[i]);
588+
if (++i >= argc) {
589+
invalid_param = true;
590+
break;
591+
}
592+
int32_t end = std::stoi(argv[i]);
593+
params.control_vector_layer_range = std::make_tuple(start, end);
565594
} else if (arg == "--mmproj") {
566595
if (++i >= argc) {
567596
invalid_param = true;
@@ -1087,6 +1116,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10871116
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
10881117
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
10891118
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
1119+
printf(" --control-vector FNAME\n");
1120+
printf(" add a control vector\n");
1121+
printf(" --control-vector-scaled FNAME S\n");
1122+
printf(" add a control vector with user defined scaling S\n");
1123+
printf(" --control-vector-layer-range START END\n");
1124+
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
10901125
printf(" -m FNAME, --model FNAME\n");
10911126
printf(" model path (default: %s)\n", params.model.c_str());
10921127
printf(" -md FNAME, --model-draft FNAME\n");
@@ -1351,6 +1386,41 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
13511386
return std::make_tuple(nullptr, nullptr);
13521387
}
13531388

1389+
if (!params.control_vectors.empty()) {
1390+
int32_t layer_start, layer_end;
1391+
std::tie(layer_start, layer_end) = params.control_vector_layer_range;
1392+
1393+
if (layer_start == 0) layer_start = 1;
1394+
if (layer_end == 0) layer_end = 31;
1395+
1396+
struct llama_control_vector * vector = nullptr;
1397+
1398+
for (const auto& t : params.control_vectors) {
1399+
std::string path;
1400+
float strength;
1401+
std::tie(path, strength) = t;
1402+
1403+
fprintf(stderr, "%s: loading control vector from %s\n", __func__, path.c_str());
1404+
struct llama_control_vector * temp = llama_control_vector_load(path.c_str());
1405+
if (temp == nullptr) {
1406+
fprintf(stderr, "%s: error: failed to load control vector from %s\n", __func__, path.c_str());
1407+
llama_free(lctx);
1408+
llama_free_model(model);
1409+
return std::make_tuple(nullptr, nullptr);
1410+
}
1411+
llama_control_vector_scale(temp, strength);
1412+
1413+
if (vector == nullptr) {
1414+
vector = temp;
1415+
} else {
1416+
llama_control_vector_add(vector, temp);
1417+
llama_control_vector_free(temp);
1418+
}
1419+
}
1420+
1421+
llama_apply_control_vector(lctx, vector, layer_start, layer_end);
1422+
}
1423+
13541424
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
13551425
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
13561426
float lora_scale = std::get<1>(params.lora_adapter[i]);

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ struct gpt_params {
102102
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
103103
std::string lora_base = ""; // base model path for the lora adapter
104104

105+
std::vector<std::tuple<std::string, float>> control_vectors; // control vector with user defined scale
106+
std::tuple<int32_t, int32_t> control_vector_layer_range; // layer range for control vector
107+
105108
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
106109
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
107110
// (which is more convenient to use for plotting)

0 commit comments

Comments
 (0)