mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-28 05:35:00 +00:00
Merge branch 'master' into default_miro
This commit is contained in:
commit
ccc82ceb7e
578 changed files with 15978 additions and 15598 deletions
|
@ -39,10 +39,10 @@ type ApplicationConfig struct {
|
|||
OpaqueErrors bool
|
||||
UseSubtleKeyComparison bool
|
||||
DisableApiKeyRequirementForHttpGet bool
|
||||
DisableMetrics bool
|
||||
HttpGetExemptedEndpoints []*regexp.Regexp
|
||||
DisableGalleryEndpoint bool
|
||||
|
||||
ModelLibraryURL string
|
||||
LoadToMemory []string
|
||||
|
||||
Galleries []Gallery
|
||||
|
||||
|
@ -63,6 +63,8 @@ type ApplicationConfig struct {
|
|||
ModelsURL []string
|
||||
|
||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||
|
||||
MachineTag string
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
|
@ -92,6 +94,12 @@ func WithModelPath(path string) AppOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithMachineTag(tag string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MachineTag = tag
|
||||
}
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.CORS = b
|
||||
|
@ -116,12 +124,6 @@ func WithP2PToken(s string) AppOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithModelLibraryURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ModelLibraryURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithLibPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.LibPath = path
|
||||
|
@ -331,6 +333,12 @@ func WithOpaqueErrors(opaque bool) AppOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithLoadToMemory(models []string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.LoadToMemory = models
|
||||
}
|
||||
}
|
||||
|
||||
func WithSubtleKeyComparison(subtle bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.UseSubtleKeyComparison = subtle
|
||||
|
@ -343,6 +351,10 @@ func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
|
|||
}
|
||||
}
|
||||
|
||||
var DisableMetricsEndpoint AppOption = func(o *ApplicationConfig) {
|
||||
o.DisableMetrics = true
|
||||
}
|
||||
|
||||
func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.HttpGetExemptedEndpoints = []*regexp.Regexp{}
|
||||
|
|
|
@ -3,11 +3,13 @@ package config
|
|||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -19,21 +21,22 @@ type TTSConfig struct {
|
|||
// Voice wav path or id
|
||||
Voice string `yaml:"voice"`
|
||||
|
||||
// Vall-e-x
|
||||
VallE VallE `yaml:"vall-e"`
|
||||
AudioPath string `yaml:"audio_path"`
|
||||
}
|
||||
|
||||
type BackendConfig struct {
|
||||
schema.PredictionOptions `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
|
||||
F16 *bool `yaml:"f16"`
|
||||
Threads *int `yaml:"threads"`
|
||||
Debug *bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
F16 *bool `yaml:"f16"`
|
||||
Threads *int `yaml:"threads"`
|
||||
Debug *bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases"`
|
||||
KnownUsecases *BackendConfigUsecases `yaml:"-"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
|
@ -68,6 +71,8 @@ type BackendConfig struct {
|
|||
|
||||
Description string `yaml:"description"`
|
||||
Usage string `yaml:"usage"`
|
||||
|
||||
Options []string `yaml:"options"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
@ -76,10 +81,6 @@ type File struct {
|
|||
URI downloader.URI `yaml:"uri" json:"uri"`
|
||||
}
|
||||
|
||||
type VallE struct {
|
||||
AudioPath string `yaml:"audio_path"`
|
||||
}
|
||||
|
||||
type FeatureFlag map[string]*bool
|
||||
|
||||
func (ff FeatureFlag) Enabled(s string) bool {
|
||||
|
@ -93,16 +94,15 @@ type GRPC struct {
|
|||
}
|
||||
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
}
|
||||
|
||||
// LLMConfig is a struct that holds the configuration that are
|
||||
|
@ -130,25 +130,30 @@ type LLMConfig struct {
|
|||
TrimSpace []string `yaml:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix"`
|
||||
|
||||
ContextSize *int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
ContextSize *int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
LoadFormat string `yaml:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling"`
|
||||
ModelType string `yaml:"type"`
|
||||
|
@ -157,6 +162,8 @@ type LLMConfig struct {
|
|||
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
||||
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
}
|
||||
|
||||
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
||||
|
@ -192,6 +199,21 @@ type TemplateConfig struct {
|
|||
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
||||
// It defaults to \n
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||
|
||||
Multimodal string `yaml:"multimodal"`
|
||||
|
||||
JinjaTemplate bool `yaml:"jinja_template"`
|
||||
}
|
||||
|
||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
type BCAlias BackendConfig
|
||||
var aux BCAlias
|
||||
if err := value.Decode(&aux); err != nil {
|
||||
return err
|
||||
}
|
||||
*c = BackendConfig(aux)
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackendConfig) SetFunctionCallString(s string) {
|
||||
|
@ -411,3 +433,121 @@ func (c *BackendConfig) Validate() bool {
|
|||
func (c *BackendConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
|
||||
}
|
||||
|
||||
type BackendConfigUsecases int
|
||||
|
||||
const (
|
||||
FLAG_ANY BackendConfigUsecases = 0b000000000
|
||||
FLAG_CHAT BackendConfigUsecases = 0b000000001
|
||||
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
|
||||
FLAG_EDIT BackendConfigUsecases = 0b000000100
|
||||
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
|
||||
FLAG_RERANK BackendConfigUsecases = 0b000010000
|
||||
FLAG_IMAGE BackendConfigUsecases = 0b000100000
|
||||
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
|
||||
FLAG_TTS BackendConfigUsecases = 0b010000000
|
||||
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
|
||||
)
|
||||
|
||||
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
||||
return map[string]BackendConfigUsecases{
|
||||
"FLAG_ANY": FLAG_ANY,
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
|
||||
"FLAG_RERANK": FLAG_RERANK,
|
||||
"FLAG_IMAGE": FLAG_IMAGE,
|
||||
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
|
||||
"FLAG_TTS": FLAG_TTS,
|
||||
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
|
||||
"FLAG_LLM": FLAG_LLM,
|
||||
}
|
||||
}
|
||||
|
||||
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := FLAG_ANY
|
||||
flags := GetAllBackendConfigUsecases()
|
||||
for _, str := range input {
|
||||
flag, exists := flags["FLAG_"+strings.ToUpper(str)]
|
||||
if exists {
|
||||
result |= flag
|
||||
}
|
||||
}
|
||||
return &result
|
||||
}
|
||||
|
||||
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
|
||||
func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
|
||||
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
|
||||
return true
|
||||
}
|
||||
return c.GuessUsecases(u)
|
||||
}
|
||||
|
||||
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
|
||||
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
|
||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||
func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
|
||||
if c.TemplateConfig.Completion == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_EDIT) == FLAG_EDIT {
|
||||
if c.TemplateConfig.Edit == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
|
||||
if c.Embeddings == nil || !*c.Embeddings {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_IMAGE) == FLAG_IMAGE {
|
||||
imageBackends := []string{"diffusers", "stablediffusion", "stablediffusion-ggml"}
|
||||
if !slices.Contains(imageBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
if (u & FLAG_RERANK) == FLAG_RERANK {
|
||||
if c.Backend != "rerankers" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
|
||||
if c.Backend != "whisper" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if (u & FLAG_TTS) == FLAG_TTS {
|
||||
ttsBackends := []string{"piper", "transformers-musicgen", "parler-tts"}
|
||||
if !slices.Contains(ttsBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
|
||||
if c.Backend != "transformers-musicgen" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
35
core/config/backend_config_filter.go
Normal file
35
core/config/backend_config_filter.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package config
|
||||
|
||||
import "regexp"
|
||||
|
||||
type BackendConfigFilterFn func(string, *BackendConfig) bool
|
||||
|
||||
func NoFilterFn(_ string, _ *BackendConfig) bool { return true }
|
||||
|
||||
func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
|
||||
if filter == "" {
|
||||
return NoFilterFn, nil
|
||||
}
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(name string, config *BackendConfig) bool {
|
||||
if config != nil {
|
||||
return rxp.MatchString(config.Name)
|
||||
}
|
||||
return rxp.MatchString(name)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn {
|
||||
if usecases == FLAG_ANY {
|
||||
return NoFilterFn
|
||||
}
|
||||
return func(name string, config *BackendConfig) bool {
|
||||
if config == nil {
|
||||
return false // TODO: Potentially make this a param, for now, no known usecase to include
|
||||
}
|
||||
return config.HasUsecases(usecases)
|
||||
}
|
||||
}
|
|
@ -140,7 +140,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
|
|||
}
|
||||
}
|
||||
|
||||
cfg.SetDefaults(opts...)
|
||||
cfg.SetDefaults(append(opts, ModelPath(modelPath))...)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
@ -201,6 +201,26 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
|||
return res
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
var res []BackendConfig
|
||||
|
||||
if filter == nil {
|
||||
filter = NoFilterFn
|
||||
}
|
||||
|
||||
for n, v := range bcl.configs {
|
||||
if filter(n, &v) {
|
||||
res = append(res, v)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: I don't think this one needs to Sort on name... but we'll see what breaks.
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
|
|
|
@ -19,12 +19,17 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||
`backend: "../foo-bar"
|
||||
name: "foo"
|
||||
parameters:
|
||||
model: "foo-bar"`)
|
||||
model: "foo-bar"
|
||||
known_usecases:
|
||||
- chat
|
||||
- COMPLETION
|
||||
`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := readBackendConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
Expect(config.Validate()).To(BeFalse())
|
||||
Expect(config.KnownUsecases).ToNot(BeNil())
|
||||
})
|
||||
It("Test Validate", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
@ -43,9 +48,9 @@ parameters:
|
|||
Expect(config.Name).To(Equal("bar-baz"))
|
||||
Expect(config.Validate()).To(BeTrue())
|
||||
|
||||
// download https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/models/hermes-2-pro-mistral.yaml
|
||||
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
|
||||
httpClient := http.Client{}
|
||||
resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/models/hermes-2-pro-mistral.yaml")
|
||||
resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml")
|
||||
Expect(err).To(BeNil())
|
||||
defer resp.Body.Close()
|
||||
tmp, err = os.CreateTemp("", "config.yaml")
|
||||
|
@ -61,4 +66,99 @@ parameters:
|
|||
Expect(config.Validate()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
It("Properly handles backend usecase matching", func() {
|
||||
|
||||
a := BackendConfig{
|
||||
Name: "a",
|
||||
}
|
||||
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
|
||||
|
||||
b := BackendConfig{
|
||||
Name: "b",
|
||||
Backend: "stablediffusion",
|
||||
}
|
||||
Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
|
||||
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
|
||||
c := BackendConfig{
|
||||
Name: "c",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "chat",
|
||||
},
|
||||
}
|
||||
Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
||||
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
||||
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
d := BackendConfig{
|
||||
Name: "d",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "chat",
|
||||
Completion: "completion",
|
||||
},
|
||||
}
|
||||
Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
||||
Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
||||
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
trueValue := true
|
||||
e := BackendConfig{
|
||||
Name: "e",
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: TemplateConfig{
|
||||
Completion: "completion",
|
||||
},
|
||||
Embeddings: &trueValue,
|
||||
}
|
||||
|
||||
Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse())
|
||||
Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
||||
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
|
||||
|
||||
f := BackendConfig{
|
||||
Name: "f",
|
||||
Backend: "piper",
|
||||
}
|
||||
Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
|
||||
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
|
||||
g := BackendConfig{
|
||||
Name: "g",
|
||||
Backend: "whisper",
|
||||
}
|
||||
Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
|
||||
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
|
||||
|
||||
h := BackendConfig{
|
||||
Name: "h",
|
||||
Backend: "transformers-musicgen",
|
||||
}
|
||||
Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse())
|
||||
Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue())
|
||||
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
|
||||
|
||||
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
|
||||
i := BackendConfig{
|
||||
Name: "i",
|
||||
Backend: "whisper",
|
||||
// Earlier test checks parsing, this just needs to set final values
|
||||
KnownUsecases: &knownUsecases,
|
||||
}
|
||||
Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue())
|
||||
Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
|
||||
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
|
||||
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
|
||||
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
|
||||
|
||||
})
|
||||
})
|
||||
|
|
|
@ -48,5 +48,66 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||
// config should includes whisper-1 models's api.config
|
||||
Expect(loadedModelNames).To(ContainElements("whisper-1"))
|
||||
})
|
||||
|
||||
It("Test new loadconfig", func() {
|
||||
|
||||
bcl := NewBackendConfigLoader(os.Getenv("MODELS_PATH"))
|
||||
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
Expect(err).To(BeNil())
|
||||
configs := bcl.GetAllBackendConfigs()
|
||||
loadedModelNames := []string{}
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
}
|
||||
Expect(configs).ToNot(BeNil())
|
||||
totalModels := len(loadedModelNames)
|
||||
|
||||
Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001"))
|
||||
|
||||
// config should includes text-embedding-ada-002 models's api.config
|
||||
Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))
|
||||
|
||||
// config should includes rwkv_test models's api.config
|
||||
Expect(loadedModelNames).To(ContainElements("rwkv_test"))
|
||||
|
||||
// config should includes whisper-1 models's api.config
|
||||
Expect(loadedModelNames).To(ContainElements("whisper-1"))
|
||||
|
||||
// create a temp directory and store a temporary model
|
||||
tmpdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
// create a temporary model
|
||||
model := `name: "test-model"
|
||||
description: "test model"
|
||||
options:
|
||||
- foo
|
||||
- bar
|
||||
- baz
|
||||
`
|
||||
modelFile := tmpdir + "/test-model.yaml"
|
||||
err = os.WriteFile(modelFile, []byte(model), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = bcl.LoadBackendConfigsFromPath(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
configs = bcl.GetAllBackendConfigs()
|
||||
Expect(len(configs)).ToNot(Equal(totalModels))
|
||||
|
||||
loadedModelNames = []string{}
|
||||
var testModel BackendConfig
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
if v.Name == "test-model" {
|
||||
testModel = v
|
||||
}
|
||||
}
|
||||
Expect(loadedModelNames).To(ContainElements("test-model"))
|
||||
Expect(testModel.Description).To(Equal("test model"))
|
||||
Expect(testModel.Options).To(ContainElements("foo", "bar", "baz"))
|
||||
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -26,14 +26,14 @@ const (
|
|||
type settingsConfig struct {
|
||||
StopWords []string
|
||||
TemplateConfig TemplateConfig
|
||||
RepeatPenalty float64
|
||||
RepeatPenalty float64
|
||||
}
|
||||
|
||||
// default settings to adopt with a given model family
|
||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||
Gemma: {
|
||||
RepeatPenalty: 1.0,
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||
|
@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
|
|||
} else {
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||
}
|
||||
|
||||
if cfg.HasTemplate() {
|
||||
return
|
||||
}
|
||||
|
||||
// identify from well known templates first, otherwise use the raw jinja template
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found {
|
||||
// try to use the jinja template
|
||||
cfg.TemplateConfig.JinjaTemplate = true
|
||||
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||
}
|
||||
}
|
||||
|
||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue