feat: Centralized Request Processing middleware (#3847)

* squash past, centralize request middleware PR

Signed-off-by: Dave Lee <dave@gray101.com>

* migrate bruno request files to examples repo

Signed-off-by: Dave Lee <dave@gray101.com>

* fix

Signed-off-by: Dave Lee <dave@gray101.com>

* Update tests/e2e-aio/e2e_test.go

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Dave 2025-02-10 06:06:16 -05:00 committed by GitHub
parent c330360785
commit 3cddf24747
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 240975 additions and 821 deletions

8
aio/cpu/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -129,7 +129,7 @@ detect_gpu
detect_gpu_size detect_gpu_size
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}" export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
check_vars check_vars

8
aio/gpu-8g/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

8
aio/intel/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -145,13 +145,7 @@ func New(opts ...config.AppOption) (*Application, error) {
if options.LoadToMemory != nil { if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory { for _, m := range options.LoadToMemory {
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath, cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
config.LoadOptionF16(options.F16),
config.ModelPath(options.ModelPath),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -33,7 +33,7 @@ type TokenUsage struct {
TimingTokenGeneration float64 TimingTokenGeneration float64
} }
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery // Check if the modelFile exists, if it doesn't try to load it from the gallery
@ -48,7 +48,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
} }
} }
opts := ModelOptions(c, o) opts := ModelOptions(*c, o)
inferenceModel, err := loader.Load(opts...) inferenceModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -84,7 +84,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) { fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath) opts := gRPCPredictOpts(*c, loader.ModelPath)
opts.Prompt = s opts.Prompt = s
opts.Messages = protoMessages opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate

View file

@ -9,10 +9,10 @@ import (
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
) )
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.Load(opts...) rerankModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,7 +13,6 @@ import (
) )
func SoundGeneration( func SoundGeneration(
modelFile string,
text string, text string,
duration *float32, duration *float32,
temperature *float32, temperature *float32,
@ -25,8 +24,9 @@ func SoundGeneration(
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) opts := ModelOptions(backendConfig, appConfig)
soundGenModel, err := loader.Load(opts...) soundGenModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -44,7 +44,7 @@ func SoundGeneration(
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text, Text: text,
Model: modelFile, Model: backendConfig.Model,
Dst: filePath, Dst: filePath,
Sample: doSample, Sample: doSample,
Duration: duration, Duration: duration,

View file

@ -4,19 +4,17 @@ import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc"
model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
var inferenceModel grpc.Backend var inferenceModel grpc.Backend
var err error var err error
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err = loader.Load(opts...) inferenceModel, err = loader.Load(opts...)
if err != nil { if err != nil {
return schema.TokenizeResponse{}, err return schema.TokenizeResponse{}, err
} }

View file

@ -47,7 +47,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
tks = append(tks, int(t)) tks = append(tks, int(t))
} }
tr.Segments = append(tr.Segments, tr.Segments = append(tr.Segments,
schema.Segment{ schema.TranscriptionSegment{
Text: s.Text, Text: s.Text,
Id: int(s.Id), Id: int(s.Id),
Start: time.Duration(s.Start), Start: time.Duration(s.Start),

View file

@ -14,28 +14,22 @@ import (
) )
func ModelTTS( func ModelTTS(
backend,
text, text,
modelFile,
voice, voice,
language string, language string,
loader *model.ModelLoader, loader *model.ModelLoader,
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
bb := backend opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
if bb == "" {
bb = model.PiperBackend
}
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
ttsModel, err := loader.Load(opts...) ttsModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if ttsModel == nil { if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model") return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
} }
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@ -45,22 +39,21 @@ func ModelTTS(
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName) filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
modelPath := "" modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath // Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like // TODO: we should actually first check if the modelFile is looking like
// a FS path // a FS path
mp := filepath.Join(loader.ModelPath, modelFile) mp := filepath.Join(loader.ModelPath, backendConfig.Model)
if _, err := os.Stat(mp); err == nil { if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil { if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err return "", nil, err
} }
modelPath = mp modelPath = mp
} else { } else {
modelPath = modelFile modelPath = backendConfig.Model // skip this step if it fails?????
}
} }
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

38
core/backend/vad.go Normal file
View file

@ -0,0 +1,38 @@
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}

View file

@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{} options := config.BackendConfig{}
options.SetDefaults() options.SetDefaults()
options.Backend = t.Backend options.Backend = t.Backend
options.Model = t.Model
var inputFile *string var inputFile *string
if t.InputFile != "" { if t.InputFile != "" {
inputFile = &t.InputFile inputFile = &t.InputFile
} }
filePath, _, err := backend.SoundGeneration(t.Model, text, filePath, _, err := backend.SoundGeneration(text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View file

@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{} options := config.BackendConfig{}
options.SetDefaults() options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options) filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
if err != nil { if err != nil {
return err return err
} }

View file

@ -437,19 +437,21 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int type BackendConfigUsecases int
const ( const (
FLAG_ANY BackendConfigUsecases = 0b000000000 FLAG_ANY BackendConfigUsecases = 0b00000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001 FLAG_CHAT BackendConfigUsecases = 0b00000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010 FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100 FLAG_EDIT BackendConfigUsecases = 0b00000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000 FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000 FLAG_RERANK BackendConfigUsecases = 0b00000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000 FLAG_IMAGE BackendConfigUsecases = 0b00000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000 FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
FLAG_TTS BackendConfigUsecases = 0b010000000 FLAG_TTS BackendConfigUsecases = 0b00010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000 FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000
// Common Subsets // Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
) )
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@ -464,6 +466,8 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT, "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS, "FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION, "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM, "FLAG_LLM": FLAG_LLM,
} }
} }
@ -549,5 +553,18 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
} }
} }
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" {
return false
}
}
return true return true
} }

View file

@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
c := &[]*BackendConfig{} c := &[]*BackendConfig{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
} }
if err := yaml.Unmarshal(f, c); err != nil { if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
} }
for _, cc := range *c { for _, cc := range *c {
@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
c := &BackendConfig{} c := &BackendConfig{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
} }
if err := yaml.Unmarshal(f, c); err != nil { if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
} }
c.SetDefaults(opts...) c.SetDefaults(opts...)
@ -117,8 +117,10 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
// Load a config file if present after the model name // Load a config file if present after the model name
cfg := &BackendConfig{ cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{ PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName, Model: modelName,
}, },
},
} }
cfgExisting, exists := bcl.GetBackendConfig(modelName) cfgExisting, exists := bcl.GetBackendConfig(modelName)
@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil return cfg, nil
} }
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile // This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock() bcl.Lock()
@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
defer bcl.Unlock() defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...) c, err := readBackendConfigFromFile(file, opts...)
if err != nil { if err != nil {
return fmt.Errorf("cannot read config file: %w", err) return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
} }
if c.Validate() { if c.Validate() {
@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock() bcl.Lock()
defer bcl.Unlock() defer bcl.Unlock()
entries, err := os.ReadDir(path) entries, err := os.ReadDir(path)
if err != nil { if err != nil {
return fmt.Errorf("cannot read directory '%s': %w", path, err) return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
} }
files := make([]fs.FileInfo, 0, len(entries)) files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
} }
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...) c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name()) log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
continue continue
} }
if c.Validate() { if c.Validate() {
bcl.configs[c.Name] = *c bcl.configs[c.Name] = *c
} else { } else {
log.Error().Err(err).Msgf("config is not valid") log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
} }
} }

View file

@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
} }
// We try to guess only if we don't have a template defined already // We try to guess only if we don't have a template defined already
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName())) guessPath := filepath.Join(modelPath, cfg.ModelFileName())
f, err := gguf.ParseGGUFFile(guessPath)
if err != nil { if err != nil {
// Only valid for gguf files // Only valid for gguf files
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file") log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
return return
} }

View file

@ -130,7 +130,6 @@ func API(application *application.Application) (*fiber.App, error) {
return metricsService.Shutdown() return metricsService.Shutdown()
}) })
} }
} }
// Health Checks should always be exempt from auth, so register these first // Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router) routes.HealthRoutes(router)
@ -167,13 +166,15 @@ func API(application *application.Application) (*fiber.App, error) {
galleryService := services.NewGalleryService(application.ApplicationConfig()) galleryService := services.NewGalleryService(application.ApplicationConfig())
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader()) galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) requestExtractor := middleware.NewRequestExtractor(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, application) routes.RegisterElevenLabsRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI { if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
} }
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) routes.RegisterJINARoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
httpFS := http.FS(embedDirStatic) httpFS := http.FS(embedDirStatic)

