diff --git a/api/api_test.go b/api/api_test.go index 147774df..2da2a7d7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -30,10 +30,10 @@ import ( ) type modelApplyRequest struct { - ID string `json:"id"` - URL string `json:"url"` - Name string `json:"name"` - Overrides map[string]string `json:"overrides"` + ID string `json:"id"` + URL string `json:"url"` + Name string `json:"name"` + Overrides map[string]interface{} `json:"overrides"` } func getModelStatus(url string) (response map[string]interface{}) { @@ -243,7 +243,7 @@ var _ = Describe("API test", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Name: "bert", - Overrides: map[string]string{ + Overrides: map[string]interface{}{ "backend": "llama", }, }) @@ -269,7 +269,7 @@ var _ = Describe("API test", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Name: "bert", - Overrides: map[string]string{}, + Overrides: map[string]interface{}{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) @@ -297,7 +297,7 @@ var _ = Describe("API test", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/openllama_3b.yaml", Name: "openllama_3b", - Overrides: map[string]string{"backend": "llama"}, + Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) @@ -366,9 +366,8 @@ var _ = Describe("API test", func() { } response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ - URL: "github:go-skynet/model-gallery/gpt4all-j.yaml", - Name: "gpt4all-j", - Overrides: map[string]string{}, + URL: "github:go-skynet/model-gallery/gpt4all-j.yaml", + Name: "gpt4all-j", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go index 7d867813..2f85e175 100644 --- a/pkg/grpc/llm/llama/llama.go +++ b/pkg/grpc/llm/llama/llama.go @@ -58,6 +58,15 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { } func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { + ropeFreqBase := float32(1000) + ropeFreqScale := float32(1) + + if opts.RopeFreqBase != 0 { + ropeFreqBase = opts.RopeFreqBase + } + if opts.RopeFreqScale != 0 { + ropeFreqScale = opts.RopeFreqScale + } predictOptions := []llama.PredictOption{ llama.SetTemperature(opts.Temperature), llama.SetTopP(opts.TopP), @@ -65,8 +74,8 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { llama.SetTokens(int(opts.Tokens)), llama.SetThreads(int(opts.Threads)), llama.WithGrammar(opts.Grammar), - llama.SetRopeFreqBase(opts.RopeFreqBase), - llama.SetRopeFreqScale(opts.RopeFreqScale), + llama.SetRopeFreqBase(ropeFreqBase), + llama.SetRopeFreqScale(ropeFreqScale), llama.SetNegativePromptScale(opts.NegativePromptScale), llama.SetNegativePrompt(opts.NegativePrompt), }