mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-27 14:05:00 +00:00
refactor: move remaining api packages to core (#1731)
* core 1 * api/openai/files fix * core 2 - core/config * move over core api.go and tests to the start of core/http * move over localai specific endpoints to core/http, begin the service/endpoint split there * refactor big chunk on the plane * refactor chunk 2 on plane, next step: port and modify changes to request.go * easy fixes for request.go, major changes not done yet * lintfix * json tag lintfix? * gitignore and .keep files * strange fix attempt: rename the config dir?
This commit is contained in:
parent
316de82f51
commit
1c312685aa
50 changed files with 1440 additions and 1206 deletions
229
core/http/api.go
229
core/http/api.go
|
@ -3,122 +3,29 @@ package http
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/localai"
|
||||
"github.com/go-skynet/LocalAI/api/openai"
|
||||
config "github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/options"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/startup"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
|
||||
options := options.NewOptions(opts...)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if options.Debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
|
||||
|
||||
cl := config.NewConfigLoader()
|
||||
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.ListConfigs() {
|
||||
cfg, _ := cl.GetConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
if options.AssetsDestination != "" {
|
||||
// Extract files from the embedded FS
|
||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
options.Loader.StopAllGRPC()
|
||||
}()
|
||||
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
options.Loader,
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
options.Loader.SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
}
|
||||
|
||||
return options, cl, nil
|
||||
}
|
||||
|
||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
|
||||
options, cl, err := Startup(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||
}
|
||||
|
||||
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.DisableMessage,
|
||||
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: appConfig.DisableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
|
@ -139,7 +46,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
},
|
||||
})
|
||||
|
||||
if options.Debug {
|
||||
if appConfig.Debug {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||
}))
|
||||
|
@ -147,17 +54,25 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
|
||||
// Default middleware config
|
||||
|
||||
if !options.Debug {
|
||||
if !appConfig.Debug {
|
||||
app.Use(recover.New())
|
||||
}
|
||||
|
||||
if options.Metrics != nil {
|
||||
app.Use(metrics.APIMiddleware(options.Metrics))
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
app.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(options.ApiKeys) == 0 {
|
||||
if len(appConfig.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -172,10 +87,10 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
}
|
||||
|
||||
// Add file keys to options.ApiKeys
|
||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
||||
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
|
||||
}
|
||||
|
||||
if len(options.ApiKeys) == 0 {
|
||||
if len(appConfig.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -189,7 +104,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
for _, key := range options.ApiKeys {
|
||||
for _, key := range appConfig.ApiKeys {
|
||||
if apiKey == key {
|
||||
return c.Next()
|
||||
}
|
||||
|
@ -199,20 +114,20 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
|
||||
}
|
||||
|
||||
if options.CORS {
|
||||
if appConfig.CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if options.CORSAllowOrigins == "" {
|
||||
if appConfig.CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
}
|
||||
|
||||
// LocalAI API endpoints
|
||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
||||
galleryService.Start(options.Context, cl)
|
||||
galleryService := services.NewGalleryService(appConfig.ModelPath)
|
||||
galleryService.Start(appConfig.Context, cl)
|
||||
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
|
@ -220,69 +135,63 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
// Make sure directories exists
|
||||
os.MkdirAll(options.ImageDir, 0755)
|
||||
os.MkdirAll(options.AudioDir, 0755)
|
||||
os.MkdirAll(options.UploadDir, 0755)
|
||||
os.MkdirAll(options.Loader.ModelPath, 0755)
|
||||
|
||||
// Load upload json
|
||||
openai.LoadUploadConfig(options.UploadDir)
|
||||
openai.LoadUploadConfig(appConfig.UploadDir)
|
||||
|
||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||
|
||||
// files
|
||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options))
|
||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options))
|
||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options))
|
||||
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options))
|
||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
|
||||
|
||||
if options.ImageDir != "" {
|
||||
app.Static("/generated-images", options.ImageDir)
|
||||
if appConfig.ImageDir != "" {
|
||||
app.Static("/generated-images", appConfig.ImageDir)
|
||||
}
|
||||
|
||||
if options.AudioDir != "" {
|
||||
app.Static("/generated-audio", options.AudioDir)
|
||||
if appConfig.AudioDir != "" {
|
||||
app.Static("/generated-audio", appConfig.AudioDir)
|
||||
}
|
||||
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
|
@ -294,15 +203,15 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||
app.Get("/readyz", ok)
|
||||
|
||||
// Experimental Backend Statistics Module
|
||||
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
|
||||
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
|
||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||
|
||||
// models
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
|
||||
app.Get("/metrics", metrics.MetricsHandler())
|
||||
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
|
|
@ -13,9 +13,10 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
. "github.com/go-skynet/LocalAI/core/http"
|
||||
"github.com/go-skynet/LocalAI/core/options"
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
"github.com/go-skynet/LocalAI/core/startup"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
@ -127,25 +128,33 @@ var backendAssets embed.FS
|
|||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var modelLoader *model.ModelLoader
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
var modelDir string
|
||||
var bcl *config.BackendConfigLoader
|
||||
var ml *model.ModelLoader
|
||||
var applicationConfig *config.ApplicationConfig
|
||||
|
||||
commonOpts := []options.AppOption{
|
||||
options.WithDebug(true),
|
||||
options.WithDisableMessage(true),
|
||||
commonOpts := []config.AppOption{
|
||||
config.WithDebug(true),
|
||||
config.WithDisableMessage(true),
|
||||
}
|
||||
|
||||
Context("API with ephemeral models", func() {
|
||||
BeforeEach(func() {
|
||||
|
||||
BeforeEach(func(sc SpecContext) {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
modelDir = filepath.Join(tmpdir, "models")
|
||||
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||
err = os.Mkdir(backendAssetsDir, 0755)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
g := []gallery.GalleryModel{
|
||||
|
@ -172,16 +181,18 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
config.WithBackendAssets(backendAssets),
|
||||
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithContext(c),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
|
@ -198,15 +209,21 @@ var _ = Describe("API test", func() {
|
|||
}, "2m").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
AfterEach(func(sc SpecContext) {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
os.RemoveAll(tmpdir)
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = os.ReadDir(tmpdir)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("Applying models", func() {
|
||||
It("applies models from a gallery", func() {
|
||||
|
||||
It("applies models from a gallery", func() {
|
||||
models := getModels("http://127.0.0.1:9090/models/available")
|
||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||
|
@ -228,10 +245,10 @@ var _ = Describe("API test", func() {
|
|||
}, "360s", "10s").Should(Equal(true))
|
||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
|
||||
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
|
@ -253,6 +270,7 @@ var _ = Describe("API test", func() {
|
|||
}
|
||||
})
|
||||
It("overrides models", func() {
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||
Name: "bert",
|
||||
|
@ -270,7 +288,7 @@ var _ = Describe("API test", func() {
|
|||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
|
@ -294,7 +312,7 @@ var _ = Describe("API test", func() {
|
|||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
|
@ -483,8 +501,11 @@ var _ = Describe("API test", func() {
|
|||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
modelDir = filepath.Join(tmpdir, "models")
|
||||
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||
err = os.Mkdir(backendAssetsDir, 0755)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
galleries := []gallery.Gallery{
|
||||
|
@ -494,21 +515,20 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir))...,
|
||||
config.WithContext(c),
|
||||
config.WithAudioDir(tmpdir),
|
||||
config.WithImageDir(tmpdir),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
config.WithBackendAssets(backendAssets),
|
||||
config.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
|
@ -527,8 +547,14 @@ var _ = Describe("API test", func() {
|
|||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
os.RemoveAll(tmpdir)
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = os.ReadDir(tmpdir)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
It("installs and is capable to run tts", Label("tts"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
|
@ -599,20 +625,20 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Context("API query", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
modelPath := os.Getenv("MODELS_PATH")
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var err error
|
||||
|
||||
app, err = App(
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithMetrics(metricsService),
|
||||
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
|
@ -630,7 +656,10 @@ var _ = Describe("API test", func() {
|
|||
})
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
|
@ -811,20 +840,20 @@ var _ = Describe("API test", func() {
|
|||
|
||||
Context("Config file", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
modelPath := os.Getenv("MODELS_PATH")
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
var err error
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
|
@ -840,7 +869,10 @@ var _ = Describe("API test", func() {
|
|||
})
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
It("can generate chat completions from config file (list1)", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||
|
|
43
core/http/ctx/fiber.go
Normal file
43
core/http/ctx/fiber.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package fiberContext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ModelFromContext returns the model from the context
|
||||
// If no model is specified, it will take the first available
|
||||
// Takes a model string as input which should be the one received from the user request.
|
||||
// It returns the model name resolved from the context and an error if any.
|
||||
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
|
||||
if ctx.Params("model") != "" {
|
||||
modelInput = ctx.Params("model")
|
||||
}
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelInput == "" && !bearerExists && firstModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelInput = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelInput)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelInput = bearer
|
||||
}
|
||||
return modelInput, nil
|
||||
}
|
36
core/http/endpoints/localai/backend_monitor.go
Normal file
36
core/http/endpoints/localai/backend_monitor.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
146
core/http/endpoints/localai/gallery.go
Normal file
146
core/http/endpoints/localai/gallery.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []gallery.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
modelPath: modelPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.C <- gallery.GalleryOp{
|
||||
Req: input.GalleryModel,
|
||||
Id: uuid.String(),
|
||||
GalleryName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
return c.JSON(struct {
|
||||
ID string `json:"uuid"`
|
||||
StatusURL string `json:"status"`
|
||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||
for _, m := range models {
|
||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||
}
|
||||
dat, err := json.Marshal(models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s already exists", input.Name)
|
||||
}
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||
mgs.galleries = append(mgs.galleries, *input)
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||
}
|
||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
})
|
||||
return c.Send(nil)
|
||||
}
|
||||
}
|
43
core/http/endpoints/localai/metrics.go
Normal file
43
core/http/endpoints/localai/metrics.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
48
core/http/endpoints/localai/tts.go
Normal file
48
core/http/endpoints/localai/tts.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TTSRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
|
||||
if input.Backend != "" {
|
||||
cfg.Backend = input.Backend
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
609
core/http/endpoints/openai/chat.go
Normal file
609
core/http/endpoints/openai/chat.go
Normal file
|
@ -0,0 +1,609 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
emptyMessage := ""
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
result := ""
|
||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
// TODO: Change generated BNF grammar to be compliant with the schema so we can
|
||||
// stream the result token by token here.
|
||||
return true
|
||||
})
|
||||
|
||||
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
|
||||
noActionToRun := len(results) > 0 && results[0].name == noAction
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error handling question: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
|
||||
default:
|
||||
for i, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(responses)
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
noActionDescription := "use this action to answer without performing any action"
|
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
||||
}
|
||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
}
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream
|
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
FunctionName: i.Name,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
if toStream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
|
||||
switch {
|
||||
case toStream:
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
if !processFunctions {
|
||||
go process(predInput, input, config, ml, responses)
|
||||
} else {
|
||||
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||
}
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
for ev := range responses {
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
break
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if toolsCalled {
|
||||
finishReason = "tool_calls"
|
||||
} else if toolsCalled && len(input.Tools) == 0 {
|
||||
finishReason = "function_call"
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: &emptyMessage},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
|
||||
// no streaming mode
|
||||
default:
|
||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !processFunctions {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
|
||||
noActionsToRun := len(results) > 0 && results[0].name == noActionName
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error handling question: %s", err.Error())
|
||||
return
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||
default:
|
||||
toolChoice := schema.Choice{
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.FinishReason = "tool_calls"
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
// we need to append our result if we are using tools
|
||||
*c = append(*c, toolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(args), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = backend.Finetune(*config, prompt, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
return message, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU/GPU) another computation
|
||||
config.Grammar = ""
|
||||
images := []string{}
|
||||
for _, m := range input.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return "", err
|
||||
}
|
||||
return backend.Finetune(*config, prompt, prediction.Response), nil
|
||||
}
|
||||
|
||||
type funcCallResults struct {
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
|
||||
results := []funcCallResults{}
|
||||
|
||||
// TODO: use generics to avoid this code duplication
|
||||
if multipleResults {
|
||||
ss := []map[string]interface{}{}
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
for _, s := range ss {
|
||||
func_name, ok := s["function"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args, ok := s["arguments"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
} else {
|
||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name, ok := ss["function"]
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
199
core/http/endpoints/openai/completion.go
Normal file
199
core/http/endpoints/openai/completion.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
templateFile = config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, ml, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for k, i := range config.PromptStrings {
|
||||
if templateFile != "" {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
94
core/http/endpoints/openai/edit.go
Normal file
94
core/http/endpoints/openai/edit.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
templateFile = config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for _, i := range config.InputStrings {
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
79
core/http/endpoints/openai/embeddings.go
Normal file
79
core/http/endpoints/openai/embeddings.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readRequest(c, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
218
core/http/endpoints/openai/files.go
Normal file
218
core/http/endpoints/openai/files.go
Normal file
|
@ -0,0 +1,218 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var uploadedFiles []File
|
||||
|
||||
const uploadedFilesFile = "uploadedFiles.json"
|
||||
|
||||
// File represents the structure of a file object from the OpenAI API.
|
||||
type File struct {
|
||||
ID string `json:"id"` // Unique identifier for the file
|
||||
Object string `json:"object"` // Type of the object (e.g., "file")
|
||||
Bytes int `json:"bytes"` // Size of the file in bytes
|
||||
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
|
||||
Filename string `json:"filename"` // The name of the file
|
||||
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
||||
}
|
||||
|
||||
func saveUploadConfig(uploadDir string) {
|
||||
file, err := json.MarshalIndent(uploadedFiles, "", " ")
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func LoadUploadConfig(uploadPath string) {
|
||||
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
|
||||
|
||||
_, err := os.Stat(uploadFilePath)
|
||||
if os.IsNotExist(err) {
|
||||
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
|
||||
return
|
||||
}
|
||||
|
||||
file, err := os.ReadFile(uploadFilePath)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed to read file: %s", err)
|
||||
} else {
|
||||
err = json.Unmarshal(file, &uploadedFiles)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check the file size
|
||||
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
|
||||
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
|
||||
}
|
||||
|
||||
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||
if purpose == "" {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
|
||||
}
|
||||
|
||||
// Sanitize the filename to prevent directory traversal
|
||||
filename := utils.SanitizeFileName(file.Filename)
|
||||
|
||||
savePath := filepath.Join(appConfig.UploadDir, filename)
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
|
||||
}
|
||||
|
||||
err = c.SaveFile(file, savePath)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
|
||||
}
|
||||
|
||||
f := File{
|
||||
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
|
||||
Object: "file",
|
||||
Bytes: int(file.Size),
|
||||
CreatedAt: time.Now(),
|
||||
Filename: file.Filename,
|
||||
Purpose: purpose,
|
||||
}
|
||||
|
||||
uploadedFiles = append(uploadedFiles, f)
|
||||
saveUploadConfig(appConfig.UploadDir)
|
||||
return c.Status(fiber.StatusOK).JSON(f)
|
||||
}
|
||||
}
|
||||
|
||||
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
type ListFiles struct {
|
||||
Data []File
|
||||
Object string
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
var listFiles ListFiles
|
||||
|
||||
purpose := c.Query("purpose")
|
||||
if purpose == "" {
|
||||
listFiles.Data = uploadedFiles
|
||||
} else {
|
||||
for _, f := range uploadedFiles {
|
||||
if purpose == f.Purpose {
|
||||
listFiles.Data = append(listFiles.Data, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
listFiles.Object = "list"
|
||||
return c.Status(fiber.StatusOK).JSON(listFiles)
|
||||
}
|
||||
}
|
||||
|
||||
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
|
||||
id := c.Params("file_id")
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("file_id parameter is required")
|
||||
}
|
||||
|
||||
for _, f := range uploadedFiles {
|
||||
if id == f.ID {
|
||||
return &f, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to find file id %s", id)
|
||||
}
|
||||
|
||||
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
|
||||
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(file)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
|
||||
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
type DeleteStatus struct {
|
||||
Id string
|
||||
Object string
|
||||
Deleted bool
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||
if err != nil {
|
||||
// If the file doesn't exist then we should just continue to remove it
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
|
||||
}
|
||||
}
|
||||
|
||||
// Remove upload from list
|
||||
for i, f := range uploadedFiles {
|
||||
if f.ID == file.ID {
|
||||
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
saveUploadConfig(appConfig.UploadDir)
|
||||
return c.JSON(DeleteStatus{
|
||||
Id: file.ID,
|
||||
Object: "file",
|
||||
Deleted: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
file, err := getFileFromRequest(c)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
|
||||
return c.Send(fileContents)
|
||||
}
|
||||
}
|
287
core/http/endpoints/openai/files_test.go
Normal file
287
core/http/endpoints/openai/files_test.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
type ListFiles struct {
|
||||
Data []File
|
||||
Object string
|
||||
}
|
||||
|
||||
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
|
||||
// Preparing the mocked objects
|
||||
loader = &config.BackendConfigLoader{}
|
||||
|
||||
option = &config.ApplicationConfig{
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(option.UploadDir)
|
||||
|
||||
app = fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||
// Preparing the mocked objects
|
||||
loader := &config.BackendConfigLoader{}
|
||||
|
||||
option := &config.ApplicationConfig{
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(option.UploadDir)
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||
|
||||
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||
fmt.Println(f1)
|
||||
fmt.Printf("ERror: %v", err)
|
||||
|
||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||
})
|
||||
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
// Check if file exists in the disk
|
||||
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
|
||||
_, err := os.Stat(filePath)
|
||||
|
||||
assert.False(t, os.IsNotExist(err))
|
||||
assert.Equal(t, file.Bytes, 5242880)
|
||||
assert.NotEmpty(t, file.CreatedAt)
|
||||
assert.Equal(t, file.Filename, "test.txt")
|
||||
assert.Equal(t, file.Purpose, "fine-tune")
|
||||
})
|
||||
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||
resp, err := CallListFilesEndpoint(t, app, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
if len(listFiles.Data) != len(uploadedFiles) {
|
||||
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||
|
||||
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||
assert.NoError(t, err)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
if len(listFiles.Data) != 1 {
|
||||
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
listFiles := responseToListFile(t, resp)
|
||||
|
||||
if len(listFiles.Data) != 0 {
|
||||
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/files", nil)
|
||||
resp, _ := app.Test(req)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var listFiles ListFiles
|
||||
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
|
||||
t.Errorf("Failed to decode response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(listFiles.Data) != 0 {
|
||||
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
|
||||
var target string
|
||||
if purpose != "" {
|
||||
target = fmt.Sprintf("/files?purpose=%s", purpose)
|
||||
} else {
|
||||
target = "/files"
|
||||
}
|
||||
req := httptest.NewRequest("GET", target, nil)
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
|
||||
return app.Test(request)
|
||||
}
|
||||
|
||||
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||
// Create a file that exceeds the limit
|
||||
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||
|
||||
// Creating a new HTTP Request
|
||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
|
||||
// Create a file that exceeds the limit
|
||||
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||
|
||||
// Creating a new HTTP Request
|
||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||
resp, err := app.Test(req)
|
||||
assert.NoError(t, err)
|
||||
f := responseToFile(t, resp)
|
||||
|
||||
id := f.ID
|
||||
t.Cleanup(func() {
|
||||
_, err := CallFilesDeleteEndpoint(t, app, id)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
return f
|
||||
|
||||
}
|
||||
|
||||
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||
target := fmt.Sprintf("/files/%s", fileId)
|
||||
req := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
return app.Test(req)
|
||||
}
|
||||
|
||||
// Helper to create multi-part file
|
||||
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
|
||||
body := new(strings.Builder)
|
||||
writer := multipart.NewWriter(body)
|
||||
file, _ := os.Open(filePath)
|
||||
defer file.Close()
|
||||
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
|
||||
io.Copy(part, file)
|
||||
|
||||
if purpose != "" {
|
||||
_ = writer.WriteField("purpose", purpose)
|
||||
}
|
||||
|
||||
writer.Close()
|
||||
return strings.NewReader(body.String()), writer
|
||||
}
|
||||
|
||||
// Helper to create test files
|
||||
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
|
||||
err := os.MkdirAll(option.UploadDir, 0755)
|
||||
if err != nil {
|
||||
|
||||
t.Fatalf("Error MKDIR: %v", err)
|
||||
}
|
||||
|
||||
file, _ := os.Create(name)
|
||||
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(name)
|
||||
os.RemoveAll(option.UploadDir)
|
||||
})
|
||||
return file
|
||||
}
|
||||
|
||||
func bodyToString(resp *http.Response, t *testing.T) string {
|
||||
return string(bodyToByteArray(resp, t))
|
||||
}
|
||||
|
||||
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return bodyBytes
|
||||
}
|
||||
|
||||
func responseToFile(t *testing.T, resp *http.Response) File {
|
||||
var file File
|
||||
responseToString := bodyToString(resp, t)
|
||||
|
||||
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode response: %s", err)
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
|
||||
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
|
||||
var listFiles ListFiles
|
||||
responseToString := bodyToString(resp, t)
|
||||
|
||||
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to decode response: %s", err)
|
||||
}
|
||||
|
||||
return listFiles
|
||||
}
|
239
core/http/endpoints/openai/image.go
Normal file
239
core/http/endpoints/openai/image.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp("", "image")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
*
|
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "512x512"
|
||||
}'
|
||||
|
||||
*
|
||||
*/
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if m == "" {
|
||||
m = model.StableDiffusionBackend
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", m)
|
||||
|
||||
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
|
||||
fileData := []byte{}
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
out, err := downloadFile(input.File)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading file:%w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
switch config.Backend {
|
||||
case "stablediffusion":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
case "tinydream":
|
||||
config.Backend = model.TinyDreamBackend
|
||||
case "":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(input.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
return fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if input.ResponseFormat.Type == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
// src and clip_skip
|
||||
var result []schema.Item
|
||||
for _, i := range config.PromptStrings {
|
||||
n := input.N
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := config.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if input.Mode != 0 {
|
||||
mode = input.Mode
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
step = input.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = appConfig.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := &schema.Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: result,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
55
core/http/endpoints/openai/inference.go
Normal file
55
core/http/endpoints/openai/inference.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *config.BackendConfig,
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage := backend.TokenUsage{}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
tokenUsage.Completion += prediction.Usage.Completion
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, err
|
||||
}
|
69
core/http/endpoints/openai/list.go
Normal file
69
core/http/endpoints/openai/list.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
models, err := ml.ListModels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
|
||||
var filterFn func(name string) bool
|
||||
filter := c.Query("filter")
|
||||
|
||||
// If filter is not specified, do not filter the list by model name
|
||||
if filter == "" {
|
||||
filterFn = func(_ string) bool { return true }
|
||||
} else {
|
||||
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filterFn = func(name string) bool {
|
||||
return rxp.MatchString(name)
|
||||
}
|
||||
}
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||
|
||||
// Start with the known configurations
|
||||
for _, c := range cl.GetAllBackendConfigs() {
|
||||
if excludeConfigured {
|
||||
mm[c.Model] = nil
|
||||
}
|
||||
|
||||
if filterFn(c.Name) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
// Then iterate through the loose files:
|
||||
for _, m := range models {
|
||||
// And only adds them if they shouldn't be skipped.
|
||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(struct {
|
||||
Object string `json:"object"`
|
||||
Data []schema.OpenAIModel `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
}
|
||||
}
|
281
core/http/endpoints/openai/request.go
Normal file
281
core/http/endpoints/openai/request.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
input := new(schema.OpenAIRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
input.Context = ctx
|
||||
input.Cancel = cancel
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
||||
|
||||
return modelFile, input, err
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func getBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice grammar.Tool
|
||||
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := getBase64Image(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
fmt.Print("Failed encoding image", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.LLMConfig.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||
cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateRequestConfig(cfg, input)
|
||||
|
||||
return cfg, input, err
|
||||
}
|
71
core/http/endpoints/openai/transcription.go
Normal file
71
core/http/endpoints/openai/transcription.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue