mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: various refactorings
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
f2f1d7fe72
commit
5dcfdbe51d
28 changed files with 2130 additions and 1933 deletions
108
api/api.go
108
api/api.go
|
@ -3,8 +3,13 @@ package api
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/localai"
|
||||||
|
"github.com/go-skynet/LocalAI/api/openai"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
|
@ -13,18 +18,18 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func App(opts ...AppOption) (*fiber.App, error) {
|
func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||||
options := newOptions(opts...)
|
options := options.NewOptions(opts...)
|
||||||
|
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
if options.debug {
|
if options.Debug {
|
||||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return errors as JSON responses
|
// Return errors as JSON responses
|
||||||
app := fiber.New(fiber.Config{
|
app := fiber.New(fiber.Config{
|
||||||
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
DisableStartupMessage: options.disableMessage,
|
DisableStartupMessage: options.DisableMessage,
|
||||||
// Override default error handler
|
// Override default error handler
|
||||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||||
// Status code defaults to 500
|
// Status code defaults to 500
|
||||||
|
@ -38,44 +43,44 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||||
|
|
||||||
// Send custom error page
|
// Send custom error page
|
||||||
return ctx.Status(code).JSON(
|
return ctx.Status(code).JSON(
|
||||||
ErrorResponse{
|
openai.ErrorResponse{
|
||||||
Error: &APIError{Message: err.Error(), Code: code},
|
Error: &openai.APIError{Message: err.Error(), Code: code},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if options.debug {
|
if options.Debug {
|
||||||
app.Use(logger.New(logger.Config{
|
app.Use(logger.New(logger.Config{
|
||||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath)
|
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||||
|
|
||||||
cm := NewConfigMerger()
|
cm := config.NewConfigLoader()
|
||||||
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
|
if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil {
|
||||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.configFile != "" {
|
if options.ConfigFile != "" {
|
||||||
if err := cm.LoadConfigFile(options.configFile); err != nil {
|
if err := cm.LoadConfigFile(options.ConfigFile); err != nil {
|
||||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.debug {
|
if options.Debug {
|
||||||
for _, v := range cm.ListConfigs() {
|
for _, v := range cm.ListConfigs() {
|
||||||
cfg, _ := cm.GetConfig(v)
|
cfg, _ := cm.GetConfig(v)
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.assetsDestination != "" {
|
if options.AssetsDestination != "" {
|
||||||
// Extract files from the embedded FS
|
// Extract files from the embedded FS
|
||||||
err := assets.ExtractFiles(options.backendAssets, options.assetsDestination)
|
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||||
log.Debug().Msgf("Extracting backend assets files to %s", options.assetsDestination)
|
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||||
}
|
}
|
||||||
|
@ -84,31 +89,32 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
app.Use(recover.New())
|
app.Use(recover.New())
|
||||||
|
|
||||||
if options.preloadJSONModels != "" {
|
if options.PreloadJSONModels != "" {
|
||||||
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil {
|
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.preloadModelsFromPath != "" {
|
if options.PreloadModelsFromPath != "" {
|
||||||
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil {
|
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.cors {
|
if options.CORS {
|
||||||
if options.corsAllowOrigins == "" {
|
var c func(ctx *fiber.Ctx) error
|
||||||
app.Use(cors.New())
|
if options.CORSAllowOrigins == "" {
|
||||||
|
c = cors.New()
|
||||||
} else {
|
} else {
|
||||||
app.Use(cors.New(cors.Config{
|
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||||
AllowOrigins: options.corsAllowOrigins,
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app.Use(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
applier := newGalleryApplier(options.loader.ModelPath)
|
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
||||||
applier.start(options.context, cm)
|
galleryService.Start(options.Context, cm)
|
||||||
|
|
||||||
app.Get("/version", func(c *fiber.Ctx) error {
|
app.Get("/version", func(c *fiber.Ctx) error {
|
||||||
return c.JSON(struct {
|
return c.JSON(struct {
|
||||||
|
@ -116,43 +122,43 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
|
app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
|
||||||
app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath))
|
app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
|
||||||
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService))
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
|
app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options))
|
||||||
app.Post("/chat/completions", chatEndpoint(cm, options))
|
app.Post("/chat/completions", openai.ChatEndpoint(cm, options))
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", editEndpoint(cm, options))
|
app.Post("/v1/edits", openai.EditEndpoint(cm, options))
|
||||||
app.Post("/edits", editEndpoint(cm, options))
|
app.Post("/edits", openai.EditEndpoint(cm, options))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", completionEndpoint(cm, options))
|
app.Post("/v1/completions", openai.CompletionEndpoint(cm, options))
|
||||||
app.Post("/completions", completionEndpoint(cm, options))
|
app.Post("/completions", openai.CompletionEndpoint(cm, options))
|
||||||
app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options))
|
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options))
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
|
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||||
app.Post("/embeddings", embeddingsEndpoint(cm, options))
|
app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||||
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
|
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
|
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options))
|
||||||
app.Post("/tts", ttsEndpoint(cm, options))
|
app.Post("/tts", localai.TTSEndpoint(cm, options))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", imageEndpoint(cm, options))
|
app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options))
|
||||||
|
|
||||||
if options.imageDir != "" {
|
if options.ImageDir != "" {
|
||||||
app.Static("/generated-images", options.imageDir)
|
app.Static("/generated-images", options.ImageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.audioDir != "" {
|
if options.AudioDir != "" {
|
||||||
app.Static("/generated-audio", options.audioDir)
|
app.Static("/generated-audio", options.AudioDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := func(c *fiber.Ctx) error {
|
ok := func(c *fiber.Ctx) error {
|
||||||
|
@ -164,8 +170,8 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||||
app.Get("/readyz", ok)
|
app.Get("/readyz", ok)
|
||||||
|
|
||||||
// models
|
// models
|
||||||
app.Get("/v1/models", listModels(options.loader, cm))
|
app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm))
|
||||||
app.Get("/models", listModels(options.loader, cm))
|
app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm))
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api"
|
. "github.com/go-skynet/LocalAI/api"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
@ -154,9 +155,10 @@ var _ = Describe("API test", func() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app, err = App(WithContext(c),
|
app, err = App(
|
||||||
WithGalleries(galleries),
|
options.WithContext(c),
|
||||||
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
|
options.WithGalleries(galleries),
|
||||||
|
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
|
@ -342,7 +344,7 @@ var _ = Describe("API test", func() {
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
app, err = App(WithContext(c), WithModelLoader(modelLoader))
|
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
|
@ -462,7 +464,7 @@ var _ = Describe("API test", func() {
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE")))
|
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE")))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
|
|
107
api/backend/embeddings.go
Normal file
107
api/backend/embeddings.go
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
bert "github.com/go-skynet/go-bert.cpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
||||||
|
if !c.Embeddings {
|
||||||
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile := c.Model
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c)
|
||||||
|
|
||||||
|
var inferenceModel interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
opts := []model.Option{
|
||||||
|
model.WithLoadGRPCOpts(grpcOpts),
|
||||||
|
model.WithThreads(uint32(c.Threads)),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithModelFile(modelFile),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Backend == "" {
|
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
|
} else {
|
||||||
|
opts = append(opts, model.WithBackendString(c.Backend))
|
||||||
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var fn func() ([]float32, error)
|
||||||
|
switch model := inferenceModel.(type) {
|
||||||
|
case *grpc.Client:
|
||||||
|
fn = func() ([]float32, error) {
|
||||||
|
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
embeds := []int32{}
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
embeds = append(embeds, int32(t))
|
||||||
|
}
|
||||||
|
predictOptions.EmbeddingTokens = embeds
|
||||||
|
|
||||||
|
res, err := model.Embeddings(context.TODO(), predictOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Embeddings, nil
|
||||||
|
}
|
||||||
|
predictOptions.Embeddings = s
|
||||||
|
|
||||||
|
res, err := model.Embeddings(context.TODO(), predictOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bert embeddings
|
||||||
|
case *bert.Bert:
|
||||||
|
fn = func() ([]float32, error) {
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
|
||||||
|
}
|
||||||
|
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fn = func() ([]float32, error) {
|
||||||
|
return nil, fmt.Errorf("embeddings not supported by the backend")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() ([]float32, error) {
|
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
l := Lock(modelFile)
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
embeds, err := fn()
|
||||||
|
if err != nil {
|
||||||
|
return embeds, err
|
||||||
|
}
|
||||||
|
// Remove trailing 0s
|
||||||
|
for i := len(embeds) - 1; i >= 0; i-- {
|
||||||
|
if embeds[i] == 0.0 {
|
||||||
|
embeds = embeds[:i]
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return embeds, nil
|
||||||
|
}, nil
|
||||||
|
}
|
56
api/backend/image.go
Normal file
56
api/backend/image.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
||||||
|
if c.Backend != model.StableDiffusionBackend {
|
||||||
|
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||||
|
}
|
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader(
|
||||||
|
model.WithBackendString(c.Backend),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithThreads(uint32(c.Threads)),
|
||||||
|
model.WithModelFile(c.ImageGenerationAssets),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var fn func() error
|
||||||
|
switch model := inferenceModel.(type) {
|
||||||
|
case *stablediffusion.StableDiffusion:
|
||||||
|
fn = func() error {
|
||||||
|
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
fn = func() error {
|
||||||
|
return fmt.Errorf("creation of images not supported by the backend")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() error {
|
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock()
|
||||||
|
l, ok := mutexes[c.Backend]
|
||||||
|
if !ok {
|
||||||
|
m := &sync.Mutex{}
|
||||||
|
mutexes[c.Backend] = m
|
||||||
|
l = m
|
||||||
|
}
|
||||||
|
mutexMap.Unlock()
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
return fn()
|
||||||
|
}, nil
|
||||||
|
}
|
160
api/backend/llm.go
Normal file
160
api/backend/llm.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/donomii/go-rwkv.cpp"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/bloomz.cpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||||
|
supportStreams := false
|
||||||
|
modelFile := c.Model
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c)
|
||||||
|
|
||||||
|
var inferenceModel interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
opts := []model.Option{
|
||||||
|
model.WithLoadGRPCOpts(grpcOpts),
|
||||||
|
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithModelFile(modelFile),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Backend == "" {
|
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
|
} else {
|
||||||
|
opts = append(opts, model.WithBackendString(c.Backend))
|
||||||
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var fn func() (string, error)
|
||||||
|
|
||||||
|
switch model := inferenceModel.(type) {
|
||||||
|
case *rwkv.RwkvState:
|
||||||
|
supportStreams = true
|
||||||
|
|
||||||
|
fn = func() (string, error) {
|
||||||
|
stopWord := "\n"
|
||||||
|
if len(c.StopWords) > 0 {
|
||||||
|
stopWord = c.StopWords[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := model.ProcessInput(s); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
case *bloomz.Bloomz:
|
||||||
|
fn = func() (string, error) {
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []bloomz.PredictOption{
|
||||||
|
bloomz.SetTemperature(c.Temperature),
|
||||||
|
bloomz.SetTopP(c.TopP),
|
||||||
|
bloomz.SetTopK(c.TopK),
|
||||||
|
bloomz.SetTokens(c.Maxtokens),
|
||||||
|
bloomz.SetThreads(c.Threads),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Seed != 0 {
|
||||||
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.Predict(
|
||||||
|
s,
|
||||||
|
predictOptions...,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
case *grpc.Client:
|
||||||
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||||
|
supportStreams = true
|
||||||
|
fn = func() (string, error) {
|
||||||
|
|
||||||
|
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||||
|
opts.Prompt = s
|
||||||
|
if tokenCallback != nil {
|
||||||
|
ss := ""
|
||||||
|
err := model.PredictStream(context.TODO(), opts, func(s string) {
|
||||||
|
tokenCallback(s)
|
||||||
|
ss += s
|
||||||
|
})
|
||||||
|
return ss, err
|
||||||
|
} else {
|
||||||
|
reply, err := model.Predict(context.TODO(), opts)
|
||||||
|
return reply.Message, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *langchain.HuggingFace:
|
||||||
|
fn = func() (string, error) {
|
||||||
|
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []langchain.PredictOption{
|
||||||
|
langchain.SetModel(c.Model),
|
||||||
|
langchain.SetMaxTokens(c.Maxtokens),
|
||||||
|
langchain.SetTemperature(c.Temperature),
|
||||||
|
langchain.SetStopWords(c.StopWords),
|
||||||
|
}
|
||||||
|
|
||||||
|
pred, er := model.PredictHuggingFace(s, predictOptions...)
|
||||||
|
if er != nil {
|
||||||
|
return "", er
|
||||||
|
}
|
||||||
|
return pred.Completion, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() (string, error) {
|
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
l := Lock(modelFile)
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
res, err := fn()
|
||||||
|
if tokenCallback != nil && !supportStreams {
|
||||||
|
tokenCallback(res)
|
||||||
|
}
|
||||||
|
return res, err
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||||
|
var mu sync.Mutex = sync.Mutex{}
|
||||||
|
|
||||||
|
func Finetune(config config.Config, input, prediction string) string {
|
||||||
|
if config.Echo {
|
||||||
|
prediction = input + prediction
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range config.Cutstrings {
|
||||||
|
mu.Lock()
|
||||||
|
reg, ok := cutstrings[c]
|
||||||
|
if !ok {
|
||||||
|
cutstrings[c] = regexp.MustCompile(c)
|
||||||
|
reg = cutstrings[c]
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
prediction = reg.ReplaceAllString(prediction, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range config.TrimSpace {
|
||||||
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||||
|
}
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
}
|
22
api/backend/lock.go
Normal file
22
api/backend/lock.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
var mutexMap sync.Mutex
|
||||||
|
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||||
|
|
||||||
|
func Lock(s string) *sync.Mutex {
|
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock()
|
||||||
|
l, ok := mutexes[s]
|
||||||
|
if !ok {
|
||||||
|
m := &sync.Mutex{}
|
||||||
|
mutexes[s] = m
|
||||||
|
l = m
|
||||||
|
}
|
||||||
|
mutexMap.Unlock()
|
||||||
|
l.Lock()
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
98
api/backend/options.go
Normal file
98
api/backend/options.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||||
|
"github.com/go-skynet/bloomz.cpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func langchainOptions(c config.Config) []langchain.PredictOption {
|
||||||
|
return []langchain.PredictOption{
|
||||||
|
langchain.SetModel(c.Model),
|
||||||
|
langchain.SetMaxTokens(c.Maxtokens),
|
||||||
|
langchain.SetTemperature(c.Temperature),
|
||||||
|
langchain.SetStopWords(c.StopWords),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bloomzOptions(c config.Config) []bloomz.PredictOption {
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []bloomz.PredictOption{
|
||||||
|
bloomz.SetTemperature(c.Temperature),
|
||||||
|
bloomz.SetTopP(c.TopP),
|
||||||
|
bloomz.SetTopK(c.TopK),
|
||||||
|
bloomz.SetTokens(c.Maxtokens),
|
||||||
|
bloomz.SetThreads(c.Threads),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Seed != 0 {
|
||||||
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
||||||
|
}
|
||||||
|
return predictOptions
|
||||||
|
}
|
||||||
|
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
||||||
|
b := 512
|
||||||
|
if c.Batch != 0 {
|
||||||
|
b = c.Batch
|
||||||
|
}
|
||||||
|
return &pb.ModelOptions{
|
||||||
|
ContextSize: int32(c.ContextSize),
|
||||||
|
Seed: int32(c.Seed),
|
||||||
|
NBatch: int32(b),
|
||||||
|
F16Memory: c.F16,
|
||||||
|
MLock: c.MMlock,
|
||||||
|
NUMA: c.NUMA,
|
||||||
|
Embeddings: c.Embeddings,
|
||||||
|
LowVRAM: c.LowVRAM,
|
||||||
|
NGPULayers: int32(c.NGPULayers),
|
||||||
|
MMap: c.MMap,
|
||||||
|
MainGPU: c.MainGPU,
|
||||||
|
Threads: int32(c.Threads),
|
||||||
|
TensorSplit: c.TensorSplit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
||||||
|
promptCachePath := ""
|
||||||
|
if c.PromptCachePath != "" {
|
||||||
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||||
|
os.MkdirAll(filepath.Dir(p), 0755)
|
||||||
|
promptCachePath = p
|
||||||
|
}
|
||||||
|
return &pb.PredictOptions{
|
||||||
|
Temperature: float32(c.Temperature),
|
||||||
|
TopP: float32(c.TopP),
|
||||||
|
TopK: int32(c.TopK),
|
||||||
|
Tokens: int32(c.Maxtokens),
|
||||||
|
Threads: int32(c.Threads),
|
||||||
|
PromptCacheAll: c.PromptCacheAll,
|
||||||
|
PromptCacheRO: c.PromptCacheRO,
|
||||||
|
PromptCachePath: promptCachePath,
|
||||||
|
F16KV: c.F16,
|
||||||
|
DebugMode: c.Debug,
|
||||||
|
Grammar: c.Grammar,
|
||||||
|
|
||||||
|
Mirostat: int32(c.Mirostat),
|
||||||
|
MirostatETA: float32(c.MirostatETA),
|
||||||
|
MirostatTAU: float32(c.MirostatTAU),
|
||||||
|
Debug: c.Debug,
|
||||||
|
StopPrompts: c.StopWords,
|
||||||
|
Repeat: int32(c.RepeatPenalty),
|
||||||
|
NKeep: int32(c.Keep),
|
||||||
|
Batch: int32(c.Batch),
|
||||||
|
IgnoreEOS: c.IgnoreEOS,
|
||||||
|
Seed: int32(c.Seed),
|
||||||
|
FrequencyPenalty: float32(c.FrequencyPenalty),
|
||||||
|
MLock: c.MMlock,
|
||||||
|
MMap: c.MMap,
|
||||||
|
MainGPU: c.MainGPU,
|
||||||
|
TensorSplit: c.TensorSplit,
|
||||||
|
TailFreeSamplingZ: float32(c.TFZ),
|
||||||
|
TypicalP: float32(c.TypicalP),
|
||||||
|
}
|
||||||
|
}
|
401
api/config.go
401
api/config.go
|
@ -1,401 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
OpenAIRequest `yaml:"parameters"`
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
StopWords []string `yaml:"stopwords"`
|
|
||||||
Cutstrings []string `yaml:"cutstrings"`
|
|
||||||
TrimSpace []string `yaml:"trimspace"`
|
|
||||||
ContextSize int `yaml:"context_size"`
|
|
||||||
F16 bool `yaml:"f16"`
|
|
||||||
NUMA bool `yaml:"numa"`
|
|
||||||
Threads int `yaml:"threads"`
|
|
||||||
Debug bool `yaml:"debug"`
|
|
||||||
Roles map[string]string `yaml:"roles"`
|
|
||||||
Embeddings bool `yaml:"embeddings"`
|
|
||||||
Backend string `yaml:"backend"`
|
|
||||||
TemplateConfig TemplateConfig `yaml:"template"`
|
|
||||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
|
||||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
|
||||||
Mirostat int `yaml:"mirostat"`
|
|
||||||
NGPULayers int `yaml:"gpu_layers"`
|
|
||||||
MMap bool `yaml:"mmap"`
|
|
||||||
MMlock bool `yaml:"mmlock"`
|
|
||||||
LowVRAM bool `yaml:"low_vram"`
|
|
||||||
|
|
||||||
TensorSplit string `yaml:"tensor_split"`
|
|
||||||
MainGPU string `yaml:"main_gpu"`
|
|
||||||
ImageGenerationAssets string `yaml:"asset_dir"`
|
|
||||||
|
|
||||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
|
||||||
|
|
||||||
Grammar string `yaml:"grammar"`
|
|
||||||
|
|
||||||
FunctionsConfig Functions `yaml:"function"`
|
|
||||||
|
|
||||||
PromptStrings, InputStrings []string
|
|
||||||
InputToken [][]int
|
|
||||||
functionCallString, functionCallNameString string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Functions struct {
|
|
||||||
DisableNoAction bool `yaml:"disable_no_action"`
|
|
||||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
|
||||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TemplateConfig struct {
|
|
||||||
Completion string `yaml:"completion"`
|
|
||||||
Functions string `yaml:"function"`
|
|
||||||
Chat string `yaml:"chat"`
|
|
||||||
Edit string `yaml:"edit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigMerger struct {
|
|
||||||
configs map[string]Config
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultConfig(modelFile string) *Config {
|
|
||||||
return &Config{
|
|
||||||
OpenAIRequest: defaultRequest(modelFile),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfigMerger() *ConfigMerger {
|
|
||||||
return &ConfigMerger{
|
|
||||||
configs: make(map[string]Config),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func ReadConfigFile(file string) ([]*Config, error) {
|
|
||||||
c := &[]*Config{}
|
|
||||||
f, err := os.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
|
||||||
}
|
|
||||||
if err := yaml.Unmarshal(f, c); err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return *c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) {
|
|
||||||
c := &Config{}
|
|
||||||
f, err := os.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
|
||||||
}
|
|
||||||
if err := yaml.Unmarshal(f, c); err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigFile(file string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
c, err := ReadConfigFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot load config file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cc := range c {
|
|
||||||
cm.configs[cc.Name] = *cc
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfig(file string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
c, err := ReadConfig(file)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot read config file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cm.configs[c.Name] = *c
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
v, exists := cm.configs[m]
|
|
||||||
return v, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigMerger) ListConfigs() []string {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
var res []string
|
|
||||||
for k := range cm.configs {
|
|
||||||
res = append(res, k)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigs(path string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
entries, err := os.ReadDir(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
files := make([]fs.FileInfo, 0, len(entries))
|
|
||||||
for _, entry := range entries {
|
|
||||||
info, err := entry.Info()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
files = append(files, info)
|
|
||||||
}
|
|
||||||
for _, file := range files {
|
|
||||||
// Skip templates, YAML and .keep files
|
|
||||||
if !strings.Contains(file.Name(), ".yaml") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
|
||||||
if err == nil {
|
|
||||||
cm.configs[c.Name] = *c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateConfig(config *Config, input *OpenAIRequest) {
|
|
||||||
if input.Echo {
|
|
||||||
config.Echo = input.Echo
|
|
||||||
}
|
|
||||||
if input.TopK != 0 {
|
|
||||||
config.TopK = input.TopK
|
|
||||||
}
|
|
||||||
if input.TopP != 0 {
|
|
||||||
config.TopP = input.TopP
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Grammar != "" {
|
|
||||||
config.Grammar = input.Grammar
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Temperature != 0 {
|
|
||||||
config.Temperature = input.Temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Maxtokens != 0 {
|
|
||||||
config.Maxtokens = input.Maxtokens
|
|
||||||
}
|
|
||||||
|
|
||||||
switch stop := input.Stop.(type) {
|
|
||||||
case string:
|
|
||||||
if stop != "" {
|
|
||||||
config.StopWords = append(config.StopWords, stop)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range stop {
|
|
||||||
if s, ok := pp.(string); ok {
|
|
||||||
config.StopWords = append(config.StopWords, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RepeatPenalty != 0 {
|
|
||||||
config.RepeatPenalty = input.RepeatPenalty
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Keep != 0 {
|
|
||||||
config.Keep = input.Keep
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Batch != 0 {
|
|
||||||
config.Batch = input.Batch
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.F16 {
|
|
||||||
config.F16 = input.F16
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.IgnoreEOS {
|
|
||||||
config.IgnoreEOS = input.IgnoreEOS
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Seed != 0 {
|
|
||||||
config.Seed = input.Seed
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Mirostat != 0 {
|
|
||||||
config.Mirostat = input.Mirostat
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatETA != 0 {
|
|
||||||
config.MirostatETA = input.MirostatETA
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatTAU != 0 {
|
|
||||||
config.MirostatTAU = input.MirostatTAU
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.TypicalP != 0 {
|
|
||||||
config.TypicalP = input.TypicalP
|
|
||||||
}
|
|
||||||
|
|
||||||
switch inputs := input.Input.(type) {
|
|
||||||
case string:
|
|
||||||
if inputs != "" {
|
|
||||||
config.InputStrings = append(config.InputStrings, inputs)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range inputs {
|
|
||||||
switch i := pp.(type) {
|
|
||||||
case string:
|
|
||||||
config.InputStrings = append(config.InputStrings, i)
|
|
||||||
case []interface{}:
|
|
||||||
tokens := []int{}
|
|
||||||
for _, ii := range i {
|
|
||||||
tokens = append(tokens, int(ii.(float64)))
|
|
||||||
}
|
|
||||||
config.InputToken = append(config.InputToken, tokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Can be either a string or an object
|
|
||||||
switch fnc := input.FunctionCall.(type) {
|
|
||||||
case string:
|
|
||||||
if fnc != "" {
|
|
||||||
config.functionCallString = fnc
|
|
||||||
}
|
|
||||||
case map[string]interface{}:
|
|
||||||
var name string
|
|
||||||
n, exists := fnc["name"]
|
|
||||||
if exists {
|
|
||||||
nn, e := n.(string)
|
|
||||||
if e {
|
|
||||||
name = nn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.functionCallNameString = name
|
|
||||||
}
|
|
||||||
|
|
||||||
switch p := input.Prompt.(type) {
|
|
||||||
case string:
|
|
||||||
config.PromptStrings = append(config.PromptStrings, p)
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range p {
|
|
||||||
if s, ok := pp.(string); ok {
|
|
||||||
config.PromptStrings = append(config.PromptStrings, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
|
|
||||||
input := new(OpenAIRequest)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFile := input.Model
|
|
||||||
|
|
||||||
if c.Params("model") != "" {
|
|
||||||
modelFile = c.Params("model")
|
|
||||||
}
|
|
||||||
|
|
||||||
received, _ := json.Marshal(input)
|
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received))
|
|
||||||
|
|
||||||
// Set model from bearer token, if available
|
|
||||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
|
||||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
|
||||||
|
|
||||||
// If no model was specified, take the first available
|
|
||||||
if modelFile == "" && !bearerExists && randomModel {
|
|
||||||
models, _ := loader.ListModels()
|
|
||||||
if len(models) > 0 {
|
|
||||||
modelFile = models[0]
|
|
||||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("No model specified, returning error")
|
|
||||||
return "", nil, fmt.Errorf("no model specified")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a model is found in bearer token takes precedence
|
|
||||||
if bearerExists {
|
|
||||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
|
||||||
modelFile = bearer
|
|
||||||
}
|
|
||||||
return modelFile, input, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
|
|
||||||
// Load a config file if present after the model name
|
|
||||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
|
||||||
|
|
||||||
var config *Config
|
|
||||||
|
|
||||||
defaults := func() {
|
|
||||||
config = defaultConfig(modelFile)
|
|
||||||
config.ContextSize = ctx
|
|
||||||
config.Threads = threads
|
|
||||||
config.F16 = f16
|
|
||||||
config.Debug = debug
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, exists := cm.GetConfig(modelFile)
|
|
||||||
if !exists {
|
|
||||||
if _, err := os.Stat(modelConfig); err == nil {
|
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
|
||||||
}
|
|
||||||
cfg, exists = cm.GetConfig(modelFile)
|
|
||||||
if exists {
|
|
||||||
config = &cfg
|
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
config = &cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
updateConfig(config, input)
|
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if config.Threads == 0 {
|
|
||||||
if threads != 0 {
|
|
||||||
config.Threads = threads
|
|
||||||
} else {
|
|
||||||
config.Threads = 4
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
|
||||||
if debug {
|
|
||||||
config.Debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, input, nil
|
|
||||||
}
|
|
209
api/config/config.go
Normal file
209
api/config/config.go
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
package api_config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
PredictionOptions `yaml:"parameters"`
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
StopWords []string `yaml:"stopwords"`
|
||||||
|
Cutstrings []string `yaml:"cutstrings"`
|
||||||
|
TrimSpace []string `yaml:"trimspace"`
|
||||||
|
ContextSize int `yaml:"context_size"`
|
||||||
|
F16 bool `yaml:"f16"`
|
||||||
|
NUMA bool `yaml:"numa"`
|
||||||
|
Threads int `yaml:"threads"`
|
||||||
|
Debug bool `yaml:"debug"`
|
||||||
|
Roles map[string]string `yaml:"roles"`
|
||||||
|
Embeddings bool `yaml:"embeddings"`
|
||||||
|
Backend string `yaml:"backend"`
|
||||||
|
TemplateConfig TemplateConfig `yaml:"template"`
|
||||||
|
MirostatETA float64 `yaml:"mirostat_eta"`
|
||||||
|
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||||
|
Mirostat int `yaml:"mirostat"`
|
||||||
|
NGPULayers int `yaml:"gpu_layers"`
|
||||||
|
MMap bool `yaml:"mmap"`
|
||||||
|
MMlock bool `yaml:"mmlock"`
|
||||||
|
LowVRAM bool `yaml:"low_vram"`
|
||||||
|
|
||||||
|
TensorSplit string `yaml:"tensor_split"`
|
||||||
|
MainGPU string `yaml:"main_gpu"`
|
||||||
|
ImageGenerationAssets string `yaml:"asset_dir"`
|
||||||
|
|
||||||
|
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||||
|
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||||
|
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||||
|
|
||||||
|
Grammar string `yaml:"grammar"`
|
||||||
|
|
||||||
|
PromptStrings, InputStrings []string
|
||||||
|
InputToken [][]int
|
||||||
|
functionCallString, functionCallNameString string
|
||||||
|
|
||||||
|
FunctionsConfig Functions `yaml:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Functions struct {
|
||||||
|
DisableNoAction bool `yaml:"disable_no_action"`
|
||||||
|
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||||
|
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateConfig struct {
|
||||||
|
Completion string `yaml:"completion"`
|
||||||
|
Functions string `yaml:"function"`
|
||||||
|
Chat string `yaml:"chat"`
|
||||||
|
Edit string `yaml:"edit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigLoader struct {
|
||||||
|
configs map[string]Config
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallString(s string) {
|
||||||
|
c.functionCallString = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallNameString(s string) {
|
||||||
|
c.functionCallNameString = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ShouldUseFunctions() bool {
|
||||||
|
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ShouldCallSpecificFunction() bool {
|
||||||
|
return len(c.functionCallNameString) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) FunctionToCall() string {
|
||||||
|
return c.functionCallNameString
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultPredictOptions(modelFile string) PredictionOptions {
|
||||||
|
return PredictionOptions{
|
||||||
|
TopP: 0.7,
|
||||||
|
TopK: 80,
|
||||||
|
Maxtokens: 512,
|
||||||
|
Temperature: 0.9,
|
||||||
|
Model: modelFile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultConfig(modelFile string) *Config {
|
||||||
|
return &Config{
|
||||||
|
PredictionOptions: defaultPredictOptions(modelFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigLoader() *ConfigLoader {
|
||||||
|
return &ConfigLoader{
|
||||||
|
configs: make(map[string]Config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func ReadConfigFile(file string) ([]*Config, error) {
|
||||||
|
c := &[]*Config{}
|
||||||
|
f, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
|
}
|
||||||
|
if err := yaml.Unmarshal(f, c); err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return *c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadConfig(file string) (*Config, error) {
|
||||||
|
c := &Config{}
|
||||||
|
f, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
|
}
|
||||||
|
if err := yaml.Unmarshal(f, c); err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
c, err := ReadConfigFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot load config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cc := range c {
|
||||||
|
cm.configs[cc.Name] = *cc
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfig(file string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
c, err := ReadConfig(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot read config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.configs[c.Name] = *c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
v, exists := cm.configs[m]
|
||||||
|
return v, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) ListConfigs() []string {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
var res []string
|
||||||
|
for k := range cm.configs {
|
||||||
|
res = append(res, k)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
entries, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
files := make([]fs.FileInfo, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
files = append(files, info)
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
// Skip templates, YAML and .keep files
|
||||||
|
if !strings.Contains(file.Name(), ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||||
|
if err == nil {
|
||||||
|
cm.configs[c.Name] = *c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,8 +1,10 @@
|
||||||
package api
|
package api_config_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
. "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -26,29 +28,29 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("Test LoadConfigs", func() {
|
It("Test LoadConfigs", func() {
|
||||||
cm := NewConfigMerger()
|
cm := NewConfigLoader()
|
||||||
options := newOptions()
|
opts := options.NewOptions()
|
||||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||||
WithModelLoader(modelLoader)(options)
|
options.WithModelLoader(modelLoader)(opts)
|
||||||
|
|
||||||
err := cm.LoadConfigs(options.loader.ModelPath)
|
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(cm.configs).ToNot(BeNil())
|
Expect(cm.ListConfigs()).ToNot(BeNil())
|
||||||
|
|
||||||
// config should includes gpt4all models's api.config
|
// config should includes gpt4all models's api.config
|
||||||
Expect(cm.configs).To(HaveKey("gpt4all"))
|
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
|
||||||
|
|
||||||
// config should includes gpt2 models's api.config
|
// config should includes gpt2 models's api.config
|
||||||
Expect(cm.configs).To(HaveKey("gpt4all-2"))
|
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
|
||||||
|
|
||||||
// config should includes text-embedding-ada-002 models's api.config
|
// config should includes text-embedding-ada-002 models's api.config
|
||||||
Expect(cm.configs).To(HaveKey("text-embedding-ada-002"))
|
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||||
|
|
||||||
// config should includes rwkv_test models's api.config
|
// config should includes rwkv_test models's api.config
|
||||||
Expect(cm.configs).To(HaveKey("rwkv_test"))
|
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
|
||||||
|
|
||||||
// config should includes whisper-1 models's api.config
|
// config should includes whisper-1 models's api.config
|
||||||
Expect(cm.configs).To(HaveKey("whisper-1"))
|
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
37
api/config/prediction.go
Normal file
37
api/config/prediction.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package api_config
|
||||||
|
|
||||||
|
type PredictionOptions struct {
|
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Language string `json:"language"`
|
||||||
|
|
||||||
|
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||||
|
N int `json:"n"`
|
||||||
|
|
||||||
|
// Common options between all the API calls, part of the OpenAI spec
|
||||||
|
TopP float64 `json:"top_p" yaml:"top_p"`
|
||||||
|
TopK int `json:"top_k" yaml:"top_k"`
|
||||||
|
Temperature float64 `json:"temperature" yaml:"temperature"`
|
||||||
|
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
||||||
|
Echo bool `json:"echo"`
|
||||||
|
|
||||||
|
// Custom parameters - not present in the OpenAI API
|
||||||
|
Batch int `json:"batch" yaml:"batch"`
|
||||||
|
F16 bool `json:"f16" yaml:"f16"`
|
||||||
|
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||||
|
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||||
|
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||||
|
|
||||||
|
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
||||||
|
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
||||||
|
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
||||||
|
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||||
|
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||||
|
|
||||||
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||||
|
Seed int `json:"seed" yaml:"seed"`
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
json "github.com/json-iterator/go"
|
json "github.com/json-iterator/go"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -38,7 +39,7 @@ type galleryApplier struct {
|
||||||
statuses map[string]*galleryOpStatus
|
statuses map[string]*galleryOpStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGalleryApplier(modelPath string) *galleryApplier {
|
func NewGalleryService(modelPath string) *galleryApplier {
|
||||||
return &galleryApplier{
|
return &galleryApplier{
|
||||||
modelPath: modelPath,
|
modelPath: modelPath,
|
||||||
C: make(chan galleryOp),
|
C: make(chan galleryOp),
|
||||||
|
@ -47,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareModel applies a
|
// prepareModel applies a
|
||||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
|
||||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -72,7 +73,7 @@ func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
||||||
return g.statuses[s]
|
return g.statuses[s]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -148,7 +149,7 @@ type galleryModel struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||||
dat, err := os.ReadFile(s)
|
dat, err := os.ReadFile(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -156,7 +157,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gal
|
||||||
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||||
var requests []galleryModel
|
var requests []galleryModel
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
err := json.Unmarshal([]byte(s), &requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -174,7 +175,9 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []g
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
/// Endpoints
|
||||||
|
|
||||||
|
func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
status := g.getStatus(c.Params("uuid"))
|
status := g.getStatus(c.Params("uuid"))
|
||||||
|
@ -191,7 +194,7 @@ type GalleryModel struct {
|
||||||
gallery.GalleryModel
|
gallery.GalleryModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
input := new(GalleryModel)
|
input := new(GalleryModel)
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
|
@ -216,7 +219,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
|
func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
|
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package api
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/go-skynet/LocalAI/pkg/tts"
|
"github.com/go-skynet/LocalAI/pkg/tts"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
@ -32,7 +35,7 @@ func generateUniqueFileName(dir, baseName, ext string) string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
input := new(TTSRequest)
|
input := new(TTSRequest)
|
||||||
|
@ -41,10 +44,10 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
piperModel, err := o.loader.BackendLoader(
|
piperModel, err := o.Loader.BackendLoader(
|
||||||
model.WithBackendString(model.PiperBackend),
|
model.WithBackendString(model.PiperBackend),
|
||||||
model.WithModelFile(input.Model),
|
model.WithModelFile(input.Model),
|
||||||
model.WithAssetDir(o.assetsDestination))
|
model.WithAssetDir(o.AssetsDestination))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -58,16 +61,16 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return fmt.Errorf("loader returned non-piper object %+v", w)
|
return fmt.Errorf("loader returned non-piper object %+v", w)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(o.audioDir, 0755); err != nil {
|
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := generateUniqueFileName(o.audioDir, "piper", ".wav")
|
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||||
filePath := filepath.Join(o.audioDir, fileName)
|
filePath := filepath.Join(o.AudioDir, fileName)
|
||||||
|
|
||||||
modelPath := filepath.Join(o.loader.ModelPath, input.Model)
|
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
|
||||||
|
|
||||||
if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil {
|
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
973
api/openai.go
973
api/openai.go
|
@ -1,973 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// APIError provides error information returned by the OpenAI API.
|
|
||||||
type APIError struct {
|
|
||||||
Code any `json:"code,omitempty"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Param *string `json:"param,omitempty"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrorResponse struct {
|
|
||||||
Error *APIError `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIUsage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Item struct {
|
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Object string `json:"object,omitempty"`
|
|
||||||
|
|
||||||
// Images
|
|
||||||
URL string `json:"url,omitempty"`
|
|
||||||
B64JSON string `json:"b64_json,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIResponse struct {
|
|
||||||
Created int `json:"created,omitempty"`
|
|
||||||
Object string `json:"object,omitempty"`
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Choices []Choice `json:"choices,omitempty"`
|
|
||||||
Data []Item `json:"data,omitempty"`
|
|
||||||
|
|
||||||
Usage OpenAIUsage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Choice struct {
|
|
||||||
Index int `json:"index,omitempty"`
|
|
||||||
FinishReason string `json:"finish_reason,omitempty"`
|
|
||||||
Message *Message `json:"message,omitempty"`
|
|
||||||
Delta *Message `json:"delta,omitempty"`
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
// The message role
|
|
||||||
Role string `json:"role,omitempty" yaml:"role"`
|
|
||||||
// The message content
|
|
||||||
Content *string `json:"content" yaml:"content"`
|
|
||||||
// A result of a function call
|
|
||||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
|
|
||||||
// whisper
|
|
||||||
File string `json:"file" validate:"required"`
|
|
||||||
Language string `json:"language"`
|
|
||||||
//whisper/image
|
|
||||||
ResponseFormat string `json:"response_format"`
|
|
||||||
// image
|
|
||||||
Size string `json:"size"`
|
|
||||||
// Prompt is read only by completion/image API calls
|
|
||||||
Prompt interface{} `json:"prompt" yaml:"prompt"`
|
|
||||||
|
|
||||||
// Edit endpoint
|
|
||||||
Instruction string `json:"instruction" yaml:"instruction"`
|
|
||||||
Input interface{} `json:"input" yaml:"input"`
|
|
||||||
|
|
||||||
Stop interface{} `json:"stop" yaml:"stop"`
|
|
||||||
|
|
||||||
// Messages is read only by chat/completion API calls
|
|
||||||
Messages []Message `json:"messages" yaml:"messages"`
|
|
||||||
|
|
||||||
// A list of available functions to call
|
|
||||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
|
||||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
|
||||||
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Echo bool `json:"echo"`
|
|
||||||
// Common options between all the API calls
|
|
||||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
|
||||||
TopK int `json:"top_k" yaml:"top_k"`
|
|
||||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
|
||||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
|
||||||
|
|
||||||
N int `json:"n"`
|
|
||||||
|
|
||||||
// Custom parameters - not present in the OpenAI API
|
|
||||||
Batch int `json:"batch" yaml:"batch"`
|
|
||||||
F16 bool `json:"f16" yaml:"f16"`
|
|
||||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
|
||||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
|
||||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
|
||||||
|
|
||||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
|
||||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
|
||||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
|
||||||
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
|
||||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
|
||||||
|
|
||||||
Seed int `json:"seed" yaml:"seed"`
|
|
||||||
|
|
||||||
// Image (not supported by OpenAI)
|
|
||||||
Mode int `json:"mode"`
|
|
||||||
Step int `json:"step"`
|
|
||||||
|
|
||||||
// A grammar to constrain the LLM output
|
|
||||||
Grammar string `json:"grammar" yaml:"grammar"`
|
|
||||||
// A grammar object
|
|
||||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultRequest(modelFile string) OpenAIRequest {
|
|
||||||
return OpenAIRequest{
|
|
||||||
TopP: 0.7,
|
|
||||||
TopK: 80,
|
|
||||||
Maxtokens: 512,
|
|
||||||
Temperature: 0.9,
|
|
||||||
Model: modelFile,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
|
||||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
|
||||||
resp := OpenAIResponse{
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{
|
|
||||||
{
|
|
||||||
Index: 0,
|
|
||||||
Text: s,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Object: "text_completion",
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
|
||||||
|
|
||||||
responses <- resp
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
model, input, err := readInput(c, o.loader, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input)
|
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
if input.Stream {
|
|
||||||
log.Debug().Msgf("Stream request received")
|
|
||||||
c.Context().SetContentType("text/event-stream")
|
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
//c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache")
|
|
||||||
c.Set("Connection", "keep-alive")
|
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateFile := config.Model
|
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" {
|
|
||||||
templateFile = config.TemplateConfig.Completion
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Stream {
|
|
||||||
if len(config.PromptStrings) > 1 {
|
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput := config.PromptStrings[0]
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
|
||||||
Input string
|
|
||||||
}{
|
|
||||||
Input: predInput,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make(chan OpenAIResponse)
|
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses)
|
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
||||||
|
|
||||||
for ev := range responses {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := json.NewEncoder(&buf)
|
|
||||||
enc.Encode(ev)
|
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{
|
|
||||||
{
|
|
||||||
Index: 0,
|
|
||||||
FinishReason: "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Object: "text_completion",
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
||||||
w.WriteString("data: [DONE]\n\n")
|
|
||||||
w.Flush()
|
|
||||||
}))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []Choice
|
|
||||||
for _, i := range config.PromptStrings {
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
|
||||||
Input string
|
|
||||||
}{
|
|
||||||
Input: i,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) {
|
|
||||||
*c = append(*c, Choice{Text: s})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, r...)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "text_completion",
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
|
||||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
model, input, err := readInput(c, o.loader, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
items := []Item{}
|
|
||||||
|
|
||||||
for i, s := range config.InputToken {
|
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings, err := embedFn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, s := range config.InputStrings {
|
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings, err := embedFn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Data: items,
|
|
||||||
Object: "list",
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isEOS(s string) bool {
|
|
||||||
if s == "<|endoftext|>" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
|
||||||
initialMessage := OpenAIResponse{
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
responses <- initialMessage
|
|
||||||
|
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
|
||||||
resp := OpenAIResponse{
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
|
||||||
|
|
||||||
if s != "" && !isEOS(s) {
|
|
||||||
responses <- resp
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
}
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
processFunctions := false
|
|
||||||
funcs := grammar.Functions{}
|
|
||||||
model, input, err := readInput(c, o.loader, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Configuration read: %+v", config)
|
|
||||||
|
|
||||||
// Allow the user to set custom actions via config file
|
|
||||||
// to be "embedded" in each model
|
|
||||||
noActionName := "answer"
|
|
||||||
noActionDescription := "use this action to answer without performing any action"
|
|
||||||
|
|
||||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
|
||||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
|
||||||
}
|
|
||||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
|
||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
|
||||||
}
|
|
||||||
|
|
||||||
// process functions if we have any defined or if we have a function call string
|
|
||||||
if len(input.Functions) > 0 &&
|
|
||||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) {
|
|
||||||
log.Debug().Msgf("Response needs to process functions")
|
|
||||||
|
|
||||||
processFunctions = true
|
|
||||||
|
|
||||||
noActionGrammar := grammar.Function{
|
|
||||||
Name: noActionName,
|
|
||||||
Description: noActionDescription,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "The message to reply the user with",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the no action function
|
|
||||||
funcs = append(funcs, input.Functions...)
|
|
||||||
if !config.FunctionsConfig.DisableNoAction {
|
|
||||||
funcs = append(funcs, noActionGrammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force picking one of the functions by the request
|
|
||||||
if config.functionCallNameString != "" {
|
|
||||||
funcs = funcs.Select(config.functionCallNameString)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update input grammar
|
|
||||||
jsStruct := funcs.ToJSONStructure()
|
|
||||||
config.Grammar = jsStruct.Grammar("")
|
|
||||||
} else if input.JSONFunctionGrammarObject != nil {
|
|
||||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
|
||||||
}
|
|
||||||
|
|
||||||
// functions are not supported in stream mode (yet?)
|
|
||||||
toStream := input.Stream && !processFunctions
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameters: %+v", config)
|
|
||||||
|
|
||||||
var predInput string
|
|
||||||
|
|
||||||
mess := []string{}
|
|
||||||
for _, i := range input.Messages {
|
|
||||||
var content string
|
|
||||||
role := i.Role
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
|
||||||
roleFn := "assistant_function_call"
|
|
||||||
r := config.Roles[roleFn]
|
|
||||||
if r != "" {
|
|
||||||
role = roleFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r := config.Roles[role]
|
|
||||||
contentExists := i.Content != nil && *i.Content != ""
|
|
||||||
if r != "" {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(r, " ", *i.Content)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
|
||||||
} else {
|
|
||||||
content = fmt.Sprint(r, " ", string(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(*i.Content)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + string(j)
|
|
||||||
} else {
|
|
||||||
content = string(j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mess = append(mess, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput = strings.Join(mess, "\n")
|
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
log.Debug().Msgf("Stream request received")
|
|
||||||
c.Context().SetContentType("text/event-stream")
|
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
// c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache")
|
|
||||||
c.Set("Connection", "keep-alive")
|
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateFile := config.Model
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Chat
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Functions
|
|
||||||
}
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
|
||||||
Input string
|
|
||||||
Functions []grammar.Function
|
|
||||||
}{
|
|
||||||
Input: predInput,
|
|
||||||
Functions: funcs,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
|
||||||
if processFunctions {
|
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
responses := make(chan OpenAIResponse)
|
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses)
|
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
||||||
|
|
||||||
for ev := range responses {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := json.NewEncoder(&buf)
|
|
||||||
enc.Encode(ev)
|
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{
|
|
||||||
{
|
|
||||||
FinishReason: "stop",
|
|
||||||
Index: 0,
|
|
||||||
Delta: &Message{},
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
||||||
w.WriteString("data: [DONE]\n\n")
|
|
||||||
w.Flush()
|
|
||||||
}))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
|
|
||||||
if processFunctions {
|
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
|
||||||
ss := map[string]interface{}{}
|
|
||||||
json.Unmarshal([]byte(s), &ss)
|
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
|
||||||
func_name := ss["function"]
|
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
d, _ := json.Marshal(args)
|
|
||||||
|
|
||||||
ss["arguments"] = string(d)
|
|
||||||
ss["name"] = func_name
|
|
||||||
|
|
||||||
// if do nothing, reply with a message
|
|
||||||
if func_name == noActionName {
|
|
||||||
log.Debug().Msgf("nothing to do, computing a reply")
|
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
|
||||||
arguments := map[string]interface{}{}
|
|
||||||
json.Unmarshal([]byte(d), &arguments)
|
|
||||||
m, exists := arguments["message"]
|
|
||||||
if exists {
|
|
||||||
switch message := m.(type) {
|
|
||||||
case string:
|
|
||||||
if message != "" {
|
|
||||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
|
||||||
message = Finetune(*config, predInput, message)
|
|
||||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
|
||||||
|
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
|
||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
|
||||||
// Note: This costs (in term of CPU) another computation
|
|
||||||
config.Grammar = ""
|
|
||||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prediction, err := predFunc()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction)
|
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
|
|
||||||
} else {
|
|
||||||
// otherwise reply with the function call
|
|
||||||
*c = append(*c, Choice{
|
|
||||||
FinishReason: "function_call",
|
|
||||||
Message: &Message{Role: "assistant", FunctionCall: ss},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "chat.completion",
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", respData)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
model, input, err := readInput(c, o.loader, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
templateFile := config.Model
|
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" {
|
|
||||||
templateFile = config.TemplateConfig.Edit
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []Choice
|
|
||||||
for _, i := range config.InputStrings {
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
|
||||||
Input string
|
|
||||||
Instruction string
|
|
||||||
}{Input: i})
|
|
||||||
if err == nil {
|
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) {
|
|
||||||
*c = append(*c, Choice{Text: s})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, r...)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "edit",
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
|
||||||
|
|
||||||
/*
|
|
||||||
*
|
|
||||||
|
|
||||||
curl http://localhost:8080/v1/images/generations \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"prompt": "A cute baby sea otter",
|
|
||||||
"n": 1,
|
|
||||||
"size": "512x512"
|
|
||||||
}'
|
|
||||||
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
m, input, err := readInput(c, o.loader, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m == "" {
|
|
||||||
m = model.StableDiffusionBackend
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Loading model: %+v", m)
|
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
// XXX: Only stablediffusion is supported for now
|
|
||||||
if config.Backend == "" {
|
|
||||||
config.Backend = model.StableDiffusionBackend
|
|
||||||
}
|
|
||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x")
|
|
||||||
if len(sizeParts) != 2 {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
width, err := strconv.Atoi(sizeParts[0])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
height, err := strconv.Atoi(sizeParts[1])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
|
|
||||||
b64JSON := false
|
|
||||||
if input.ResponseFormat == "b64_json" {
|
|
||||||
b64JSON = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []Item
|
|
||||||
for _, i := range config.PromptStrings {
|
|
||||||
n := input.N
|
|
||||||
if input.N == 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
prompts := strings.Split(i, "|")
|
|
||||||
positive_prompt := prompts[0]
|
|
||||||
negative_prompt := ""
|
|
||||||
if len(prompts) > 1 {
|
|
||||||
negative_prompt = prompts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
mode := 0
|
|
||||||
step := 15
|
|
||||||
|
|
||||||
if input.Mode != 0 {
|
|
||||||
mode = input.Mode
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Step != 0 {
|
|
||||||
step = input.Step
|
|
||||||
}
|
|
||||||
|
|
||||||
tempDir := ""
|
|
||||||
if !b64JSON {
|
|
||||||
tempDir = o.imageDir
|
|
||||||
}
|
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := ioutil.TempFile(tempDir, "b64")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
outputFile.Close()
|
|
||||||
output := outputFile.Name() + ".png"
|
|
||||||
// Rename the temporary file
|
|
||||||
err = os.Rename(outputFile.Name(), output)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := c.BaseURL()
|
|
||||||
|
|
||||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := fn(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
item := &Item{}
|
|
||||||
|
|
||||||
if b64JSON {
|
|
||||||
defer os.RemoveAll(output)
|
|
||||||
data, err := os.ReadFile(output)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
|
||||||
} else {
|
|
||||||
base := filepath.Base(output)
|
|
||||||
item.URL = baseURL + "/generated-images/" + base
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, *item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &OpenAIResponse{
|
|
||||||
Data: result,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
|
||||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
m, input, err := readInput(c, o.loader, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
// retrieve the file data from the request
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
f, err := file.Open()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(dir)
|
|
||||||
|
|
||||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
|
||||||
dstFile, err := os.Create(dst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.Copy(dstFile, f); err != nil {
|
|
||||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
|
||||||
|
|
||||||
whisperModel, err := o.loader.BackendLoader(
|
|
||||||
model.WithBackendString(model.WhisperBackend),
|
|
||||||
model.WithModelFile(config.Model),
|
|
||||||
model.WithThreads(uint32(config.Threads)),
|
|
||||||
model.WithAssetDir(o.assetsDestination))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if whisperModel == nil {
|
|
||||||
return fmt.Errorf("could not load whisper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
w, ok := whisperModel.(whisper.Model)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("loader returned non-whisper object")
|
|
||||||
}
|
|
||||||
|
|
||||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
|
||||||
// TODO: handle different outputs here
|
|
||||||
return c.Status(http.StatusOK).JSON(tr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
models, err := loader.ListModels()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var mm map[string]interface{} = map[string]interface{}{}
|
|
||||||
|
|
||||||
dataModels := []OpenAIModel{}
|
|
||||||
for _, m := range models {
|
|
||||||
mm[m] = nil
|
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, k := range cm.ListConfigs() {
|
|
||||||
if _, exists := mm[k]; !exists {
|
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []OpenAIModel `json:"data"`
|
|
||||||
}{
|
|
||||||
Object: "list",
|
|
||||||
Data: dataModels,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
105
api/openai/api.go
Normal file
105
api/openai/api.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// APIError provides error information returned by the OpenAI API.
|
||||||
|
type APIError struct {
|
||||||
|
Code any `json:"code,omitempty"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Param *string `json:"param,omitempty"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error *APIError `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Item struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Object string `json:"object,omitempty"`
|
||||||
|
|
||||||
|
// Images
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64JSON string `json:"b64_json,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIResponse struct {
|
||||||
|
Created int `json:"created,omitempty"`
|
||||||
|
Object string `json:"object,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Choices []Choice `json:"choices,omitempty"`
|
||||||
|
Data []Item `json:"data,omitempty"`
|
||||||
|
|
||||||
|
Usage OpenAIUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
Message *Message `json:"message,omitempty"`
|
||||||
|
Delta *Message `json:"delta,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
// The message role
|
||||||
|
Role string `json:"role,omitempty" yaml:"role"`
|
||||||
|
// The message content
|
||||||
|
Content *string `json:"content" yaml:"content"`
|
||||||
|
// A result of a function call
|
||||||
|
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIRequest struct {
|
||||||
|
config.PredictionOptions
|
||||||
|
|
||||||
|
// whisper
|
||||||
|
File string `json:"file" validate:"required"`
|
||||||
|
//whisper/image
|
||||||
|
ResponseFormat string `json:"response_format"`
|
||||||
|
// image
|
||||||
|
Size string `json:"size"`
|
||||||
|
// Prompt is read only by completion/image API calls
|
||||||
|
Prompt interface{} `json:"prompt" yaml:"prompt"`
|
||||||
|
|
||||||
|
// Edit endpoint
|
||||||
|
Instruction string `json:"instruction" yaml:"instruction"`
|
||||||
|
Input interface{} `json:"input" yaml:"input"`
|
||||||
|
|
||||||
|
Stop interface{} `json:"stop" yaml:"stop"`
|
||||||
|
|
||||||
|
// Messages is read only by chat/completion API calls
|
||||||
|
Messages []Message `json:"messages" yaml:"messages"`
|
||||||
|
|
||||||
|
// A list of available functions to call
|
||||||
|
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||||
|
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||||
|
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
|
||||||
|
// Image (not supported by OpenAI)
|
||||||
|
Mode int `json:"mode"`
|
||||||
|
Step int `json:"step"`
|
||||||
|
|
||||||
|
// A grammar to constrain the LLM output
|
||||||
|
Grammar string `json:"grammar" yaml:"grammar"`
|
||||||
|
|
||||||
|
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
||||||
|
}
|
320
api/openai/chat.go
Normal file
320
api/openai/chat.go
Normal file
|
@ -0,0 +1,320 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||||
|
initialMessage := OpenAIResponse{
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Role: "assistant"}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
|
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||||
|
resp := OpenAIResponse{
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
processFunctions := false
|
||||||
|
funcs := grammar.Functions{}
|
||||||
|
model, input, err := readInput(c, o.Loader, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Configuration read: %+v", config)
|
||||||
|
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model
|
||||||
|
noActionName := "answer"
|
||||||
|
noActionDescription := "use this action to answer without performing any action"
|
||||||
|
|
||||||
|
if config.FunctionsConfig.NoActionFunctionName != "" {
|
||||||
|
noActionName = config.FunctionsConfig.NoActionFunctionName
|
||||||
|
}
|
||||||
|
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
||||||
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// process functions if we have any defined or if we have a function call string
|
||||||
|
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||||
|
log.Debug().Msgf("Response needs to process functions")
|
||||||
|
|
||||||
|
processFunctions = true
|
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{
|
||||||
|
Name: noActionName,
|
||||||
|
Description: noActionDescription,
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to reply the user with",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, input.Functions...)
|
||||||
|
if !config.FunctionsConfig.DisableNoAction {
|
||||||
|
funcs = append(funcs, noActionGrammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.FunctionToCall() != "" {
|
||||||
|
funcs = funcs.Select(config.FunctionToCall())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure()
|
||||||
|
config.Grammar = jsStruct.Grammar("")
|
||||||
|
} else if input.JSONFunctionGrammarObject != nil {
|
||||||
|
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
toStream := input.Stream && !processFunctions
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config)
|
||||||
|
|
||||||
|
var predInput string
|
||||||
|
|
||||||
|
mess := []string{}
|
||||||
|
for _, i := range input.Messages {
|
||||||
|
var content string
|
||||||
|
role := i.Role
|
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||||
|
roleFn := "assistant_function_call"
|
||||||
|
r := config.Roles[roleFn]
|
||||||
|
if r != "" {
|
||||||
|
role = roleFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := config.Roles[role]
|
||||||
|
contentExists := i.Content != nil && *i.Content != ""
|
||||||
|
if r != "" {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(r, " ", *i.Content)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||||
|
} else {
|
||||||
|
content = fmt.Sprint(r, " ", string(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(*i.Content)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + string(j)
|
||||||
|
} else {
|
||||||
|
content = string(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mess = append(mess, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput = strings.Join(mess, "\n")
|
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||||
|
|
||||||
|
if toStream {
|
||||||
|
log.Debug().Msgf("Stream request received")
|
||||||
|
c.Context().SetContentType("text/event-stream")
|
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
// c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache")
|
||||||
|
c.Set("Connection", "keep-alive")
|
||||||
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
|
}
|
||||||
|
|
||||||
|
templateFile := config.Model
|
||||||
|
|
||||||
|
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Chat
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||||
|
Input string
|
||||||
|
Functions []grammar.Function
|
||||||
|
}{
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
|
if processFunctions {
|
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
if toStream {
|
||||||
|
responses := make(chan OpenAIResponse)
|
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses)
|
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
|
for ev := range responses {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.Encode(ev)
|
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
FinishReason: "stop",
|
||||||
|
Index: 0,
|
||||||
|
Delta: &Message{},
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||||
|
w.WriteString("data: [DONE]\n\n")
|
||||||
|
w.Flush()
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
|
if processFunctions {
|
||||||
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
|
ss := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(s), &ss)
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name := ss["function"]
|
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
|
||||||
|
ss["arguments"] = string(d)
|
||||||
|
ss["name"] = func_name
|
||||||
|
|
||||||
|
// if do nothing, reply with a message
|
||||||
|
if func_name == noActionName {
|
||||||
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(d), &arguments)
|
||||||
|
m, exists := arguments["message"]
|
||||||
|
if exists {
|
||||||
|
switch message := m.(type) {
|
||||||
|
case string:
|
||||||
|
if message != "" {
|
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||||
|
message = backend.Finetune(*config, predInput, message)
|
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||||
|
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU) another computation
|
||||||
|
config.Grammar = ""
|
||||||
|
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction)
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
|
||||||
|
} else {
|
||||||
|
// otherwise reply with the function call
|
||||||
|
*c = append(*c, Choice{
|
||||||
|
FinishReason: "function_call",
|
||||||
|
Message: &Message{Role: "assistant", FunctionCall: ss},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "chat.completion",
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", respData)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
159
api/openai/completion.go
Normal file
159
api/openai/completion.go
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
|
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||||
|
resp := OpenAIResponse{
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Text: s,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Object: "text_completion",
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
model, input, err := readInput(c, o.Loader, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("`input`: %+v", input)
|
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
if input.Stream {
|
||||||
|
log.Debug().Msgf("Stream request received")
|
||||||
|
c.Context().SetContentType("text/event-stream")
|
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
//c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache")
|
||||||
|
c.Set("Connection", "keep-alive")
|
||||||
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
|
}
|
||||||
|
|
||||||
|
templateFile := config.Model
|
||||||
|
|
||||||
|
if config.TemplateConfig.Completion != "" {
|
||||||
|
templateFile = config.TemplateConfig.Completion
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Stream {
|
||||||
|
if len(config.PromptStrings) > 1 {
|
||||||
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||||
|
Input string
|
||||||
|
}{
|
||||||
|
Input: predInput,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := make(chan OpenAIResponse)
|
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses)
|
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
|
for ev := range responses {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.Encode(ev)
|
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Object: "text_completion",
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||||
|
w.WriteString("data: [DONE]\n\n")
|
||||||
|
w.Flush()
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []Choice
|
||||||
|
for _, i := range config.PromptStrings {
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||||
|
Input string
|
||||||
|
}{
|
||||||
|
Input: i,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
i = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
|
*c = append(*c, Choice{Text: s})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, r...)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "text_completion",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
67
api/openai/edit.go
Normal file
67
api/openai/edit.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
model, input, err := readInput(c, o.Loader, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
templateFile := config.Model
|
||||||
|
|
||||||
|
if config.TemplateConfig.Edit != "" {
|
||||||
|
templateFile = config.TemplateConfig.Edit
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []Choice
|
||||||
|
for _, i := range config.InputStrings {
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||||
|
Input string
|
||||||
|
Instruction string
|
||||||
|
}{Input: i})
|
||||||
|
if err == nil {
|
||||||
|
i = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
|
*c = append(*c, Choice{Text: s})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, r...)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "edit",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
70
api/openai/embeddings.go
Normal file
70
api/openai/embeddings.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
model, input, err := readInput(c, o.Loader, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
items := []Item{}
|
||||||
|
|
||||||
|
for i, s := range config.InputToken {
|
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := embedFn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range config.InputStrings {
|
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := embedFn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Data: items,
|
||||||
|
Object: "list",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
158
api/openai/image.go
Normal file
158
api/openai/image.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/images/generations \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"prompt": "A cute baby sea otter",
|
||||||
|
"n": 1,
|
||||||
|
"size": "512x512"
|
||||||
|
}'
|
||||||
|
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
m, input, err := readInput(c, o.Loader, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m == "" {
|
||||||
|
m = model.StableDiffusionBackend
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Loading model: %+v", m)
|
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
// XXX: Only stablediffusion is supported for now
|
||||||
|
if config.Backend == "" {
|
||||||
|
config.Backend = model.StableDiffusionBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
sizeParts := strings.Split(input.Size, "x")
|
||||||
|
if len(sizeParts) != 2 {
|
||||||
|
return fmt.Errorf("Invalid value for 'size'")
|
||||||
|
}
|
||||||
|
width, err := strconv.Atoi(sizeParts[0])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Invalid value for 'size'")
|
||||||
|
}
|
||||||
|
height, err := strconv.Atoi(sizeParts[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Invalid value for 'size'")
|
||||||
|
}
|
||||||
|
|
||||||
|
b64JSON := false
|
||||||
|
if input.ResponseFormat == "b64_json" {
|
||||||
|
b64JSON = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []Item
|
||||||
|
for _, i := range config.PromptStrings {
|
||||||
|
n := input.N
|
||||||
|
if input.N == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
prompts := strings.Split(i, "|")
|
||||||
|
positive_prompt := prompts[0]
|
||||||
|
negative_prompt := ""
|
||||||
|
if len(prompts) > 1 {
|
||||||
|
negative_prompt = prompts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := 0
|
||||||
|
step := 15
|
||||||
|
|
||||||
|
if input.Mode != 0 {
|
||||||
|
mode = input.Mode
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Step != 0 {
|
||||||
|
step = input.Step
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir := ""
|
||||||
|
if !b64JSON {
|
||||||
|
tempDir = o.ImageDir
|
||||||
|
}
|
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := ioutil.TempFile(tempDir, "b64")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
outputFile.Close()
|
||||||
|
output := outputFile.Name() + ".png"
|
||||||
|
// Rename the temporary file
|
||||||
|
err = os.Rename(outputFile.Name(), output)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
|
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fn(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &Item{}
|
||||||
|
|
||||||
|
if b64JSON {
|
||||||
|
defer os.RemoveAll(output)
|
||||||
|
data, err := os.ReadFile(output)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||||
|
} else {
|
||||||
|
base := filepath.Base(output)
|
||||||
|
item.URL = baseURL + "/generated-images/" + base
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, *item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &OpenAIResponse{
|
||||||
|
Data: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
36
api/openai/inference.go
Normal file
36
api/openai/inference.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
||||||
|
result := []Choice{}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the model function to call for the result
|
||||||
|
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction)
|
||||||
|
cb(prediction, &result)
|
||||||
|
|
||||||
|
//result = append(result, Choice{Text: prediction})
|
||||||
|
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
37
api/openai/list.go
Normal file
37
api/openai/list.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
models, err := loader.ListModels()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var mm map[string]interface{} = map[string]interface{}{}
|
||||||
|
|
||||||
|
dataModels := []OpenAIModel{}
|
||||||
|
for _, m := range models {
|
||||||
|
mm[m] = nil
|
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range cm.ListConfigs() {
|
||||||
|
if _, exists := mm[k]; !exists {
|
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []OpenAIModel `json:"data"`
|
||||||
|
}{
|
||||||
|
Object: "list",
|
||||||
|
Data: dataModels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
234
api/openai/request.go
Normal file
234
api/openai/request.go
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
|
||||||
|
input := new(OpenAIRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile := input.Model
|
||||||
|
|
||||||
|
if c.Params("model") != "" {
|
||||||
|
modelFile = c.Params("model")
|
||||||
|
}
|
||||||
|
|
||||||
|
received, _ := json.Marshal(input)
|
||||||
|
|
||||||
|
log.Debug().Msgf("Request received: %s", string(received))
|
||||||
|
|
||||||
|
// Set model from bearer token, if available
|
||||||
|
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||||
|
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||||
|
|
||||||
|
// If no model was specified, take the first available
|
||||||
|
if modelFile == "" && !bearerExists && randomModel {
|
||||||
|
models, _ := loader.ListModels()
|
||||||
|
if len(models) > 0 {
|
||||||
|
modelFile = models[0]
|
||||||
|
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("No model specified, returning error")
|
||||||
|
return "", nil, fmt.Errorf("no model specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a model is found in bearer token takes precedence
|
||||||
|
if bearerExists {
|
||||||
|
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||||
|
modelFile = bearer
|
||||||
|
}
|
||||||
|
return modelFile, input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateConfig(config *config.Config, input *OpenAIRequest) {
|
||||||
|
if input.Echo {
|
||||||
|
config.Echo = input.Echo
|
||||||
|
}
|
||||||
|
if input.TopK != 0 {
|
||||||
|
config.TopK = input.TopK
|
||||||
|
}
|
||||||
|
if input.TopP != 0 {
|
||||||
|
config.TopP = input.TopP
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Grammar != "" {
|
||||||
|
config.Grammar = input.Grammar
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Temperature != 0 {
|
||||||
|
config.Temperature = input.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Maxtokens != 0 {
|
||||||
|
config.Maxtokens = input.Maxtokens
|
||||||
|
}
|
||||||
|
|
||||||
|
switch stop := input.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
if stop != "" {
|
||||||
|
config.StopWords = append(config.StopWords, stop)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, pp := range stop {
|
||||||
|
if s, ok := pp.(string); ok {
|
||||||
|
config.StopWords = append(config.StopWords, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.RepeatPenalty != 0 {
|
||||||
|
config.RepeatPenalty = input.RepeatPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Keep != 0 {
|
||||||
|
config.Keep = input.Keep
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Batch != 0 {
|
||||||
|
config.Batch = input.Batch
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.F16 {
|
||||||
|
config.F16 = input.F16
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.IgnoreEOS {
|
||||||
|
config.IgnoreEOS = input.IgnoreEOS
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Seed != 0 {
|
||||||
|
config.Seed = input.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Mirostat != 0 {
|
||||||
|
config.Mirostat = input.Mirostat
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.MirostatETA != 0 {
|
||||||
|
config.MirostatETA = input.MirostatETA
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.MirostatTAU != 0 {
|
||||||
|
config.MirostatTAU = input.MirostatTAU
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.TypicalP != 0 {
|
||||||
|
config.TypicalP = input.TypicalP
|
||||||
|
}
|
||||||
|
|
||||||
|
switch inputs := input.Input.(type) {
|
||||||
|
case string:
|
||||||
|
if inputs != "" {
|
||||||
|
config.InputStrings = append(config.InputStrings, inputs)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, pp := range inputs {
|
||||||
|
switch i := pp.(type) {
|
||||||
|
case string:
|
||||||
|
config.InputStrings = append(config.InputStrings, i)
|
||||||
|
case []interface{}:
|
||||||
|
tokens := []int{}
|
||||||
|
for _, ii := range i {
|
||||||
|
tokens = append(tokens, int(ii.(float64)))
|
||||||
|
}
|
||||||
|
config.InputToken = append(config.InputToken, tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can be either a string or an object
|
||||||
|
switch fnc := input.FunctionCall.(type) {
|
||||||
|
case string:
|
||||||
|
if fnc != "" {
|
||||||
|
config.SetFunctionCallString(fnc)
|
||||||
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
|
var name string
|
||||||
|
n, exists := fnc["name"]
|
||||||
|
if exists {
|
||||||
|
nn, e := n.(string)
|
||||||
|
if !e {
|
||||||
|
name = nn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.SetFunctionCallNameString(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := input.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
config.PromptStrings = append(config.PromptStrings, p)
|
||||||
|
case []interface{}:
|
||||||
|
for _, pp := range p {
|
||||||
|
if s, ok := pp.(string); ok {
|
||||||
|
config.PromptStrings = append(config.PromptStrings, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) {
|
||||||
|
// Load a config file if present after the model name
|
||||||
|
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
||||||
|
|
||||||
|
var cfg *config.Config
|
||||||
|
|
||||||
|
defaults := func() {
|
||||||
|
cfg = config.DefaultConfig(modelFile)
|
||||||
|
cfg.ContextSize = ctx
|
||||||
|
cfg.Threads = threads
|
||||||
|
cfg.F16 = f16
|
||||||
|
cfg.Debug = debug
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgExisting, exists := cm.GetConfig(modelFile)
|
||||||
|
if !exists {
|
||||||
|
if _, err := os.Stat(modelConfig); err == nil {
|
||||||
|
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||||
|
}
|
||||||
|
cfgExisting, exists = cm.GetConfig(modelFile)
|
||||||
|
if exists {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
} else {
|
||||||
|
defaults()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
defaults()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the parameters for the language model prediction
|
||||||
|
updateConfig(cfg, input)
|
||||||
|
|
||||||
|
// Don't allow 0 as setting
|
||||||
|
if cfg.Threads == 0 {
|
||||||
|
if threads != 0 {
|
||||||
|
cfg.Threads = threads
|
||||||
|
} else {
|
||||||
|
cfg.Threads = 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce debug flag if passed from CLI
|
||||||
|
if debug {
|
||||||
|
cfg.Debug = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, input, nil
|
||||||
|
}
|
91
api/openai/transcription.go
Normal file
91
api/openai/transcription.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
|
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
m, input, err := readInput(c, o.Loader, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
// retrieve the file data from the request
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
|
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||||
|
dstFile, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(dstFile, f); err != nil {
|
||||||
|
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
|
whisperModel, err := o.Loader.BackendLoader(
|
||||||
|
model.WithBackendString(model.WhisperBackend),
|
||||||
|
model.WithModelFile(config.Model),
|
||||||
|
model.WithThreads(uint32(config.Threads)),
|
||||||
|
model.WithAssetDir(o.AssetsDestination))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if whisperModel == nil {
|
||||||
|
return fmt.Errorf("could not load whisper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
w, ok := whisperModel.(whisper.Model)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("loader returned non-whisper object")
|
||||||
|
}
|
||||||
|
|
||||||
|
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||||
|
// TODO: handle different outputs here
|
||||||
|
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package options
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -11,35 +11,35 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Option struct {
|
type Option struct {
|
||||||
context context.Context
|
Context context.Context
|
||||||
configFile string
|
ConfigFile string
|
||||||
loader *model.ModelLoader
|
Loader *model.ModelLoader
|
||||||
uploadLimitMB, threads, ctxSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
f16 bool
|
F16 bool
|
||||||
debug, disableMessage bool
|
Debug, DisableMessage bool
|
||||||
imageDir string
|
ImageDir string
|
||||||
audioDir string
|
AudioDir string
|
||||||
cors bool
|
CORS bool
|
||||||
preloadJSONModels string
|
PreloadJSONModels string
|
||||||
preloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
corsAllowOrigins string
|
CORSAllowOrigins string
|
||||||
|
|
||||||
galleries []gallery.Gallery
|
Galleries []gallery.Gallery
|
||||||
|
|
||||||
backendAssets embed.FS
|
BackendAssets embed.FS
|
||||||
assetsDestination string
|
AssetsDestination string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*Option)
|
||||||
|
|
||||||
func newOptions(o ...AppOption) *Option {
|
func NewOptions(o ...AppOption) *Option {
|
||||||
opt := &Option{
|
opt := &Option{
|
||||||
context: context.Background(),
|
Context: context.Background(),
|
||||||
uploadLimitMB: 15,
|
UploadLimitMB: 15,
|
||||||
threads: 1,
|
Threads: 1,
|
||||||
ctxSize: 512,
|
ContextSize: 512,
|
||||||
debug: true,
|
Debug: true,
|
||||||
disableMessage: true,
|
DisableMessage: true,
|
||||||
}
|
}
|
||||||
for _, oo := range o {
|
for _, oo := range o {
|
||||||
oo(opt)
|
oo(opt)
|
||||||
|
@ -49,25 +49,25 @@ func newOptions(o ...AppOption) *Option {
|
||||||
|
|
||||||
func WithCors(b bool) AppOption {
|
func WithCors(b bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.cors = b
|
o.CORS = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCorsAllowOrigins(b string) AppOption {
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.corsAllowOrigins = b
|
o.CORSAllowOrigins = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssetsOutput(out string) AppOption {
|
func WithBackendAssetsOutput(out string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.assetsDestination = out
|
o.AssetsDestination = out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssets(f embed.FS) AppOption {
|
func WithBackendAssets(f embed.FS) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.backendAssets = f
|
o.BackendAssets = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,89 +81,89 @@ func WithStringGalleries(galls string) AppOption {
|
||||||
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
|
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
|
||||||
log.Error().Msgf("failed loading galleries: %s", err.Error())
|
log.Error().Msgf("failed loading galleries: %s", err.Error())
|
||||||
}
|
}
|
||||||
o.galleries = append(o.galleries, galleries...)
|
o.Galleries = append(o.Galleries, galleries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.galleries = append(o.galleries, galleries...)
|
o.Galleries = append(o.Galleries, galleries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContext(ctx context.Context) AppOption {
|
func WithContext(ctx context.Context) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.context = ctx
|
o.Context = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.preloadModelsFromPath = configFile
|
o.PreloadModelsFromPath = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithJSONStringPreload(configFile string) AppOption {
|
func WithJSONStringPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.preloadJSONModels = configFile
|
o.PreloadJSONModels = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func WithConfigFile(configFile string) AppOption {
|
func WithConfigFile(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.configFile = configFile
|
o.ConfigFile = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.loader = loader
|
o.Loader = loader
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithUploadLimitMB(limit int) AppOption {
|
func WithUploadLimitMB(limit int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.uploadLimitMB = limit
|
o.UploadLimitMB = limit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithThreads(threads int) AppOption {
|
func WithThreads(threads int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.threads = threads
|
o.Threads = threads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContextSize(ctxSize int) AppOption {
|
func WithContextSize(ctxSize int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.ctxSize = ctxSize
|
o.ContextSize = ctxSize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithF16(f16 bool) AppOption {
|
func WithF16(f16 bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.f16 = f16
|
o.F16 = f16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDebug(debug bool) AppOption {
|
func WithDebug(debug bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.debug = debug
|
o.Debug = debug
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDisableMessage(disableMessage bool) AppOption {
|
func WithDisableMessage(disableMessage bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.disableMessage = disableMessage
|
o.DisableMessage = disableMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAudioDir(audioDir string) AppOption {
|
func WithAudioDir(audioDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.audioDir = audioDir
|
o.AudioDir = audioDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithImageDir(imageDir string) AppOption {
|
func WithImageDir(imageDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.imageDir = imageDir
|
o.ImageDir = imageDir
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,415 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/donomii/go-rwkv.cpp"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/langchain"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
|
||||||
"github.com/go-skynet/bloomz.cpp"
|
|
||||||
bert "github.com/go-skynet/go-bert.cpp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
var mutexMap sync.Mutex
|
|
||||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
|
||||||
|
|
||||||
func gRPCModelOpts(c Config) *pb.ModelOptions {
|
|
||||||
b := 512
|
|
||||||
if c.Batch != 0 {
|
|
||||||
b = c.Batch
|
|
||||||
}
|
|
||||||
return &pb.ModelOptions{
|
|
||||||
ContextSize: int32(c.ContextSize),
|
|
||||||
Seed: int32(c.Seed),
|
|
||||||
NBatch: int32(b),
|
|
||||||
F16Memory: c.F16,
|
|
||||||
MLock: c.MMlock,
|
|
||||||
NUMA: c.NUMA,
|
|
||||||
Embeddings: c.Embeddings,
|
|
||||||
LowVRAM: c.LowVRAM,
|
|
||||||
NGPULayers: int32(c.NGPULayers),
|
|
||||||
MMap: c.MMap,
|
|
||||||
MainGPU: c.MainGPU,
|
|
||||||
Threads: int32(c.Threads),
|
|
||||||
TensorSplit: c.TensorSplit,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions {
|
|
||||||
promptCachePath := ""
|
|
||||||
if c.PromptCachePath != "" {
|
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
|
||||||
os.MkdirAll(filepath.Dir(p), 0755)
|
|
||||||
promptCachePath = p
|
|
||||||
}
|
|
||||||
return &pb.PredictOptions{
|
|
||||||
Temperature: float32(c.Temperature),
|
|
||||||
TopP: float32(c.TopP),
|
|
||||||
TopK: int32(c.TopK),
|
|
||||||
Tokens: int32(c.Maxtokens),
|
|
||||||
Threads: int32(c.Threads),
|
|
||||||
PromptCacheAll: c.PromptCacheAll,
|
|
||||||
PromptCacheRO: c.PromptCacheRO,
|
|
||||||
PromptCachePath: promptCachePath,
|
|
||||||
F16KV: c.F16,
|
|
||||||
DebugMode: c.Debug,
|
|
||||||
Grammar: c.Grammar,
|
|
||||||
|
|
||||||
Mirostat: int32(c.Mirostat),
|
|
||||||
MirostatETA: float32(c.MirostatETA),
|
|
||||||
MirostatTAU: float32(c.MirostatTAU),
|
|
||||||
Debug: c.Debug,
|
|
||||||
StopPrompts: c.StopWords,
|
|
||||||
Repeat: int32(c.RepeatPenalty),
|
|
||||||
NKeep: int32(c.Keep),
|
|
||||||
Batch: int32(c.Batch),
|
|
||||||
IgnoreEOS: c.IgnoreEOS,
|
|
||||||
Seed: int32(c.Seed),
|
|
||||||
FrequencyPenalty: float32(c.FrequencyPenalty),
|
|
||||||
MLock: c.MMlock,
|
|
||||||
MMap: c.MMap,
|
|
||||||
MainGPU: c.MainGPU,
|
|
||||||
TensorSplit: c.TensorSplit,
|
|
||||||
TailFreeSamplingZ: float32(c.TFZ),
|
|
||||||
TypicalP: float32(c.TypicalP),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) {
|
|
||||||
if c.Backend != model.StableDiffusionBackend {
|
|
||||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
|
||||||
}
|
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader(
|
|
||||||
model.WithBackendString(c.Backend),
|
|
||||||
model.WithAssetDir(o.assetsDestination),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithModelFile(c.ImageGenerationAssets),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var fn func() error
|
|
||||||
switch model := inferenceModel.(type) {
|
|
||||||
case *stablediffusion.StableDiffusion:
|
|
||||||
fn = func() error {
|
|
||||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
fn = func() error {
|
|
||||||
return fmt.Errorf("creation of images not supported by the backend")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func() error {
|
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock()
|
|
||||||
l, ok := mutexes[c.Backend]
|
|
||||||
if !ok {
|
|
||||||
m := &sync.Mutex{}
|
|
||||||
mutexes[c.Backend] = m
|
|
||||||
l = m
|
|
||||||
}
|
|
||||||
mutexMap.Unlock()
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
|
|
||||||
return fn()
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) {
|
|
||||||
if !c.Embeddings {
|
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFile := c.Model
|
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
|
||||||
|
|
||||||
var inferenceModel interface{}
|
|
||||||
var err error
|
|
||||||
|
|
||||||
opts := []model.Option{
|
|
||||||
model.WithLoadGRPCOpts(grpcOpts),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithAssetDir(o.assetsDestination),
|
|
||||||
model.WithModelFile(modelFile),
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Backend == "" {
|
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
|
||||||
} else {
|
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var fn func() ([]float32, error)
|
|
||||||
switch model := inferenceModel.(type) {
|
|
||||||
case *grpc.Client:
|
|
||||||
fn = func() ([]float32, error) {
|
|
||||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
|
||||||
if len(tokens) > 0 {
|
|
||||||
embeds := []int32{}
|
|
||||||
|
|
||||||
for _, t := range tokens {
|
|
||||||
embeds = append(embeds, int32(t))
|
|
||||||
}
|
|
||||||
predictOptions.EmbeddingTokens = embeds
|
|
||||||
|
|
||||||
res, err := model.Embeddings(context.TODO(), predictOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.Embeddings, nil
|
|
||||||
}
|
|
||||||
predictOptions.Embeddings = s
|
|
||||||
|
|
||||||
res, err := model.Embeddings(context.TODO(), predictOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.Embeddings, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// bert embeddings
|
|
||||||
case *bert.Bert:
|
|
||||||
fn = func() ([]float32, error) {
|
|
||||||
if len(tokens) > 0 {
|
|
||||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
|
|
||||||
}
|
|
||||||
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
fn = func() ([]float32, error) {
|
|
||||||
return nil, fmt.Errorf("embeddings not supported by the backend")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func() ([]float32, error) {
|
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock()
|
|
||||||
l, ok := mutexes[modelFile]
|
|
||||||
if !ok {
|
|
||||||
m := &sync.Mutex{}
|
|
||||||
mutexes[modelFile] = m
|
|
||||||
l = m
|
|
||||||
}
|
|
||||||
mutexMap.Unlock()
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
|
|
||||||
embeds, err := fn()
|
|
||||||
if err != nil {
|
|
||||||
return embeds, err
|
|
||||||
}
|
|
||||||
// Remove trailing 0s
|
|
||||||
for i := len(embeds) - 1; i >= 0; i-- {
|
|
||||||
if embeds[i] == 0.0 {
|
|
||||||
embeds = embeds[:i]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return embeds, nil
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
|
||||||
supportStreams := false
|
|
||||||
modelFile := c.Model
|
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
|
||||||
|
|
||||||
var inferenceModel interface{}
|
|
||||||
var err error
|
|
||||||
|
|
||||||
opts := []model.Option{
|
|
||||||
model.WithLoadGRPCOpts(grpcOpts),
|
|
||||||
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
|
||||||
model.WithAssetDir(o.assetsDestination),
|
|
||||||
model.WithModelFile(modelFile),
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Backend == "" {
|
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
|
||||||
} else {
|
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var fn func() (string, error)
|
|
||||||
|
|
||||||
switch model := inferenceModel.(type) {
|
|
||||||
case *rwkv.RwkvState:
|
|
||||||
supportStreams = true
|
|
||||||
|
|
||||||
fn = func() (string, error) {
|
|
||||||
stopWord := "\n"
|
|
||||||
if len(c.StopWords) > 0 {
|
|
||||||
stopWord = c.StopWords[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := model.ProcessInput(s); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
case *bloomz.Bloomz:
|
|
||||||
fn = func() (string, error) {
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []bloomz.PredictOption{
|
|
||||||
bloomz.SetTemperature(c.Temperature),
|
|
||||||
bloomz.SetTopP(c.TopP),
|
|
||||||
bloomz.SetTopK(c.TopK),
|
|
||||||
bloomz.SetTokens(c.Maxtokens),
|
|
||||||
bloomz.SetThreads(c.Threads),
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Seed != 0 {
|
|
||||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
|
||||||
}
|
|
||||||
|
|
||||||
return model.Predict(
|
|
||||||
s,
|
|
||||||
predictOptions...,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
case *grpc.Client:
|
|
||||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
|
||||||
supportStreams = true
|
|
||||||
fn = func() (string, error) {
|
|
||||||
|
|
||||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
|
||||||
opts.Prompt = s
|
|
||||||
if tokenCallback != nil {
|
|
||||||
ss := ""
|
|
||||||
err := model.PredictStream(context.TODO(), opts, func(s string) {
|
|
||||||
tokenCallback(s)
|
|
||||||
ss += s
|
|
||||||
})
|
|
||||||
return ss, err
|
|
||||||
} else {
|
|
||||||
reply, err := model.Predict(context.TODO(), opts)
|
|
||||||
return reply.Message, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case *langchain.HuggingFace:
|
|
||||||
fn = func() (string, error) {
|
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []langchain.PredictOption{
|
|
||||||
langchain.SetModel(c.Model),
|
|
||||||
langchain.SetMaxTokens(c.Maxtokens),
|
|
||||||
langchain.SetTemperature(c.Temperature),
|
|
||||||
langchain.SetStopWords(c.StopWords),
|
|
||||||
}
|
|
||||||
|
|
||||||
pred, er := model.PredictHuggingFace(s, predictOptions...)
|
|
||||||
if er != nil {
|
|
||||||
return "", er
|
|
||||||
}
|
|
||||||
return pred.Completion, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func() (string, error) {
|
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock()
|
|
||||||
l, ok := mutexes[modelFile]
|
|
||||||
if !ok {
|
|
||||||
m := &sync.Mutex{}
|
|
||||||
mutexes[modelFile] = m
|
|
||||||
l = m
|
|
||||||
}
|
|
||||||
mutexMap.Unlock()
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
|
|
||||||
res, err := fn()
|
|
||||||
if tokenCallback != nil && !supportStreams {
|
|
||||||
tokenCallback(res)
|
|
||||||
}
|
|
||||||
return res, err
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
|
||||||
result := []Choice{}
|
|
||||||
|
|
||||||
n := input.N
|
|
||||||
|
|
||||||
if input.N == 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the model function to call for the result
|
|
||||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback)
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
prediction, err := predFunc()
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction)
|
|
||||||
cb(prediction, &result)
|
|
||||||
|
|
||||||
//result = append(result, Choice{Text: prediction})
|
|
||||||
|
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
|
||||||
var mu sync.Mutex = sync.Mutex{}
|
|
||||||
|
|
||||||
func Finetune(config Config, input, prediction string) string {
|
|
||||||
if config.Echo {
|
|
||||||
prediction = input + prediction
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range config.Cutstrings {
|
|
||||||
mu.Lock()
|
|
||||||
reg, ok := cutstrings[c]
|
|
||||||
if !ok {
|
|
||||||
cutstrings[c] = regexp.MustCompile(c)
|
|
||||||
reg = cutstrings[c]
|
|
||||||
}
|
|
||||||
mu.Unlock()
|
|
||||||
prediction = reg.ReplaceAllString(prediction, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range config.TrimSpace {
|
|
||||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
|
||||||
}
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
}
|
|
35
main.go
35
main.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
api "github.com/go-skynet/LocalAI/api"
|
api "github.com/go-skynet/LocalAI/api"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -129,23 +130,23 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||||
Copyright: "Ettore Di Giacinto",
|
Copyright: "Ettore Di Giacinto",
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
app, err := api.App(
|
app, err := api.App(
|
||||||
api.WithConfigFile(ctx.String("config-file")),
|
options.WithConfigFile(ctx.String("config-file")),
|
||||||
api.WithJSONStringPreload(ctx.String("preload-models")),
|
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||||
api.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||||
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
||||||
api.WithContextSize(ctx.Int("context-size")),
|
options.WithContextSize(ctx.Int("context-size")),
|
||||||
api.WithDebug(ctx.Bool("debug")),
|
options.WithDebug(ctx.Bool("debug")),
|
||||||
api.WithImageDir(ctx.String("image-path")),
|
options.WithImageDir(ctx.String("image-path")),
|
||||||
api.WithAudioDir(ctx.String("audio-path")),
|
options.WithAudioDir(ctx.String("audio-path")),
|
||||||
api.WithF16(ctx.Bool("f16")),
|
options.WithF16(ctx.Bool("f16")),
|
||||||
api.WithStringGalleries(ctx.String("galleries")),
|
options.WithStringGalleries(ctx.String("galleries")),
|
||||||
api.WithDisableMessage(false),
|
options.WithDisableMessage(false),
|
||||||
api.WithCors(ctx.Bool("cors")),
|
options.WithCors(ctx.Bool("cors")),
|
||||||
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||||
api.WithThreads(ctx.Int("threads")),
|
options.WithThreads(ctx.Int("threads")),
|
||||||
api.WithBackendAssets(backendAssets),
|
options.WithBackendAssets(backendAssets),
|
||||||
api.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||||
api.WithUploadLimitMB(ctx.Int("upload-limit")))
|
options.WithUploadLimitMB(ctx.Int("upload-limit")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,6 +126,9 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||||
predictOptions := buildPredictOptions(opts)
|
predictOptions := buildPredictOptions(opts)
|
||||||
|
|
||||||
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
|
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
|
||||||
|
if token == "<|endoftext|>" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
results <- token
|
results <- token
|
||||||
return true
|
return true
|
||||||
}))
|
}))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue