squash past, centralize request middleware PR

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave Lee 2025-02-05 14:14:10 -05:00
parent 28a1310890
commit c1f30ba3a9
No known key found for this signature in database
55 changed files with 481027 additions and 821 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,22 @@
meta {
name: vad test too few
type: http
seq: 1
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/vad
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "silero-vad",
"audio": []
}
}

8
aio/cpu/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -129,7 +129,7 @@ detect_gpu
detect_gpu_size
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
check_vars

8
aio/gpu-8g/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

8
aio/intel/vad.yaml Normal file
View file

@ -0,0 +1,8 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View file

@ -145,13 +145,7 @@ func New(opts ...config.AppOption) (*Application, error) {
if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory {
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
config.LoadOptionF16(options.F16),
config.ModelPath(options.ModelPath),
)
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
if err != nil {
return nil, err
}

View file

@ -33,7 +33,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@ -48,7 +48,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}
}
opts := ModelOptions(c, o)
opts := ModelOptions(*c, o)
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
@ -84,7 +84,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts := gRPCPredictOpts(*c, loader.ModelPath)
opts.Prompt = s
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate

View file

@ -9,10 +9,10 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig)
rerankModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}

View file

@ -13,7 +13,6 @@ import (
)
func SoundGeneration(
modelFile string,
text string,
duration *float32,
temperature *float32,
@ -25,8 +24,9 @@ func SoundGeneration(
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
opts := ModelOptions(backendConfig, appConfig)
soundGenModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
@ -44,7 +44,7 @@ func SoundGeneration(
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: modelFile,
Model: backendConfig.Model,
Dst: filePath,
Sample: doSample,
Duration: duration,

View file

@ -4,19 +4,23 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
opts := ModelOptions(backendConfig, appConfig)
// TODO: looks weird, seems to be a correct merge?
if backendConfig.Backend == "" {
inferenceModel, err = loader.Load(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.Load(opts...)
}
if err != nil {
return schema.TokenizeResponse{}, err
}

View file

@ -47,7 +47,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
schema.TranscriptionSegment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),

View file

@ -14,28 +14,22 @@ import (
)
func ModelTTS(
backend,
text,
modelFile,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
ttsModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
}
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@ -45,22 +39,21 @@ func ModelTTS(
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
}
modelPath = backendConfig.Model // skip this step if it fails?????
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

38
core/backend/vad.go Normal file
View file

@ -0,0 +1,38 @@
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}

View file

@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
var inputFile *string
if t.InputFile != "" {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(t.Model, text,
filePath, _, err := backend.SoundGeneration(text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View file

@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}

View file

@ -436,19 +436,21 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
FLAG_ANY BackendConfigUsecases = 0b00000000000
FLAG_CHAT BackendConfigUsecases = 0b00000000001
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
FLAG_EDIT BackendConfigUsecases = 0b00000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
FLAG_RERANK BackendConfigUsecases = 0b00000010000
FLAG_IMAGE BackendConfigUsecases = 0b00000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
FLAG_TTS BackendConfigUsecases = 0b00010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@ -463,6 +465,8 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
}
}
@ -548,5 +552,18 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
}
}
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" {
return false
}
}
return true
}

View file

@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
}
for _, cc := range *c {
@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
}
c.SetDefaults(opts...)
@ -117,8 +117,10 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
},
}
cfgExisting, exists := bcl.GetBackendConfig(modelName)
@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
}
if c.Validate() {
@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("cannot read directory '%s': %w", path, err)
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
}
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil {
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
continue
}
if c.Validate() {
bcl.configs[c.Name] = *c
} else {
log.Error().Err(err).Msgf("config is not valid")
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
}
}

View file

@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
}
// We try to guess only if we don't have a template defined already
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName()))
guessPath := filepath.Join(modelPath, cfg.ModelFileName())
f, err := gguf.ParseGGUFFile(guessPath)
if err != nil {
// Only valid for gguf files
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file")
log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
return
}

View file

@ -130,7 +130,6 @@ func API(application *application.Application) (*fiber.App, error) {
return metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router)
@ -167,13 +166,15 @@ func API(application *application.Application) (*fiber.App, error) {
galleryService := services.NewGalleryService(application.ApplicationConfig())
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, application)
requestExtractor := middleware.NewRequestExtractor(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
}
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterJINARoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
httpFS := http.FS(embedDirStatic)

View file

@ -1,47 +0,0 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
if ctx.Query("model") != "" {
modelInput = ctx.Query("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View file

@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
@ -17,45 +17,21 @@ import (
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsSoundGenerationRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
if input.Duration != nil {
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
}
if input.Temperature != nil {
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
}
// TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -3,7 +3,7 @@ package elevenlabs
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -20,39 +20,21 @@ import (
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -3,9 +3,9 @@ package jina
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@ -19,58 +19,32 @@ import (
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Cannot parse JSON",
})
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
request := &proto.RerankRequest{
Query: req.Query,
TopN: int32(req.TopN),
Documents: req.Documents,
Query: input.Query,
TopN: int32(input.TopN),
Documents: input.Documents,
}
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
results, err := backend.Rerank(request, ml, appConfig, *cfg)
if err != nil {
return err
}
response := &schema.JINARerankResponse{
Model: req.Model,
Model: input.Model,
}
for _, r := range results.Results {

View file

@ -4,13 +4,15 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model"
)
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
//
// @Summary Get TokenMetrics for Active Slot.
@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
if err != nil {
log.Err(err)

View file

@ -4,10 +4,9 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// TokenizeEndpoint exposes a REST API to tokenize the content
@ -16,42 +15,21 @@ import (
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenizeRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return c.JSON(tokenResponse)
return ctx.JSON(tokenResponse)
}
}

View file

@ -3,7 +3,7 @@ package localai
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -24,37 +24,24 @@ import (
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
if cfg.Backend == "" {
if input.Backend != "" {
cfg.Backend = input.Backend
} else {
cfg.Backend = model.PiperBackend
}
}
if input.Language != "" {
@ -65,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Voice = input.Voice
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

View file

@ -4,9 +4,8 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@ -19,45 +18,20 @@ import (
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.VADRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}

View file

@ -5,18 +5,19 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@ -174,26 +175,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
@ -539,7 +534,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
audios = append(audios, m.StringAudios...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View file

@ -10,12 +10,13 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
@ -26,11 +27,11 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@ -63,22 +64,18 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
}
return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
// Opt-in extra usage flag
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
if config.ResponseFormatMap != nil {
@ -122,7 +119,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses, extraUsage)
go process(id, predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

View file

@ -2,16 +2,17 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
@ -25,20 +26,21 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}

View file

@ -2,11 +2,11 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/google/uuid"
@ -23,14 +23,14 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
log.Debug().Msgf("Parameter Config: %+v", config)

View file

@ -15,6 +15,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
}
if m == "" {
m = "stablediffusion"
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest
}
src := ""
if input.File != "" {
fileData := []byte{}
var err error
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {

View file

@ -37,7 +37,7 @@ func ComputeChoices(
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback)
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}

View file

@ -1,7 +1,6 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
@ -10,6 +9,8 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -25,15 +26,16 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request: %w", err)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {

View file

@ -1,20 +1,22 @@
package openai
package middleware
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@ -23,33 +25,166 @@ type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
type RequestExtractor struct {
backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
backendConfigLoader: backendConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
}
received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())
if model == "" {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx, cancel := context.WithCancel(o.Context)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
if err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
return modelFile, input, err
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
@ -249,6 +384,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
@ -305,22 +442,9 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
config.Step = q
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
if !cfg.Validate() {
return nil, nil, fmt.Errorf("failed to validate config")
if config.Validate() {
return nil
}
return cfg, input, err
return fmt.Errorf("unable to validate configuration after merging")
}

View file

@ -4,17 +4,26 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterElevenLabsRoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/text-to-speech/:voice-id",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
}

View file

@ -3,16 +3,22 @@ package routes
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterJINARoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
app.Post("/v1/rerank",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
}

View file

@ -5,13 +5,16 @@ import (
"github.com/gofiber/swagger"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterLocalAIRoutes(router *fiber.App,
requestExtractor *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
@ -33,8 +36,18 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
}
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
router.Post("/tts",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
vadChain := []fiber.Handler{
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
// Stores
sl := model.NewModelLoader("")
@ -47,10 +60,14 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
}
// Experimental Backend Statistics Module
// Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
if p2p.IsP2PEnabled() {
@ -67,6 +84,9 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/system", localai.SystemInformations(ml, appConfig))
// misc
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
router.Post("/v1/tokenize",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
}

View file

@ -3,51 +3,50 @@ package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
)
func RegisterOpenAIRoutes(app *fiber.App,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
chatChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/chat/completions", chatChain...)
app.Post("/chat/completions", chatChain...)
// edit
app.Post("/v1/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
editChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.Post("/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// completion
completionChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
// assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
@ -81,45 +80,37 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// completion
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
embeddingChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/transcriptions",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
)
app.Post("/v1/audio/speech",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/images/generations",
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir)

View file

@ -3,6 +3,7 @@ package schema
type ElevenLabsTTSRequest struct {
Text string `json:"text" yaml:"text"`
ModelID string `json:"model_id" yaml:"model_id"`
LanguageCode string `json:"language_code" yaml:"language_code"`
}
type ElevenLabsSoundGenerationRequest struct {
@ -12,3 +13,17 @@ type ElevenLabsSoundGenerationRequest struct {
Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"`
DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"`
}
func (elttsr *ElevenLabsTTSRequest) ModelName(s *string) string {
if s != nil {
elttsr.ModelID = *s
}
return elttsr.ModelID
}
func (elsgr *ElevenLabsSoundGenerationRequest) ModelName(s *string) string {
if s != nil {
elsgr.ModelID = *s
}
return elsgr.ModelID
}

View file

@ -2,10 +2,11 @@ package schema
// RerankRequest defines the structure of the request payload
type JINARerankRequest struct {
Model string `json:"model"`
BasicModelRequest
Query string `json:"query"`
Documents []string `json:"documents"`
TopN int `json:"top_n"`
Backend string `json:"backend"`
}
// DocumentResult represents a single document result

View file

@ -6,11 +6,11 @@ import (
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
BasicModelRequest
}
type TokenMetricsRequest struct {
Model string `json:"model" yaml:"model"`
BasicModelRequest
}
type BackendMonitorResponse struct {
@ -26,7 +26,7 @@ type GalleryResponse struct {
// @Description TTS request body
type TTSRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
BasicModelRequest
Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"`
@ -36,10 +36,19 @@ type TTSRequest struct {
// @Description VAD request body
type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
BasicModelRequest
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
}
type VADSegment struct {
Start float32 `json:"start" yaml:"start"`
End float32 `json:"end" yaml:"end"`
}
type VADResponse struct {
Segments []VADSegment `json:"segments" yaml:"segments"`
}
type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

View file

@ -3,7 +3,7 @@ package schema
type PredictionOptions struct {
// Also part of the OpenAI official spec
Model string `json:"model" yaml:"model"`
BasicModelRequest `yaml:",inline"`
// Also part of the OpenAI official spec
Language string `json:"language"`

22
core/schema/request.go Normal file
View file

@ -0,0 +1,22 @@
package schema
// This file and type represent a generic request to LocalAI - as opposed to requests to LocalAI-specific endpoints, which live in localai.go
type LocalAIRequest interface {
ModelName(*string) string
}
type BasicModelRequest struct {
Model string `json:"model" yaml:"model"`
// TODO: Should this also include the following fields from the OpenAI side of the world?
// If so, changes should be made to core/http/middleware/request.go to match
// Context context.Context `json:"-"`
// Cancel context.CancelFunc `json:"-"`
}
func (bmr *BasicModelRequest) ModelName(s *string) string {
if s != nil {
bmr.Model = *s
}
return bmr.Model
}

View file

@ -1,8 +1,8 @@
package schema
type TokenizeRequest struct {
BasicModelRequest
Content string `json:"content"`
Model string `json:"model"`
}
type TokenizeResponse struct {

View file

@ -2,7 +2,7 @@ package schema
import "time"
type Segment struct {
type TranscriptionSegment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
@ -11,6 +11,6 @@ type Segment struct {
}
type TranscriptionResult struct {
Segments []Segment `json:"segments"`
Segments []TranscriptionSegment `json:"segments"`
Text string `json:"text"`
}

View file

@ -49,3 +49,15 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c
return dataModels, nil
}
func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) {
filter, err := config.BuildNameFilterFn(modelName)
if err != nil {
return false, err
}
models, err := ListModels(bcl, ml, filter, looseFilePolicy)
if err != nil {
return false, err
}
return (len(models) > 0), nil
}

2
go.mod
View file

@ -239,7 +239,7 @@ require (
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc
github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect

2
go.sum
View file

@ -524,6 +524,8 @@ github.com/mudler/edgevpn v0.29.0 h1:SEkVyjXL6P8szUZFlL8W1EYBxvFsEIFvXlXcRfGrXYU
github.com/mudler/edgevpn v0.29.0/go.mod h1:+kSy9b44eo97PnJ3fOnTkcTgxNXdgJBcd2bopx4leto=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo=
github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA=
github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU=
github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU=
github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo=

View file

@ -460,7 +460,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error)
func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) {
o := NewOptions(opts...)
log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString)
log.Info().Str("modelID", o.modelID).Str("backend", o.backendString).Str("o.model", o.model).Msg("BackendLoader starting")
backend := strings.ToLower(o.backendString)
if realBackend, exists := Aliases[backend]; exists {

View file

@ -56,6 +56,14 @@ func WithBackendString(backend string) Option {
}
}
func WithDefaultBackendString(backend string) Option {
return func(o *Options) {
if o.backendString == "" {
o.backendString = backend
}
}
}
func WithModel(modelFile string) Option {
return func(o *Options) {
o.model = modelFile

View file

@ -43,7 +43,7 @@ var _ = BeforeSuite(func() {
apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
defaultConfig.BaseURL = apiEndpoint
} else {
fmt.Println("Default ", apiEndpoint)
GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = apiEndpoint
}
@ -95,10 +95,11 @@ func startDockerImage() {
PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: apiPort}},
},
Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc)},
Env: []string{"MODELS_PATH=/models", "DEBUG=true", "THREADS=" + fmt.Sprint(proc), "LOCALAI_SINGLE_ACTIVE_BACKEND=true"},
Mounts: []string{md + ":/models"},
}
GinkgoWriter.Printf("Launching Docker Container %q\n%+v\n", containerImageTag, options)
r, err := pool.RunWithOptions(options)
Expect(err).To(Not(HaveOccurred()))

View file

@ -121,14 +121,13 @@ var _ = Describe("E2E test", func() {
Context("images", func() {
It("correctly", func() {
resp, err := client.CreateImage(context.TODO(),
openai.ImageRequest{
req := openai.ImageRequest{
Prompt: "test",
Quality: "1",
Size: openai.CreateImageSize256x256,
},
)
Expect(err).ToNot(HaveOccurred())
Size: openai.CreateImageSize512x512,
}
resp, err := client.CreateImage(context.TODO(), req)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req))
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
})
@ -232,13 +231,42 @@ var _ = Describe("E2E test", func() {
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
})
})
Context("vad", func() {
It("correctly", func() {
modelName := "silero-vad"
req := schema.VADRequest{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Audio: SampleVADAudio, // Use hardcoded sample data for now.
}
serialized, err := json.Marshal(req)
Expect(err).To(BeNil())
Expect(serialized).ToNot(BeNil())
vadEndpoint := apiEndpoint + "/vad"
resp, err := http.Post(vadEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil())
body, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
deserializedResponse := schema.VADResponse{}
err = json.Unmarshal(body, &deserializedResponse)
Expect(err).To(BeNil())
Expect(deserializedResponse).ToNot(BeZero())
Expect(deserializedResponse.Segments).ToNot(BeZero())
})
})
Context("reranker", func() {
It("correctly", func() {
modelName := "jina-reranker-v1-base-en"
req := schema.JINARerankRequest{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Query: "Organic skincare products for sensitive skin",
Documents: []string{
"Eco-friendly kitchenware for modern homes",

File diff suppressed because it is too large Load diff