draft of the config mapper?

This commit is contained in:
Dave Lee 2023-06-06 17:48:34 -04:00
parent 9a1302ac0c
commit 8cf65018ac
4 changed files with 143 additions and 12 deletions

View file

@ -7,6 +7,7 @@ import (
"strings" "strings"
"sync" "sync"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -194,19 +195,27 @@ func (cm *ConfigManager) ListConfigs() []ConfigRegistration {
// These functions I'm a bit dubious about. I think there's a better refactoring down in pkg/model // These functions I'm a bit dubious about. I think there's a better refactoring down in pkg/model
// But to get a minimal test up and running, here we go! // But to get a minimal test up and running, here we go!
// TODO: non text completion
func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption { func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption {
llamaOpts := []llama.ModelOption{} llamaOpts := []llama.ModelOption{}
// Code to Port: switch req := sc.GetRequestDefaults().(type) {
case CreateCompletionRequest:
case CreateChatCompletionRequest:
if req.XLocalaiExtensions.F16 != nil && *(req.XLocalaiExtensions.F16) {
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
if req.MaxTokens != nil && *req.MaxTokens > 0 {
llamaOpts = append(llamaOpts, llama.SetContext(*req.MaxTokens)) // todo is this right?
}
// TODO DO MORE!
}
// Code to Port:
// if c.ContextSize != 0 {
// llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
// }
// if c.F16 {
// llamaOpts = append(llamaOpts, llama.EnableF16Memory)
// }
// if c.Embeddings { // if c.Embeddings {
// llamaOpts = append(llamaOpts, llama.EnableEmbeddings) // llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
// } // }
@ -216,4 +225,102 @@ func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption {
// } // }
return llamaOpts return llamaOpts
} }
func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption {
llamaOpts := []llama.PredictOption{}
switch req := sc.GetRequestDefaults().(type) {
case CreateCompletionRequest:
case CreateChatCompletionRequest:
if req.Temperature != nil {
llamaOpts = append(llamaOpts, llama.SetTemperature(float64(*req.Temperature))) // Oh boy. TODO Investigate. This is why I'm doing this.
}
if req.TopP != nil {
llamaOpts = append(llamaOpts, llama.SetTopP(float64(*req.TopP))) // CAST
}
if req.MaxTokens != nil {
llamaOpts = append(llamaOpts, llama.SetTokens(*req.MaxTokens))
}
if req.FrequencyPenalty != nil {
llamaOpts = append(llamaOpts, llama.SetPenalty(float64(*req.FrequencyPenalty))) // CAST
}
if stop0, err := req.Stop.AsCreateChatCompletionRequestStop0(); err == nil {
llamaOpts = append(llamaOpts, llama.SetStopWords(stop0))
}
if stop1, err := req.Stop.AsCreateChatCompletionRequestStop1(); err == nil && len(stop1) > 0 {
llamaOpts = append(llamaOpts, llama.SetStopWords(stop1...))
}
if req.XLocalaiExtensions != nil {
if req.XLocalaiExtensions.TopK != nil {
llamaOpts = append(llamaOpts, llama.SetTopK(*req.XLocalaiExtensions.TopK))
}
if req.XLocalaiExtensions.F16 != nil && *(req.XLocalaiExtensions.F16) {
llamaOpts = append(llamaOpts, llama.EnableF16KV)
}
if req.XLocalaiExtensions.Seed != nil {
llamaOpts = append(llamaOpts, llama.SetSeed(*req.XLocalaiExtensions.Seed))
}
if req.XLocalaiExtensions.IgnoreEos != nil && *(req.XLocalaiExtensions.IgnoreEos) {
llamaOpts = append(llamaOpts, llama.IgnoreEOS)
}
if req.XLocalaiExtensions.Debug != nil && *(req.XLocalaiExtensions.Debug) {
llamaOpts = append(llamaOpts, llama.Debug)
}
if req.XLocalaiExtensions.Mirostat != nil {
llamaOpts = append(llamaOpts, llama.SetMirostat(*req.XLocalaiExtensions.Mirostat))
}
if req.XLocalaiExtensions.MirostatEta != nil {
llamaOpts = append(llamaOpts, llama.SetMirostatETA(*req.XLocalaiExtensions.MirostatEta))
}
if req.XLocalaiExtensions.MirostatTau != nil {
llamaOpts = append(llamaOpts, llama.SetMirostatTAU(*req.XLocalaiExtensions.MirostatTau))
}
if req.XLocalaiExtensions.Keep != nil {
llamaOpts = append(llamaOpts, llama.SetNKeep(*req.XLocalaiExtensions.Keep))
}
if req.XLocalaiExtensions.Batch != nil && *(req.XLocalaiExtensions.Batch) != 0 {
llamaOpts = append(llamaOpts, llama.SetBatch(*req.XLocalaiExtensions.Batch))
}
}
}
// CODE TO PORT
// predictOptions := []llama.PredictOption{
// llama.SetThreads(c.Threads),
// }
// if c.PromptCacheAll {
// predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
// }
// if c.PromptCachePath != "" {
// // Create parent directory
// p := filepath.Join(modelPath, c.PromptCachePath)
// os.MkdirAll(filepath.Dir(p), 0755)
// predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
// }
return llamaOpts
}

View file

@ -63,9 +63,9 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode
}, nil }, nil
} }
func (las *LocalAIServer) loadModel(configStub ConfigStub) { // func (las *LocalAIServer) loadModel(configStub ConfigStub) {
} // }
// CancelFineTune implements StrictServerInterface // CancelFineTune implements StrictServerInterface
func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTuneRequestObject) (CancelFineTuneResponseObject, error) { func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTuneRequestObject) (CancelFineTuneResponseObject, error) {

16
go.mod
View file

@ -108,3 +108,19 @@ require (
) )
replace github.com/deepmap/oapi-codegen v1.12.4 => github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63 replace github.com/deepmap/oapi-codegen v1.12.4 => github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63
replace github.com/go-skynet/go-llama.cpp => /Users/dave/projects/LocalAI/go-llama
replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang => /Users/dave/projects/LocalAI/gpt4all/gpt4all-bindings/golang
replace github.com/go-skynet/go-ggml-transformers.cpp => /Users/dave/projects/LocalAI/go-ggml-transformers
replace github.com/donomii/go-rwkv.cpp => /Users/dave/projects/LocalAI/go-rwkv
replace github.com/ggerganov/whisper.cpp => /Users/dave/projects/LocalAI/whisper.cpp
replace github.com/go-skynet/go-bert.cpp => /Users/dave/projects/LocalAI/go-bert
replace github.com/go-skynet/bloomz.cpp => /Users/dave/projects/LocalAI/bloomz
replace github.com/mudler/go-stable-diffusion => /Users/dave/projects/LocalAI/go-stable-diffusion

View file

@ -16,6 +16,9 @@ components:
seed: seed:
type: integer type: integer
nullable: true nullable: true
debug:
type: boolean
nullable: true
#@overlay/match missing_ok=True #@overlay/match missing_ok=True
LocalAITextRequestExtension: LocalAITextRequestExtension:
allOf: allOf:
@ -72,6 +75,11 @@ components:
#@overlay/match missing_ok=True #@overlay/match missing_ok=True
x-localai-extensions: x-localai-extensions:
$ref: "#/components/schemas/LocalAITextRequestExtension" $ref: "#/components/schemas/LocalAITextRequestExtension"
CreateCompletionRequest:
properties:
#@overlay/match missing_ok=True
x-localai-extensions:
$ref: "#/components/schemas/LocalAITextRequestExtension"
CreateImageRequest: CreateImageRequest:
properties: properties:
#@overlay/match missing_ok=True #@overlay/match missing_ok=True