Skip to content

train: add simple loading already tokenized data from parquet dataset #14522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
option(LLAMA_PARQUET "Enable Parquet dataset support via Arrow/Parquet C++" OFF)

# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
Expand Down Expand Up @@ -173,6 +174,12 @@ if (MINGW)
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()

if(LLAMA_PARQUET)
find_package(Arrow REQUIRED)
find_package(Parquet REQUIRED)
add_definitions(-DLLAMA_PARQUET)
endif()

#
# build the library
#
Expand Down
48 changes: 36 additions & 12 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,14 +1470,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--chunks"}, "N",
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
[](common_params & params, int value) {
params.n_chunks = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
{"-fa", "--flash-attn"},
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
Expand Down Expand Up @@ -2115,70 +2115,70 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.hellaswag = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--hellaswag-tasks"}, "N",
string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks),
[](common_params & params, int value) {
params.hellaswag_tasks = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--winogrande"},
"compute Winogrande score over random tasks from datafile supplied with -f",
[](common_params & params) {
params.winogrande = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--winogrande-tasks"}, "N",
string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks),
[](common_params & params, int value) {
params.winogrande_tasks = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--multiple-choice"},
"compute multiple choice score over random tasks from datafile supplied with -f",
[](common_params & params) {
params.multiple_choice = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--multiple-choice-tasks"}, "N",
string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks),
[](common_params & params, int value) {
params.multiple_choice_tasks = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--kl-divergence"},
"computes KL-divergence to logits provided via --kl-divergence-base",
[](common_params & params) {
params.kl_divergence = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--save-all-logits", "--kl-divergence-base"}, "FNAME",
"set logits file",
[](common_params & params, const std::string & value) {
params.logits_file = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--ppl-stride"}, "N",
string_format("stride for perplexity calculation (default: %d)", params.ppl_stride),
[](common_params & params, int value) {
params.ppl_stride = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"--ppl-output-type"}, "<0|1>",
string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type),
[](common_params & params, int value) {
params.ppl_output_type = value;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"-dt", "--defrag-thold"}, "N",
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
Expand Down Expand Up @@ -3415,6 +3415,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
#ifdef LLAMA_PARQUET
add_opt(common_arg(
{"--dataset-format"}, "text",
string_format("Dataset format: text or parquet (requires LLAMA_PARQUET)"),
[](common_params & params, const std::string & format) {
params.dataset_format = format; //or parquet//TODO ENUM CLASS
}
).set_examples({LLAMA_EXAMPLE_FINETUNE}));

add_opt(common_arg(
{"--parquet-path"}, "parquet.parquet",
string_format("Parquet path"),
[](common_params & params, const std::string & filepath) {//TODO -read dir
params.parquet_path = filepath;
}
).set_examples({LLAMA_EXAMPLE_FINETUNE}));

add_opt(common_arg(
{"--tokens-column"}, "tokens",
string_format("Name of tokens column (list<int32>) in Parquet file"),
[](common_params & params, const std::string & column) {
params.tokens_column = column;
}
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
#endif
return ctx_arg;
}
4 changes: 4 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ enum llama_example {
LLAMA_EXAMPLE_TTS,

LLAMA_EXAMPLE_COUNT,
LLAMA_EXAMPLE_FINETUNE,
};

enum common_sampler_type {
Expand Down Expand Up @@ -282,6 +283,9 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string dataset_format = "text"; // "text" | "parquet"
std::string parquet_path; // path to Parquet
std::string tokens_column = "tokens"; // name column list<int32>

std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
Expand Down
11 changes: 11 additions & 0 deletions examples/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.

Proof of concept:

With load data from common file:

``` sh
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
```

With load data from parquet (without batching):

You need install arrow package and build with LLAMA_PARQUET=ON

``` sh
mkdir build; cmake -DLLAMA_PARQUET=ON .. ; make
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/llama-finetune -ngl 999 --dataset-format parquet --parquet-path parquet.parquet --tokens-column tokens --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
```
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.
21 changes: 19 additions & 2 deletions examples/training/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "../../src/parquet_dataset.h"

#include <cmath>
#include <cstdio>
Expand All @@ -18,7 +19,7 @@ int main(int argc, char ** argv) {

params.escape = false;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
return 1;
}

Expand Down Expand Up @@ -57,7 +58,23 @@ int main(int argc, char ** argv) {

constexpr float val_split = 0.05f;

std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
std::vector<llama_token> tokens;
#ifdef LLAMA_PARQUET
if (params.dataset_format == "text") {
#endif
tokens = common_tokenize(ctx.get(), params.prompt, true); //load from text file
#ifdef LLAMA_PARQUET
}
else if (params.dataset_format == "parquet") {
tokens = load_parquet_dataset(params.parquet_path, params.tokens_column);
if (tokens.empty()) {
LOG_ERR("No tokens in %s, or column %s not found/invalid", params.parquet_path.c_str(), params.tokens_column.c_str());
return 1;
}
LOG_INF("Loaded %zu tokens from Parquet", tokens.size());
}
#endif

ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);

struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
Expand Down
8 changes: 7 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_library(llama
llama-quant.cpp
llama-sampling.cpp
llama-vocab.cpp
parquet_dataset.cpp
unicode-data.cpp
unicode.cpp
unicode.h
Expand All @@ -41,7 +42,12 @@ target_include_directories(llama PRIVATE .)
target_include_directories(llama PUBLIC ../include)
target_compile_features (llama PRIVATE cxx_std_17) # don't bump

target_link_libraries(llama PUBLIC ggml)

if(LLAMA_PARQUET)
target_link_libraries(llama PUBLIC ggml Arrow::arrow_shared Parquet::parquet_shared)
else()
target_link_libraries(llama PUBLIC ggml)
endif()

if (BUILD_SHARED_LIBS)
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
Expand Down
47 changes: 47 additions & 0 deletions src/parquet_dataset.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifdef LLAMA_PARQUET
#include "parquet_dataset.h"
#include <arrow/api.h>
#include <arrow/io/file.h>
#include <parquet/arrow/reader.h>
#include "llama-impl.h"

std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column) {
arrow::MemoryPool *pool = arrow::default_memory_pool();
std::shared_ptr<arrow::io::RandomAccessFile> infile;
PARQUET_ASSIGN_OR_THROW(infile, arrow::io::ReadableFile::Open(path));
arrow::Result<std::unique_ptr<parquet::arrow::FileReader>> reader_raw;
PARQUET_ASSIGN_OR_THROW(reader_raw, parquet::arrow::OpenFile(infile, pool));

std::unique_ptr<parquet::arrow::FileReader> reader = std::move(reader_raw.ValueUnsafe());
std::shared_ptr<arrow::Table> table;
PARQUET_THROW_NOT_OK(reader->ReadTable(&table));

auto field = table->schema()->GetFieldByName(column);
if (!field || !field->type()->Equals(arrow::list(arrow::int32()))) {
LLAMA_LOG_ERROR("Parquet column '%s' missing or not list<int32>", column.c_str());
return {};
}

auto col = table->GetColumnByName(column);
std::vector<llama_token> tokens;
for (int chunk = 0; chunk < col->num_chunks(); ++chunk) {
auto list_arr = std::static_pointer_cast<arrow::ListArray>(col->chunk(chunk));
auto values_arr = std::static_pointer_cast<arrow::Int32Array>(list_arr->values());
// get raw offsets (int32_t or int64_t based on ListArray template)
const auto *offsets = list_arr->raw_value_offsets();
// offsets length = list_arr->length() + 1
int64_t values_length = values_arr->length();
for (int64_t i = 0; i < list_arr->length(); ++i) {
int64_t start = offsets[i];
int64_t end = offsets[i + 1];
// Clamp end
if (start < 0) start = 0;
if (end > values_length) end = values_length;
for (int64_t j = start; j < end; ++j) {
tokens.push_back(static_cast<llama_token>(values_arr->Value(j)));
}
}
}
return tokens;
}
#endif // LLAMA_PARQUET
10 changes: 10 additions & 0 deletions src/parquet_dataset.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef PARQUET_DATASET_H
#define PARQUET_DATASET_H
#include <string>
#include <vector>
#include "llama.h"

#ifdef LLAMA_PARQUET
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column);
#endif
#endif //
Loading