diff --git a/apiv2/config.go b/apiv2/config.go index f2892e5d..e40957ec 100644 --- a/apiv2/config.go +++ b/apiv2/config.go @@ -17,14 +17,16 @@ type ConfigRegistration struct { Model string `yaml:"model" json:"model" mapstructure:"model"` } -type ConfigLocalPaths struct { - Model string `yaml:"model" mapstructure:"model"` - Template string `yaml:"template" mapstructure:"template"` +type ConfigLocalSettings struct { + ModelPath string `yaml:"model" mapstructure:"model"` + TemplatePath string `yaml:"template" mapstructure:"template"` + Backend string `yaml:"backend" mapstructure:"backend"` + Threads int `yaml:"threads" mapstructure:"threads"` } type ConfigStub struct { - Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"` - LocalPaths ConfigLocalPaths `yaml:"local_paths" mapstructure:"local_paths"` + Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"` + LocalSettings ConfigLocalSettings `yaml:"local_paths" mapstructure:"local_paths"` } type SpecificConfig[RequestModel any] struct { @@ -34,7 +36,7 @@ type SpecificConfig[RequestModel any] struct { type Config interface { GetRequestDefaults() interface{} - GetLocalPaths() ConfigLocalPaths + GetLocalSettings() ConfigLocalSettings GetRegistration() ConfigRegistration } @@ -42,8 +44,8 @@ func (cs ConfigStub) GetRequestDefaults() interface{} { return nil } -func (cs ConfigStub) GetLocalPaths() ConfigLocalPaths { - return cs.LocalPaths +func (cs ConfigStub) GetLocalSettings() ConfigLocalSettings { + return cs.LocalSettings } func (cs ConfigStub) GetRegistration() ConfigRegistration { @@ -58,8 +60,8 @@ func (sc SpecificConfig[RequestModel]) GetRequest() RequestModel { return sc.RequestDefaults } -func (sc SpecificConfig[RequestModel]) GetLocalPaths() ConfigLocalPaths { - return sc.LocalPaths +func (sc SpecificConfig[RequestModel]) GetLocalSettings() ConfigLocalSettings { + return sc.LocalSettings } func (sc SpecificConfig[RequestModel]) GetRegistration() ConfigRegistration { @@ -228,9 +230,14 @@ func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption { } func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption { - llamaOpts := []llama.PredictOption{} + llamaOpts := []llama.PredictOption{ + llama.SetThreads(sc.GetLocalSettings().Threads), + } switch req := sc.GetRequestDefaults().(type) { + // TODO Refactor this when we get to p2 and add image / audio + // I expect that it'll be worth pulling out the base case first, and doing fancy fallthrough things. + // Text Requests: case CreateCompletionRequest: case CreateChatCompletionRequest: @@ -306,10 +313,7 @@ func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption // CODE TO PORT - // predictOptions := []llama.PredictOption{ - - // llama.SetThreads(c.Threads), - // } + // SKIPPING PROMPT CACHE FOR PASS ONE, TODO READ ABOUT IT // if c.PromptCacheAll { // predictOptions = append(predictOptions, llama.EnablePromptCacheAll) diff --git a/apiv2/localai.go b/apiv2/localai.go index 4a29d0a5..29c9131e 100644 --- a/apiv2/localai.go +++ b/apiv2/localai.go @@ -56,8 +56,8 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode return &SpecificConfig[RequestType]{ ConfigStub: ConfigStub{ - Registration: config.GetRegistration(), - LocalPaths: config.GetLocalPaths(), + Registration: config.GetRegistration(), + LocalSettings: config.GetLocalSettings(), }, RequestDefaults: request, }, nil diff --git a/main.go b/main.go index f5a46534..15356b09 100644 --- a/main.go +++ b/main.go @@ -155,7 +155,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. testField, exists := v2ConfigManager.GetConfig(reg) if exists { - log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalPaths().Model) + log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalSettings().ModelPath) } } @@ -164,7 +164,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. log.Log().Msgf("NEW v2 test: %+v", v2Server) } - + app, err := api.App( api.WithConfigFile(ctx.String("config-file")), api.WithJSONStringPreload(ctx.String("preload-models")),