mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-29 22:20:43 +00:00
feat: extend model configuration for llama.cpp
Allow to set `mmap`, `mlock`, `frequency_penalty`, `tfz`, `tensor_split`, `main_gpu`.
This commit is contained in:
parent
694dd4ad9e
commit
3b5df11881
6 changed files with 58 additions and 30 deletions
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
|
|||
GOVET=$(GOCMD) vet
|
||||
BINARY_NAME=local-ai
|
||||
|
||||
GOLLAMA_VERSION?=37ef81d01ae0848575e416e48b41d112ef0d520e
|
||||
GOLLAMA_VERSION?=9a7c56ad810067660cb06b26ccf8a3ac9dd2b21d
|
||||
GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all
|
||||
GPT4ALL_VERSION?=f7498c9
|
||||
GOGGMLTRANSFORMERS_VERSION?=bd765bb6f3b38a63f915f3725e488aad492eedd4
|
||||
|
|
|
@ -33,6 +33,11 @@ type Config struct {
|
|||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||
Mirostat int `yaml:"mirostat"`
|
||||
NGPULayers int `yaml:"gpu_layers"`
|
||||
MMap bool `yaml:"mmap"`
|
||||
MMlock bool `yaml:"mmlock"`
|
||||
|
||||
TensorSplit string `yaml:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu"`
|
||||
ImageGenerationAssets string `yaml:"asset_dir"`
|
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||
|
@ -53,6 +58,12 @@ type ConfigMerger struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func defaultConfig(modelFile string) *Config {
|
||||
return &Config{
|
||||
OpenAIRequest: defaultRequest(modelFile),
|
||||
}
|
||||
}
|
||||
|
||||
func NewConfigMerger() *ConfigMerger {
|
||||
return &ConfigMerger{
|
||||
configs: make(map[string]Config),
|
||||
|
@ -308,13 +319,11 @@ func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader
|
|||
var config *Config
|
||||
cfg, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
config = &Config{
|
||||
OpenAIRequest: defaultRequest(modelFile),
|
||||
ContextSize: ctx,
|
||||
Threads: threads,
|
||||
F16: f16,
|
||||
Debug: debug,
|
||||
}
|
||||
config = defaultConfig(modelFile)
|
||||
config.ContextSize = ctx
|
||||
config.Threads = threads
|
||||
config.F16 = f16
|
||||
config.Debug = debug
|
||||
} else {
|
||||
config = &cfg
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -125,6 +125,9 @@ type OpenAIRequest struct {
|
|||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||
|
||||
Seed int `json:"seed" yaml:"seed"`
|
||||
|
||||
// Image (not supported by OpenAI)
|
||||
|
@ -191,7 +194,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||
}
|
||||
|
||||
if input.Stream {
|
||||
if (len(config.PromptStrings) > 1) {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing")
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,12 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
|
|||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit))
|
||||
if c.Batch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch))
|
||||
}
|
||||
return llamaOpts
|
||||
}
|
||||
|
||||
|
@ -217,6 +223,15 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
|
|||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ))
|
||||
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
|
|
4
go.mod
4
go.mod
|
@ -9,7 +9,7 @@ require (
|
|||
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230531070950-0548994371f7
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230606131358-bd765bb6f3b3
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230606152241-37ef81d01ae0
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230607000523-9a7c56ad8100
|
||||
github.com/gofiber/fiber/v2 v2.46.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
|
@ -17,7 +17,7 @@ require (
|
|||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230605194130-266f13aee9d8
|
||||
github.com/onsi/ginkgo/v2 v2.9.7
|
||||
github.com/onsi/gomega v1.27.7
|
||||
github.com/onsi/gomega v1.27.8
|
||||
github.com/otiai10/openaigo v1.1.0
|
||||
github.com/rs/zerolog v1.29.1
|
||||
github.com/sashabaranov/go-openai v1.10.0
|
||||
|
|
1
go.sum
1
go.sum
|
@ -201,6 +201,7 @@ github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
|
|||
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
|
||||
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
||||
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
|
||||
github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ=
|
||||
github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks=
|
||||
github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM=
|
||||
github.com/otiai10/openaigo v1.1.0 h1:zRvGBqZUW5PCMgdkJNsPVTBd8tOLCMTipXE5wD2pdTg=
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue