feat(template): read jinja templates from gguf files (#4332)

* Read jinja templates as fallback

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Move templating out of model loader

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Test TemplateMessages

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Set role and content from transformers

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Tests: be more flexible

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* More jinja

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Small refactoring and adaptations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-12-08 13:50:33 +01:00 committed by GitHub
parent f5e1527a5a
commit cea5a0ea42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 971 additions and 785 deletions

View file

@ -9,8 +9,6 @@ import (
"sync"
"time"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
@ -23,7 +21,6 @@ type ModelLoader struct {
ModelPath string
mu sync.Mutex
models map[string]*Model
templates *templates.TemplateCache
wd *WatchDog
}
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
nml := &ModelLoader{
ModelPath: modelPath,
models: make(map[string]*Model),
templates: templates.NewTemplateCache(modelPath),
}
return nml

View file

@ -1,52 +0,0 @@
package model
import (
"fmt"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/templates"
)
// Rather than pass an interface{} to the prompt template:
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
type PromptTemplateData struct {
SystemPrompt string
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
Input string
Instruction string
Functions []functions.Function
MessageIndex int
}
type ChatMessageTemplateData struct {
SystemPrompt string
Role string
RoleName string
FunctionName string
Content string
MessageIndex int
Function bool
FunctionCall interface{}
LastMessage bool
}
const (
ChatPromptTemplate templates.TemplateType = iota
ChatMessageTemplate
CompletionPromptTemplate
EditPromptTemplate
FunctionsPromptTemplate
)
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
// TODO: should this check be improved?
if templateType == ChatMessageTemplate {
return "", fmt.Errorf("invalid templateType: ChatMessage")
}
return ml.templates.EvaluateTemplate(templateType, templateName, in)
}
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
}

View file

@ -1,197 +0,0 @@
package model_test
import (
. "github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .Content}}
{{.Content }}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>`
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content -}}
{{ else if .FunctionCall -}}
{{ toJson .FunctionCall -}}
{{ end -}}
<|eot_id|>`
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"template": llama3,
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"assistant": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_call": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "",
FunctionCall: map[string]string{"function": "test"},
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_response": {
"template": llama3,
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Content: "Response from tool",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
}
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"template": chatML,
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"assistant": {
"template": chatML,
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_call": {
"template": chatML,
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Content: "",
FunctionCall: map[string]string{"function": "test"},
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
"function_response": {
"template": chatML,
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Content: "Response from tool",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
},
},
}
var _ = Describe("Templates", func() {
Context("chat message ChatML", func() {
var modelLoader *ModelLoader
BeforeEach(func() {
modelLoader = NewModelLoader("")
})
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message llama3", func() {
var modelLoader *ModelLoader
BeforeEach(func() {
modelLoader = NewModelLoader("")
})
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
})

View file

@ -11,59 +11,41 @@ import (
"github.com/mudler/LocalAI/pkg/utils"
"github.com/Masterminds/sprig/v3"
"github.com/nikolalohinski/gonja/v2"
"github.com/nikolalohinski/gonja/v2/exec"
)
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
type TemplateType int
type TemplateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
type templateCache struct {
mu sync.Mutex
templatesPath string
templates map[TemplateType]map[string]*template.Template
jinjaTemplates map[TemplateType]map[string]*exec.Template
}
func NewTemplateCache(templatesPath string) *TemplateCache {
tc := &TemplateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
func newTemplateCache(templatesPath string) *templateCache {
tc := &templateCache{
templatesPath: templatesPath,
templates: make(map[TemplateType]map[string]*template.Template),
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
}
return tc
}
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
if _, ok := tc.templates[tt]; !ok {
tc.templates[tt] = make(map[string]*template.Template)
}
}
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateName]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateName)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateName)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
func (tc *templateCache) existsInModelPath(s string) bool {
return utils.ExistsInPath(tc.templatesPath, s)
}
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.templates[templateType][templateName]; ok {
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
return fmt.Errorf("template file outside path: %s", file)
}
// can either be a file in the system or a string with the template
if tc.existsInModelPath(modelTemplateFile) {
d, err := os.ReadFile(file)
if err != nil {
return err
}
dat = string(d)
} else {
dat = templateName
}
// Parse the template
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
if err != nil {
return err
}
tc.templates[templateType][templateName] = tmpl
return nil
}
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
if _, ok := tc.jinjaTemplates[tt]; !ok {
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
}
}
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
// Check if the template was already loaded
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
return nil
}
// Check if the model path exists
// skip any error here - we run anyway if a template does not exist
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
dat := ""
file := filepath.Join(tc.templatesPath, modelTemplateFile)
// Security check
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
return fmt.Errorf("template file outside path: %s", file)
}
// can either be a file in the system or a string with the template
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
d, err := os.ReadFile(file)
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
dat = templateName
}
// Parse the template
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
tmpl, err := gonja.FromString(dat)
if err != nil {
return err
}
tc.templates[templateType][templateName] = tmpl
tc.jinjaTemplates[templateType][templateName] = tmpl
return nil
}
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeJinjaTemplateMapKey(templateType)
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
data := exec.NewContext(in)
if err := m.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
tc.mu.Lock()
defer tc.mu.Unlock()
tc.initializeTemplateMapKey(templateType)
m, ok := tc.templates[templateType][templateNameOrContent]
if !ok {
// return "", fmt.Errorf("template not loaded: %s", templateName)
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
if loadErr != nil {
return "", loadErr
}
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
}
if m == nil {
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}

