Small refactoring and adaptations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-12-08 12:02:41 +01:00
parent 614bb5e542
commit f47f344836
6 changed files with 115 additions and 165 deletions

View file

@ -18,7 +18,7 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
modelLoader: model.NewModelLoader(appConfig.ModelPath),
applicationConfig: appConfig,
templatesEvaluator: templates.NewEvaluator(templates.NewTemplateCache(appConfig.ModelPath)),
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
}
}

View file

@ -303,7 +303,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if shouldUseFn && config.Grammar != "" {
if config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}

View file

@ -20,15 +20,15 @@ import (
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
type TemplateType int
type TemplateCache struct {
type templateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
jinjaTemplates map[TemplateType]map[string]*exec.Template
}
func NewTemplateCache(templatesPath string) *TemplateCache {
tc := &TemplateCache{
func newTemplateCache(templatesPath string) *templateCache {
tc := &templateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
@ -36,43 +36,16 @@ func NewTemplateCache(templatesPath string) *TemplateCache {
return tc
}
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
if _, ok := tc.templates[tt]; !ok {
tc.templates[tt] = make(map[string]*template.Template)
}
}
func (tc *TemplateCache) ExistsInModelPath(s string) bool {
func (tc *templateCache) existsInModelPath(s string) bool {
return utils.ExistsInPath(tc.templatesPath, s)
}
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.templates[templateType][templateName]; ok {
@ -92,7 +65,7 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
}
// can either be a file in the system or a string with the template
if tc.ExistsInModelPath(modelTemplateFile) {
if tc.existsInModelPath(modelTemplateFile) {
d, err := os.ReadFile(file)
if err != nil {
return err
@ -112,41 +85,13 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
return nil
}
func (tc *TemplateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
if _, ok := tc.jinjaTemplates[tt]; !ok {
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
}
}
func (tc *TemplateCache) EvaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
data := exec.NewContext(in)
if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
return nil
@ -183,3 +128,57 @@ func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, te
return nil
}
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
data := exec.NewContext(in)
if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}

View file

@ -1,89 +0,0 @@
package templates_test
import (
"os"
"path/filepath"
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const jinjaTemplate = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
var _ = Describe("TemplateCache", func() {
var (
templateCache *templates.TemplateCache
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "templates")
Expect(err).NotTo(HaveOccurred())
// Writing example template files
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
Expect(err).NotTo(HaveOccurred())
templateCache = templates.NewTemplateCache(tempDir)
})
AfterEach(func() {
os.RemoveAll(tempDir) // Clean up
})
Describe("EvaluateTemplate", func() {
Context("when template is loaded successfully", func() {
It("should evaluate the template correctly", func() {
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("Hello, Gopher!"))
})
})
Context("when template isn't a file", func() {
It("should parse from string", func() {
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("Gopher"))
})
})
Context("when template is empty", func() {
It("should return an empty string", func() {
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(""))
})
})
Context("when template is jinja2", func() {
It("should parse from string", func() {
result, err := templateCache.EvaluateJinjaTemplate(1, jinjaTemplate, map[string]interface{}{"messages": []map[string]interface{}{{"role": "user", "content": "Hello, Gopher!"}}})
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello, Gopher!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"))
})
})
})
Describe("concurrency", func() {
It("should handle multiple concurrent accesses", func(done Done) {
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
close(done)
}, 0.1) // timeout in seconds
})
})

View file

@ -44,12 +44,12 @@ const (
)
type Evaluator struct {
cache *TemplateCache
cache *templateCache
}
func NewEvaluator(cache *TemplateCache) *Evaluator {
func NewEvaluator(modelPath string) *Evaluator {
return &Evaluator{
cache: cache,
cache: newTemplateCache(modelPath),
}
}
@ -57,7 +57,7 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
template := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if e.cache.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
template = config.Model
}
@ -88,11 +88,11 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
}
return e.cache.EvaluateTemplate(templateType, template, in)
return e.cache.evaluateTemplate(templateType, template, in)
}
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
}
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes
conversation["tools"] = funcs
}
return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
@ -130,7 +130,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
conversation["system_prompt"] = in.SystemPrompt
conversation["content"] = in.Input
return e.cache.EvaluateJinjaTemplate(templateType, templateName, conversation)
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {

View file

@ -10,6 +10,14 @@ import (
. "github.com/onsi/gomega"
)
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
@ -183,11 +191,30 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
},
}
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: toolCallJinja,
JinjaTemplate: true,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
}
var _ = Describe("Templates", func() {
Context("chat message ChatML", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator(NewTemplateCache(""))
evaluator = NewEvaluator("")
})
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
@ -200,7 +227,7 @@ var _ = Describe("Templates", func() {
Context("chat message llama3", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator(NewTemplateCache(""))
evaluator = NewEvaluator("")
})
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
@ -210,4 +237,17 @@ var _ = Describe("Templates", func() {
})
}
})
Context("chat message jinja", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range jinjaTest {
foo := jinjaTest[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
})