mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
Stores to chromem (WIP)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
2f09aa1b85
commit
a1d5462ad0
16 changed files with 50 additions and 489 deletions
|
@ -21,8 +21,7 @@ service Backend {
|
||||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||||
|
|
||||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
rpc StoresReset(StoresResetOptions) returns (Result) {}
|
||||||
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
|
|
||||||
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
||||||
|
|
||||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||||
|
@ -78,19 +77,10 @@ message StoresSetOptions {
|
||||||
repeated StoresValue Values = 2;
|
repeated StoresValue Values = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoresDeleteOptions {
|
message StoresResetOptions {
|
||||||
repeated StoresKey Keys = 1;
|
repeated StoresKey Keys = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoresGetOptions {
|
|
||||||
repeated StoresKey Keys = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StoresGetResult {
|
|
||||||
repeated StoresKey Keys = 1;
|
|
||||||
repeated StoresValue Values = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StoresFindOptions {
|
message StoresFindOptions {
|
||||||
StoresKey Key = 1;
|
StoresKey Key = 1;
|
||||||
int32 TopK = 2;
|
int32 TopK = 2;
|
||||||
|
|
|
@ -4,101 +4,36 @@ package main
|
||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
import (
|
import (
|
||||||
"container/heap"
|
"container/heap"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"runtime"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
chromem "github.com/philippgille/chromem-go"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
base.SingleThread
|
base.SingleThread
|
||||||
|
*chromem.DB
|
||||||
// The sorted keys
|
*chromem.Collection
|
||||||
keys [][]float32
|
|
||||||
// The sorted values
|
|
||||||
values [][]byte
|
|
||||||
|
|
||||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
|
||||||
// TODO: Should we normalize incoming keys if they are not instead?
|
|
||||||
keysAreNormalized bool
|
|
||||||
// The first key decides the length of the keys
|
|
||||||
keyLen int
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
|
||||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
|
||||||
type Pair struct {
|
|
||||||
Key []float32
|
|
||||||
Value []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStore() *Store {
|
func NewStore() *Store {
|
||||||
return &Store{
|
return &Store{}
|
||||||
keys: make([][]float32, 0),
|
|
||||||
values: make([][]byte, 0),
|
|
||||||
keysAreNormalized: true,
|
|
||||||
keyLen: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareSlices(k1, k2 []float32) int {
|
|
||||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
|
||||||
|
|
||||||
return slices.Compare(k1, k2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
|
||||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
|
||||||
return compareSlices(k, target) == 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
|
||||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
|
||||||
return compareSlices(k, t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSortedPairs(kvs []Pair) bool {
|
|
||||||
for i := 1; i < len(kvs); i++ {
|
|
||||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSortedKeys(keys [][]float32) bool {
|
|
||||||
for i := 1; i < len(keys); i++ {
|
|
||||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
|
||||||
ks := make([][]float32, len(keys))
|
|
||||||
|
|
||||||
for i, k := range keys {
|
|
||||||
ks[i] = k.Floats
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(ks, compareSlices)
|
|
||||||
|
|
||||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
|
||||||
assert(isSortedKeys(ks), "keys are not sorted")
|
|
||||||
|
|
||||||
return ks
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||||
|
db := chromem.NewDB()
|
||||||
|
collection, err := db.CreateCollection("all-documents", nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.DB = db
|
||||||
|
s.Collection = collection
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,156 +46,25 @@ func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||||
if len(opts.Keys) != len(opts.Values) {
|
if len(opts.Keys) != len(opts.Values) {
|
||||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||||
}
|
}
|
||||||
|
docs := []chromem.Document{}
|
||||||
if s.keyLen == -1 {
|
|
||||||
s.keyLen = len(opts.Keys[0].Floats)
|
|
||||||
} else {
|
|
||||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
|
||||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kvs := make([]Pair, len(opts.Keys))
|
|
||||||
|
|
||||||
for i, k := range opts.Keys {
|
for i, k := range opts.Keys {
|
||||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
docs = append(docs, chromem.Document{
|
||||||
s.keysAreNormalized = false
|
ID: k.String(),
|
||||||
var sample []float32
|
Content: opts.Values[i].String(),
|
||||||
if len(s.keys) > 5 {
|
})
|
||||||
sample = k.Floats[:5]
|
|
||||||
} else {
|
|
||||||
sample = k.Floats
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Key is not normalized: %v", sample)
|
|
||||||
}
|
|
||||||
|
|
||||||
kvs[i] = Pair{
|
|
||||||
Key: k.Floats,
|
|
||||||
Value: opts.Values[i].Bytes,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
return s.Collection.AddDocuments(context.Background(), docs, runtime.NumCPU())
|
||||||
return compareSlices(a.Key, b.Key)
|
|
||||||
})
|
|
||||||
|
|
||||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
|
||||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
|
||||||
|
|
||||||
l := len(kvs) + len(s.keys)
|
|
||||||
merge_ks := make([][]float32, 0, l)
|
|
||||||
merge_vs := make([][]byte, 0, l)
|
|
||||||
|
|
||||||
i, j := 0, 0
|
|
||||||
for {
|
|
||||||
if i+j >= l {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if i >= len(kvs) {
|
|
||||||
merge_ks = append(merge_ks, s.keys[j])
|
|
||||||
merge_vs = append(merge_vs, s.values[j])
|
|
||||||
j++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if j >= len(s.keys) {
|
|
||||||
merge_ks = append(merge_ks, kvs[i].Key)
|
|
||||||
merge_vs = append(merge_vs, kvs[i].Value)
|
|
||||||
i++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
|
||||||
if c < 0 {
|
|
||||||
merge_ks = append(merge_ks, kvs[i].Key)
|
|
||||||
merge_vs = append(merge_vs, kvs[i].Value)
|
|
||||||
i++
|
|
||||||
} else if c > 0 {
|
|
||||||
merge_ks = append(merge_ks, s.keys[j])
|
|
||||||
merge_vs = append(merge_vs, s.values[j])
|
|
||||||
j++
|
|
||||||
} else {
|
|
||||||
merge_ks = append(merge_ks, kvs[i].Key)
|
|
||||||
merge_vs = append(merge_vs, kvs[i].Value)
|
|
||||||
i++
|
|
||||||
j++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
|
||||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
|
||||||
|
|
||||||
s.keys = merge_ks
|
|
||||||
s.values = merge_vs
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
func (s *Store) StoresReset(opts *pb.StoresResetOptions) error {
|
||||||
if len(opts.Keys) == 0 {
|
err := s.DB.DeleteCollection("all-documents")
|
||||||
return fmt.Errorf("no keys to delete")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
s.Collection, err = s.CreateCollection("all-documents", nil, nil)
|
||||||
if len(opts.Keys) == 0 {
|
return err
|
||||||
return fmt.Errorf("no keys to add")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.keyLen == -1 {
|
|
||||||
s.keyLen = len(opts.Keys[0].Floats)
|
|
||||||
} else {
|
|
||||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
|
||||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ks := sortIntoKeySlicese(opts.Keys)
|
|
||||||
|
|
||||||
l := len(s.keys) - len(ks)
|
|
||||||
merge_ks := make([][]float32, 0, l)
|
|
||||||
merge_vs := make([][]byte, 0, l)
|
|
||||||
|
|
||||||
tail_ks := s.keys
|
|
||||||
tail_vs := s.values
|
|
||||||
for _, k := range ks {
|
|
||||||
j, found := findInSortedSlice(tail_ks, k)
|
|
||||||
|
|
||||||
if found {
|
|
||||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
|
||||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
|
||||||
tail_ks = tail_ks[j+1:]
|
|
||||||
tail_vs = tail_vs[j+1:]
|
|
||||||
} else {
|
|
||||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs))
|
|
||||||
}
|
|
||||||
|
|
||||||
merge_ks = append(merge_ks, tail_ks...)
|
|
||||||
merge_vs = append(merge_vs, tail_vs...)
|
|
||||||
|
|
||||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
|
||||||
|
|
||||||
s.keys = merge_ks
|
|
||||||
s.values = merge_vs
|
|
||||||
|
|
||||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
|
||||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
|
||||||
assert(func() bool {
|
|
||||||
for _, k := range ks {
|
|
||||||
if _, found := findInSortedSlice(s.keys, k); found {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}(), "Keys to delete still present")
|
|
||||||
|
|
||||||
if len(s.keys) != l {
|
|
||||||
log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||||
|
|
|
@ -1000,7 +1000,7 @@ var _ = Describe("API test", func() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteBody := schema.StoresDelete{
|
deleteBody := schema.StoresReset{
|
||||||
Keys: [][]float32{
|
Keys: [][]float32{
|
||||||
{0.1, 0.2, 0.3},
|
{0.1, 0.2, 0.3},
|
||||||
},
|
},
|
||||||
|
|
|
@ -36,9 +36,9 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func StoresResetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
input := new(schema.StoresDelete)
|
input := new(schema.StoresReset)
|
||||||
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
if err := c.BodyParser(input); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -49,7 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
if _, err := sb.StoresReset(c.Context(), nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,37 +57,6 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(schema.StoresGet)
|
|
||||||
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
res := schema.StoresGetResponse{
|
|
||||||
Keys: keys,
|
|
||||||
Values: make([]string, len(vals)),
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range vals {
|
|
||||||
res.Values[i] = string(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
input := new(schema.StoresFind)
|
input := new(schema.StoresFind)
|
||||||
|
|
|
@ -39,8 +39,7 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
sl := model.NewModelLoader("")
|
||||||
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
router.Post("/stores/reset", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||||
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
|
||||||
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !appConfig.DisableMetrics {
|
||||||
|
|
|
@ -47,21 +47,8 @@ type StoresSet struct {
|
||||||
Values []string `json:"values" yaml:"values"`
|
Values []string `json:"values" yaml:"values"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StoresDelete struct {
|
type StoresReset struct {
|
||||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||||
|
|
||||||
Keys [][]float32 `json:"keys"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type StoresGet struct {
|
|
||||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
|
||||||
|
|
||||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type StoresGetResponse struct {
|
|
||||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
|
||||||
Values []string `json:"values" yaml:"values"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StoresFind struct {
|
type StoresFind struct {
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -93,6 +93,7 @@ require (
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||||
|
github.com/philippgille/chromem-go v0.7.0 // indirect
|
||||||
github.com/pion/datachannel v1.5.10 // indirect
|
github.com/pion/datachannel v1.5.10 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||||
github.com/pion/ice/v2 v2.3.37 // indirect
|
github.com/pion/ice/v2 v2.3.37 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -611,6 +611,8 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H
|
||||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||||
github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw=
|
github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw=
|
||||||
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
|
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
|
||||||
|
github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY=
|
||||||
|
github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
|
||||||
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
||||||
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||||
github.com/pion/datachannel v1.5.8 h1:ph1P1NsGkazkjrvyMfhRBUAWMxugJjq2HfQifaOoSNo=
|
github.com/pion/datachannel v1.5.8 h1:ph1P1NsGkazkjrvyMfhRBUAWMxugJjq2HfQifaOoSNo=
|
||||||
|
|
|
@ -46,8 +46,7 @@ type Backend interface {
|
||||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||||
|
|
||||||
StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
|
|
||||||
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
||||||
|
|
||||||
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
||||||
|
|
|
@ -80,11 +80,7 @@ func (llm *Base) StoresSet(*pb.StoresSetOptions) error {
|
||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *Base) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
func (llm *Base) StoresReset(*pb.StoresResetOptions) error {
|
||||||
return pb.StoresGetResult{}, fmt.Errorf("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *Base) StoresDelete(*pb.StoresDeleteOptions) error {
|
|
||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -303,7 +303,7 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
|
||||||
return client.StoresSet(ctx, in, opts...)
|
return client.StoresSet(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
func (c *Client) StoreReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
defer c.opMutex.Unlock()
|
defer c.opMutex.Unlock()
|
||||||
|
@ -318,25 +318,7 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
client := pb.NewBackendClient(conn)
|
client := pb.NewBackendClient(conn)
|
||||||
return client.StoresDelete(ctx, in, opts...)
|
return client.StoresReset(ctx, in, opts...)
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
|
||||||
if !c.parallel {
|
|
||||||
c.opMutex.Lock()
|
|
||||||
defer c.opMutex.Unlock()
|
|
||||||
}
|
|
||||||
c.setBusy(true)
|
|
||||||
defer c.setBusy(false)
|
|
||||||
c.wdMark()
|
|
||||||
defer c.wdUnMark()
|
|
||||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
client := pb.NewBackendClient(conn)
|
|
||||||
return client.StoresGet(ctx, in, opts...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||||
|
|
|
@ -71,12 +71,8 @@ func (e *embedBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions, o
|
||||||
return e.s.StoresSet(ctx, in)
|
return e.s.StoresSet(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *embedBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
func (e *embedBackend) StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
return e.s.StoresDelete(ctx, in)
|
return e.s.StoresReset(ctx, in)
|
||||||
}
|
|
||||||
|
|
||||||
func (e *embedBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
|
||||||
return e.s.StoresGet(ctx, in)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||||
|
|
|
@ -21,8 +21,7 @@ type LLM interface {
|
||||||
Status() (pb.StatusResponse, error)
|
Status() (pb.StatusResponse, error)
|
||||||
|
|
||||||
StoresSet(*pb.StoresSetOptions) error
|
StoresSet(*pb.StoresSetOptions) error
|
||||||
StoresDelete(*pb.StoresDeleteOptions) error
|
StoresReset(*pb.StoresResetOptions) error
|
||||||
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
|
|
||||||
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
||||||
|
|
||||||
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
||||||
|
|
|
@ -191,28 +191,16 @@ func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Re
|
||||||
return &pb.Result{Message: "Set key", Success: true}, nil
|
return &pb.Result{Message: "Set key", Success: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
|
func (s *server) StoresReset(ctx context.Context, in *pb.StoresResetOptions) (*pb.Result, error) {
|
||||||
if s.llm.Locking() {
|
if s.llm.Locking() {
|
||||||
s.llm.Lock()
|
s.llm.Lock()
|
||||||
defer s.llm.Unlock()
|
defer s.llm.Unlock()
|
||||||
}
|
}
|
||||||
err := s.llm.StoresDelete(in)
|
err := s.llm.StoresReset(in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
|
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
|
||||||
}
|
}
|
||||||
return &pb.Result{Message: "Deleted key", Success: true}, nil
|
return &pb.Result{Message: "Deleted mem db", Success: true}, nil
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
|
|
||||||
if s.llm.Locking() {
|
|
||||||
s.llm.Lock()
|
|
||||||
defer s.llm.Unlock()
|
|
||||||
}
|
|
||||||
res, err := s.llm.StoresGet(in)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
|
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wrapper for the GRPC client so that simple use cases are handled without verbosity
|
|
||||||
|
|
||||||
// SetCols sets multiple key-value pairs in the store
|
|
||||||
// It's in columnar format so that keys[i] is associated with values[i]
|
|
||||||
func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error {
|
|
||||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
|
||||||
for i, k := range keys {
|
|
||||||
protoKeys[i] = &proto.StoresKey{
|
|
||||||
Floats: k,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
protoValues := make([]*proto.StoresValue, len(values))
|
|
||||||
for i, v := range values {
|
|
||||||
protoValues[i] = &proto.StoresValue{
|
|
||||||
Bytes: v,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
setOpts := &proto.StoresSetOptions{
|
|
||||||
Keys: protoKeys,
|
|
||||||
Values: protoValues,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.StoresSet(ctx, setOpts)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Success {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("failed to set keys: %v", res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSingle sets a single key-value pair in the store
|
|
||||||
// Don't call this in a tight loop, instead use SetCols
|
|
||||||
func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) error {
|
|
||||||
return SetCols(ctx, c, [][]float32{key}, [][]byte{value})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteCols deletes multiple key-value pairs from the store
|
|
||||||
// It's in columnar format so that keys[i] is associated with values[i]
|
|
||||||
func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error {
|
|
||||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
|
||||||
for i, k := range keys {
|
|
||||||
protoKeys[i] = &proto.StoresKey{
|
|
||||||
Floats: k,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
deleteOpts := &proto.StoresDeleteOptions{
|
|
||||||
Keys: protoKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.StoresDelete(ctx, deleteOpts)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Success {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("failed to delete keys: %v", res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteSingle deletes a single key-value pair from the store
|
|
||||||
// Don't call this in a tight loop, instead use DeleteCols
|
|
||||||
func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error {
|
|
||||||
return DeleteCols(ctx, c, [][]float32{key})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCols gets multiple key-value pairs from the store
|
|
||||||
// It's in columnar format so that keys[i] is associated with values[i]
|
|
||||||
// Be warned the keys are sorted and will be returned in a different order than they were input
|
|
||||||
// There is no guarantee as to how the keys are sorted
|
|
||||||
func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) {
|
|
||||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
|
||||||
for i, k := range keys {
|
|
||||||
protoKeys[i] = &proto.StoresKey{
|
|
||||||
Floats: k,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
getOpts := &proto.StoresGetOptions{
|
|
||||||
Keys: protoKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.StoresGet(ctx, getOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ks := make([][]float32, len(res.Keys))
|
|
||||||
for i, k := range res.Keys {
|
|
||||||
ks[i] = k.Floats
|
|
||||||
}
|
|
||||||
vs := make([][]byte, len(res.Values))
|
|
||||||
for i, v := range res.Values {
|
|
||||||
vs[i] = v.Bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
return ks, vs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSingle gets a single key-value pair from the store
|
|
||||||
// Don't call this in a tight loop, instead use GetCols
|
|
||||||
func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, error) {
|
|
||||||
_, values, err := GetCols(ctx, c, [][]float32{key})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(values) > 0 {
|
|
||||||
return values[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to get key")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find similar keys to the given key. Returns the keys, values, and similarities
|
|
||||||
func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) {
|
|
||||||
findOpts := &proto.StoresFindOptions{
|
|
||||||
Key: &proto.StoresKey{
|
|
||||||
Floats: key,
|
|
||||||
},
|
|
||||||
TopK: int32(topk),
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.StoresFind(ctx, findOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ks := make([][]float32, len(res.Keys))
|
|
||||||
vs := make([][]byte, len(res.Values))
|
|
||||||
|
|
||||||
for i, k := range res.Keys {
|
|
||||||
ks[i] = k.Floats
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range res.Values {
|
|
||||||
vs[i] = v.Bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
return ks, vs, res.Similarities, nil
|
|
||||||
}
|
|
|
@ -70,6 +70,10 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should be able to set a key", func() {
|
It("should be able to set a key", func() {
|
||||||
|
sc.StoresSet(context.Background(), &store.StoresSetOptions{
|
||||||
|
Keys: [][]float32{{0.1, 0.2, 0.3}},
|
||||||
|
Values: [][]byte{[]byte("test")},
|
||||||
|
})
|
||||||
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue