mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
use float32
This commit is contained in:
@@ -2,18 +2,18 @@ package format
|
||||
|
||||
import "math"
|
||||
|
||||
func Normalize(vec []float64) []float64 {
|
||||
func Normalize(vec []float32) []float32 {
|
||||
var sum float64
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
sum += float64(v * v)
|
||||
}
|
||||
|
||||
sum = math.Sqrt(sum)
|
||||
|
||||
var norm float64
|
||||
var norm float32
|
||||
|
||||
if sum > 0 {
|
||||
norm = 1.0 / sum
|
||||
norm = float32(1.0 / sum)
|
||||
} else {
|
||||
norm = 0.0
|
||||
}
|
||||
|
||||
@@ -7,21 +7,21 @@ import (
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float64
|
||||
input []float32
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float64{1}},
|
||||
{input: []float64{0, 1, 2, 3}},
|
||||
{input: []float64{0.1, 0.2, 0.3}},
|
||||
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float64{0, 0, 0}},
|
||||
{input: []float32{1}},
|
||||
{input: []float32{0, 1, 2, 3}},
|
||||
{input: []float32{0.1, 0.2, 0.3}},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float32{0, 0, 0}},
|
||||
}
|
||||
|
||||
assertNorm := func(vec []float64) (res bool) {
|
||||
assertNorm := func(vec []float32) (res bool) {
|
||||
sum := 0.0
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
sum += float64(v * v)
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
return sum == 0
|
||||
|
||||
Reference in New Issue
Block a user