mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-30 06:30:43 +00:00
Small refactoring and adaptations
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
614bb5e542
commit
f47f344836
6 changed files with 115 additions and 165 deletions
|
@ -18,7 +18,7 @@ func newApplication(appConfig *config.ApplicationConfig) *Application {
|
|||
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||
modelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||
applicationConfig: appConfig,
|
||||
templatesEvaluator: templates.NewEvaluator(templates.NewTemplateCache(appConfig.ModelPath)),
|
||||
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -303,7 +303,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
|||
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if shouldUseFn && config.Grammar != "" {
|
||||
if config.Grammar != "" {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,15 +20,15 @@ import (
|
|||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||
type TemplateType int
|
||||
|
||||
type TemplateCache struct {
|
||||
type templateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||
}
|
||||
|
||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
||||
tc := &TemplateCache{
|
||||
func newTemplateCache(templatesPath string) *templateCache {
|
||||
tc := &templateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||
|
@ -36,43 +36,16 @@ func NewTemplateCache(templatesPath string) *TemplateCache {
|
|||
return tc
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.templates[tt]; !ok {
|
||||
tc.templates[tt] = make(map[string]*template.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) ExistsInModelPath(s string) bool {
|
||||
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||
return utils.ExistsInPath(tc.templatesPath, s)
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||
|
@ -92,7 +65,7 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if tc.ExistsInModelPath(modelTemplateFile) {
|
||||
if tc.existsInModelPath(modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -112,41 +85,13 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
return nil
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) EvaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeJinjaTemplateMapKey(templateType)
|
||||
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
data := exec.NewContext(in)
|
||||
|
||||
if err := m.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||
return nil
|
||||
|
@ -183,3 +128,57 @@ func (tc *TemplateCache) loadJinjaTemplateIfExists(templateType TemplateType, te
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeJinjaTemplateMapKey(templateType)
|
||||
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
data := exec.NewContext(in)
|
||||
|
||||
if err := m.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
package templates_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const jinjaTemplate = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||
|
||||
var _ = Describe("TemplateCache", func() {
|
||||
var (
|
||||
templateCache *templates.TemplateCache
|
||||
tempDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "templates")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Writing example template files
|
||||
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
templateCache = templates.NewTemplateCache(tempDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir) // Clean up
|
||||
})
|
||||
|
||||
Describe("EvaluateTemplate", func() {
|
||||
Context("when template is loaded successfully", func() {
|
||||
It("should evaluate the template correctly", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal("Hello, Gopher!"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template isn't a file", func() {
|
||||
It("should parse from string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Gopher"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template is empty", func() {
|
||||
It("should return an empty string", func() {
|
||||
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when template is jinja2", func() {
|
||||
It("should parse from string", func() {
|
||||
result, err := templateCache.EvaluateJinjaTemplate(1, jinjaTemplate, map[string]interface{}{"messages": []map[string]interface{}{{"role": "user", "content": "Hello, Gopher!"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello, Gopher!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("concurrency", func() {
|
||||
It("should handle multiple concurrent accesses", func(done Done) {
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
go func() {
|
||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
||||
}()
|
||||
close(done)
|
||||
}, 0.1) // timeout in seconds
|
||||
})
|
||||
})
|
|
@ -44,12 +44,12 @@ const (
|
|||
)
|
||||
|
||||
type Evaluator struct {
|
||||
cache *TemplateCache
|
||||
cache *templateCache
|
||||
}
|
||||
|
||||
func NewEvaluator(cache *TemplateCache) *Evaluator {
|
||||
func NewEvaluator(modelPath string) *Evaluator {
|
||||
return &Evaluator{
|
||||
cache: cache,
|
||||
cache: newTemplateCache(modelPath),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
|
|||
template := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if e.cache.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
template = config.Model
|
||||
}
|
||||
|
||||
|
@ -88,11 +88,11 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
|
|||
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||||
}
|
||||
|
||||
return e.cache.EvaluateTemplate(templateType, template, in)
|
||||
return e.cache.evaluateTemplate(templateType, template, in)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
||||
|
||||
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||||
|
@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes
|
|||
conversation["tools"] = funcs
|
||||
}
|
||||
|
||||
return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
|
@ -130,7 +130,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
|
|||
conversation["system_prompt"] = in.SystemPrompt
|
||||
conversation["content"] = in.Input
|
||||
|
||||
return e.cache.EvaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||
}
|
||||
|
||||
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||
|
|
|
@ -10,6 +10,14 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||
|
||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||
{{- if .FunctionCall }}
|
||||
<tool_call>
|
||||
|
@ -183,11 +191,30 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||
},
|
||||
}
|
||||
|
||||
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||
"user": {
|
||||
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"config": &config.BackendConfig{
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
ChatMessage: toolCallJinja,
|
||||
JinjaTemplate: true,
|
||||
},
|
||||
},
|
||||
"functions": []functions.Function{},
|
||||
"shouldUseFn": false,
|
||||
"messages": []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
StringContent: "A long time ago in a galaxy far, far away...",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
var _ = Describe("Templates", func() {
|
||||
Context("chat message ChatML", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator(NewTemplateCache(""))
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range chatMLTestMatch {
|
||||
foo := chatMLTestMatch[key]
|
||||
|
@ -200,7 +227,7 @@ var _ = Describe("Templates", func() {
|
|||
Context("chat message llama3", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator(NewTemplateCache(""))
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range llama3TestMatch {
|
||||
foo := llama3TestMatch[key]
|
||||
|
@ -210,4 +237,17 @@ var _ = Describe("Templates", func() {
|
|||
})
|
||||
}
|
||||
})
|
||||
Context("chat message jinja", func() {
|
||||
var evaluator *Evaluator
|
||||
BeforeEach(func() {
|
||||
evaluator = NewEvaluator("")
|
||||
})
|
||||
for key := range jinjaTest {
|
||||
foo := jinjaTest[key]
|
||||
It("renders correctly `"+key+"`", func() {
|
||||
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue