mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-23 20:14:59 +00:00
feat(speculative-sampling): allow to specify a draft model in the model config (#1052)
**Description** This PR fixes #1013. It adds `draft_model` and `n_draft` to the model YAML config in order to load models with speculative sampling. This should be compatible as well with grammars. example: ```yaml backend: llama context_size: 1024 name: my-model-name parameters: model: foo-bar n_draft: 16 draft_model: model-name ``` --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
247d85b523
commit
8ccf5b2044
12 changed files with 485 additions and 427 deletions
|
@ -14,7 +14,8 @@ import (
|
|||
type LLM struct {
|
||||
base.SingleThread
|
||||
|
||||
llama *llama.LLama
|
||||
llama *llama.LLama
|
||||
draftModel *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
|
@ -78,7 +79,27 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
|||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
if opts.DraftModel != "" {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40
|
||||
llamaOpts = append(llamaOpts, llama.SetPerplexity(true))
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.ModelFile, llamaOpts...)
|
||||
|
||||
if opts.DraftModel != "" {
|
||||
// opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile
|
||||
if !filepath.IsAbs(opts.DraftModel) {
|
||||
dir := filepath.Dir(opts.ModelFile)
|
||||
opts.DraftModel = filepath.Join(dir, opts.DraftModel)
|
||||
}
|
||||
|
||||
draftModel, err := llama.New(opts.DraftModel, llamaOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
llm.draftModel = draftModel
|
||||
}
|
||||
|
||||
llm.llama = model
|
||||
|
||||
return err
|
||||
|
@ -162,6 +183,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
|||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
if opts.NDraft != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft)))
|
||||
}
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
|
||||
|
@ -175,6 +199,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
|||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
if llm.draftModel != nil {
|
||||
return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
|
@ -187,7 +214,13 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
|
|||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
var err error
|
||||
if llm.draftModel != nil {
|
||||
_, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
|
||||
} else {
|
||||
_, err = llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue