mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 01:05:47 +02:00
add batch embeddings
This commit is contained in:
55
llm/ext_server/server.cpp
vendored
55
llm/ext_server/server.cpp
vendored
@@ -3209,54 +3209,27 @@ int main(int argc, char **argv) {
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
json prompt;
|
||||
if (body.count("content") != 0)
|
||||
{
|
||||
prompt = body["content"];
|
||||
}
|
||||
else
|
||||
{
|
||||
prompt = "";
|
||||
|
||||
const int id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(id);
|
||||
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
|
||||
|
||||
task_result recv = llama.queue_results.recv(id);
|
||||
llama.queue_results.remove_waiting_task_id(id);
|
||||
|
||||
json embeddings = json::array();
|
||||
for (auto & elem : recv.result_json["results"]) {
|
||||
embeddings.push_back(json_value(elem, "embedding", json::array()));
|
||||
}
|
||||
|
||||
json image_data;
|
||||
if (body.count("image_data") != 0) {
|
||||
image_data = body["image_data"];
|
||||
}
|
||||
else
|
||||
{
|
||||
image_data = "";
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
||||
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
// send the result
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
json result = json{{"embeddings", embeddings}};
|
||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||
// "Bus error: 10" - this is on macOS, it does not crash on Linux
|
||||
//std::thread t2([&]()
|
||||
/*{
|
||||
bool running = true;
|
||||
while (running)
|
||||
{
|
||||
running = llama.update_slots();
|
||||
}
|
||||
}*/
|
||||
//);
|
||||
|
||||
if (sparams.n_threads_http < 1) {
|
||||
// +2 threads for monitoring endpoints
|
||||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
|
||||
Reference in New Issue
Block a user