local settings refactor for threads and backend

This commit is contained in:
Dave Lee 2023-06-06 18:17:46 -04:00
parent 6e3cbe3be8
commit 8fc4b6cded
3 changed files with 23 additions and 19 deletions

View file

@ -17,14 +17,16 @@ type ConfigRegistration struct {
Model string `yaml:"model" json:"model" mapstructure:"model"`
}
type ConfigLocalPaths struct {
Model string `yaml:"model" mapstructure:"model"`
Template string `yaml:"template" mapstructure:"template"`
type ConfigLocalSettings struct {
ModelPath string `yaml:"model" mapstructure:"model"`
TemplatePath string `yaml:"template" mapstructure:"template"`
Backend string `yaml:"backend" mapstructure:"backend"`
Threads int `yaml:"threads" mapstructure:"threads"`
}
type ConfigStub struct {
Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"`
LocalPaths ConfigLocalPaths `yaml:"local_paths" mapstructure:"local_paths"`
Registration ConfigRegistration `yaml:"registration" mapstructure:"registration"`
LocalSettings ConfigLocalSettings `yaml:"local_paths" mapstructure:"local_paths"`
}
type SpecificConfig[RequestModel any] struct {
@ -34,7 +36,7 @@ type SpecificConfig[RequestModel any] struct {
type Config interface {
GetRequestDefaults() interface{}
GetLocalPaths() ConfigLocalPaths
GetLocalSettings() ConfigLocalSettings
GetRegistration() ConfigRegistration
}
@ -42,8 +44,8 @@ func (cs ConfigStub) GetRequestDefaults() interface{} {
return nil
}
func (cs ConfigStub) GetLocalPaths() ConfigLocalPaths {
return cs.LocalPaths
func (cs ConfigStub) GetLocalSettings() ConfigLocalSettings {
return cs.LocalSettings
}
func (cs ConfigStub) GetRegistration() ConfigRegistration {
@ -58,8 +60,8 @@ func (sc SpecificConfig[RequestModel]) GetRequest() RequestModel {
return sc.RequestDefaults
}
func (sc SpecificConfig[RequestModel]) GetLocalPaths() ConfigLocalPaths {
return sc.LocalPaths
func (sc SpecificConfig[RequestModel]) GetLocalSettings() ConfigLocalSettings {
return sc.LocalSettings
}
func (sc SpecificConfig[RequestModel]) GetRegistration() ConfigRegistration {
@ -228,9 +230,14 @@ func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption {
}
func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption {
llamaOpts := []llama.PredictOption{}
llamaOpts := []llama.PredictOption{
llama.SetThreads(sc.GetLocalSettings().Threads),
}
switch req := sc.GetRequestDefaults().(type) {
// TODO Refactor this when we get to p2 and add image / audio
// I expect that it'll be worth pulling out the base case first, and doing fancy fallthrough things.
// Text Requests:
case CreateCompletionRequest:
case CreateChatCompletionRequest:
@ -306,10 +313,7 @@ func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption
// CODE TO PORT
// predictOptions := []llama.PredictOption{
// llama.SetThreads(c.Threads),
// }
// SKIPPING PROMPT CACHE FOR PASS ONE, TODO READ ABOUT IT
// if c.PromptCacheAll {
// predictOptions = append(predictOptions, llama.EnablePromptCacheAll)

View file

@ -56,8 +56,8 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode
return &SpecificConfig[RequestType]{
ConfigStub: ConfigStub{
Registration: config.GetRegistration(),
LocalPaths: config.GetLocalPaths(),
Registration: config.GetRegistration(),
LocalSettings: config.GetLocalSettings(),
},
RequestDefaults: request,
}, nil

View file

@ -155,7 +155,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
testField, exists := v2ConfigManager.GetConfig(reg)
if exists {
log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalPaths().Model)
log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalSettings().ModelPath)
}
}
@ -164,7 +164,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
log.Log().Msgf("NEW v2 test: %+v", v2Server)
}
app, err := api.App(
api.WithConfigFile(ctx.String("config-file")),
api.WithJSONStringPreload(ctx.String("preload-models")),