mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
feat(stores): Vector store backend (#1795)
Add simple vector store backend Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
parent
4b1ee0c170
commit
643d85d2cc
30 changed files with 3250 additions and 441 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
. "github.com/go-skynet/LocalAI/core/http"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/core/startup"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
|
@ -122,6 +123,75 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
|
|||
return
|
||||
}
|
||||
|
||||
func postRequestJSON[B any](url string, bodyJson *B) error {
|
||||
payload, err := json.Marshal(bodyJson)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
|
||||
payload, err := json.Marshal(reqJson)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return json.Unmarshal(body, respJson)
|
||||
}
|
||||
|
||||
//go:embed backend-assets/*
|
||||
var backendAssets embed.FS
|
||||
|
||||
|
@ -836,6 +906,78 @@ var _ = Describe("API test", func() {
|
|||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
})
|
||||
|
||||
// See tests/integration/stores_test
|
||||
Context("Stores", Label("stores"), func() {
|
||||
|
||||
It("sets, gets, finds and deletes entries", func() {
|
||||
ks := [][]float32{
|
||||
{0.1, 0.2, 0.3},
|
||||
{0.4, 0.5, 0.6},
|
||||
{0.7, 0.8, 0.9},
|
||||
}
|
||||
vs := []string{
|
||||
"test1",
|
||||
"test2",
|
||||
"test3",
|
||||
}
|
||||
setBody := schema.StoresSet{
|
||||
Keys: ks,
|
||||
Values: vs,
|
||||
}
|
||||
|
||||
url := "http://127.0.0.1:9090/stores/"
|
||||
err := postRequestJSON(url+"set", &setBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
getBody := schema.StoresGet{
|
||||
Keys: ks,
|
||||
}
|
||||
var getRespBody schema.StoresGetResponse
|
||||
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
|
||||
|
||||
for i, v := range getRespBody.Keys {
|
||||
if v[0] == 0.1 {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test1"))
|
||||
} else if v[0] == 0.4 {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test2"))
|
||||
} else {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test3"))
|
||||
}
|
||||
}
|
||||
|
||||
deleteBody := schema.StoresDelete{
|
||||
Keys: [][]float32{
|
||||
{0.1, 0.2, 0.3},
|
||||
},
|
||||
}
|
||||
err = postRequestJSON(url+"delete", &deleteBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
findBody := schema.StoresFind{
|
||||
Key: []float32{0.1, 0.3, 0.7},
|
||||
Topk: 10,
|
||||
}
|
||||
|
||||
var findRespBody schema.StoresFindResponse
|
||||
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(findRespBody.Keys)).To(Equal(2))
|
||||
|
||||
for i, v := range findRespBody.Keys {
|
||||
if v[0] == 0.4 {
|
||||
Expect(findRespBody.Values[i]).To(Equal("test2"))
|
||||
} else {
|
||||
Expect(findRespBody.Values[i]).To(Equal("test3"))
|
||||
}
|
||||
|
||||
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
|
||||
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Config file", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue