mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: move other backends to grpc
This finally makes everything more consistent Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
5dcfdbe51d
commit
1d0ed95a54
54 changed files with 3171 additions and 1712 deletions
233
api/api_test.go
233
api/api_test.go
|
@ -5,7 +5,9 @@ import (
|
|||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -24,6 +26,7 @@ import (
|
|||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type modelApplyRequest struct {
|
||||
|
@ -203,7 +206,7 @@ var _ = Describe("API test", func() {
|
|||
fmt.Println(response)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
|
||||
|
@ -245,9 +248,8 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -270,9 +272,8 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -299,14 +300,58 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
By("testing completion")
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
||||
|
||||
By("testing functions")
|
||||
resp2, err := client.CreateChatCompletion(
|
||||
context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: "openllama_3b",
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What is the weather like in San Francisco (celsius)?",
|
||||
},
|
||||
},
|
||||
Functions: []openai.FunctionDefinition{
|
||||
openai.FunctionDefinition{
|
||||
Name: "get_current_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"location": {
|
||||
Type: jsonschema.String,
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: jsonschema.String,
|
||||
Enum: []string{"celcius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp2.Choices)).To(Equal(1))
|
||||
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
|
||||
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
|
||||
|
||||
var res map[string]string
|
||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||
})
|
||||
|
||||
It("runs gpt4all", Label("gpt4all"), func() {
|
||||
|
@ -326,15 +371,126 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("Model gallery", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
galleries := []gallery.Gallery{
|
||||
{
|
||||
Name: "model-gallery",
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
app, err = App(
|
||||
options.WithContext(c),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
|
||||
// Wait for API to be ready
|
||||
client = openai.NewClientWithConfig(defaultConfig)
|
||||
Eventually(func() error {
|
||||
_, err := client.ListModels(context.TODO())
|
||||
return err
|
||||
}, "2m").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
os.RemoveAll(tmpdir)
|
||||
})
|
||||
It("installs and is capable to run tts", Label("tts"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
ID: "model-gallery@voice-en-us-kathleen-low",
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
// An HTTP Post to the /tts endpoint should return a wav audio file
|
||||
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`)))
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
|
||||
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
|
||||
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
|
||||
})
|
||||
It("installs and is capable to generate images", Label("stablediffusion"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
ID: "model-gallery@stablediffusion",
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
resp, err := http.Post(
|
||||
"http://127.0.0.1:9090/v1/images/generations",
|
||||
"application/json",
|
||||
bytes.NewBuffer([]byte(`{
|
||||
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
|
||||
"mode": 2, "seed":9000,
|
||||
"size": "256x256", "n":2}`)))
|
||||
// The response should contain an URL
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred(), string(dat))
|
||||
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
|
||||
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -401,7 +557,7 @@ var _ = Describe("API test", func() {
|
|||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:"))
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
|
||||
})
|
||||
It("transcribes audio", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
|
@ -446,14 +602,67 @@ var _ = Describe("API test", func() {
|
|||
})
|
||||
|
||||
Context("backends", func() {
|
||||
It("runs rwkv", func() {
|
||||
It("runs rwkv completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(resp.Choices[0].Text).To(Equal(" five."))
|
||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
||||
|
||||
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
|
||||
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Text
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(text).To(ContainSubstring("five"))
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
It("runs rwkv chat completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Delta.Content
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
|
||||
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue