squash past, centralize request middleware PR

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave Lee 2025-02-05 14:14:10 -05:00
parent 28a1310890
commit c1f30ba3a9
No known key found for this signature in database
55 changed files with 481027 additions and 821 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,22 @@
meta {
name: vad test too few
type: http
seq: 1
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/vad
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "silero-vad",
"audio": []
}
}

8
aio/cpu/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -129,7 +129,7 @@ detect_gpu
detect_gpu_size detect_gpu_size
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}" export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
check_vars check_vars

8
aio/gpu-8g/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

8
aio/intel/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -145,13 +145,7 @@ func New(opts ...config.AppOption) (*Application, error) {
if options.LoadToMemory != nil { if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory { for _, m := range options.LoadToMemory {
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath, cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
config.LoadOptionF16(options.F16),
config.ModelPath(options.ModelPath),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -33,7 +33,7 @@ type TokenUsage struct {
TimingTokenGeneration float64 TimingTokenGeneration float64
} }
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery // Check if the modelFile exists, if it doesn't try to load it from the gallery
@ -48,7 +48,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
} }
} }
opts := ModelOptions(c, o) opts := ModelOptions(*c, o)
inferenceModel, err := loader.Load(opts...) inferenceModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -84,7 +84,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) { fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath) opts := gRPCPredictOpts(*c, loader.ModelPath)
opts.Prompt = s opts.Prompt = s
opts.Messages = protoMessages opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate

View file

@ -9,10 +9,10 @@ import (
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
) )
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.Load(opts...) rerankModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,7 +13,6 @@ import (
) )
func SoundGeneration( func SoundGeneration(
modelFile string,
text string, text string,
duration *float32, duration *float32,
temperature *float32, temperature *float32,
@ -25,8 +24,9 @@ func SoundGeneration(
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) opts := ModelOptions(backendConfig, appConfig)
soundGenModel, err := loader.Load(opts...) soundGenModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -44,7 +44,7 @@ func SoundGeneration(
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text, Text: text,
Model: modelFile, Model: backendConfig.Model,
Dst: filePath, Dst: filePath,
Sample: doSample, Sample: doSample,
Duration: duration, Duration: duration,

View file

@ -4,19 +4,23 @@ import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc"
model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
var inferenceModel grpc.Backend var inferenceModel grpc.Backend
var err error var err error
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err = loader.Load(opts...) // TODO: looks weird, seems to be a correct merge?
if backendConfig.Backend == "" {
inferenceModel, err = loader.Load(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.Load(opts...)
}
if err != nil { if err != nil {
return schema.TokenizeResponse{}, err return schema.TokenizeResponse{}, err
} }

View file

@ -47,7 +47,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
tks = append(tks, int(t)) tks = append(tks, int(t))
} }
tr.Segments = append(tr.Segments, tr.Segments = append(tr.Segments,
schema.Segment{ schema.TranscriptionSegment{
Text: s.Text, Text: s.Text,
Id: int(s.Id), Id: int(s.Id),
Start: time.Duration(s.Start), Start: time.Duration(s.Start),

View file

@ -14,28 +14,22 @@ import (
) )
func ModelTTS( func ModelTTS(
backend,
text, text,
modelFile,
voice, voice,
language string, language string,
loader *model.ModelLoader, loader *model.ModelLoader,
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig, backendConfig config.BackendConfig,
) (string, *proto.Result, error) { ) (string, *proto.Result, error) {
bb := backend opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
if bb == "" {
bb = model.PiperBackend
}
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
ttsModel, err := loader.Load(opts...) ttsModel, err := loader.Load(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if ttsModel == nil { if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model") return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
} }
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@ -45,22 +39,21 @@ func ModelTTS(
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName) filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
modelPath := "" modelPath := ""
if modelFile != "" { // Checking first that it exists and is not outside ModelPath
// If the model file is not empty, we pass it joined with the model path // TODO: we should actually first check if the modelFile is looking like
// Checking first that it exists and is not outside ModelPath // a FS path
// TODO: we should actually first check if the modelFile is looking like mp := filepath.Join(loader.ModelPath, backendConfig.Model)
// a FS path if _, err := os.Stat(mp); err == nil {
mp := filepath.Join(loader.ModelPath, modelFile) if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
if _, err := os.Stat(mp); err == nil { return "", nil, err
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
} }
modelPath = mp
} else {
modelPath = backendConfig.Model // skip this step if it fails?????
} }
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

38
core/backend/vad.go Normal file
View file

@ -0,0 +1,38 @@
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}

View file

@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{} options := config.BackendConfig{}
options.SetDefaults() options.SetDefaults()
options.Backend = t.Backend options.Backend = t.Backend
options.Model = t.Model
var inputFile *string var inputFile *string
if t.InputFile != "" { if t.InputFile != "" {
inputFile = &t.InputFile inputFile = &t.InputFile
} }
filePath, _, err := backend.SoundGeneration(t.Model, text, filePath, _, err := backend.SoundGeneration(text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View file

@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{} options := config.BackendConfig{}
options.SetDefaults() options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options) filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
if err != nil { if err != nil {
return err return err
} }

View file

@ -436,19 +436,21 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int type BackendConfigUsecases int
const ( const (
FLAG_ANY BackendConfigUsecases = 0b000000000 FLAG_ANY BackendConfigUsecases = 0b00000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001 FLAG_CHAT BackendConfigUsecases = 0b00000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010 FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100 FLAG_EDIT BackendConfigUsecases = 0b00000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000 FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000 FLAG_RERANK BackendConfigUsecases = 0b00000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000 FLAG_IMAGE BackendConfigUsecases = 0b00000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000 FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
FLAG_TTS BackendConfigUsecases = 0b010000000 FLAG_TTS BackendConfigUsecases = 0b00010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000 FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000
// Common Subsets // Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
) )
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@ -463,6 +465,8 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT, "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS, "FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION, "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM, "FLAG_LLM": FLAG_LLM,
} }
} }
@ -548,5 +552,18 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
} }
} }
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" {
return false
}
}
return true return true
} }

View file

@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
c := &[]*BackendConfig{} c := &[]*BackendConfig{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
} }
if err := yaml.Unmarshal(f, c); err != nil { if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
} }
for _, cc := range *c { for _, cc := range *c {
@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
c := &BackendConfig{} c := &BackendConfig{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
} }
if err := yaml.Unmarshal(f, c); err != nil { if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
} }
c.SetDefaults(opts...) c.SetDefaults(opts...)
@ -117,7 +117,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
// Load a config file if present after the model name // Load a config file if present after the model name
cfg := &BackendConfig{ cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{ PredictionOptions: schema.PredictionOptions{
Model: modelName, BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
}, },
} }
@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil return cfg, nil
} }
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile // This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock() bcl.Lock()
@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
defer bcl.Unlock() defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...) c, err := readBackendConfigFromFile(file, opts...)
if err != nil { if err != nil {
return fmt.Errorf("cannot read config file: %w", err) return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
} }
if c.Validate() { if c.Validate() {
@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock() bcl.Lock()
defer bcl.Unlock() defer bcl.Unlock()
entries, err := os.ReadDir(path) entries, err := os.ReadDir(path)
if err != nil { if err != nil {
return fmt.Errorf("cannot read directory '%s': %w", path, err) return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
} }
files := make([]fs.FileInfo, 0, len(entries)) files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
} }
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...) c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name()) log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
continue continue
} }
if c.Validate() { if c.Validate() {
bcl.configs[c.Name] = *c bcl.configs[c.Name] = *c
} else { } else {
log.Error().Err(err).Msgf("config is not valid") log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
} }
} }

View file

@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
} }
// We try to guess only if we don't have a template defined already // We try to guess only if we don't have a template defined already
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName())) guessPath := filepath.Join(modelPath, cfg.ModelFileName())
f, err := gguf.ParseGGUFFile(guessPath)
if err != nil { if err != nil {
// Only valid for gguf files // Only valid for gguf files
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file") log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
return return
} }

View file

@ -130,7 +130,6 @@ func API(application *application.Application) (*fiber.App, error) {
return metricsService.Shutdown() return metricsService.Shutdown()
}) })
} }
} }
// Health Checks should always be exempt from auth, so register these first // Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router) routes.HealthRoutes(router)
@ -167,13 +166,15 @@ func API(application *application.Application) (*fiber.App, error) {
galleryService := services.NewGalleryService(application.ApplicationConfig()) galleryService := services.NewGalleryService(application.ApplicationConfig())
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader()) galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) requestExtractor := middleware.NewRequestExtractor(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, application) routes.RegisterElevenLabsRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI { if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService) routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
} }
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()) routes.RegisterJINARoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
httpFS := http.FS(embedDirStatic) httpFS := http.FS(embedDirStatic)

View file

