mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
feat(template): read jinja templates from gguf files (#4332)
* Read jinja templates as fallback Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move templating out of model loader Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Test TemplateMessages Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Set role and content from transformers Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Tests: be more flexible Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * More jinja Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small refactoring and adaptations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
f5e1527a5a
commit
cea5a0ea42
23 changed files with 971 additions and 785 deletions
|
@ -9,8 +9,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -23,7 +21,6 @@ type ModelLoader struct {
|
|||
ModelPath string
|
||||
mu sync.Mutex
|
||||
models map[string]*Model
|
||||
templates *templates.TemplateCache
|
||||
wd *WatchDog
|
||||
}
|
||||
|
||||
|
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||
nml := &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
models: make(map[string]*Model),
|
||||
templates: templates.NewTemplateCache(modelPath),
|
||||
}
|
||||
|
||||
return nml
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/templates"
|
||||
)
|
||||
|
||||
// Rather than pass an interface{} to the prompt template:
|
||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||
type PromptTemplateData struct {
|
||||
SystemPrompt string
|
||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
RoleName string
|
||||
FunctionName string
|
||||
Content string
|
||||
MessageIndex int
|
||||
Function bool
|
||||
FunctionCall interface{}
|
||||
LastMessage bool
|
||||
}
|
||||
|
||||
const (
|
||||
ChatPromptTemplate templates.TemplateType = iota
|
||||
ChatMessageTemplate
|
||||
CompletionPromptTemplate
|
||||
EditPromptTemplate
|
||||
FunctionsPromptTemplate
|
||||
)
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
// TODO: should this check be improved?
|
||||
if templateType == ChatMessageTemplate {
|
||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
||||
}
|
||||
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
package model_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||
{{- if .FunctionCall }}
|
||||
<tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
<tool_response>
|
||||
{{- end }}
|
||||
{{- if .Content}}
|
||||
{{.Content }}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall }}
|
||||
</tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
</tool_response>
|
||||
{{- end }}<|im_end|>`
|
||||
|
||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content -}}
|
||||
{{ else if .FunctionCall -}}
|
||||
{{ toJson .FunctionCall -}}
|
||||
{{ end -}}
|
||||
<|eot_id|>`
|
||||
|
||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "user",
|
||||
RoleName: "user",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_call": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"template": llama3,
|
||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "tool",
|
||||
RoleName: "tool",
|
||||
Content: "Response from tool",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "user",
|
||||
RoleName: "user",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "A long time ago in a galaxy far, far away...",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_call": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "assistant",
|
||||
RoleName: "assistant",
|
||||
Content: "",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"template": chatML,
|
||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||
"data": ChatMessageTemplateData{
|
||||
SystemPrompt: "",
|
||||
Role: "tool",
|
||||
RoleName: "tool",
|
||||
Content: "Response from tool",
|
||||
FunctionCall: nil,
|
||||
FunctionName: "",
|
||||
LastMessage: false,
|
||||
Function: false,
|
||||
MessageIndex: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var _ = Describe("Templates", func() {
|
||||
Context("chat message ChatML", func() {
|
||||
var modelLoader *ModelLoader
|
||||
BeforeEach(func() {
|
||||
modelLoader = NewModelLoader("")
|
||||
})
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message llama3", func() {
|
||||
var modelLoader *ModelLoader
|
||||
BeforeEach(func() {
|
||||
modelLoader = NewModelLoader("")
|
||||
})
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
|
@ -11,59 +11,41 @@ import (
|
|||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
|
||||
"github.com/nikolalohinski/gonja/v2"
|
||||
"github.com/nikolalohinski/gonja/v2/exec"
|
||||
)
|
||||
|
||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||
type TemplateType int
|
||||
|
||||
type TemplateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
type templateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||
}
|
||||
|
||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
||||
tc := &TemplateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
func newTemplateCache(templatesPath string) *templateCache {
|
||||
tc := &templateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.templates[tt]; !ok {
|
||||
tc.templates[tt] = make(map[string]*template.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateName]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateName)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||
return utils.ExistsInPath(tc.templatesPath, s)
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||
|
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if tc.existsInModelPath(modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dat = string(d)
|
||||
} else {
|
||||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
// skip any error here - we run anyway if a template does not exist
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||
|
||||
dat := ""
|
||||
file := filepath.Join(tc.templatesPath, modelTemplateFile)
|
||||
|
||||
// Security check
|
||||
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
|
||||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
|
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
tmpl, err := gonja.FromString(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
tc.jinjaTemplates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeJinjaTemplateMapKey(templateType)
|
||||
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
data := exec.NewContext(in)
|
||||
|
||||
if err := m.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
package templates_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TemplateCache", func() {
|
||||
var (
|
||||
templateCache *templates.TemplateCache
|
||||
tempDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "templates")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Writing example template files
|
||||
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
templateCache = templates.NewTemplateCache(tempDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir) // Clean up
|
||||
})
|
||||
|
||||
Describe("EvaluateTemplate", func() {
|
||||
Context("when template is loaded successfully", func() {
|
||||
It("should evaluate the template correctly", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal("Hello, Gopher!"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template isn't a file", func() {
|
||||
It("should parse from string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Gopher"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template is empty", func() {
|
||||
It("should return an empty string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal(""))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("concurrency", func() {
|
||||
It("should handle multiple concurrent accesses", func(done Done) {
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
close(done)
|
||||
}, 0.1) // timeout in seconds
|
||||
})
|
||||
})
|
295
pkg/templates/evaluator.go
Normal file
295
pkg/templates/evaluator.go
Normal file
|
@ -0,0 +1,295 @@
|
|||
package templates
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Rather than pass an interface{} to the prompt template:
|
||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||
type PromptTemplateData struct {
|
||||
SystemPrompt string
|
||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []functions.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
RoleName string
|
||||
FunctionName string
|
||||
Content string
|
||||
MessageIndex int
|
||||
Function bool
|
||||
FunctionCall interface{}
|
||||
LastMessage bool
|
||||
}
|
||||
|
||||
const (
|
||||
ChatPromptTemplate TemplateType = iota
|
||||
ChatMessageTemplate
|
||||
CompletionPromptTemplate
|
||||
EditPromptTemplate
|
||||
FunctionsPromptTemplate
|
||||
)
|
||||
|
||||
type Evaluator struct {
|
||||
cache *templateCache
|
||||
}
|
||||
|
||||
func NewEvaluator(modelPath string) *Evaluator {
|
||||
return &Evaluator{
|
||||
cache: newTemplateCache(modelPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||||
template := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
template = config.Model
|
||||
}
|
||||
|
||||
switch templateType {
|
||||
case CompletionPromptTemplate:
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
template = config.TemplateConfig.Completion
|
||||
}
|
||||
case EditPromptTemplate:
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
template = config.TemplateConfig.Edit
|
||||
}
|
||||
case ChatPromptTemplate:
|
||||
if config.TemplateConfig.Chat != "" {
|
||||
template = config.TemplateConfig.Chat
|
||||
}
|
||||
case FunctionsPromptTemplate:
|
||||
if config.TemplateConfig.Functions != "" {
|
||||
template = config.TemplateConfig.Functions
|
||||
}
|
||||
}
|
||||
|
||||
if template == "" {
|
||||
return in.Input, nil
|
||||
}
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||||
}
|
||||
|
||||
return e.cache.evaluateTemplate(templateType, template, in)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
||||
|
||||
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||||
|
||||
conversation := make(map[string]interface{})
|
||||
messages := make([]map[string]interface{}, len(messageData))
|
||||
|
||||
// convert from ChatMessageTemplateData to what the jinja template expects
|
||||
|
||||
for _, message := range messageData {
|
||||
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
|
||||
var data []byte
|
||||
data, _ = json.Marshal(message.FunctionCall)
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": message.RoleName,
|
||||
"content": message.Content,
|
||||
"tool_call": string(data),
|
||||
})
|
||||
}
|
||||
|
||||
conversation["messages"] = messages
|
||||
|
||||
// if tools are detected, add these
|
||||
if len(funcs) > 0 {
|
||||
conversation["tools"] = funcs
|
||||
}
|
||||
|
||||
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
|
||||
conversation := make(map[string]interface{})
|
||||
|
||||
conversation["system_prompt"] = in.SystemPrompt
|
||||
conversation["content"] = in.Input
|
||||
|
||||
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
|
||||
if config.TemplateConfig.JinjaTemplate {
|
||||
var messageData []ChatMessageTemplateData
|
||||
for messageIndex, i := range messages {
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
messageData = append(messageData, ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: config.Roles[i.Role],
|
||||
RoleName: i.Role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(messages) - 1),
|
||||
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
})
|
||||
}
|
||||
|
||||
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
|
||||
if err == nil {
|
||||
return templatedInput
|
||||
}
|
||||
}
|
||||
|
||||
var predInput string
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(messages) - 1),
|
||||
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
|
||||
marshalAnyRole := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
marshalAny := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
|
||||
if i.FunctionCall != nil {
|
||||
marshalAnyRole(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAnyRole(i.ToolCalls)
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
marshalAny(i.FunctionCall)
|
||||
}
|
||||
if i.ToolCalls != nil {
|
||||
marshalAny(i.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
joinCharacter := "\n"
|
||||
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||||
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, joinCharacter)
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
promptTemplate := ChatPromptTemplate
|
||||
|
||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||
promptTemplate = FunctionsPromptTemplate
|
||||
}
|
||||
|
||||
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
|
||||
return predInput
|
||||
}
|
253
pkg/templates/evaluator_test.go
Normal file
253
pkg/templates/evaluator_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package templates_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/mudler/LocalAI/pkg/templates"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||
|
||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||
{{- if .FunctionCall }}
|
||||
<tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
<tool_response>
|
||||
{{- end }}
|
||||
{{- if .Content}}
|
||||
{{.Content }}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{- end }}
|
||||
{{- if .FunctionCall }}
|
||||
</tool_call>
|
||||
{{- else if eq .RoleName "tool" }}
|
||||
</tool_response>
|
||||
{{- end }}<|im_end|>`
|
||||
|
||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content -}}
|
||||
{{ else if .FunctionCall -}}
|
||||
{{ toJson .FunctionCall -}}
|
||||
{{ end -}}
|
||||
<|eot_id|>`
|
||||
|
||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_call": {
|
||||
|
||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: llama3,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "tool",
|
||||
StringContent: "Response from tool",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
}
|
||||
|
||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
"shouldUseFn": false,
|
||||
},
|
||||
"function_call": {
|
||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{
|
||||
{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Parameters: nil,
|
||||
},
|
||||
},
|
||||
"shouldUseFn": true,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
FunctionCall: map[string]string{"function": "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"function_response": {
|
||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: chatML,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "tool",
|
||||
StringContent: "Response from tool",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: toolCallJinja,
|
||||
JinjaTemplate: true,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
var _ = Describe("Templates", func() {
|
||||
Context("chat message ChatML", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message llama3", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
Context("chat message jinja", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range jinjaTest {
|
||||
foo := jinjaTest[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
Loading…
Add table
Add a link
Reference in a new issue