feat(model-list): be consistent, skip known files from listing (#2760)

fix(model-list): be consistent, skip known files from listing

This changeset does two things:

- Removes the dependency of listing models from the OpenAI schema.
- Tries to reduce confusion between ListModels() in model loader and in
  the service - now there is only one ListModels which is in services
and does not depend anymore on the OpenAI schema
- The OpenAI-schema functions were moved nearby the OpenAI specific
  endpoints that needs the schema
- Drops the ListModel Service structure as there was no real need for
  it.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-07-10 15:28:39 +02:00 committed by GitHub
parent 28c6daf916
commit 59ef426fbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 97 additions and 70 deletions

View file

@ -28,7 +28,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)

View file

@ -28,7 +28,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)

View file

@ -29,7 +29,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)

View file

@ -5,6 +5,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
@ -12,7 +13,7 @@ import (
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, _ := ml.ListModels()
models, _ := services.ListModels(cl, ml, "", true)
backendConfigs := cl.GetAllBackendConfigs()
galleryConfigs := map[string]*gallery.Config{}

View file

@ -11,6 +11,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
@ -79,7 +80,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
}
if !modelExists(ml, request.Model) {
if !modelExists(cl, ml, request.Model) {
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
}
@ -213,9 +214,9 @@ func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
return filteredAssistants
}
func modelExists(ml *model.ModelLoader, modelName string) (found bool) {
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
found = false
models, err := ml.ListModels()
models, err := services.ListModels(cl, ml, "", true)
if err != nil {
return
}

View file

@ -159,7 +159,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, startupOptions, true)
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View file

@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View file

@ -18,7 +18,7 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View file

@ -23,7 +23,7 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, ml, appConfig, true)
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View file

@ -66,7 +66,7 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View file

@ -2,15 +2,17 @@ package openai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
model "github.com/mudler/LocalAI/pkg/model"
)
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
// If blank, no filter is applied.
filter := c.Query("filter")
@ -18,7 +20,7 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
dataModels, err := lms.ListModels(filter, excludeConfigured)
dataModels, err := modelList(bcl, ml, filter, excludeConfigured)
if err != nil {
return err
}
@ -28,3 +30,20 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
})
}
}
func modelList(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := services.ListModels(bcl, ml, filter, excludeConfigured)
if err != nil {
return nil, err
}
dataModels := []schema.OpenAIModel{}
// Then iterate through the loose files:
for _, m := range models {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return dataModels, nil
}

View file

@ -15,7 +15,7 @@ import (
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
@ -31,7 +31,7 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}

View file

@ -25,7 +25,7 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}