Skip to content

Commit 2574024

Browse files
author
lexasub
committed
train: add simple loading already tokenized data from parquet dataset
1 parent bee2842 commit 2574024

File tree

5 files changed

+83
-2
lines changed

5 files changed

+83
-2
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8484
# 3rd party libs
8585
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
8686
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
87+
option(LLAMA_PARQUET "Enable Parquet dataset support via Arrow/Parquet C++" OFF)
8788

8889
# Required for relocatable CMake package
8990
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -173,6 +174,12 @@ if (MINGW)
173174
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
174175
endif()
175176

177+
if(LLAMA_PARQUET)
178+
find_package(Arrow REQUIRED)
179+
find_package(Parquet REQUIRED)
180+
add_definitions(-DLLAMA_PARQUET)
181+
endif()
182+
176183
#
177184
# build the library
178185
#

examples/training/finetune.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "../../src/parquet_dataset.h"
56

67
#include <cmath>
78
#include <cstdio>
@@ -57,7 +58,17 @@ int main(int argc, char ** argv) {
5758

5859
constexpr float val_split = 0.05f;
5960

60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61+
#ifndef LLAMA_PARQUET
62+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true); //load from text file
63+
#else
64+
auto tokens = load_parquet_dataset("test.parquet" /*params.parquet_path, params.tokens_column*/ ,"tokens");
65+
if (tokens.empty()) {
66+
//LOG_ERR("No tokens in %s, or column %s not found/invalid", params.parquet_path.c_str(), params.tokens_column.c_str());
67+
return 1;
68+
}
69+
LOG_INF("Loaded %zu tokens from Parquet", tokens.size());
70+
#endif
71+
6172
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6273

6374
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);

src/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_library(llama
3232
llama-quant.cpp
3333
llama-sampling.cpp
3434
llama-vocab.cpp
35+
parquet_dataset.cpp
3536
unicode-data.cpp
3637
unicode.cpp
3738
unicode.h
@@ -41,7 +42,12 @@ target_include_directories(llama PRIVATE .)
4142
target_include_directories(llama PUBLIC ../include)
4243
target_compile_features (llama PRIVATE cxx_std_17) # don't bump
4344

44-
target_link_libraries(llama PUBLIC ggml)
45+
46+
if(LLAMA_PARQUET)
47+
target_link_libraries(llama PUBLIC ggml Arrow::arrow_shared Parquet::parquet_shared)
48+
else()
49+
target_link_libraries(llama PUBLIC ggml)
50+
endif()
4551

4652
if (BUILD_SHARED_LIBS)
4753
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)

src/parquet_dataset.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifdef LLAMA_PARQUET
2+
#include "parquet_dataset.h"
3+
#include <arrow/api.h>
4+
#include <arrow/io/file.h>
5+
#include <parquet/arrow/reader.h>
6+
#include "llama-impl.h"
7+
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column) {
9+
arrow::MemoryPool *pool = arrow::default_memory_pool();
10+
std::shared_ptr<arrow::io::RandomAccessFile> infile;
11+
PARQUET_ASSIGN_OR_THROW(infile, arrow::io::ReadableFile::Open(path));
12+
arrow::Result<std::unique_ptr<parquet::arrow::FileReader>> reader_raw;
13+
PARQUET_ASSIGN_OR_THROW(reader_raw, parquet::arrow::OpenFile(infile, pool));
14+
15+
std::unique_ptr<parquet::arrow::FileReader> reader = std::move(reader_raw.ValueUnsafe());
16+
std::shared_ptr<arrow::Table> table;
17+
PARQUET_THROW_NOT_OK(reader->ReadTable(&table));
18+
19+
auto field = table->schema()->GetFieldByName(column);
20+
if (!field || !field->type()->Equals(arrow::list(arrow::int32()))) {
21+
LLAMA_LOG_ERROR("Parquet column '%s' missing or not list<int32>", column.c_str());
22+
return {};
23+
}
24+
25+
auto col = table->GetColumnByName(column);
26+
std::vector<llama_token> tokens;
27+
for (int chunk = 0; chunk < col->num_chunks(); ++chunk) {
28+
auto list_arr = std::static_pointer_cast<arrow::ListArray>(col->chunk(chunk));
29+
auto values_arr = std::static_pointer_cast<arrow::Int32Array>(list_arr->values());
30+
// get raw offsets (int32_t or int64_t based on ListArray template)
31+
const auto *offsets = list_arr->raw_value_offsets();
32+
// offsets length = list_arr->length() + 1
33+
int64_t values_length = values_arr->length();
34+
for (int64_t i = 0; i < list_arr->length(); ++i) {
35+
int64_t start = offsets[i];
36+
int64_t end = offsets[i + 1];
37+
// Clamp end
38+
if (start < 0) start = 0;
39+
if (end > values_length) end = values_length;
40+
for (int64_t j = start; j < end; ++j) {
41+
tokens.push_back(static_cast<llama_token>(values_arr->Value(j)));
42+
}
43+
}
44+
}
45+
return tokens;
46+
}
47+
#endif // LLAMA_PARQUET

src/parquet_dataset.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef PARQUET_DATASET_H
2+
#define PARQUET_DATASET_H
3+
#include <string>
4+
#include <vector>
5+
#include "llama.h"
6+
7+
#ifdef LLAMA_PARQUET
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column);
9+
#endif
10+
#endif //

0 commit comments

Comments
 (0)