diff --git a/.gitignore b/.gitignore index f8ceb1560a1df..7d16c0567884f 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ .swiftpm .vs/ .vscode/ +.history/ nppBackup diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 852352383bdbe..f2b67544602b0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -4213,7 +4213,7 @@ int main(int argc, char ** argv) { throw std::runtime_error("prompt must be a string"); } - if (oaicompat && has_mtmd) { + if (has_mtmd) { // multimodal std::string prompt_str = prompt.get(); mtmd_input_text inp_txt = { @@ -4330,9 +4330,68 @@ int main(int argc, char ** argv) { } }; - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - std::vector files; // dummy + json medias = json_value(data, "medias", json::array()); + auto & opt = ctx_server.oai_parser_opt; + std::vector files; + + if (medias.is_array()) { + for (auto & m : medias) { + std::string type = json_value(m, "type", std::string()); + std::string data = json_value(m, "data", std::string()); + if (type.empty() || data.empty()) { + continue; + } + if (type == "image_url" || type == "image" || type == "img") { + if (!opt.allow_image) { + throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + if (string_starts_with(data, "http")) { + // download remote image + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", data.c_str()); + auto res = common_remote_get_content(data, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer buf; + buf.insert(buf.end(), res.second.begin(), res.second.end()); + files.push_back(buf); + } else { + throw std::runtime_error("Failed to download image"); + } + } else { + // try to decode base64 image + std::vector parts = string_split(data, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + files.push_back(decoded_data); + } + } + } else if (type == "input_audio" || type == "audio") { + if (!opt.allow_audio) { + throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + std::string format = json_value(m, "format", std::string()); + // while we also support flac, we don't allow it here so we matches the OAI spec + if (format != "wav" && format != "mp3") { + throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + } + auto decoded_data = base64_decode(data); // expected to be base64 encoded + files.push_back(decoded_data); + } + } + } handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data,