groundwork: ListModels Filtering Upgrade (#2773)

* seperate the filtering from the middleware changes

---------

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-10-01 14:55:46 -04:00 committed by GitHub
parent f84b55d1ef
commit 307a835199
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 387 additions and 81 deletions

View file

@ -3,11 +3,13 @@ package config
import (
"os"
"regexp"
"slices"
"strings"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"gopkg.in/yaml.v3"
)
const (
@ -27,13 +29,15 @@ type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *BackendConfigUsecases `yaml:"-"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
@ -194,6 +198,17 @@ type TemplateConfig struct {
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
}
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias BackendConfig
var aux BCAlias
if err := value.Decode(&aux); err != nil {
return err
}
*c = BackendConfig(aux)
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
return nil
}
func (c *BackendConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
@ -410,3 +425,121 @@ func (c *BackendConfig) Validate() bool {
func (c *BackendConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
}
type BackendConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
)
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
return map[string]BackendConfigUsecases{
"FLAG_ANY": FLAG_ANY,
"FLAG_CHAT": FLAG_CHAT,
"FLAG_COMPLETION": FLAG_COMPLETION,
"FLAG_EDIT": FLAG_EDIT,
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
"FLAG_RERANK": FLAG_RERANK,
"FLAG_IMAGE": FLAG_IMAGE,
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_LLM": FLAG_LLM,
}
}
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllBackendConfigUsecases()
for _, str := range input {
flag, exists := flags["FLAG_"+strings.ToUpper(str)]
if exists {
result |= flag
}
}
return &result
}
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
return true
}
return c.GuessUsecases(u)
}
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
return false
}
}
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" {
return false
}
}
if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" {
return false
}
}
if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
if c.Embeddings == nil || !*c.Embeddings {
return false
}
}
if (u & FLAG_IMAGE) == FLAG_IMAGE {
imageBackends := []string{"diffusers", "tinydream", "stablediffusion"}
if !slices.Contains(imageBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" {
return false
}
}
if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
if c.Backend != "whisper" {
return false
}
}
if (u & FLAG_TTS) == FLAG_TTS {
ttsBackends := []string{"piper", "transformers-musicgen", "parler-tts"}
if !slices.Contains(ttsBackends, c.Backend) {
return false
}
}
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
if c.Backend != "transformers-musicgen" {
return false
}
}
return true
}

View file

@ -0,0 +1,35 @@
package config
import "regexp"
type BackendConfigFilterFn func(string, *BackendConfig) bool
func NoFilterFn(_ string, _ *BackendConfig) bool { return true }
func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
if filter == "" {
return NoFilterFn, nil
}
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
return func(name string, config *BackendConfig) bool {
if config != nil {
return rxp.MatchString(config.Name)
}
return rxp.MatchString(name)
}, nil
}
func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
return func(name string, config *BackendConfig) bool {
if config == nil {
return false // TODO: Potentially make this a param, for now, no known usecase to include
}
return config.HasUsecases(usecases)
}
}

View file

@ -201,6 +201,26 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
return res
}
func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
if filter == nil {
filter = NoFilterFn
}
for n, v := range bcl.configs {
if filter(n, &v) {
res = append(res, v)
}
}
// TODO: I don't think this one needs to Sort on name... but we'll see what breaks.
return res
}
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
bcl.Lock()
defer bcl.Unlock()

View file

@ -19,12 +19,17 @@ var _ = Describe("Test cases for config related functions", func() {
`backend: "../foo-bar"
name: "foo"
parameters:
model: "foo-bar"`)
model: "foo-bar"
known_usecases:
- chat
- COMPLETION
`)
Expect(err).ToNot(HaveOccurred())
config, err := readBackendConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
Expect(config.Validate()).To(BeFalse())
Expect(config.KnownUsecases).ToNot(BeNil())
})
It("Test Validate", func() {
tmp, err := os.CreateTemp("", "config.yaml")
@ -61,4 +66,99 @@ parameters:
Expect(config.Validate()).To(BeTrue())
})
})
It("Properly handles backend usecase matching", func() {
a := BackendConfig{
Name: "a",
}
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
b := BackendConfig{
Name: "b",
Backend: "stablediffusion",
}
Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
c := BackendConfig{
Name: "c",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Chat: "chat",
},
}
Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
d := BackendConfig{
Name: "d",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Chat: "chat",
Completion: "completion",
},
}
Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
trueValue := true
e := BackendConfig{
Name: "e",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Completion: "completion",
},
Embeddings: &trueValue,
}
Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
f := BackendConfig{
Name: "f",
Backend: "piper",
}
Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
g := BackendConfig{
Name: "g",
Backend: "whisper",
}
Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
h := BackendConfig{
Name: "h",
Backend: "transformers-musicgen",
}
Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse())
Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue())
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
i := BackendConfig{
Name: "i",
Backend: "whisper",
// Earlier test checks parsing, this just needs to set final values
KnownUsecases: &knownUsecases,
}
Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
})
})