feat(stores): Vector store backend (#1795)

Add simple vector store backend

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe 2024-03-22 20:14:04 +00:00 committed by GitHub
parent 4b1ee0c170
commit 643d85d2cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 3250 additions and 441 deletions

View file

@ -15,6 +15,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/downloader"
@ -122,6 +123,75 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
func postRequestJSON[B any](url string, bodyJson *B) error {
payload, err := json.Marshal(bodyJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return nil
}
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
payload, err := json.Marshal(reqJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return json.Unmarshal(body, respJson)
}
//go:embed backend-assets/*
var backendAssets embed.FS
@ -836,6 +906,78 @@ var _ = Describe("API test", func() {
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})
// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {
It("sets, gets, finds and deletes entries", func() {
ks := [][]float32{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}
vs := []string{
"test1",
"test2",
"test3",
}
setBody := schema.StoresSet{
Keys: ks,
Values: vs,
}
url := "http://127.0.0.1:9090/stores/"
err := postRequestJSON(url+"set", &setBody)
Expect(err).ToNot(HaveOccurred())
getBody := schema.StoresGet{
Keys: ks,
}
var getRespBody schema.StoresGetResponse
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
for i, v := range getRespBody.Keys {
if v[0] == 0.1 {
Expect(getRespBody.Values[i]).To(Equal("test1"))
} else if v[0] == 0.4 {
Expect(getRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(getRespBody.Values[i]).To(Equal("test3"))
}
}
deleteBody := schema.StoresDelete{
Keys: [][]float32{
{0.1, 0.2, 0.3},
},
}
err = postRequestJSON(url+"delete", &deleteBody)
Expect(err).ToNot(HaveOccurred())
findBody := schema.StoresFind{
Key: []float32{0.1, 0.3, 0.7},
Topk: 10,
}
var findRespBody schema.StoresFindResponse
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(findRespBody.Keys)).To(Equal(2))
for i, v := range findRespBody.Keys {
if v[0] == 0.4 {
Expect(findRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(findRespBody.Values[i]).To(Equal("test3"))
}
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
}
})
})
})
Context("Config file", func() {