@@ -27,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2727
2828// --------------------------------------------
2929
30+ #include < ctime>
31+ #include < json/json.h>
32+ #include < string>
33+ #include < vector>
34+
35+ std::string create_embedding_payload (const std::vector<float > &embedding,
36+ int prompt_tokens) {
37+ Json::Value root;
38+
39+ root[" object" ] = " list" ;
40+
41+ Json::Value dataArray (Json::arrayValue);
42+ Json::Value dataItem;
43+
44+ dataItem[" object" ] = " embedding" ;
45+
46+ Json::Value embeddingArray (Json::arrayValue);
47+ for (const auto &value : embedding) {
48+ embeddingArray.append (value);
49+ }
50+ dataItem[" embedding" ] = embeddingArray;
51+ dataItem[" index" ] = 0 ;
52+
53+ dataArray.append (dataItem);
54+ root[" data" ] = dataArray;
55+
56+ root[" model" ] = " _" ;
57+
58+ Json::Value usage;
59+ usage[" prompt_tokens" ] = prompt_tokens;
60+ usage[" total_tokens" ] = prompt_tokens; // Assuming total tokens equals prompt
61+ // tokens in this context
62+ root[" usage" ] = usage;
63+
64+ Json::StreamWriterBuilder writer;
65+ writer[" indentation" ] = " " ; // Compact output
66+ return Json::writeString (writer, root);
67+ }
68+
3069std::string create_full_return_json (const std::string &id,
3170 const std::string &model,
3271 const std::string &content,
@@ -245,17 +284,18 @@ void llamaCPP::embedding(
245284 const auto &jsonBody = req->getJsonObject ();
246285
247286 json prompt;
248- if (jsonBody->isMember (" content " ) != 0 ) {
249- prompt = (*jsonBody)[" content " ].asString ();
287+ if (jsonBody->isMember (" input " ) != 0 ) {
288+ prompt = (*jsonBody)[" input " ].asString ();
250289 } else {
251290 prompt = " " ;
252291 }
253292 const int task_id = llama.request_completion (
254293 {{" prompt" , prompt}, {" n_predict" , 0 }}, false , true );
255294 task_result result = llama.next_result (task_id);
256- std::string embeddingResp = result.result_json . dump () ;
295+ std::vector< float > embedding_result = result.result_json [ " embedding " ] ;
257296 auto resp = nitro_utils::nitroHttpResponse ();
258- resp->setBody (embeddingResp);
297+ std::string embedding_resp = create_embedding_payload (embedding_result, 0 );
298+ resp->setBody (embedding_resp);
259299 resp->setContentTypeString (" application/json" );
260300 callback (resp);
261301 return ;
@@ -363,7 +403,7 @@ void llamaCPP::loadModel(
363403 llama.initialize ();
364404
365405 Json::Value jsonResp;
366- jsonResp[" message" ] = " Failed to load model " ;
406+ jsonResp[" message" ] = " Model loaded successfully " ;
367407 model_loaded = true ;
368408 auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
369409
0 commit comments