diff --git a/backend/go/stores/store.go b/backend/go/stores/store.go index a4849b57..c8788a9c 100644 --- a/backend/go/stores/store.go +++ b/backend/go/stores/store.go @@ -311,12 +311,16 @@ func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) } func isNormalized(k []float32) bool { - var sum float32 + var sum float64 + for _, v := range k { - sum += v + v64 := float64(v) + sum += v64*v64 } - return sum == 1.0 + s := math.Sqrt(sum) + + return s >= 0.99 && s <= 1.01 } // TODO: This we could replace with handwritten SIMD code @@ -328,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 { dot += k1[i] * k2[i] } - assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot)) + assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot)) // 2.0 * (1.0 - dot) would be the Euclidean distance return dot @@ -418,7 +422,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { sim := float32(dot / (mag1 * math.Sqrt(mag2))) - assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim)) + assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim)) return sim }