local settings refactor for threads and backend

This commit is contained in:
Dave Lee 2023-06-06 18:17:46 -04:00
parent 6e3cbe3be8
commit 8fc4b6cded
3 changed files with 23 additions and 19 deletions

View file

@ -17,14 +17,16 @@ type ConfigRegistration struct {
Model string `yaml:"model" json:"model" mapstructure:"model"` Model string `yaml:"model" json:"model" mapstructure:"model"`
} }
type ConfigLocalPaths struct { type ConfigLocalSettings struct {
Model string `yaml:"model" mapstructure:"model"` ModelPath string `yaml:"model" mapstructure:"model"`
Template string `yaml:"template" mapstructure:"template"` TemplatePath string `yaml:"template" mapstructure:"template"`
Backend string `yaml:"backend" mapstructure:"backend"`
Threads int `yaml:"threads" mapstructure:"threads"`
} }
type ConfigStub struct { type ConfigStub struct {
Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"` Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"`
LocalPaths ConfigLocalPaths `yaml:"local_paths" mapstructure:"local_paths"` LocalSettings ConfigLocalSettings `yaml:"local_paths" mapstructure:"local_paths"`
} }
type SpecificConfig[RequestModel any] struct { type SpecificConfig[RequestModel any] struct {
@ -34,7 +36,7 @@ type SpecificConfig[RequestModel any] struct {
type Config interface { type Config interface {
GetRequestDefaults() interface{} GetRequestDefaults() interface{}
GetLocalPaths() ConfigLocalPaths GetLocalSettings() ConfigLocalSettings
GetRegistration() ConfigRegistration GetRegistration() ConfigRegistration
} }
@ -42,8 +44,8 @@ func (cs ConfigStub) GetRequestDefaults() interface{} {
return nil return nil
} }
func (cs ConfigStub) GetLocalPaths() ConfigLocalPaths { func (cs ConfigStub) GetLocalSettings() ConfigLocalSettings {
return cs.LocalPaths return cs.LocalSettings
} }
func (cs ConfigStub) GetRegistration() ConfigRegistration { func (cs ConfigStub) GetRegistration() ConfigRegistration {
@ -58,8 +60,8 @@ func (sc SpecificConfig[RequestModel]) GetRequest() RequestModel {
return sc.RequestDefaults return sc.RequestDefaults
} }
func (sc SpecificConfig[RequestModel]) GetLocalPaths() ConfigLocalPaths { func (sc SpecificConfig[RequestModel]) GetLocalSettings() ConfigLocalSettings {
return sc.LocalPaths return sc.LocalSettings
} }
func (sc SpecificConfig[RequestModel]) GetRegistration() ConfigRegistration { func (sc SpecificConfig[RequestModel]) GetRegistration() ConfigRegistration {
@ -228,9 +230,14 @@ func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption {
} }
func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption { func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption {
llamaOpts := []llama.PredictOption{} llamaOpts := []llama.PredictOption{
llama.SetThreads(sc.GetLocalSettings().Threads),
}
switch req := sc.GetRequestDefaults().(type) { switch req := sc.GetRequestDefaults().(type) {
// TODO Refactor this when we get to p2 and add image / audio
// I expect that it'll be worth pulling out the base case first, and doing fancy fallthrough things.
// Text Requests:
case CreateCompletionRequest: case CreateCompletionRequest:
case CreateChatCompletionRequest: case CreateChatCompletionRequest:
@ -306,10 +313,7 @@ func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption
// CODE TO PORT // CODE TO PORT
// predictOptions := []llama.PredictOption{ // SKIPPING PROMPT CACHE FOR PASS ONE, TODO READ ABOUT IT
// llama.SetThreads(c.Threads),
// }
// if c.PromptCacheAll { // if c.PromptCacheAll {
// predictOptions = append(predictOptions, llama.EnablePromptCacheAll) // predictOptions = append(predictOptions, llama.EnablePromptCacheAll)

View file

@ -56,8 +56,8 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode
return &SpecificConfig[RequestType]{ return &SpecificConfig[RequestType]{
ConfigStub: ConfigStub{ ConfigStub: ConfigStub{
Registration: config.GetRegistration(), Registration: config.GetRegistration(),
LocalPaths: config.GetLocalPaths(), LocalSettings: config.GetLocalSettings(),
}, },
RequestDefaults: request, RequestDefaults: request,
}, nil }, nil

View file

@ -155,7 +155,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
testField, exists := v2ConfigManager.GetConfig(reg) testField, exists := v2ConfigManager.GetConfig(reg)
if exists { if exists {
log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalPaths().Model) log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalSettings().ModelPath)
} }
} }