mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-22 03:24:59 +00:00
Allow to template model prompts inputs
This commit is contained in:
parent
48aca246e3
commit
9fb581739b
2 changed files with 59 additions and 8 deletions
12
api.go
12
api.go
|
@ -103,10 +103,18 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre
|
||||||
mess = append(mess, i.Content)
|
mess = append(mess, i.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Received", input, input.Model)
|
predInput := strings.Join(mess, "\n")
|
||||||
|
|
||||||
|
templatedInput, err := loader.TemplatePrefix(input.Model, struct {
|
||||||
|
Input string
|
||||||
|
}{Input: predInput})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
}
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
// Generate the prediction using the language model
|
||||||
prediction, err := model.Predict(
|
prediction, err := model.Predict(
|
||||||
strings.Join(mess, "\n"),
|
templatedInput,
|
||||||
llama.SetTemperature(temperature),
|
llama.SetTemperature(temperature),
|
||||||
llama.SetTopP(topP),
|
llama.SetTopP(topP),
|
||||||
llama.SetTopK(topK),
|
llama.SetTopK(topK),
|
||||||
|
|
|
@ -1,30 +1,55 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
llama "github.com/go-skynet/go-llama.cpp"
|
llama "github.com/go-skynet/go-llama.cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelLoader struct {
|
type ModelLoader struct {
|
||||||
modelPath string
|
modelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
models map[string]*llama.LLama
|
models map[string]*llama.LLama
|
||||||
|
promptsTemplates map[string]*template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)}
|
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
|
||||||
|
ml.mu.Lock()
|
||||||
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
|
m, ok := ml.promptsTemplates[modelName]
|
||||||
|
if !ok {
|
||||||
|
// try to find a s.bin
|
||||||
|
modelBin := fmt.Sprintf("%s.bin", modelName)
|
||||||
|
m, ok = ml.promptsTemplates[modelBin]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("no prompt template available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, in); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
defer ml.mu.Unlock()
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
// Check if we already have a loaded model
|
// Check if we already have a loaded model
|
||||||
modelFile := filepath.Join(ml.modelPath, s)
|
modelFile := filepath.Join(ml.modelPath, modelName)
|
||||||
|
|
||||||
if m, ok := ml.models[modelFile]; ok {
|
if m, ok := ml.models[modelFile]; ok {
|
||||||
return m, nil
|
return m, nil
|
||||||
|
@ -47,6 +72,24 @@ func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there is a prompt template, load it
|
||||||
|
|
||||||
|
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
|
||||||
|
// Check if the model path exists
|
||||||
|
if _, err := os.Stat(modelTemplateFile); err == nil {
|
||||||
|
dat, err := os.ReadFile(modelTemplateFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ml.promptsTemplates[modelFile] = tmpl
|
||||||
|
}
|
||||||
|
|
||||||
ml.models[modelFile] = model
|
ml.models[modelFile] = model
|
||||||
return model, err
|
return model, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue