feat(speculative-sampling): allow to specify a draft model in the model config (#1052)

**Description**

This PR fixes #1013.

It adds `draft_model` and `n_draft` to the model YAML config in order to
load models with speculative sampling. This should be compatible as well
with grammars.

example:

```yaml
backend: llama                                                                                                                                                                   
context_size: 1024                                                                                                                                                                        
name: my-model-name
parameters:
  model: foo-bar
n_draft: 16                                                                                                                                                                      
draft_model: model-name
```

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-09-14 17:44:16 +02:00 committed by GitHub
parent 247d85b523
commit 8ccf5b2044
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 485 additions and 427 deletions

View file

@ -14,7 +14,8 @@ import (
type LLM struct {
base.SingleThread
llama *llama.LLama
llama *llama.LLama
draftModel *llama.LLama
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
@ -78,7 +79,27 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
}
if opts.DraftModel != "" {
// https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40
llamaOpts = append(llamaOpts, llama.SetPerplexity(true))
}
model, err := llama.New(opts.ModelFile, llamaOpts...)
if opts.DraftModel != "" {
// opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile
if !filepath.IsAbs(opts.DraftModel) {
dir := filepath.Dir(opts.ModelFile)
opts.DraftModel = filepath.Join(dir, opts.DraftModel)
}
draftModel, err := llama.New(opts.DraftModel, llamaOpts...)
if err != nil {
return err
}
llm.draftModel = draftModel
}
llm.llama = model
return err
@ -162,6 +183,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
}
if opts.NDraft != 0 {
predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft)))
}
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty))
@ -175,6 +199,9 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
if llm.draftModel != nil {
return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
}
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
@ -187,7 +214,13 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
}))
go func() {
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
var err error
if llm.draftModel != nil {
_, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...)
} else {
_, err = llm.llama.Predict(opts.Prompt, predictOptions...)
}
if err != nil {
fmt.Println("err: ", err)
}