feat(aio): add tests, update model definitions (#1880)

This commit is contained in:
Ettore Di Giacinto 2024-03-22 21:13:11 +01:00 committed by GitHub
parent 3bec467a91
commit 4b1ee0c170
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 461 additions and 40 deletions

View file

@ -0,0 +1,97 @@
package e2e_test
import (
"context"
"fmt"
"os"
"runtime"
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/sashabaranov/go-openai"
)
var pool *dockertest.Pool
var resource *dockertest.Resource
var client *openai.Client
var containerImage = os.Getenv("LOCALAI_IMAGE")
var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
var modelsDir = os.Getenv("LOCALAI_MODELS_DIR")
var apiPort = os.Getenv("LOCALAI_API_PORT")
func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI E2E test suite")
}
var _ = BeforeSuite(func() {
if containerImage == "" {
Fail("LOCALAI_IMAGE is not set")
}
if containerImageTag == "" {
Fail("LOCALAI_IMAGE_TAG is not set")
}
if apiPort == "" {
apiPort = "8080"
}
p, err := dockertest.NewPool("")
Expect(err).To(Not(HaveOccurred()))
Expect(p.Client.Ping()).To(Succeed())
pool = p
// get cwd
cwd, err := os.Getwd()
Expect(err).To(Not(HaveOccurred()))
md := cwd + "/models"
if modelsDir != "" {
md = modelsDir
}
proc := runtime.NumCPU()
options := &dockertest.RunOptions{
Repository: containerImage,
Tag: containerImageTag,
// Cmd: []string{"server", "/data"},
PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: apiPort}},
},
Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc)},
Mounts: []string{md + ":/models"},
}
r, err := pool.RunWithOptions(options)
Expect(err).To(Not(HaveOccurred()))
resource = r
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1"
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "20m").ShouldNot(HaveOccurred())
})
var _ = AfterSuite(func() {
Expect(pool.Purge(resource)).To(Succeed())
//dat, err := os.ReadFile(resource.Container.LogPath)
//Expect(err).To(Not(HaveOccurred()))
//Expect(string(dat)).To(ContainSubstring("GRPC Service Ready"))
//fmt.Println(string(dat))
})
var _ = AfterEach(func() {
//Expect(dbClient.Clear()).To(Succeed())
})

152
tests/e2e-aio/e2e_test.go Normal file
View file

@ -0,0 +1,152 @@
package e2e_test
import (
"context"
"fmt"
"io"
"net/http"
"os"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sashabaranov/go-openai"
)
var _ = Describe("E2E test", func() {
Context("Generating", func() {
BeforeEach(func() {
//
})
// Check that the GPU was used
AfterEach(func() {
//
})
Context("text", func() {
It("correctly", func() {
model := "gpt-4"
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{
Model: model, Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "How much is 2+2?",
},
}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content))
})
})
Context("images", func() {
It("correctly", func() {
resp, err := client.CreateImage(context.TODO(),
openai.ImageRequest{
Prompt: "test",
Size: openai.CreateImageSize512x512,
//ResponseFormat: openai.CreateImageResponseFormatURL,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("http://localhost:8080"), fmt.Sprint(resp.Data[0].URL))
})
})
Context("embeddings", func() {
It("correctly", func() {
resp, err := client.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{
Input: []string{"doc"},
Model: openai.AdaEmbeddingV2,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
})
})
Context("vision", func() {
It("correctly", func() {
model := "gpt-4-vision-preview"
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{
Model: model, Messages: []openai.ChatCompletionMessage{
{
Role: "user",
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: "What is in the image?",
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
Detail: openai.ImageURLDetailLow,
},
},
},
},
}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("wooden"), ContainSubstring("grass")), fmt.Sprint(resp.Choices[0].Message.Content))
})
})
Context("text to audio", func() {
It("correctly", func() {
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
Model: openai.TTSModel1,
Input: "Hello!",
Voice: openai.VoiceAlloy,
})
Expect(err).ToNot(HaveOccurred())
defer res.Close()
_, err = io.ReadAll(res)
Expect(err).ToNot(HaveOccurred())
})
})
Context("audio to text", func() {
It("correctly", func() {
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
file, err := downloadHttpFile(downloadURL)
Expect(err).ToNot(HaveOccurred())
req := openai.AudioRequest{
Model: openai.Whisper1,
FilePath: file,
}
resp, err := client.CreateTranscription(context.Background(), req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
})
})
})
})
func downloadHttpFile(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
tmpfile, err := os.CreateTemp("", "example")
if err != nil {
return "", err
}
defer tmpfile.Close()
_, err = io.Copy(tmpfile, resp.Body)
if err != nil {
return "", err
}
return tmpfile.Name(), nil
}