tests: add gpt4all tests

This commit is contained in:
mudler 2023-06-05 00:07:15 +02:00
parent 564818f52c
commit f3916d3790
2 changed files with 29 additions and 1 deletions

View file

@ -232,6 +232,7 @@ test-models/testmodel:
cp tests/models_fixtures/* test-models cp tests/models_fixtures/* test-models
test: prepare test-models/testmodel test: prepare test-models/testmodel
cp -r backend-assets api
cp tests/models_fixtures/* test-models cp tests/models_fixtures/* test-models
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./api ./pkg C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./api ./pkg

View file

@ -3,6 +3,7 @@ package api_test
import ( import (
"bytes" "bytes"
"context" "context"
"embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -95,6 +96,9 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return return
} }
//go:embed backend-assets/*
var backendAssets embed.FS
var _ = Describe("API test", func() { var _ = Describe("API test", func() {
var app *fiber.App var app *fiber.App
@ -114,7 +118,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir) modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
app, err = App(WithContext(c), WithModelLoader(modelLoader)) app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -191,6 +195,29 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["backend"]).To(Equal("bert-embeddings"))
}) })
It("runs gpt4all", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
Name: "gpt4all-j",
Overrides: map[string]string{},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s").Should(Equal(true))
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
})
}) })
}) })