whisper: add tests and allow to set upload size (#237)

This commit is contained in:
Ettore Di Giacinto 2023-05-12 10:04:20 +02:00 committed by GitHub
parent 5115b2faa3
commit fd1df4e971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 21 deletions

View file

@ -12,7 +12,7 @@ import (
"github.com/rs/zerolog/log"
)
func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@ -20,6 +20,7 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: disableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {

View file

@ -3,6 +3,7 @@ package api_test
import (
"context"
"os"
"path/filepath"
. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/model"
@ -23,7 +24,7 @@ var _ = Describe("API test", func() {
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App("", modelLoader, 1, 512, false, true, true)
app = App("", modelLoader, 15, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -45,7 +46,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(3))
Expect(len(models.Models)).To(Equal(4))
Expect(models.Models[0].ID).To(Equal("testmodel"))
})
It("can generate completions", func() {
@ -81,13 +82,23 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:"))
})
PIt("transcribes audio", func() {
resp, err := client.CreateTranscription(
context.Background(),
openai.AudioRequest{
Model: openai.Whisper1,
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
})
})
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true)
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -108,7 +119,7 @@ var _ = Describe("API test", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(5))
Expect(len(models.Models)).To(Equal(6))
Expect(models.Models[0].ID).To(Equal("testmodel"))
})
It("can generate chat completions from config file", func() {
@ -134,5 +145,6 @@ var _ = Describe("API test", func() {
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
})
})

View file

@ -409,14 +409,13 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
return err
}
f, err := file.Open()
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
return err
}
defer f.Close()
log.Debug().Msgf("Audio file: %+v", file)
dir, err := os.MkdirTemp("", "whisper")
@ -428,26 +427,33 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Audio file %+v - %+v - err %+v", file.Filename, dst, err)
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := loader.BackendLoader("whisper", config.Model, []llama.ModelOption{}, uint32(config.Threads))
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
return err
}
w := whisperModel.(whisper.Model)
if whisperModel == nil {
return fmt.Errorf("could not load whisper model")
}
tr, err := whisperutil.Transcript(w, dst, input.Language)
w, ok := whisperModel.(whisper.Model)
if !ok {
return fmt.Errorf("loader returned non-whisper object")
}
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
return err
}
log.Debug().Msgf("Trascribed: %+v", tr)