@ -1,47 +0,0 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
if ctx.Query("model") != "" {
modelInput = ctx.Query("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View file

@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -17,45 +17,21 @@ import (
// @Router /v1/sound-generation [post] // @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsSoundGenerationRequest)
// Get input data from the request body input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if err := c.BodyParser(input); err != nil { if !ok || input.ModelID == "" {
return err return fiber.ErrBadRequest
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.ModelID return fiber.ErrBadRequest
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
if input.Duration != nil {
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
}
if input.Temperature != nil {
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
}
// TODO: Support uploading files? // TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,7 +3,7 @@ package elevenlabs
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -20,39 +20,21 @@ import (
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id") voiceID := c.Params("voice-id")
// Get input data from the request body input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if err := c.BodyParser(input); err != nil { if !ok || input.ModelID == "" {
return err return fiber.ErrBadRequest
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.ModelID return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg) filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,9 +3,9 @@ package jina
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
@ -19,58 +19,32 @@ import (
// @Router /v1/rerank [post] // @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil { input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ if !ok || input.Model == "" {
"error": "Cannot parse JSON", return fiber.ErrBadRequest
})
} }
input := new(schema.TTSRequest) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
request := &proto.RerankRequest{ request := &proto.RerankRequest{
Query: req.Query, Query: input.Query,
TopN: int32(req.TopN), TopN: int32(input.TopN),
Documents: req.Documents, Documents: input.Documents,
} }
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg) results, err := backend.Rerank(request, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }
response := &schema.JINARerankResponse{ response := &schema.JINARerankResponse{
Model: req.Model, Model: input.Model,
} }
for _, r := range results.Results { for _, r := range results.Results {

View file

@ -4,13 +4,15 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID // TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
// //
// @Summary Get TokenMetrics for Active Slot. // @Summary Get TokenMetrics for Active Slot.
@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
return err return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if err != nil { if !ok || modelFile != "" {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil { if err != nil {
log.Err(err) log.Err(err)

View file

@ -4,10 +4,9 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
) )
// TokenizeEndpoint exposes a REST API to tokenize the content // TokenizeEndpoint exposes a REST API to tokenize the content
@ -16,42 +15,21 @@ import (
// @Success 200 {object} schema.TokenizeResponse "Response" // @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post] // @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
input := new(schema.TokenizeRequest) if !ok || input.Model == "" {
return fiber.ErrBadRequest
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil { if err != nil {
return err return err
} }
return ctx.JSON(tokenResponse)
return c.JSON(tokenResponse)
} }
} }

View file

@ -3,7 +3,7 @@ package localai
import ( import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -24,37 +24,24 @@ import (
// @Router /tts [post] // @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil { if cfg.Backend == "" {
log.Err(err) if input.Backend != "" {
modelFile = input.Model cfg.Backend = input.Backend
log.Warn().Msgf("Model not found in context: %s", input.Model) } else {
} else { cfg.Backend = model.PiperBackend
modelFile = cfg.Model }
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
} }
if input.Language != "" { if input.Language != "" {
@ -65,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Voice = input.Voice cfg.Voice = input.Voice
} }
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg) filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,9 +4,8 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -19,45 +18,20 @@ import (
// @Router /vad [post] // @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.VADRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
// Get input data from the request body return fiber.ErrBadRequest
if err := c.BodyParser(input); err != nil {
return err
} }
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || cfg == nil {
modelFile = input.Model return fiber.ErrBadRequest
log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil { resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil { if err != nil {
return err return err
} }

View file

@ -5,18 +5,19 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -174,26 +175,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
textContentToReturn = "" textContentToReturn = ""
id = uuid.New().String() id = uuid.New().String()
created = int(time.Now().Unix()) created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID") input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if len(strings.TrimSpace(correlationID)) == 0 { if !ok || input.Model == "" {
correlationID = id return fiber.ErrBadRequest
} }
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
funcs := input.Functions funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
@ -539,7 +534,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
audios = append(audios, m.StringAudios...) audios = append(audios, m.StringAudios...)
} }
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil) predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
if err != nil { if err != nil {
log.Error().Err(err).Msg("model inference failed") log.Error().Err(err).Msg("model inference failed")
return "", err return "", err

View file

@ -10,12 +10,13 @@ import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -26,11 +27,11 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post] // @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix()) created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{ usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt, PromptTokens: tokenUsage.Prompt,
@ -63,22 +64,18 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
} }
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
// Add Correlation // Handle Correlation
c.Set("X-Correlation-ID", id) id := c.Get("X-Correlation-ID", uuid.New().String())
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
log.Debug().Msgf("`input`: %+v", input) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) return fiber.ErrBadRequest
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
} }
if config.ResponseFormatMap != nil { if config.ResponseFormatMap != nil {
@ -122,7 +119,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse) responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses, extraUsage) go process(id, predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

View file

@ -2,16 +2,17 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -25,20 +26,21 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
// Opt-in extra usage flag // Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != "" extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) log.Debug().Msgf("Edit Endpoint Input : %+v", input)
if err != nil { log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
var result []schema.Choice var result []schema.Choice
totalTokenUsage := backend.TokenUsage{} totalTokenUsage := backend.TokenUsage{}

View file

@ -2,11 +2,11 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/google/uuid" "github.com/google/uuid"
@ -23,14 +23,14 @@ import (
// @Router /v1/embeddings [post] // @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, cl, ml, appConfig, true) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)

View file

@ -15,6 +15,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post] // @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
} }
if m == "" { config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
m = "stablediffusion" if !ok || config == nil {
} log.Error().Msg("Image Endpoint - Invalid Config")
log.Debug().Msgf("Loading model: %+v", m) return fiber.ErrBadRequest
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
} }
src := "" src := ""
if input.File != "" { if input.File != "" {
fileData := []byte{} fileData := []byte{}
var err error
// check if input.File is an URL, if so download it and save it // check if input.File is an URL, if so download it and save it
// to a temporary file // to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {

View file

@ -37,7 +37,7 @@ func ComputeChoices(
} }
// get the model function to call for the result // get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback) predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
if err != nil { if err != nil {
return result, backend.TokenUsage{}, err return result, backend.TokenUsage{}, err
} }

View file

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -10,6 +9,8 @@ import (
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -25,15 +26,16 @@ import (
// @Router /v1/audio/transcriptions [post] // @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if err != nil { if !ok || input.Model == "" {
return fmt.Errorf("failed reading parameters from request:%w", err) return fiber.ErrBadRequest
} }
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if err != nil { if !ok || config == nil {
return fmt.Errorf("failed reading parameters from request: %w", err) return fiber.ErrBadRequest
} }
// retrieve the file data from the request // retrieve the file data from the request
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {

View file

@ -1,326 +1,450 @@
package openai package middleware
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
) "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
type correlationIDKeyType string )
// CorrelationIDKey to track request across process boundary type correlationIDKeyType string
const CorrelationIDKey correlationIDKeyType = "correlationID"
// CorrelationIDKey to track request across process boundary
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { const CorrelationIDKey correlationIDKeyType = "correlationID"
input := new(schema.OpenAIRequest)
type RequestExtractor struct {
// Get input data from the request body backendConfigLoader *config.BackendConfigLoader
if err := c.BodyParser(input); err != nil { modelLoader *model.ModelLoader
return "", nil, fmt.Errorf("failed parsing request body: %w", err) applicationConfig *config.ApplicationConfig
} }
received, _ := json.Marshal(input) func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
// Extract or generate the correlation ID return &RequestExtractor{
correlationID := c.Get("X-Correlation-ID", uuid.New().String()) backendConfigLoader: backendConfigLoader,
modelLoader: modelLoader,
ctx, cancel := context.WithCancel(o.Context) applicationConfig: applicationConfig,
// Add the correlation ID to the new context }
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) }
input.Context = ctxWithCorrelationID const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
input.Cancel = cancel const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
log.Debug().Msgf("Request received: %s", string(received))
// TODO: Refactor to not return error if unchanged
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
return modelFile, input, err if ok && model != "" {
} return
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { model = ctx.Params("model")
if input.Echo {
config.Echo = input.Echo if (model == "") && ctx.Query("model") != "" {
} model = ctx.Query("model")
if input.TopK != nil { }
config.TopK = input.TopK
} if model == "" {
if input.TopP != nil { // Set model from bearer token, if available
config.TopP = input.TopP bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
} if bearer != "" {
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if input.Backend != "" { if err == nil && exists {
config.Backend = input.Backend model = bearer
} }
}
if input.ClipSkip != 0 { }
config.Diffusers.ClipSkip = input.ClipSkip
} ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
} return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
if input.NegativePromptScale != 0 { localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
config.NegativePromptScale = input.NegativePromptScale if !ok || localModelName == "" {
} ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
if input.UseFastTokenizer { }
config.UseFastTokenizer = input.UseFastTokenizer return ctx.Next()
} }
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
} return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
if input.RopeFreqBase != 0 { localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
config.RopeFreqBase = input.RopeFreqBase if localModelName != "" { // Don't overwrite existing values
} return ctx.Next()
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
} if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
if input.Grammar != "" { return ctx.Next()
config.Grammar = input.Grammar }
}
if len(modelNames) == 0 {
if input.Temperature != nil { log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
config.Temperature = input.Temperature // This is non-fatal - making it so was breaking the case of direct installation of raw models
} // return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
if input.Maxtokens != nil { }
config.Maxtokens = input.Maxtokens
} ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
if input.ResponseFormat != nil { return ctx.Next()
switch responseFormat := input.ResponseFormat.(type) { }
case string: }
config.ResponseFormat = responseFormat
case map[string]interface{}: // TODO: If context and cancel above belong on all methods, move that part of above into here!
config.ResponseFormatMap = responseFormat // Otherwise, it's in its own method below for now
} func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
} return func(ctx *fiber.Ctx) error {
input := initializer()
switch stop := input.Stop.(type) { if input == nil {
case string: return fmt.Errorf("unable to initialize body")
if stop != "" { }
config.StopWords = append(config.StopWords, stop) if err := ctx.BodyParser(input); err != nil {
} return fmt.Errorf("failed parsing request body: %w", err)
case []interface{}: }
for _, pp := range stop {
if s, ok := pp.(string); ok { // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
config.StopWords = append(config.StopWords, s) if input.ModelName(nil) == "" {
} localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
} if ok && localModelName != "" {
} log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
if len(input.Tools) > 0 { }
for _, tool := range input.Tools { }
input.Functions = append(input.Functions, tool.Function)
} cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
}
if err != nil {
if input.ToolsChoice != nil { log.Err(err)
var toolChoice functions.Tool log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
switch content := input.ToolsChoice.(type) { log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
case string: cfg.Model = input.ModelName(nil)
_ = json.Unmarshal([]byte(content), &toolChoice) }
case map[string]interface{}:
dat, _ := json.Marshal(content) ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
_ = json.Unmarshal(dat, &toolChoice) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
}
input.FunctionCall = map[string]interface{}{ return ctx.Next()
"name": toolChoice.Function.Name, }
} }
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
// Decode each request's message content input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
imgIndex, vidIndex, audioIndex := 0, 0, 0 if !ok || input.Model == "" {
for i, m := range input.Messages { return fiber.ErrBadRequest
nrOfImgsInMessage := 0 }
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0 cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
switch content := m.Content.(type) { return fiber.ErrBadRequest
case string: }
input.Messages[i].StringContent = content
case []interface{}: // Extract or generate the correlation ID
dat, _ := json.Marshal(content) correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
c := []schema.Content{} ctx.Set("X-Correlation-ID", correlationID)
json.Unmarshal(dat, &c)
c1, cancel := context.WithCancel(re.applicationConfig.Context)
textContent := "" // Add the correlation ID to the new context
// we will template this at the end ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
CONTENT: input.Context = ctxWithCorrelationID
for _, pp := range c { input.Cancel = cancel
switch pp.Type {
case "text": err := mergeOpenAIRequestAndBackendConfig(cfg, input)
textContent += pp.Text if err != nil {
//input.Messages[i].StringContent = pp.Text return err
case "video", "video_url": }
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) if cfg.Model == "" {
if err != nil { log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
log.Error().Msgf("Failed encoding video: %s", err) cfg.Model = input.Model
continue CONTENT }
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
vidIndex++ ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
nrOfVideosInMessage++
case "audio_url", "audio": return ctx.Next()
// Decode content as base64 either if it's an URL or base64 text }
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil { func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
log.Error().Msgf("Failed encoding image: %s", err) if input.Echo {
continue CONTENT config.Echo = input.Echo
} }
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff if input.TopK != nil {
audioIndex++ config.TopK = input.TopK
nrOfAudiosInMessage++ }
case "image_url", "image": if input.TopP != nil {
// Decode content as base64 either if it's an URL or base64 text config.TopP = input.TopP
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) }
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err) if input.Backend != "" {
continue CONTENT config.Backend = input.Backend
} }
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
imgIndex++ }
nrOfImgsInMessage++
} if input.ModelBaseName != "" {
} config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex, if input.NegativePromptScale != 0 {
TotalVideos: vidIndex, config.NegativePromptScale = input.NegativePromptScale
TotalAudios: audioIndex, }
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage, if input.UseFastTokenizer {
AudiosInMessage: nrOfAudiosInMessage, config.UseFastTokenizer = input.UseFastTokenizer
}, textContent) }
}
} if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
if input.RepeatPenalty != 0 { }
config.RepeatPenalty = input.RepeatPenalty
} if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
if input.FrequencyPenalty != 0 { }
config.FrequencyPenalty = input.FrequencyPenalty
} if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
if input.PresencePenalty != 0 { }
config.PresencePenalty = input.PresencePenalty
} if input.Grammar != "" {
config.Grammar = input.Grammar
if input.Keep != 0 { }
config.Keep = input.Keep
} if input.Temperature != nil {
config.Temperature = input.Temperature
if input.Batch != 0 { }
config.Batch = input.Batch
} if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
if input.IgnoreEOS { }
config.IgnoreEOS = input.IgnoreEOS
} if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
if input.Seed != nil { case string:
config.Seed = input.Seed config.ResponseFormat = responseFormat
} case map[string]interface{}:
config.ResponseFormatMap = responseFormat
if input.TypicalP != nil { }
config.TypicalP = input.TypicalP }
}
switch stop := input.Stop.(type) {
switch inputs := input.Input.(type) { case string:
case string: if stop != "" {
if inputs != "" { config.StopWords = append(config.StopWords, stop)
config.InputStrings = append(config.InputStrings, inputs) }
} case []interface{}:
case []interface{}: for _, pp := range stop {
for _, pp := range inputs { if s, ok := pp.(string); ok {
switch i := pp.(type) { config.StopWords = append(config.StopWords, s)
case string: }
config.InputStrings = append(config.InputStrings, i) }
case []interface{}: }
tokens := []int{}
for _, ii := range i { if len(input.Tools) > 0 {
tokens = append(tokens, int(ii.(float64))) for _, tool := range input.Tools {
} input.Functions = append(input.Functions, tool.Function)
config.InputToken = append(config.InputToken, tokens) }
} }
}
} if input.ToolsChoice != nil {
var toolChoice functions.Tool
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) { switch content := input.ToolsChoice.(type) {
case string: case string:
if fnc != "" { _ = json.Unmarshal([]byte(content), &toolChoice)
config.SetFunctionCallString(fnc) case map[string]interface{}:
} dat, _ := json.Marshal(content)
case map[string]interface{}: _ = json.Unmarshal(dat, &toolChoice)
var name string }
n, exists := fnc["name"] input.FunctionCall = map[string]interface{}{
if exists { "name": toolChoice.Function.Name,
nn, e := n.(string) }
if e { }
name = nn
} // Decode each request's message content
} imgIndex, vidIndex, audioIndex := 0, 0, 0
config.SetFunctionCallNameString(name) for i, m := range input.Messages {
} nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
switch p := input.Prompt.(type) { nrOfAudiosInMessage := 0
case string:
config.PromptStrings = append(config.PromptStrings, p) switch content := m.Content.(type) {
case []interface{}: case string:
for _, pp := range p { input.Messages[i].StringContent = content
if s, ok := pp.(string); ok { case []interface{}:
config.PromptStrings = append(config.PromptStrings, s) dat, _ := json.Marshal(content)
} c := []schema.Content{}
} json.Unmarshal(dat, &c)
}
textContent := ""
// If a quality was defined as number, convert it to step // we will template this at the end
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality) CONTENT:
if err == nil { for _, pp := range c {
config.Step = q switch pp.Type {
} case "text":
} textContent += pp.Text
} //input.Messages[i].StringContent = pp.Text
case "video", "video_url":
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { // Decode content as base64 either if it's an URL or base64 text
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
config.LoadOptionDebug(debug), if err != nil {
config.LoadOptionThreads(threads), log.Error().Msgf("Failed encoding video: %s", err)
config.LoadOptionContextSize(ctx), continue CONTENT
config.LoadOptionF16(f16), }
) input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
// Set the parameters for the language model prediction nrOfVideosInMessage++
updateRequestConfig(cfg, input) case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
if !cfg.Validate() { base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
return nil, nil, fmt.Errorf("failed to validate config") if err != nil {
} log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
return cfg, input, err }
} input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}

