diff --git a/pkg/templates/evaluator.go b/pkg/templates/evaluator.go index 7b2089b3..00b1c0a6 100644 --- a/pkg/templates/evaluator.go +++ b/pkg/templates/evaluator.go @@ -85,13 +85,13 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config } if config.TemplateConfig.JinjaTemplate { - return e.EvaluateJinjaTemplateForPrompt(templateType, template, in) + return e.evaluateJinjaTemplateForPrompt(templateType, template, in) } return e.cache.EvaluateTemplate(templateType, template, in) } -func (e *Evaluator) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { +func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { return e.cache.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) } @@ -120,7 +120,7 @@ func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMes return e.cache.EvaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) } -func (e *Evaluator) EvaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { +func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { conversation := make(map[string]interface{}) @@ -195,7 +195,7 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), MessageIndex: messageIndex, } - templatedChatMessage, err := e.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) if err != nil { log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") } else { diff --git a/pkg/templates/evaluator_test.go b/pkg/templates/evaluator_test.go index 06551e4d..ab150577 100644 --- a/pkg/templates/evaluator_test.go +++ b/pkg/templates/evaluator_test.go @@ -1,6 +1,9 @@ package templates_test import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/functions" . "github.com/mudler/LocalAI/pkg/templates" . "github.com/onsi/ginkgo/v2" @@ -41,126 +44,141 @@ Function response: var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ "user": { - "template": llama3, "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "user", - RoleName: "user", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "user", + StringContent: "A long time ago in a galaxy far, far away...", + }, }, }, "assistant": { - "template": llama3, "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + "shouldUseFn": false, }, "function_call": { - "template": llama3, + "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "", - FunctionCall: map[string]string{"function": "test"}, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + FunctionCall: map[string]string{"function": "test"}, + }, + }, + "shouldUseFn": false, }, "function_response": { - "template": llama3, "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "tool", - RoleName: "tool", - Content: "Response from tool", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: llama3, + }, }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "tool", + StringContent: "Response from tool", + }, + }, + "shouldUseFn": false, }, } var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ "user": { - "template": chatML, "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "user", - RoleName: "user", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "user", + StringContent: "A long time ago in a galaxy far, far away...", + }, }, }, "assistant": { - "template": chatML, "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "A long time ago in a galaxy far, far away...", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, }, + "functions": []functions.Function{}, + "messages": []schema.Message{ + { + Role: "assistant", + StringContent: "A long time ago in a galaxy far, far away...", + }, + }, + "shouldUseFn": false, }, "function_call": { - "template": chatML, "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "assistant", - RoleName: "assistant", - Content: "", - FunctionCall: map[string]string{"function": "test"}, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{ + { + Name: "test", + Description: "test", + Parameters: nil, + }, + }, + "shouldUseFn": true, + "messages": []schema.Message{ + { + Role: "assistant", + FunctionCall: map[string]string{"function": "test"}, + }, }, }, "function_response": { - "template": chatML, "expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>", - "data": ChatMessageTemplateData{ - SystemPrompt: "", - Role: "tool", - RoleName: "tool", - Content: "Response from tool", - FunctionCall: nil, - FunctionName: "", - LastMessage: false, - Function: false, - MessageIndex: 0, + "config": &config.BackendConfig{ + TemplateConfig: config.TemplateConfig{ + ChatMessage: chatML, + }, + }, + "functions": []functions.Function{}, + "shouldUseFn": false, + "messages": []schema.Message{ + { + Role: "tool", + StringContent: "Response from tool", + }, }, }, } @@ -174,8 +192,7 @@ var _ = Describe("Templates", func() { for key := range chatMLTestMatch { foo := chatMLTestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) - Expect(err).ToNot(HaveOccurred()) + templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) Expect(templated).To(Equal(foo["expected"]), templated) }) } @@ -188,8 +205,7 @@ var _ = Describe("Templates", func() { for key := range llama3TestMatch { foo := llama3TestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := evaluator.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) - Expect(err).ToNot(HaveOccurred()) + templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) Expect(templated).To(Equal(foo["expected"]), templated) }) }