diff --git a/core/application.go b/core/application.go index 78a7af9e..e4efbdd0 100644 --- a/core/application.go +++ b/core/application.go @@ -28,7 +28,6 @@ type Application struct { // LocalAI System Services BackendMonitorService *services.BackendMonitorService GalleryService *services.GalleryService - ListModelsService *services.ListModelsService LocalAIMetricsService *services.LocalAIMetricsService // OpenAIService *services.OpenAIService } diff --git a/core/http/ctx/fiber.go b/core/http/ctx/fiber.go index d298b290..94059847 100644 --- a/core/http/ctx/fiber.go +++ b/core/http/ctx/fiber.go @@ -5,6 +5,8 @@ import ( "strings" "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) @@ -13,7 +15,7 @@ import ( // If no model is specified, it will take the first available // Takes a model string as input which should be the one received from the user request. // It returns the model name resolved from the context and an error if any. -func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { +func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { if ctx.Params("model") != "" { modelInput = ctx.Params("model") } @@ -24,7 +26,7 @@ func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput stri // If no model was specified, take the first available if modelInput == "" && !bearerExists && firstModel { - models, _ := loader.ListModels() + models, _ := services.ListModels(cl, loader, "", true) if len(models) > 0 { modelInput = models[0] log.Debug().Msgf("No model specified, using: %s", modelInput) diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 12da7b9b..bb6901be 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -28,7 +28,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi return err } - modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false) + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) if err != nil { modelFile = input.ModelID log.Warn().Msgf("Model not found in context: %s", input.ModelID) diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 383dcc5e..ddeee745 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -28,7 +28,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a return err } - modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 3ae2eea5..ca3f58bd 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -29,7 +29,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi return err } - modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index 34a2d975..b9c7a573 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -5,6 +5,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" ) @@ -12,7 +13,7 @@ import ( func WelcomeEndpoint(appConfig *config.ApplicationConfig, cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, _ := ml.ListModels() + models, _ := services.ListModels(cl, ml, "", true) backendConfigs := cl.GetAllBackendConfigs() galleryConfigs := map[string]*gallery.Config{} diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index 4882eeaf..ba2ebcde 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -11,6 +11,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" @@ -79,7 +80,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) } - if !modelExists(ml, request.Model) { + if !modelExists(cl, ml, request.Model) { log.Warn().Msgf("Model: %s was not found in list of models.", request.Model) return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found") } @@ -213,9 +214,9 @@ func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant { return filteredAssistants } -func modelExists(ml *model.ModelLoader, modelName string) (found bool) { +func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) { found = false - models, err := ml.ListModels() + models, err := services.ListModels(cl, ml, "", true) if err != nil { return } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 1317ee07..763e3f69 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -159,7 +159,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, startupOptions, true) + modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 5eedfaf3..b087cc5f 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a } return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, appConfig, true) + modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index a5af12c2..bb43ac3b 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -18,7 +18,7 @@ import ( func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, appConfig, true) + modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index de7ea1c6..e247d84e 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -23,7 +23,7 @@ import ( // @Router /v1/embeddings [post] func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readRequest(c, ml, appConfig, true) + model, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 27c11f53..6c76ba84 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -66,7 +66,7 @@ func downloadFile(url string) (string, error) { // @Router /v1/images/generations [post] func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, appConfig, false) + m, input, err := readRequest(c, cl, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index ba6bd1d7..d446b100 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -2,15 +2,17 @@ package openai import ( "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" + model "github.com/mudler/LocalAI/pkg/model" ) // ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] -func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { // If blank, no filter is applied. filter := c.Query("filter") @@ -18,7 +20,7 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er // By default, exclude any loose files that are already referenced by a configuration file. excludeConfigured := c.QueryBool("excludeConfigured", true) - dataModels, err := lms.ListModels(filter, excludeConfigured) + dataModels, err := modelList(bcl, ml, filter, excludeConfigured) if err != nil { return err } @@ -28,3 +30,20 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er }) } } + +func modelList(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { + + models, err := services.ListModels(bcl, ml, filter, excludeConfigured) + if err != nil { + return nil, err + } + + dataModels := []schema.OpenAIModel{} + + // Then iterate through the loose files: + for _, m := range models { + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + } + + return dataModels, nil +} diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 009de4a0..a99ebea2 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -15,7 +15,7 @@ import ( "github.com/rs/zerolog/log" ) -func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { +func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { input := new(schema.OpenAIRequest) // Get input data from the request body @@ -31,7 +31,7 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi log.Debug().Msgf("Request received: %s", string(received)) - modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel) + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) return modelFile, input, err } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index c8e447f7..4e23f804 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -25,7 +25,7 @@ import ( // @Router /v1/audio/transcriptions [post] func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, appConfig, false) + m, input, err := readRequest(c, cl, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index cb454f33..e190bc6d 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -5,7 +5,6 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" - "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/model" ) @@ -81,8 +80,7 @@ func RegisterOpenAIRoutes(app *fiber.App, app.Static("/generated-audio", appConfig.AudioDir) } - // models - tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance. - app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS)) - app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS)) + // List models + app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index 51742b81..33706944 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -27,7 +27,6 @@ func RegisterUIRoutes(app *fiber.App, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, auth func(*fiber.Ctx) error) { - tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance. // keeps the state of models that are being installed from the UI var processingModels = xsync.NewSyncedMap[string, string]() @@ -270,7 +269,7 @@ func RegisterUIRoutes(app *fiber.App, // Show the Chat page app.Get("/chat/:model", auth, func(c *fiber.Ctx) error { - backendConfigs, _ := tmpLMS.ListModels("", true) + backendConfigs, _ := services.ListModels(cl, ml, "", true) summary := fiber.Map{ "Title": "LocalAI - Chat with " + c.Params("model"), @@ -285,7 +284,7 @@ func RegisterUIRoutes(app *fiber.App, }) app.Get("/talk/", auth, func(c *fiber.Ctx) error { - backendConfigs, _ := tmpLMS.ListModels("", true) + backendConfigs, _ := services.ListModels(cl, ml, "", true) if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models @@ -295,7 +294,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Talk", "ModelsConfig": backendConfigs, - "Model": backendConfigs[0].ID, + "Model": backendConfigs[0], "IsP2PEnabled": p2p.IsP2PEnabled(), "Version": internal.PrintableVersion(), } @@ -306,7 +305,7 @@ func RegisterUIRoutes(app *fiber.App, app.Get("/chat/", auth, func(c *fiber.Ctx) error { - backendConfigs, _ := tmpLMS.ListModels("", true) + backendConfigs, _ := services.ListModels(cl, ml, "", true) if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models @@ -314,9 +313,9 @@ func RegisterUIRoutes(app *fiber.App, } summary := fiber.Map{ - "Title": "LocalAI - Chat with " + backendConfigs[0].ID, + "Title": "LocalAI - Chat with " + backendConfigs[0], "ModelsConfig": backendConfigs, - "Model": backendConfigs[0].ID, + "Model": backendConfigs[0], "Version": internal.PrintableVersion(), "IsP2PEnabled": p2p.IsP2PEnabled(), } diff --git a/core/http/views/chat.html b/core/http/views/chat.html index 79c39570..67d40bfd 100644 --- a/core/http/views/chat.html +++ b/core/http/views/chat.html @@ -100,10 +100,10 @@ SOFTWARE. {{ $model:=.Model}} {{ range .ModelsConfig }} - {{ if eq .ID $model }} - + {{ if eq . $model }} + {{ else }} - + {{ end }} {{ end }} diff --git a/core/http/views/talk.html b/core/http/views/talk.html index afb494e9..dc25d125 100644 --- a/core/http/views/talk.html +++ b/core/http/views/talk.html @@ -62,7 +62,7 @@ {{ range .ModelsConfig }} - + {{ end }} @@ -76,7 +76,7 @@ {{ range .ModelsConfig }} - + {{ end }} @@ -89,7 +89,7 @@ > {{ range .ModelsConfig }} - + {{ end }} diff --git a/core/services/list_models.go b/core/services/list_models.go index 82503252..4b578e25 100644 --- a/core/services/list_models.go +++ b/core/services/list_models.go @@ -4,34 +4,19 @@ import ( "regexp" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" ) -type ListModelsService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig -} +func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]string, error) { -func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService { - return &ListModelsService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - } -} - -func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { - - models, err := lms.ml.ListModels() + models, err := ml.ListFilesInModelPath() if err != nil { return nil, err } var mm map[string]interface{} = map[string]interface{}{} - dataModels := []schema.OpenAIModel{} + dataModels := []string{} var filterFn func(name string) bool @@ -50,13 +35,13 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) } // Start with the known configurations - for _, c := range lms.bcl.GetAllBackendConfigs() { + for _, c := range bcl.GetAllBackendConfigs() { if excludeConfigured { mm[c.Model] = nil } if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + dataModels = append(dataModels, c.Name) } } @@ -64,7 +49,7 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) for _, m := range models { // And only adds them if they shouldn't be skipped. if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + dataModels = append(dataModels, m) } } diff --git a/core/startup/startup.go b/core/startup/startup.go index 66111b59..55f930a4 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -195,7 +195,6 @@ func createApplication(appConfig *config.ApplicationConfig) *core.Application { app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) app.GalleryService = services.NewGalleryService(app.ApplicationConfig) - app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) // app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() diff --git a/pkg/model/loader.go b/pkg/model/loader.go index faaacdd4..6acc19f6 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -30,7 +30,6 @@ type PromptTemplateData struct { MessageIndex int } -// TODO: Ask mudler about FunctionCall stuff being useful at the message level? type ChatMessageTemplateData struct { SystemPrompt string Role string @@ -87,22 +86,47 @@ func (ml *ModelLoader) ExistsInModelPath(s string) bool { return utils.ExistsInPath(ml.ModelPath, s) } -func (ml *ModelLoader) ListModels() ([]string, error) { +var knownFilesToSkip []string = []string{ + "MODEL_CARD", + "README", + "README.md", +} + +var knownModelsNameSuffixToSkip []string = []string{ + ".tmpl", + ".keep", + ".yaml", + ".yml", + ".json", + ".DS_Store", + ".", +} + +func (ml *ModelLoader) ListFilesInModelPath() ([]string, error) { files, err := os.ReadDir(ml.ModelPath) if err != nil { return []string{}, err } models := []string{} +FILE: for _, file := range files { - // Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method? - if strings.HasSuffix(file.Name(), ".tmpl") || - strings.HasSuffix(file.Name(), ".keep") || - strings.HasSuffix(file.Name(), ".yaml") || - strings.HasSuffix(file.Name(), ".yml") || - strings.HasSuffix(file.Name(), ".json") || - strings.HasSuffix(file.Name(), ".DS_Store") || - strings.HasPrefix(file.Name(), ".") { + + for _, skip := range knownFilesToSkip { + if strings.EqualFold(file.Name(), skip) { + continue FILE + } + } + + // Skip templates, YAML, .keep, .json, and .DS_Store files + for _, skip := range knownModelsNameSuffixToSkip { + if strings.HasSuffix(file.Name(), skip) { + continue FILE + } + } + + // Skip directories + if file.IsDir() { continue }