View file

@ -1,73 +0,0 @@
package templates_test
import (
"os"
"path/filepath"
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TemplateCache", func() {
var (
templateCache *templates.TemplateCache
tempDir string
)
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "templates")
Expect(err).NotTo(HaveOccurred())
// Writing example template files
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
Expect(err).NotTo(HaveOccurred())
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
Expect(err).NotTo(HaveOccurred())
templateCache = templates.NewTemplateCache(tempDir)
})
AfterEach(func() {
os.RemoveAll(tempDir) // Clean up
})
Describe("EvaluateTemplate", func() {
Context("when template is loaded successfully", func() {
It("should evaluate the template correctly", func() {
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("Hello, Gopher!"))
})
})
Context("when template isn't a file", func() {
It("should parse from string", func() {
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal("Gopher"))
})
})
Context("when template is empty", func() {
It("should return an empty string", func() {
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(""))
})
})
})
Describe("concurrency", func() {
It("should handle multiple concurrent accesses", func(done Done) {
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
go func() {
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
}()
close(done)
}, 0.1) // timeout in seconds
})
})

295
pkg/templates/evaluator.go Normal file
View file

@ -0,0 +1,295 @@
package templates
import (
"encoding/json"
"fmt"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/rs/zerolog/log"
)
// Rather than pass an interface{} to the prompt template:
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
type PromptTemplateData struct {
SystemPrompt string
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
Input string
Instruction string
Functions []functions.Function
MessageIndex int
}
type ChatMessageTemplateData struct {
SystemPrompt string
Role string
RoleName string
FunctionName string
Content string
MessageIndex int
Function bool
FunctionCall interface{}
LastMessage bool
}
const (
ChatPromptTemplate TemplateType = iota
ChatMessageTemplate
CompletionPromptTemplate
EditPromptTemplate
FunctionsPromptTemplate
)
type Evaluator struct {
cache *templateCache
}
func NewEvaluator(modelPath string) *Evaluator {
return &Evaluator{
cache: newTemplateCache(modelPath),
}
}
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
template := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
template = config.Model
}
switch templateType {
case CompletionPromptTemplate:
if config.TemplateConfig.Completion != "" {
template = config.TemplateConfig.Completion
}
case EditPromptTemplate:
if config.TemplateConfig.Edit != "" {
template = config.TemplateConfig.Edit
}
case ChatPromptTemplate:
if config.TemplateConfig.Chat != "" {
template = config.TemplateConfig.Chat
}
case FunctionsPromptTemplate:
if config.TemplateConfig.Functions != "" {
template = config.TemplateConfig.Functions
}
}
if template == "" {
return in.Input, nil
}
if config.TemplateConfig.JinjaTemplate {
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
}
return e.cache.evaluateTemplate(templateType, template, in)
}
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
}
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
conversation := make(map[string]interface{})
messages := make([]map[string]interface{}, len(messageData))
// convert from ChatMessageTemplateData to what the jinja template expects
for _, message := range messageData {
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
var data []byte
data, _ = json.Marshal(message.FunctionCall)
messages = append(messages, map[string]interface{}{
"role": message.RoleName,
"content": message.Content,
"tool_call": string(data),
})
}
conversation["messages"] = messages
// if tools are detected, add these
if len(funcs) > 0 {
conversation["tools"] = funcs
}
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
conversation := make(map[string]interface{})
conversation["system_prompt"] = in.SystemPrompt
conversation["content"] = in.Input
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
if config.TemplateConfig.JinjaTemplate {
var messageData []ChatMessageTemplateData
for messageIndex, i := range messages {
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
messageData = append(messageData, ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: config.Roles[i.Role],
RoleName: i.Role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex,
})
}
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
if err == nil {
return templatedInput
}
}
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
joinCharacter := "\n"
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
}
predInput = strings.Join(mess, joinCharacter)
log.Debug().Msgf("Prompt (before templating): %s", predInput)
promptTemplate := ChatPromptTemplate
if config.TemplateConfig.Functions != "" && shouldUseFn {
promptTemplate = FunctionsPromptTemplate
}
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
return predInput
}

View file

@ -0,0 +1,253 @@
package templates_test
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/templates"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .Content}}
{{.Content }}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>`
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content -}}
{{ else if .FunctionCall -}}
{{ toJson .FunctionCall -}}
{{ end -}}
<|eot_id|>`
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
"assistant": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
},
"function_call": {
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
},
"shouldUseFn": false,
},
"function_response": {
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: llama3,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "tool",
StringContent: "Response from tool",
},
},
"shouldUseFn": false,
},
}
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
"assistant": {
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
},
"function_call": {
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{
{
Name: "test",
Description: "test",
Parameters: nil,
},
},
"shouldUseFn": true,
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
},
},
"function_response": {
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: chatML,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "tool",
StringContent: "Response from tool",
},
},
},
}
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"config": &config.BackendConfig{
TemplateConfig: config.TemplateConfig{
ChatMessage: toolCallJinja,
JinjaTemplate: true,
},
},
"functions": []functions.Function{},
"shouldUseFn": false,
"messages": []schema.Message{
{
Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
},
}
var _ = Describe("Templates", func() {
Context("chat message ChatML", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message llama3", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
Context("chat message jinja", func() {
var evaluator *Evaluator
BeforeEach(func() {
evaluator = NewEvaluator("")
})
for key := range jinjaTest {
foo := jinjaTest[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
})
})