diff --git a/pkg/templates/evaluator.go b/pkg/templates/evaluator.go
index 7b2089b3..00b1c0a6 100644
--- a/pkg/templates/evaluator.go
+++ b/pkg/templates/evaluator.go
@@ -85,13 +85,13 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
}
if config.TemplateConfig.JinjaTemplate {
- return e.EvaluateJinjaTemplateForPrompt(templateType, template, in)
+ return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
}
return e.cache.EvaluateTemplate(templateType, template, in)
}
-func (e *Evaluator) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
+func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
}
@@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes
return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
}
-func (e *Evaluator) EvaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
+func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
conversation := make(map[string]interface{})
@@ -195,7 +195,7 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex,
}
- templatedChatMessage, err := e.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
+ templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
diff --git a/pkg/templates/evaluator_test.go b/pkg/templates/evaluator_test.go
index 06551e4d..ab150577 100644
--- a/pkg/templates/evaluator_test.go
+++ b/pkg/templates/evaluator_test.go
@@ -1,6 +1,9 @@
package templates_test
import (
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/schema"
+ "github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/templates"
. "github.com/onsi/ginkgo/v2"
@@ -41,126 +44,141 @@ Function response:
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
- "template": llama3,
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "user",
- RoleName: "user",
- Content: "A long time ago in a galaxy far, far away...",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: llama3,
+ },
+ },
+ "functions": []functions.Function{},
+ "shouldUseFn": false,
+ "messages": []schema.Message{
+ {
+ Role: "user",
+ StringContent: "A long time ago in a galaxy far, far away...",
+ },
},
},
"assistant": {
- "template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "assistant",
- RoleName: "assistant",
- Content: "A long time ago in a galaxy far, far away...",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: llama3,
+ },
},
+ "functions": []functions.Function{},
+ "messages": []schema.Message{
+ {
+ Role: "assistant",
+ StringContent: "A long time ago in a galaxy far, far away...",
+ },
+ },
+ "shouldUseFn": false,
},
"function_call": {
- "template": llama3,
+
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "assistant",
- RoleName: "assistant",
- Content: "",
- FunctionCall: map[string]string{"function": "test"},
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: llama3,
+ },
},
+ "functions": []functions.Function{},
+ "messages": []schema.Message{
+ {
+ Role: "assistant",
+ FunctionCall: map[string]string{"function": "test"},
+ },
+ },
+ "shouldUseFn": false,
},
"function_response": {
- "template": llama3,
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "tool",
- RoleName: "tool",
- Content: "Response from tool",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: llama3,
+ },
},
+ "functions": []functions.Function{},
+ "messages": []schema.Message{
+ {
+ Role: "tool",
+ StringContent: "Response from tool",
+ },
+ },
+ "shouldUseFn": false,
},
}
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": {
- "template": chatML,
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "user",
- RoleName: "user",
- Content: "A long time ago in a galaxy far, far away...",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: chatML,
+ },
+ },
+ "functions": []functions.Function{},
+ "shouldUseFn": false,
+ "messages": []schema.Message{
+ {
+ Role: "user",
+ StringContent: "A long time ago in a galaxy far, far away...",
+ },
},
},
"assistant": {
- "template": chatML,
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "assistant",
- RoleName: "assistant",
- Content: "A long time ago in a galaxy far, far away...",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: chatML,
+ },
},
+ "functions": []functions.Function{},
+ "messages": []schema.Message{
+ {
+ Role: "assistant",
+ StringContent: "A long time ago in a galaxy far, far away...",
+ },
+ },
+ "shouldUseFn": false,
},
"function_call": {
- "template": chatML,
"expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "assistant",
- RoleName: "assistant",
- Content: "",
- FunctionCall: map[string]string{"function": "test"},
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: chatML,
+ },
+ },
+ "functions": []functions.Function{
+ {
+ Name: "test",
+ Description: "test",
+ Parameters: nil,
+ },
+ },
+ "shouldUseFn": true,
+ "messages": []schema.Message{
+ {
+ Role: "assistant",
+ FunctionCall: map[string]string{"function": "test"},
+ },
},
},
"function_response": {
- "template": chatML,
"expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>",
- "data": ChatMessageTemplateData{
- SystemPrompt: "",
- Role: "tool",
- RoleName: "tool",
- Content: "Response from tool",
- FunctionCall: nil,
- FunctionName: "",
- LastMessage: false,
- Function: false,
- MessageIndex: 0,
+ "config": &config.BackendConfig{
+ TemplateConfig: config.TemplateConfig{
+ ChatMessage: chatML,
+ },
+ },
+ "functions": []functions.Function{},
+ "shouldUseFn": false,
+ "messages": []schema.Message{
+ {
+ Role: "tool",
+ StringContent: "Response from tool",
+ },
},
},
}
@@ -174,8 +192,7 @@ var _ = Describe("Templates", func() {
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
- templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
- Expect(err).ToNot(HaveOccurred())
+ templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
@@ -188,8 +205,7 @@ var _ = Describe("Templates", func() {
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
- templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
- Expect(err).ToNot(HaveOccurred())
+ templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}