View file

@ -1,47 +0,0 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
if ctx.Query("model") != "" {
modelInput = ctx.Query("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View file

@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -17,45 +17,21 @@ import (
// @Router /v1/sound-generation [post] // @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsSoundGenerationRequest)
// Get input data from the request body input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if err := c.BodyParser(input); err != nil { if !ok || input.ModelID == "" {
return err return fiber.ErrBadRequest
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.ModelID return fiber.ErrBadRequest
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
if input.Duration != nil {
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
}
if input.Temperature != nil {
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
}
// TODO: Support uploading files? // TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,7 +3,7 @@ package elevenlabs
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -20,39 +20,21 @@ import (
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id") voiceID := c.Params("voice-id")
// Get input data from the request body input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if err := c.BodyParser(input); err != nil { if !ok || input.ModelID == "" {
return err return fiber.ErrBadRequest
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.ModelID return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg) filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,9 +3,9 @@ package jina
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
@ -19,58 +19,32 @@ import (
// @Router /v1/rerank [post] // @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ if !ok || input.Model == "" {
"error": "Cannot parse JSON", return fiber.ErrBadRequest
})
} }
input := new(schema.TTSRequest) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
request := &proto.RerankRequest{ request := &proto.RerankRequest{
Query: req.Query, Query: input.Query,
TopN: int32(req.TopN), TopN: int32(input.TopN),
Documents: req.Documents, Documents: input.Documents,
} }
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg) results, err := backend.Rerank(request, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }
response := &schema.JINARerankResponse{ response := &schema.JINARerankResponse{
Model: req.Model, Model: input.Model,
} }
for _, r := range results.Results { for _, r := range results.Results {

View file

@ -4,13 +4,15 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID // TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
// //
// @Summary Get TokenMetrics for Active Slot. // @Summary Get TokenMetrics for Active Slot.
@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
return err return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if err != nil { if !ok || modelFile != "" {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil { if err != nil {
log.Err(err) log.Err(err)

View file

@ -4,10 +4,9 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
) )
// TokenizeEndpoint exposes a REST API to tokenize the content // TokenizeEndpoint exposes a REST API to tokenize the content
@ -16,42 +15,21 @@ import (
// @Success 200 {object} schema.TokenizeResponse "Response" // @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post] // @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
input := new(schema.TokenizeRequest) if !ok || input.Model == "" {
return fiber.ErrBadRequest
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil { if err != nil {
return err return err
} }
return ctx.JSON(tokenResponse)
return c.JSON(tokenResponse)
} }
} }

View file

@ -3,7 +3,7 @@ package localai
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -24,37 +24,24 @@ import (
// @Router /tts [post] // @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if cfg.Backend == "" {
if input.Backend != "" { if input.Backend != "" {
cfg.Backend = input.Backend cfg.Backend = input.Backend
} else {
cfg.Backend = model.PiperBackend
}
} }
if input.Language != "" { if input.Language != "" {
@ -65,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Voice = input.Voice cfg.Voice = input.Voice
} }
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg) filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,9 +4,8 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -19,45 +18,20 @@ import (
// @Router /vad [post] // @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.VADRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil { resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil { if err != nil {
return err return err
} }

View file

