Skip to content

Commit 6144e45

Browse files
committed
Add safetensors support
So we can load these natively just like gguf Signed-off-by: Eric Curtin <[email protected]>
1 parent 03914c7 commit 6144e45

20 files changed

+3317
-24
lines changed

common/arg.cpp

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -224,30 +224,99 @@ static handle_model_result common_params_handle_model(
224224
if (model.hf_file.empty()) {
225225
if (model.path.empty()) {
226226
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
227-
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
227+
if (auto_detected.repo.empty()) {
228228
exit(1); // built without CURL, error message already printed
229229
}
230+
230231
model.hf_repo = auto_detected.repo;
231-
model.hf_file = auto_detected.ggufFile;
232-
if (!auto_detected.mmprojFile.empty()) {
233-
result.found_mmproj = true;
234-
result.mmproj.hf_repo = model.hf_repo;
235-
result.mmproj.hf_file = auto_detected.mmprojFile;
232+
233+
// Handle safetensors format
234+
if (auto_detected.is_safetensors) {
235+
LOG_INF("%s: detected safetensors format for %s\n", __func__, model.hf_repo.c_str());
236+
237+
// Create a directory for the safetensors files
238+
std::string dir_name = model.hf_repo;
239+
string_replace_all(dir_name, "/", "_");
240+
model.path = fs_get_cache_directory() + "/" + dir_name;
241+
242+
// Create directory if it doesn't exist
243+
std::filesystem::create_directories(model.path);
244+
245+
// Download required files: config.json, tokenizer.json, tokenizer_config.json, and .safetensors files
246+
std::string model_endpoint = get_model_endpoint();
247+
std::vector<std::pair<std::string, std::string>> files_to_download;
248+
249+
// Required config files
250+
files_to_download.push_back({
251+
model_endpoint + model.hf_repo + "/resolve/main/config.json",
252+
model.path + "/config.json"
253+
});
254+
files_to_download.push_back({
255+
model_endpoint + model.hf_repo + "/resolve/main/tokenizer.json",
256+
model.path + "/tokenizer.json"
257+
});
258+
files_to_download.push_back({
259+
model_endpoint + model.hf_repo + "/resolve/main/tokenizer_config.json",
260+
model.path + "/tokenizer_config.json"
261+
});
262+
263+
// Safetensors files
264+
for (const auto & st_file : auto_detected.safetensors_files) {
265+
files_to_download.push_back({
266+
model_endpoint + model.hf_repo + "/resolve/main/" + st_file,
267+
model.path + "/" + st_file
268+
});
269+
}
270+
271+
// Download all files
272+
LOG_INF("%s: downloading %zu files for safetensors model...\n", __func__, files_to_download.size());
273+
for (const auto & [url, path] : files_to_download) {
274+
bool ok = common_download_file_single(url, path, bearer_token, offline);
275+
if (!ok) {
276+
LOG_ERR("error: failed to download file from %s\n", url.c_str());
277+
exit(1);
278+
}
279+
}
280+
281+
LOG_INF("%s: safetensors model downloaded to %s\n", __func__, model.path.c_str());
282+
} else {
283+
// Handle GGUF format (existing logic)
284+
if (auto_detected.ggufFile.empty()) {
285+
exit(1); // no GGUF file found
286+
}
287+
model.hf_file = auto_detected.ggufFile;
288+
if (!auto_detected.mmprojFile.empty()) {
289+
result.found_mmproj = true;
290+
result.mmproj.hf_repo = model.hf_repo;
291+
result.mmproj.hf_file = auto_detected.mmprojFile;
292+
}
293+
294+
std::string model_endpoint = get_model_endpoint();
295+
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
296+
// make sure model path is present (for caching purposes)
297+
if (model.path.empty()) {
298+
// this is to avoid different repo having same file name, or same file name in different subdirs
299+
std::string filename = model.hf_repo + "_" + model.hf_file;
300+
// to make sure we don't have any slashes in the filename
301+
string_replace_all(filename, "/", "_");
302+
model.path = fs_get_cache_file(filename);
303+
}
236304
}
237305
} else {
238306
model.hf_file = model.path;
239307
}
240-
}
241-
242-
std::string model_endpoint = get_model_endpoint();
243-
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
244-
// make sure model path is present (for caching purposes)
245-
if (model.path.empty()) {
246-
// this is to avoid different repo having same file name, or same file name in different subdirs
247-
std::string filename = model.hf_repo + "_" + model.hf_file;
248-
// to make sure we don't have any slashes in the filename
249-
string_replace_all(filename, "/", "_");
250-
model.path = fs_get_cache_file(filename);
308+
} else {
309+
// User specified hf_file explicitly - use GGUF download path
310+
std::string model_endpoint = get_model_endpoint();
311+
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
312+
// make sure model path is present (for caching purposes)
313+
if (model.path.empty()) {
314+
// this is to avoid different repo having same file name, or same file name in different subdirs
315+
std::string filename = model.hf_repo + "_" + model.hf_file;
316+
// to make sure we don't have any slashes in the filename
317+
string_replace_all(filename, "/", "_");
318+
model.path = fs_get_cache_file(filename);
319+
}
251320
}
252321

253322
} else if (!model.url.empty()) {

common/download.cpp

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -715,10 +715,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
715715

716716
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
717717

718-
static bool common_download_file_single(const std::string & url,
719-
const std::string & path,
720-
const std::string & bearer_token,
721-
bool offline) {
718+
bool common_download_file_single(const std::string & url,
719+
const std::string & path,
720+
const std::string & bearer_token,
721+
bool offline) {
722722
if (!offline) {
723723
return common_download_file_single_online(url, path, bearer_token);
724724
}
@@ -897,16 +897,93 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
897897
}
898898
} else if (res_code == 401) {
899899
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
900+
} else if (res_code == 400) {
901+
// 400 typically means "not a GGUF repo" - we'll check for safetensors below
902+
LOG_INF("%s: manifest endpoint returned 400 (not a GGUF repo), will check for safetensors...\n", __func__);
900903
} else {
901904
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
902905
}
903906

904907
// check response
905908
if (ggufFile.empty()) {
906-
throw std::runtime_error("error: model does not have ggufFile");
909+
// No GGUF found - try to detect safetensors format
910+
LOG_INF("%s: no GGUF file found, checking for safetensors format...\n", __func__);
911+
912+
// Query HF API to list files in the repo
913+
std::string files_url = get_model_endpoint() + "api/models/" + hf_repo + "/tree/main";
914+
915+
common_remote_params files_params;
916+
files_params.headers = headers;
917+
918+
long files_res_code = 0;
919+
std::string files_res_str;
920+
921+
if (!offline) {
922+
try {
923+
auto files_res = common_remote_get_content(files_url, files_params);
924+
files_res_code = files_res.first;
925+
files_res_str = std::string(files_res.second.data(), files_res.second.size());
926+
} catch (const std::exception & e) {
927+
throw std::runtime_error("error: model does not have ggufFile and failed to check for safetensors: " + std::string(e.what()));
928+
}
929+
} else {
930+
throw std::runtime_error("error: model does not have ggufFile (offline mode, cannot check for safetensors)");
931+
}
932+
933+
if (files_res_code != 200) {
934+
throw std::runtime_error("error: model does not have ggufFile");
935+
}
936+
937+
// Parse the files list
938+
std::vector<std::string> safetensors_files;
939+
bool has_config = false;
940+
bool has_tokenizer = false;
941+
942+
try {
943+
auto files_json = json::parse(files_res_str);
944+
945+
for (const auto & file : files_json) {
946+
if (file.contains("path")) {
947+
std::string path = file["path"].get<std::string>();
948+
949+
if (path == "config.json") {
950+
has_config = true;
951+
} else if (path == "tokenizer.json") {
952+
has_tokenizer = true;
953+
} else {
954+
// Check for .safetensors extension
955+
const std::string suffix = ".safetensors";
956+
if (path.size() >= suffix.size() &&
957+
path.compare(path.size() - suffix.size(), suffix.size(), suffix) == 0) {
958+
safetensors_files.push_back(path);
959+
}
960+
}
961+
}
962+
}
963+
} catch (const std::exception & e) {
964+
throw std::runtime_error("error: model does not have ggufFile and failed to parse file list: " + std::string(e.what()));
965+
}
966+
967+
// Check if we have the required safetensors files
968+
if (!has_config || !has_tokenizer || safetensors_files.empty()) {
969+
throw std::runtime_error("error: model does not have ggufFile or valid safetensors format");
970+
}
971+
972+
LOG_INF("%s: detected safetensors format with %zu tensor files\n", __func__, safetensors_files.size());
973+
974+
common_hf_file_res result;
975+
result.repo = hf_repo;
976+
result.is_safetensors = true;
977+
result.safetensors_files = safetensors_files;
978+
return result;
907979
}
908980

909-
return { hf_repo, ggufFile, mmprojFile };
981+
common_hf_file_res result;
982+
result.repo = hf_repo;
983+
result.ggufFile = ggufFile;
984+
result.mmprojFile = mmprojFile;
985+
result.is_safetensors = false;
986+
return result;
910987
}
911988

912989
//

common/download.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <string>
4+
#include <vector>
45

56
struct common_params_model;
67

@@ -23,6 +24,10 @@ struct common_hf_file_res {
2324
std::string repo; // repo name with ":tag" removed
2425
std::string ggufFile;
2526
std::string mmprojFile;
27+
28+
// Safetensors support
29+
bool is_safetensors = false; // true if model is in safetensors format
30+
std::vector<std::string> safetensors_files; // list of .safetensors files to download
2631
};
2732

2833
/**
@@ -41,6 +46,13 @@ common_hf_file_res common_get_hf_file(
4146
const std::string & bearer_token,
4247
bool offline);
4348

49+
// download a single file (no GGUF validation)
50+
bool common_download_file_single(
51+
const std::string & url,
52+
const std::string & path,
53+
const std::string & bearer_token,
54+
bool offline);
55+
4456
// returns true if download succeeded
4557
bool common_download_model(
4658
const common_params_model & model,

src/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ add_library(llama
3232
llama-quant.cpp
3333
llama-sampling.cpp
3434
llama-vocab.cpp
35+
llama-safetensors.cpp
36+
llama-hf-config.cpp
37+
llama-safetensors-loader.cpp
38+
llama-safetensors-types.cpp
39+
llama-model-from-safetensors.cpp
3540
unicode-data.cpp
3641
unicode.cpp
3742
unicode.h

0 commit comments

Comments
 (0)