From 8cf65018ac17d1007c66707435323d5b676e91df Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Tue, 6 Jun 2023 17:48:34 -0400 Subject: [PATCH] draft of the config mapper? --- apiv2/config.go | 125 ++++++++++++++++++++-- apiv2/localai.go | 6 +- go.mod | 16 +++ openai-openapi/localai_model_patches.yaml | 8 ++ 4 files changed, 143 insertions(+), 12 deletions(-) diff --git a/apiv2/config.go b/apiv2/config.go index d083bd03..f2892e5d 100644 --- a/apiv2/config.go +++ b/apiv2/config.go @@ -7,6 +7,7 @@ import ( "strings" "sync" + llama "github.com/go-skynet/go-llama.cpp" "github.com/mitchellh/mapstructure" "gopkg.in/yaml.v2" ) @@ -194,19 +195,27 @@ func (cm *ConfigManager) ListConfigs() []ConfigRegistration { // These functions I'm a bit dubious about. I think there's a better refactoring down in pkg/model // But to get a minimal test up and running, here we go! - +// TODO: non text completion func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption { llamaOpts := []llama.ModelOption{} - // Code to Port: + switch req := sc.GetRequestDefaults().(type) { + case CreateCompletionRequest: + case CreateChatCompletionRequest: + if req.XLocalaiExtensions.F16 != nil && *(req.XLocalaiExtensions.F16) { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + + if req.MaxTokens != nil && *req.MaxTokens > 0 { + llamaOpts = append(llamaOpts, llama.SetContext(*req.MaxTokens)) // todo is this right? + } + + // TODO DO MORE! + + } + // Code to Port: - // if c.ContextSize != 0 { - // llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) - // } - // if c.F16 { - // llamaOpts = append(llamaOpts, llama.EnableF16Memory) - // } // if c.Embeddings { // llamaOpts = append(llamaOpts, llama.EnableEmbeddings) // } @@ -216,4 +225,102 @@ func (sc SpecificConfig[RequestModel]) ToModelOptions() []llama.ModelOption { // } return llamaOpts -} \ No newline at end of file +} + +func (sc SpecificConfig[RequestModel]) ToPredictOptions() []llama.PredictOption { + llamaOpts := []llama.PredictOption{} + + switch req := sc.GetRequestDefaults().(type) { + case CreateCompletionRequest: + case CreateChatCompletionRequest: + + if req.Temperature != nil { + llamaOpts = append(llamaOpts, llama.SetTemperature(float64(*req.Temperature))) // Oh boy. TODO Investigate. This is why I'm doing this. + } + + if req.TopP != nil { + llamaOpts = append(llamaOpts, llama.SetTopP(float64(*req.TopP))) // CAST + } + + if req.MaxTokens != nil { + llamaOpts = append(llamaOpts, llama.SetTokens(*req.MaxTokens)) + } + + if req.FrequencyPenalty != nil { + llamaOpts = append(llamaOpts, llama.SetPenalty(float64(*req.FrequencyPenalty))) // CAST + } + + if stop0, err := req.Stop.AsCreateChatCompletionRequestStop0(); err == nil { + llamaOpts = append(llamaOpts, llama.SetStopWords(stop0)) + } + + if stop1, err := req.Stop.AsCreateChatCompletionRequestStop1(); err == nil && len(stop1) > 0 { + llamaOpts = append(llamaOpts, llama.SetStopWords(stop1...)) + } + + if req.XLocalaiExtensions != nil { + + if req.XLocalaiExtensions.TopK != nil { + llamaOpts = append(llamaOpts, llama.SetTopK(*req.XLocalaiExtensions.TopK)) + } + + if req.XLocalaiExtensions.F16 != nil && *(req.XLocalaiExtensions.F16) { + llamaOpts = append(llamaOpts, llama.EnableF16KV) + } + + if req.XLocalaiExtensions.Seed != nil { + llamaOpts = append(llamaOpts, llama.SetSeed(*req.XLocalaiExtensions.Seed)) + } + + if req.XLocalaiExtensions.IgnoreEos != nil && *(req.XLocalaiExtensions.IgnoreEos) { + llamaOpts = append(llamaOpts, llama.IgnoreEOS) + } + + if req.XLocalaiExtensions.Debug != nil && *(req.XLocalaiExtensions.Debug) { + llamaOpts = append(llamaOpts, llama.Debug) + } + + if req.XLocalaiExtensions.Mirostat != nil { + llamaOpts = append(llamaOpts, llama.SetMirostat(*req.XLocalaiExtensions.Mirostat)) + } + + if req.XLocalaiExtensions.MirostatEta != nil { + llamaOpts = append(llamaOpts, llama.SetMirostatETA(*req.XLocalaiExtensions.MirostatEta)) + } + + if req.XLocalaiExtensions.MirostatTau != nil { + llamaOpts = append(llamaOpts, llama.SetMirostatTAU(*req.XLocalaiExtensions.MirostatTau)) + } + + if req.XLocalaiExtensions.Keep != nil { + llamaOpts = append(llamaOpts, llama.SetNKeep(*req.XLocalaiExtensions.Keep)) + } + + if req.XLocalaiExtensions.Batch != nil && *(req.XLocalaiExtensions.Batch) != 0 { + llamaOpts = append(llamaOpts, llama.SetBatch(*req.XLocalaiExtensions.Batch)) + } + + } + + } + + // CODE TO PORT + + // predictOptions := []llama.PredictOption{ + + // llama.SetThreads(c.Threads), + // } + + // if c.PromptCacheAll { + // predictOptions = append(predictOptions, llama.EnablePromptCacheAll) + // } + + // if c.PromptCachePath != "" { + // // Create parent directory + // p := filepath.Join(modelPath, c.PromptCachePath) + // os.MkdirAll(filepath.Dir(p), 0755) + // predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) + // } + + return llamaOpts +} diff --git a/apiv2/localai.go b/apiv2/localai.go index 3627f6b6..4a29d0a5 100644 --- a/apiv2/localai.go +++ b/apiv2/localai.go @@ -63,9 +63,9 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode }, nil } -func (las *LocalAIServer) loadModel(configStub ConfigStub) { - -} +// func (las *LocalAIServer) loadModel(configStub ConfigStub) { + +// } // CancelFineTune implements StrictServerInterface func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTuneRequestObject) (CancelFineTuneResponseObject, error) { diff --git a/go.mod b/go.mod index aa5717df..f06eac08 100644 --- a/go.mod +++ b/go.mod @@ -108,3 +108,19 @@ require ( ) replace github.com/deepmap/oapi-codegen v1.12.4 => github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63 + +replace github.com/go-skynet/go-llama.cpp => /Users/dave/projects/LocalAI/go-llama + +replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang => /Users/dave/projects/LocalAI/gpt4all/gpt4all-bindings/golang + +replace github.com/go-skynet/go-ggml-transformers.cpp => /Users/dave/projects/LocalAI/go-ggml-transformers + +replace github.com/donomii/go-rwkv.cpp => /Users/dave/projects/LocalAI/go-rwkv + +replace github.com/ggerganov/whisper.cpp => /Users/dave/projects/LocalAI/whisper.cpp + +replace github.com/go-skynet/go-bert.cpp => /Users/dave/projects/LocalAI/go-bert + +replace github.com/go-skynet/bloomz.cpp => /Users/dave/projects/LocalAI/bloomz + +replace github.com/mudler/go-stable-diffusion => /Users/dave/projects/LocalAI/go-stable-diffusion diff --git a/openai-openapi/localai_model_patches.yaml b/openai-openapi/localai_model_patches.yaml index bb5defd9..6d8771b2 100644 --- a/openai-openapi/localai_model_patches.yaml +++ b/openai-openapi/localai_model_patches.yaml @@ -16,6 +16,9 @@ components: seed: type: integer nullable: true + debug: + type: boolean + nullable: true #@overlay/match missing_ok=True LocalAITextRequestExtension: allOf: @@ -72,6 +75,11 @@ components: #@overlay/match missing_ok=True x-localai-extensions: $ref: "#/components/schemas/LocalAITextRequestExtension" + CreateCompletionRequest: + properties: + #@overlay/match missing_ok=True + x-localai-extensions: + $ref: "#/components/schemas/LocalAITextRequestExtension" CreateImageRequest: properties: #@overlay/match missing_ok=True