@ -5,18 +5,19 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -174,26 +175,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
textContentToReturn = "" textContentToReturn = ""
id = uuid.New().String() id = uuid.New().String()
created = int(time.Now().Unix()) created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
funcs := input.Functions funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
@ -543,7 +538,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
audios = append(audios, m.StringAudios...) audios = append(audios, m.StringAudios...)
} }
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil) predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
if err != nil { if err != nil {
log.Error().Err(err).Msg("model inference failed") log.Error().Err(err).Msg("model inference failed")
return "", err return "", err

View file

@ -10,12 +10,13 @@ import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -26,11 +27,11 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post] // @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix()) created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{ usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt, PromptTokens: tokenUsage.Prompt,
@ -63,22 +64,18 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
} }
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
// Add Correlation // Handle Correlation
c.Set("X-Correlation-ID", id) id := c.Get("X-Correlation-ID", uuid.New().String())
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
log.Debug().Msgf("`input`: %+v", input) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) return fiber.ErrBadRequest
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
} }
if config.ResponseFormatMap != nil { if config.ResponseFormatMap != nil {
@ -122,7 +119,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse) responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses, extraUsage) go process(id, predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

View file

@ -2,16 +2,17 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -25,20 +26,21 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
// Opt-in extra usage flag // Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) log.Debug().Msgf("Edit Endpoint Input : %+v", input)
if err != nil { log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
var result []schema.Choice var result []schema.Choice
totalTokenUsage := backend.TokenUsage{} totalTokenUsage := backend.TokenUsage{}

View file

@ -2,11 +2,11 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/google/uuid" "github.com/google/uuid"
@ -23,14 +23,14 @@ import (
// @Router /v1/embeddings [post] // @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, cl, ml, appConfig, true) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)

View file

@ -15,6 +15,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post] // @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
} }
if m == "" { config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
m = "stablediffusion" if !ok || config == nil {
} log.Error().Msg("Image Endpoint - Invalid Config")
log.Debug().Msgf("Loading model: %+v", m) return fiber.ErrBadRequest
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
} }
src := "" src := ""
if input.File != "" { if input.File != "" {
fileData := []byte{} fileData := []byte{}
var err error
// check if input.File is an URL, if so download it and save it // check if input.File is an URL, if so download it and save it
// to a temporary file // to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {

View file

@ -37,7 +37,7 @@ func ComputeChoices(
} }
// get the model function to call for the result // get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback) predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
if err != nil { if err != nil {
return result, backend.TokenUsage{}, err return result, backend.TokenUsage{}, err
} }

View file

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -10,6 +9,8 @@ import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -25,15 +26,16 @@ import (
// @Router /v1/audio/transcriptions [post] // @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request: %w", err) return fiber.ErrBadRequest
} }
// retrieve the file data from the request // retrieve the file data from the request
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {

View file

@ -1,20 +1,22 @@
package openai package middleware
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -23,33 +25,166 @@ type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary // CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID" const CorrelationIDKey correlationIDKeyType = "correlationID"
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { type RequestExtractor struct {
input := new(schema.OpenAIRequest) backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
// Get input data from the request body func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
if err := c.BodyParser(input); err != nil { return &RequestExtractor{
return "", nil, fmt.Errorf("failed parsing request body: %w", err) backendConfigLoader: backendConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
} }
received, _ := json.Marshal(input) if model == "" {
// Extract or generate the correlation ID // Set model from bearer token, if available
correlationID := c.Get("X-Correlation-ID", uuid.New().String()) bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx, cancel := context.WithCancel(o.Context) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Add the correlation ID to the new context // Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID input.Context = ctxWithCorrelationID
input.Cancel = cancel input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received)) err := mergeOpenAIRequestAndBackendConfig(cfg, input)
if err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
return modelFile, input, err ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
} }
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
if input.Echo { if input.Echo {
config.Echo = input.Echo config.Echo = input.Echo
} }
@ -249,6 +384,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
config.TypicalP = input.TypicalP config.TypicalP = input.TypicalP
} }
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) { switch inputs := input.Input.(type) {
case string: case string:
if inputs != "" { if inputs != "" {
@ -305,22 +442,9 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
config.Step = q config.Step = q
} }
} }
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { if config.Validate() {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, return nil
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
if !cfg.Validate() {
return nil, nil, fmt.Errorf("failed to validate config")
} }
return fmt.Errorf("unable to validate configuration after merging")
return cfg, input, err
} }

View file

@ -4,17 +4,26 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs" "github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterElevenLabsRoutes(app *fiber.App, func RegisterElevenLabsRoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig) { appConfig *config.ApplicationConfig) {
// Elevenlabs // Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig)) app.Post("/v1/text-to-speech/:voice-id",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) app.Post("/v1/sound-generation",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
} }

View file

@ -3,16 +3,22 @@ package routes
import ( import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina" "github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterJINARoutes(app *fiber.App, func RegisterJINARoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig) { appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking // POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig)) app.Post("/v1/rerank",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
} }

View file

@ -5,13 +5,16 @@ import (
"github.com/gofiber/swagger" "github.com/gofiber/swagger"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterLocalAIRoutes(router *fiber.App, func RegisterLocalAIRoutes(router *fiber.App,
requestExtractor *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
@ -33,8 +36,18 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
} }
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) router.Post("/tts",
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
vadChain := []fiber.Handler{
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
// Stores // Stores
sl := model.NewModelLoader("") sl := model.NewModelLoader("")
@ -47,10 +60,14 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/metrics", localai.LocalAIMetricsEndpoint()) router.Get("/metrics", localai.LocalAIMetricsEndpoint())
} }
// Experimental Backend Statistics Module // Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p // p2p
if p2p.IsP2PEnabled() { if p2p.IsP2PEnabled() {
@ -67,6 +84,9 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/system", localai.SystemInformations(ml, appConfig)) router.Get("/system", localai.SystemInformations(ml, appConfig))
// misc // misc
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) router.Post("/v1/tokenize",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
} }

View file

@ -3,51 +3,50 @@ package routes
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
) )
func RegisterOpenAIRoutes(app *fiber.App, func RegisterOpenAIRoutes(app *fiber.App,
re *middleware.RequestExtractor,
application *application.Application) { application *application.Application) {
// openAI compatible API endpoint // openAI compatible API endpoint
// chat // chat
app.Post("/v1/chat/completions", chatChain := []fiber.Handler{
openai.ChatEndpoint( re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
application.BackendLoader(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.ModelLoader(), re.SetOpenAIRequest,
application.TemplatesEvaluator(), openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
application.ApplicationConfig(), }
), app.Post("/v1/chat/completions", chatChain...)
) app.Post("/chat/completions", chatChain...)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// edit // edit
app.Post("/v1/edits", editChain := []fiber.Handler{
openai.EditEndpoint( re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
application.BackendLoader(), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
application.ModelLoader(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.TemplatesEvaluator(), re.SetOpenAIRequest,
application.ApplicationConfig(), openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
), }
) app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.Post("/edits", // completion
openai.EditEndpoint( completionChain := []fiber.Handler{
application.BackendLoader(), re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
application.ModelLoader(), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
application.TemplatesEvaluator(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.ApplicationConfig(), re.SetOpenAIRequest,
), openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
) }
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
// assistant // assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
@ -81,45 +80,37 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// completion
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings // embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) embeddingChain := []fiber.Handler{
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
// audio // audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Post("/v1/audio/transcriptions",
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
)
app.Post("/v1/audio/speech",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images // images
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Post("/v1/images/generations",
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if application.ApplicationConfig().ImageDir != "" { if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir) app.Static("/generated-images", application.ApplicationConfig().ImageDir)

View file

@ -3,6 +3,7 @@ package schema
type ElevenLabsTTSRequest struct { type ElevenLabsTTSRequest struct {
Text string `json:"text" yaml:"text"` Text string `json:"text" yaml:"text"`
ModelID string `json:"model_id" yaml:"model_id"` ModelID string `json:"model_id" yaml:"model_id"`
LanguageCode string `json:"language_code" yaml:"language_code"`
} }
type ElevenLabsSoundGenerationRequest struct { type ElevenLabsSoundGenerationRequest struct {
@ -12,3 +13,17 @@ type ElevenLabsSoundGenerationRequest struct {
Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"` Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"`
DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"` DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"`
} }
func (elttsr *ElevenLabsTTSRequest) ModelName(s *string) string {
if s != nil {
elttsr.ModelID = *s
}
return elttsr.ModelID
}
func (elsgr *ElevenLabsSoundGenerationRequest) ModelName(s *string) string {
if s != nil {
elsgr.ModelID = *s
}
return elsgr.ModelID
}

View file

@ -2,10 +2,11 @@ package schema
// RerankRequest defines the structure of the request payload // RerankRequest defines the structure of the request payload
type JINARerankRequest struct { type JINARerankRequest struct {
Model string `json:"model"` BasicModelRequest
Query string `json:"query"` Query string `json:"query"`
Documents []string `json:"documents"` Documents []string `json:"documents"`
TopN int `json:"top_n"` TopN int `json:"top_n"`
Backend string `json:"backend"`
} }
// DocumentResult represents a single document result // DocumentResult represents a single document result

View file

@ -6,11 +6,11 @@ import (
) )
type BackendMonitorRequest struct { type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"` BasicModelRequest
} }
type TokenMetricsRequest struct { type TokenMetricsRequest struct {
Model string `json:"model" yaml:"model"` BasicModelRequest
} }
type BackendMonitorResponse struct { type BackendMonitorResponse struct {
@ -26,7 +26,7 @@ type GalleryResponse struct {
// @Description TTS request body // @Description TTS request body
type TTSRequest struct { type TTSRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path BasicModelRequest
Input string `json:"input" yaml:"input"` // text input Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"` Backend string `json:"backend" yaml:"backend"`
@ -36,10 +36,19 @@ type TTSRequest struct {
// @Description VAD request body // @Description VAD request body
type VADRequest struct { type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path BasicModelRequest
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
} }
type VADSegment struct {
Start float32 `json:"start" yaml:"start"`
End float32 `json:"end" yaml:"end"`
}
type VADResponse struct {
Segments []VADSegment `json:"segments" yaml:"segments"`
}
type StoresSet struct { type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"` Store string `json:"store,omitempty" yaml:"store,omitempty"`

View file

@ -3,7 +3,7 @@ package schema
type PredictionOptions struct { type PredictionOptions struct {
// Also part of the OpenAI official spec // Also part of the OpenAI official spec
Model string `json:"model" yaml:"model"` BasicModelRequest `yaml:",inline"`
// Also part of the OpenAI official spec // Also part of the OpenAI official spec
Language string `json:"language"` Language string `json:"language"`

22
core/schema/request.go Normal file
View file

@ -0,0 +1,22 @@
package schema
// This file and type represent a generic request to LocalAI - as opposed to requests to LocalAI-specific endpoints, which live in localai.go
type LocalAIRequest interface {
ModelName(*string) string
}
type BasicModelRequest struct {
Model string `json:"model" yaml:"model"`
// TODO: Should this also include the following fields from the OpenAI side of the world?
// If so, changes should be made to core/http/middleware/request.go to match
// Context context.Context `json:"-"`
// Cancel context.CancelFunc `json:"-"`
}
func (bmr *BasicModelRequest) ModelName(s *string) string {
if s != nil {
bmr.Model = *s
}
return bmr.Model
}

View file

@ -1,8 +1,8 @@
package schema package schema
type TokenizeRequest struct { type TokenizeRequest struct {
BasicModelRequest
Content string `json:"content"` Content string `json:"content"`
Model string `json:"model"`
} }
type TokenizeResponse struct { type TokenizeResponse struct {

View file

@ -2,7 +2,7 @@ package schema
import "time" import "time"
type Segment struct { type TranscriptionSegment struct {
Id int `json:"id"` Id int `json:"id"`
Start time.Duration `json:"start"` Start time.Duration `json:"start"`
End time.Duration `json:"end"` End time.Duration `json:"end"`
@ -11,6 +11,6 @@ type Segment struct {
} }
type TranscriptionResult struct { type TranscriptionResult struct {
Segments []Segment `json:"segments"` Segments []TranscriptionSegment `json:"segments"`
Text string `json:"text"` Text string `json:"text"`
} }

View file

@ -49,3 +49,15 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
return dataModels, nil return dataModels, nil
} }
func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
filter, err := config.BuildNameFilterFn(modelName)
if err != nil {
return false, err
}
models, err := ListModels(bcl, ml, filter, looseFilePolicy)
if err != nil {
return false, err
}
return (len(models) > 0), nil
}

2
go.mod
View file

@ -239,7 +239,7 @@ require (
github.com/moby/sys/sequential v0.5.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc
github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect
github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect

2
go.sum
View file

@ -524,6 +524,8 @@ github.com/mudler/edgevpn v0.29.0 h1:SEkVyjXL6P8szUZFlL8W1EYBxvFsEIFvXlXcRfGrXYU
github.com/mudler/edgevpn v0.29.0/go.mod h1:+kSy9b44eo97PnJ3fOnTkcTgxNXdgJBcd2bopx4leto= github.com/mudler/edgevpn v0.29.0/go.mod h1:+kSy9b44eo97PnJ3fOnTkcTgxNXdgJBcd2bopx4leto=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU=
github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo= github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo=

View file

@ -458,7 +458,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) { func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) log.Info().Str("modelID", o.modelID).Str("backend", o.backendString).Str("o.model", o.model).Msg("BackendLoader starting")
backend := strings.ToLower(o.backendString) backend := strings.ToLower(o.backendString)
if realBackend, exists := Aliases[backend]; exists { if realBackend, exists := Aliases[backend]; exists {

View file

@ -56,6 +56,14 @@ func WithBackendString(backend string) Option {
} }
} }
func WithDefaultBackendString(backend string) Option {
return func(o *Options) {
if o.backendString == "" {
o.backendString = backend
}
}
}
func WithModel(modelFile string) Option { func WithModel(modelFile string) Option {
return func(o *Options) { return func(o *Options) {
o.model = modelFile o.model = modelFile

View file

@ -43,7 +43,7 @@ var _ = BeforeSuite(func() {
apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely. apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
defaultConfig.BaseURL = apiEndpoint defaultConfig.BaseURL = apiEndpoint
} else { } else {
fmt.Println("Default ", apiEndpoint) GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
defaultConfig = openai.DefaultConfig(apiKey) defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = apiEndpoint defaultConfig.BaseURL = apiEndpoint
} }
@ -95,10 +95,11 @@ func startDockerImage() {
PortBindings: map[docker.Port][]docker.PortBinding{ PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: apiPort}}, "8080/tcp": []docker.PortBinding{{HostPort: apiPort}},
}, },
Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc)}, Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc), "LOCALAI_SINGLE_ACTIVE_BACKEND=true"},
Mounts: []string{md + ":/models"}, Mounts: []string{md + ":/models"},
} }
GinkgoWriter.Printf("Launching Docker Container %q\n%+v\n", containerImageTag, options)
r, err := pool.RunWithOptions(options) r, err := pool.RunWithOptions(options)
Expect(err).To(Not(HaveOccurred())) Expect(err).To(Not(HaveOccurred()))

View file

@ -121,14 +121,13 @@ var _ = Describe("E2E test", func() {
Context("images", func() { Context("images", func() {
It("correctly", func() { It("correctly", func() {
resp, err := client.CreateImage(context.TODO(), req := openai.ImageRequest{
openai.ImageRequest{
Prompt: "test", Prompt: "test",
Quality: "1", Quality: "1",
Size: openai.CreateImageSize256x256, Size: openai.CreateImageSize256x256,
}, }
) resp, err := client.CreateImage(context.TODO(), req)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req))
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
}) })
@ -232,13 +231,42 @@ var _ = Describe("E2E test", func() {
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text)) Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
}) })
}) })
Context("vad", func() {
It("correctly", func() {
modelName := "silero-vad"
req := schema.VADRequest{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Audio: SampleVADAudio, // Use hardcoded sample data for now.
}
serialized, err := json.Marshal(req)
Expect(err).To(BeNil())
Expect(serialized).ToNot(BeNil())
vadEndpoint := apiEndpoint + "/vad"
resp, err := http.Post(vadEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil())
body, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
deserializedResponse := schema.VADResponse{}
err = json.Unmarshal(body, &deserializedResponse)
Expect(err).To(BeNil())
Expect(deserializedResponse).ToNot(BeZero())
Expect(deserializedResponse.Segments).ToNot(BeZero())
})
})
Context("reranker", func() { Context("reranker", func() {
It("correctly", func() { It("correctly", func() {
modelName := "jina-reranker-v1-base-en" modelName := "jina-reranker-v1-base-en"
req := schema.JINARerankRequest{ req := schema.JINARerankRequest{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName, Model: modelName,
},
Query: "Organic skincare products for sensitive skin", Query: "Organic skincare products for sensitive skin",
Documents: []string{ Documents: []string{
"Eco-friendly kitchenware for modern homes", "Eco-friendly kitchenware for modern homes",

File diff suppressed because it is too large Load diff