View file

@ -4,17 +4,26 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs" "github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterElevenLabsRoutes(app *fiber.App, func RegisterElevenLabsRoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig) { appConfig *config.ApplicationConfig) {
// Elevenlabs // Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig)) app.Post("/v1/text-to-speech/:voice-id",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) app.Post("/v1/sound-generation",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
} }

View file

@ -3,16 +3,22 @@ package routes
import ( import (
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina" "github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterJINARoutes(app *fiber.App, func RegisterJINARoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig) { appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking // POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig)) app.Post("/v1/rerank",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
} }

View file

@ -5,13 +5,16 @@ import (
"github.com/gofiber/swagger" "github.com/gofiber/swagger"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func RegisterLocalAIRoutes(router *fiber.App, func RegisterLocalAIRoutes(router *fiber.App,
requestExtractor *middleware.RequestExtractor,
cl *config.BackendConfigLoader, cl *config.BackendConfigLoader,
ml *model.ModelLoader, ml *model.ModelLoader,
appConfig *config.ApplicationConfig, appConfig *config.ApplicationConfig,
@ -33,8 +36,18 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
} }
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) router.Post("/tts",
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig)) requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
vadChain := []fiber.Handler{
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
// Stores // Stores
sl := model.NewModelLoader("") sl := model.NewModelLoader("")
@ -47,10 +60,14 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/metrics", localai.LocalAIMetricsEndpoint()) router.Get("/metrics", localai.LocalAIMetricsEndpoint())
} }
// Experimental Backend Statistics Module // Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p // p2p
if p2p.IsP2PEnabled() { if p2p.IsP2PEnabled() {
@ -67,6 +84,9 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/system", localai.SystemInformations(ml, appConfig)) router.Get("/system", localai.SystemInformations(ml, appConfig))
// misc // misc
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) router.Post("/v1/tokenize",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
} }

View file

@ -3,51 +3,50 @@ package routes
import ( import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
) )
func RegisterOpenAIRoutes(app *fiber.App, func RegisterOpenAIRoutes(app *fiber.App,
re *middleware.RequestExtractor,
application *application.Application) { application *application.Application) {
// openAI compatible API endpoint // openAI compatible API endpoint
// chat // chat
app.Post("/v1/chat/completions", chatChain := []fiber.Handler{
openai.ChatEndpoint( re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
application.BackendLoader(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.ModelLoader(), re.SetOpenAIRequest,
application.TemplatesEvaluator(), openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
application.ApplicationConfig(), }
), app.Post("/v1/chat/completions", chatChain...)
) app.Post("/chat/completions", chatChain...)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// edit // edit
app.Post("/v1/edits", editChain := []fiber.Handler{
openai.EditEndpoint( re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
application.BackendLoader(), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
application.ModelLoader(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.TemplatesEvaluator(), re.SetOpenAIRequest,
application.ApplicationConfig(), openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
), }
) app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.Post("/edits", // completion
openai.EditEndpoint( completionChain := []fiber.Handler{
application.BackendLoader(), re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
application.ModelLoader(), re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
application.TemplatesEvaluator(), re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
application.ApplicationConfig(), re.SetOpenAIRequest,
), openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
) }
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
// assistant // assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
@ -81,45 +80,37 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig())) app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// completion
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings // embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) embeddingChain := []fiber.Handler{
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
// audio // audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Post("/v1/audio/transcriptions",
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
)
app.Post("/v1/audio/speech",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images // images
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())) app.Post("/v1/images/generations",
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if application.ApplicationConfig().ImageDir != "" { if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir) app.Static("/generated-images", application.ApplicationConfig().ImageDir)

View file

@ -1,8 +1,9 @@
package schema package schema
type ElevenLabsTTSRequest struct { type ElevenLabsTTSRequest struct {
Text string `json:"text" yaml:"text"` Text string `json:"text" yaml:"text"`
ModelID string `json:"model_id" yaml:"model_id"` ModelID string `json:"model_id" yaml:"model_id"`
LanguageCode string `json:"language_code" yaml:"language_code"`
} }
type ElevenLabsSoundGenerationRequest struct { type ElevenLabsSoundGenerationRequest struct {
@ -12,3 +13,17 @@ type ElevenLabsSoundGenerationRequest struct {
Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"` Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"`
DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"` DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"`
} }
func (elttsr *ElevenLabsTTSRequest) ModelName(s *string) string {
if s != nil {
elttsr.ModelID = *s
}
return elttsr.ModelID
}
func (elsgr *ElevenLabsSoundGenerationRequest) ModelName(s *string) string {
if s != nil {
elsgr.ModelID = *s
}
return elsgr.ModelID
}

View file

@ -2,10 +2,11 @@ package schema
// RerankRequest defines the structure of the request payload // RerankRequest defines the structure of the request payload
type JINARerankRequest struct { type JINARerankRequest struct {
Model string `json:"model"` BasicModelRequest
Query string `json:"query"` Query string `json:"query"`
Documents []string `json:"documents"` Documents []string `json:"documents"`
TopN int `json:"top_n"` TopN int `json:"top_n"`
Backend string `json:"backend"`
} }
// DocumentResult represents a single document result // DocumentResult represents a single document result

View file

@ -6,11 +6,11 @@ import (
) )
type BackendMonitorRequest struct { type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"` BasicModelRequest
} }
type TokenMetricsRequest struct { type TokenMetricsRequest struct {
Model string `json:"model" yaml:"model"` BasicModelRequest
} }
type BackendMonitorResponse struct { type BackendMonitorResponse struct {
@ -26,7 +26,7 @@ type GalleryResponse struct {
// @Description TTS request body // @Description TTS request body
type TTSRequest struct { type TTSRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path BasicModelRequest
Input string `json:"input" yaml:"input"` // text input Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"` Backend string `json:"backend" yaml:"backend"`
@ -36,10 +36,19 @@ type TTSRequest struct {
// @Description VAD request body // @Description VAD request body
type VADRequest struct { type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path BasicModelRequest
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
} }
type VADSegment struct {
Start float32 `json:"start" yaml:"start"`
End float32 `json:"end" yaml:"end"`
}
type VADResponse struct {
Segments []VADSegment `json:"segments" yaml:"segments"`
}
type StoresSet struct { type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"` Store string `json:"store,omitempty" yaml:"store,omitempty"`

View file

@ -3,7 +3,7 @@ package schema
type PredictionOptions struct { type PredictionOptions struct {
// Also part of the OpenAI official spec // Also part of the OpenAI official spec
Model string `json:"model" yaml:"model"` BasicModelRequest `yaml:",inline"`
// Also part of the OpenAI official spec // Also part of the OpenAI official spec
Language string `json:"language"` Language string `json:"language"`

22
core/schema/request.go Normal file
View file

@ -0,0 +1,22 @@
package schema
// This file and type represent a generic request to LocalAI - as opposed to requests to LocalAI-specific endpoints, which live in localai.go
type LocalAIRequest interface {
ModelName(*string) string
}
type BasicModelRequest struct {
Model string `json:"model" yaml:"model"`
// TODO: Should this also include the following fields from the OpenAI side of the world?
// If so, changes should be made to core/http/middleware/request.go to match
// Context context.Context `json:"-"`
// Cancel context.CancelFunc `json:"-"`
}
func (bmr *BasicModelRequest) ModelName(s *string) string {
if s != nil {
bmr.Model = *s
}
return bmr.Model
}

View file

@ -1,8 +1,8 @@
package schema package schema
type TokenizeRequest struct { type TokenizeRequest struct {
BasicModelRequest
Content string `json:"content"` Content string `json:"content"`
Model string `json:"model"`
} }
type TokenizeResponse struct { type TokenizeResponse struct {

View file

@ -2,7 +2,7 @@ package schema
import "time" import "time"
type Segment struct { type TranscriptionSegment struct {
Id int `json:"id"` Id int `json:"id"`
Start time.Duration `json:"start"` Start time.Duration `json:"start"`
End time.Duration `json:"end"` End time.Duration `json:"end"`
@ -11,6 +11,6 @@ type Segment struct {
} }
type TranscriptionResult struct { type TranscriptionResult struct {
Segments []Segment `json:"segments"` Segments []TranscriptionSegment `json:"segments"`
Text string `json:"text"` Text string `json:"text"`
} }

View file

@ -49,3 +49,15 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
return dataModels, nil return dataModels, nil
} }
func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
filter, err := config.BuildNameFilterFn(modelName)
if err != nil {
return false, err
}
models, err := ListModels(bcl, ml, filter, looseFilePolicy)
if err != nil {
return false, err
}
return (len(models) > 0), nil
}

2
go.mod
View file

@ -239,7 +239,7 @@ require (
github.com/moby/sys/sequential v0.5.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc
github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect
github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect

2
go.sum
View file

@ -524,6 +524,8 @@ github.com/mudler/edgevpn v0.29.0 h1:SEkVyjXL6P8szUZFlL8W1EYBxvFsEIFvXlXcRfGrXYU
github.com/mudler/edgevpn v0.29.0/go.mod h1:+kSy9b44eo97PnJ3fOnTkcTgxNXdgJBcd2bopx4leto= github.com/mudler/edgevpn v0.29.0/go.mod h1:+kSy9b44eo97PnJ3fOnTkcTgxNXdgJBcd2bopx4leto=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU=
github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo= github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo=

View file

@ -460,7 +460,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) { func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) log.Info().Str("modelID", o.modelID).Str("backend", o.backendString).Str("o.model", o.model).Msg("BackendLoader starting")
backend := strings.ToLower(o.backendString) backend := strings.ToLower(o.backendString)
if realBackend, exists := Aliases[backend]; exists { if realBackend, exists := Aliases[backend]; exists {

View file

@ -56,6 +56,14 @@ func WithBackendString(backend string) Option {
} }
} }
func WithDefaultBackendString(backend string) Option {
return func(o *Options) {
if o.backendString == "" {
o.backendString = backend
}
}
}
func WithModel(modelFile string) Option { func WithModel(modelFile string) Option {
return func(o *Options) { return func(o *Options) {
o.model = modelFile o.model = modelFile

View file

@ -43,7 +43,7 @@ var _ = BeforeSuite(func() {
apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely. apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
defaultConfig.BaseURL = apiEndpoint defaultConfig.BaseURL = apiEndpoint
} else { } else {
fmt.Println("Default ", apiEndpoint) GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
defaultConfig = openai.DefaultConfig(apiKey) defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = apiEndpoint defaultConfig.BaseURL = apiEndpoint
} }
@ -95,10 +95,11 @@ func startDockerImage() {
PortBindings: map[docker.Port][]docker.PortBinding{ PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: apiPort}}, "8080/tcp": []docker.PortBinding{{HostPort: apiPort}},
}, },
Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc)}, Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc), "LOCALAI_SINGLE_ACTIVE_BACKEND=true"},
Mounts: []string{md + ":/models"}, Mounts: []string{md + ":/models"},
} }
GinkgoWriter.Printf("Launching Docker Container %q\n%+v\n", containerImageTag, options)
r, err := pool.RunWithOptions(options) r, err := pool.RunWithOptions(options)
Expect(err).To(Not(HaveOccurred())) Expect(err).To(Not(HaveOccurred()))

View file

@ -121,14 +121,13 @@ var _ = Describe("E2E test", func() {
Context("images", func() { Context("images", func() {
It("correctly", func() { It("correctly", func() {
resp, err := client.CreateImage(context.TODO(), req := openai.ImageRequest{
openai.ImageRequest{ Prompt: "test",
Prompt: "test", Quality: "1",
Quality: "1", Size: openai.CreateImageSize512x512,
Size: openai.CreateImageSize256x256, }
}, resp, err := client.CreateImage(context.TODO(), req)
) Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req))
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
}) })
@ -232,13 +231,42 @@ var _ = Describe("E2E test", func() {
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text)) Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
}) })
}) })
Context("vad", func() {
It("correctly", func() {
modelName := "silero-vad"
req := schema.VADRequest{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Audio: SampleVADAudio, // Use hardcoded sample data for now.
}
serialized, err := json.Marshal(req)
Expect(err).To(BeNil())
Expect(serialized).ToNot(BeNil())
vadEndpoint := apiEndpoint + "/vad"
resp, err := http.Post(vadEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil())
body, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
deserializedResponse := schema.VADResponse{}
err = json.Unmarshal(body, &deserializedResponse)
Expect(err).To(BeNil())
Expect(deserializedResponse).ToNot(BeZero())
Expect(deserializedResponse.Segments).ToNot(BeZero())
})
})
Context("reranker", func() { Context("reranker", func() {
It("correctly", func() { It("correctly", func() {
modelName := "jina-reranker-v1-base-en" modelName := "jina-reranker-v1-base-en"
req := schema.JINARerankRequest{ req := schema.JINARerankRequest{
Model: modelName, BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Query: "Organic skincare products for sensitive skin", Query: "Organic skincare products for sensitive skin",
Documents: []string{ Documents: []string{
"Eco-friendly kitchenware for modern homes", "Eco-friendly kitchenware for modern homes",

File diff suppressed because it is too large Load diff