move normalization to go

This commit is contained in:
Roy Han
2024-07-01 14:10:58 -07:00
parent 9c32b6b9ed
commit aee25acb5b
7 changed files with 94 additions and 44 deletions

View File

@@ -3191,21 +3191,11 @@ int main(int argc, char **argv) {
responses = std::vector<json>(1, result.result_json);
}
json embeddings = json::array();
if (body["normalize"]) {
for (auto & elem : responses) {
std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
embedding = normalize_vector(embedding, embedding.size());
embeddings.push_back(embedding);
}
} else {
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json result = json{{"embedding", embeddings}};
// log result
return res.set_content(result.dump(), "application/json; charset=utf-8");
} else {
// return error

View File

@@ -657,19 +657,19 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out;
}
// normalize a vector
static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
double sum = 0.0;
for (float value : vec) {
sum += value * value;
}
sum = std::sqrt(sum);
// // normalize a vector
// static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
// double sum = 0.0;
// for (float value : vec) {
// sum += value * value;
// }
// sum = std::sqrt(sum);
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
// const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
std::vector<float> normalized_vec(size);
for (int i = 0; i < size; i++) {
normalized_vec[i] = vec[i] * norm;
}
return normalized_vec;
}
// std::vector<float> normalized_vec(size);
// for (int i = 0; i < size; i++) {
// normalized_vec[i] = vec[i] * norm;
// }
// return normalized_vec;
// }