add batch embeddings

This commit is contained in:
jmorganca
2024-04-14 20:53:20 -04:00
parent 8e30eb26bd
commit ad7e641815
8 changed files with 243 additions and 72 deletions

View File

@@ -3209,54 +3209,27 @@ int main(int argc, char **argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json prompt;
if (body.count("content") != 0)
{
prompt = body["content"];
}
else
{
prompt = "";
const int id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id);
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
task_result recv = llama.queue_results.recv(id);
llama.queue_results.remove_waiting_task_id(id);
json embeddings = json::array();
for (auto & elem : recv.result_json["results"]) {
embeddings.push_back(json_value(elem, "embedding", json::array()));
}
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
}
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
json result = json{{"embeddings", embeddings}};
return res.set_content(result.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
/*{
bool running = true;
while (running)
{
running = llama.update_slots();
}
}*/
//);
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);