Test TemplateMessages

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-12-07 22:10:13 +01:00
parent 91465797d3
commit d611d16ac4
2 changed files with 112 additions and 96 deletions

View file

@ -85,13 +85,13 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config
} }
if config.TemplateConfig.JinjaTemplate { if config.TemplateConfig.JinjaTemplate {
return e.EvaluateJinjaTemplateForPrompt(templateType, template, in) return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
} }
return e.cache.EvaluateTemplate(templateType, template, in) return e.cache.EvaluateTemplate(templateType, template, in)
} }
func (e *Evaluator) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
} }
@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes
return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
} }
func (e *Evaluator) EvaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
conversation := make(map[string]interface{}) conversation := make(map[string]interface{})
@ -195,7 +195,7 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
MessageIndex: messageIndex, MessageIndex: messageIndex,
} }
templatedChatMessage, err := e.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil { if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else { } else {

View file

@ -1,6 +1,9 @@
package templates_test package templates_test
import ( import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/templates" . "github.com/mudler/LocalAI/pkg/templates"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -41,126 +44,141 @@ Function response:
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": { "user": {
"template": llama3,
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "user", ChatMessage: llama3,
RoleName: "user", },
Content: "A long time ago in a galaxy far, far away...", },
FunctionCall: nil, "functions": []functions.Function{},
FunctionName: "", "shouldUseFn": false,
LastMessage: false, "messages": []schema.Message{
Function: false, {
MessageIndex: 0, Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
}, },
}, },
"assistant": { "assistant": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "assistant", ChatMessage: llama3,
RoleName: "assistant", },
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
}, },
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
}, },
"function_call": { "function_call": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "assistant", ChatMessage: llama3,
RoleName: "assistant", },
Content: "",
FunctionCall: map[string]string{"function": "test"},
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
}, },
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
},
"shouldUseFn": false,
}, },
"function_response": { "function_response": {
"template": llama3,
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "tool", ChatMessage: llama3,
RoleName: "tool", },
Content: "Response from tool",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
}, },
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "tool",
StringContent: "Response from tool",
},
},
"shouldUseFn": false,
}, },
} }
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
"user": { "user": {
"template": chatML,
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "user", ChatMessage: chatML,
RoleName: "user", },
Content: "A long time ago in a galaxy far, far away...", },
FunctionCall: nil, "functions": []functions.Function{},
FunctionName: "", "shouldUseFn": false,
LastMessage: false, "messages": []schema.Message{
Function: false, {
MessageIndex: 0, Role: "user",
StringContent: "A long time ago in a galaxy far, far away...",
},
}, },
}, },
"assistant": { "assistant": {
"template": chatML,
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "assistant", ChatMessage: chatML,
RoleName: "assistant", },
Content: "A long time ago in a galaxy far, far away...",
FunctionCall: nil,
FunctionName: "",
LastMessage: false,
Function: false,
MessageIndex: 0,
}, },
"functions": []functions.Function{},
"messages": []schema.Message{
{
Role: "assistant",
StringContent: "A long time ago in a galaxy far, far away...",
},
},
"shouldUseFn": false,
}, },
"function_call": { "function_call": {
"template": chatML,
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>", "expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "assistant", ChatMessage: chatML,
RoleName: "assistant", },
Content: "", },
FunctionCall: map[string]string{"function": "test"}, "functions": []functions.Function{
FunctionName: "", {
LastMessage: false, Name: "test",
Function: false, Description: "test",
MessageIndex: 0, Parameters: nil,
},
},
"shouldUseFn": true,
"messages": []schema.Message{
{
Role: "assistant",
FunctionCall: map[string]string{"function": "test"},
},
}, },
}, },
"function_response": { "function_response": {
"template": chatML,
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>", "expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"data": ChatMessageTemplateData{ "config": &config.BackendConfig{
SystemPrompt: "", TemplateConfig: config.TemplateConfig{
Role: "tool", ChatMessage: chatML,
RoleName: "tool", },
Content: "Response from tool", },
FunctionCall: nil, "functions": []functions.Function{},
FunctionName: "", "shouldUseFn": false,
LastMessage: false, "messages": []schema.Message{
Function: false, {
MessageIndex: 0, Role: "tool",
StringContent: "Response from tool",
},
}, },
}, },
} }
@ -174,8 +192,7 @@ var _ = Describe("Templates", func() {
for key := range chatMLTestMatch { for key := range chatMLTestMatch {
foo := chatMLTestMatch[key] foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() { It("renders correctly `"+key+"`", func() {
templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated) Expect(templated).To(Equal(foo["expected"]), templated)
}) })
} }
@ -188,8 +205,7 @@ var _ = Describe("Templates", func() {
for key := range llama3TestMatch { for key := range llama3TestMatch {
foo := llama3TestMatch[key] foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() { It("renders correctly `"+key+"`", func() {
templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated) Expect(templated).To(Equal(foo["expected"]), templated)
}) })
} }