feat(stores): Vector store backend (#1795)

Add simple vector store backend

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe 2024-03-22 20:14:04 +00:00 committed by GitHub
parent 4b1ee0c170
commit 643d85d2cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3250 additions and 441 deletions

View file

@ -172,6 +172,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
// openAI compatible API endpoint
// chat

View file

@ -15,6 +15,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/downloader"
@ -122,6 +123,75 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
func postRequestJSON[B any](url string, bodyJson *B) error {
payload, err := json.Marshal(bodyJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return nil
}
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
payload, err := json.Marshal(reqJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return json.Unmarshal(body, respJson)
}
//go:embed backend-assets/*
var backendAssets embed.FS
@ -836,6 +906,78 @@ var _ = Describe("API test", func() {
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})
// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {
It("sets, gets, finds and deletes entries", func() {
ks := [][]float32{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}
vs := []string{
"test1",
"test2",
"test3",
}
setBody := schema.StoresSet{
Keys: ks,
Values: vs,
}
url := "http://127.0.0.1:9090/stores/"
err := postRequestJSON(url+"set", &setBody)
Expect(err).ToNot(HaveOccurred())
getBody := schema.StoresGet{
Keys: ks,
}
var getRespBody schema.StoresGetResponse
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
for i, v := range getRespBody.Keys {
if v[0] == 0.1 {
Expect(getRespBody.Values[i]).To(Equal("test1"))
} else if v[0] == 0.4 {
Expect(getRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(getRespBody.Values[i]).To(Equal("test3"))
}
}
deleteBody := schema.StoresDelete{
Keys: [][]float32{
{0.1, 0.2, 0.3},
},
}
err = postRequestJSON(url+"delete", &deleteBody)
Expect(err).ToNot(HaveOccurred())
findBody := schema.StoresFind{
Key: []float32{0.1, 0.3, 0.7},
Topk: 10,
}
var findRespBody schema.StoresFindResponse
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(findRespBody.Keys)).To(Equal(2))
for i, v := range findRespBody.Keys {
if v[0] == 0.4 {
Expect(findRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(findRespBody.Values[i]).To(Equal("test3"))
}
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
}
})
})
})
Context("Config file", func() {

View file

@ -0,0 +1,121 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/store"
"github.com/gofiber/fiber/v2"
)
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresSet)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
vals := make([][]byte, len(input.Values))
for i, v := range input.Values {
vals[i] = []byte(v)
}
err = store.SetCols(c.Context(), sb, input.Keys, vals)
if err != nil {
return err
}
return c.Send(nil)
}
}
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresDelete)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
return err
}
return c.Send(nil)
}
}
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresGet)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
if err != nil {
return err
}
res := schema.StoresGetResponse{
Keys: keys,
Values: make([]string, len(vals)),
}
for i, v := range vals {
res.Values[i] = string(v)
}
return c.JSON(res)
}
}
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresFind)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
if err != nil {
return err
}
res := schema.StoresFindResponse{
Keys: keys,
Values: make([]string, len(vals)),
Similarities: similarities,
}
for i, v := range vals {
res.Values[i] = string(v)
}
return c.JSON(res)
}
}