From 648ffdf449d1cc114bb0d8fb4ceb03b6ba3de4ca Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 4 Oct 2024 18:32:29 +0200 Subject: [PATCH] feat(multimodal): allow to template placeholders (#3728) feat(multimodal): allow to template image placeholders Signed-off-by: Ettore Di Giacinto --- core/config/backend_config.go | 4 ++++ core/http/endpoints/openai/request.go | 21 ++++++++++++++++++--- pkg/model/initializers.go | 2 +- pkg/templates/multimodal.go | 24 ++++++++++++++++++++++++ pkg/templates/multimodal_test.go | 19 +++++++++++++++++++ 5 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 pkg/templates/multimodal.go create mode 100644 pkg/templates/multimodal_test.go diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 8db94f7c..79e134d8 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -196,6 +196,10 @@ type TemplateConfig struct { // JoinChatMessagesByCharacter is a string that will be used to join chat messages together. // It defaults to \n JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` + + Video string `yaml:"video"` + Image string `yaml:"image"` + Audio string `yaml:"audio"` } func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index d6182a39..a418433e 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -12,6 +12,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) @@ -168,8 +169,13 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque continue CONTENT } input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + + t := "[vid-{{.ID}}]{{.Text}}" + if config.TemplateConfig.Video != "" { + t = config.TemplateConfig.Video + } // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[vid-%d]", vidIndex) + input.Messages[i].StringContent + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent) vidIndex++ case "audio_url", "audio": // Decode content as base64 either if it's an URL or base64 text @@ -180,7 +186,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque } input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[audio-%d]", audioIndex) + input.Messages[i].StringContent + t := "[audio-{{.ID}}]{{.Text}}" + if config.TemplateConfig.Audio != "" { + t = config.TemplateConfig.Audio + } + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent) audioIndex++ case "image_url", "image": // Decode content as base64 either if it's an URL or base64 text @@ -189,9 +199,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque log.Error().Msgf("Failed encoding image: %s", err) continue CONTENT } + + t := "[img-{{.ID}}]{{.Text}}" + if config.TemplateConfig.Image != "" { + t = config.TemplateConfig.Image + } input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", imgIndex) + input.Messages[i].StringContent + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent) imgIndex++ } } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 9ecd77a6..1171de4d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -314,7 +314,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string client = NewModel(modelID, serverAddress, process) } else { - log.Debug().Msg("external backend is uri") + log.Debug().Msg("external backend is a uri") // address client = NewModel(modelID, uri, nil) } diff --git a/pkg/templates/multimodal.go b/pkg/templates/multimodal.go new file mode 100644 index 00000000..cc56c492 --- /dev/null +++ b/pkg/templates/multimodal.go @@ -0,0 +1,24 @@ +package templates + +import ( + "bytes" + "text/template" +) + +func TemplateMultiModal(templateString string, templateID int, text string) (string, error) { + // compile the template + tmpl, err := template.New("template").Parse(templateString) + if err != nil { + return "", err + } + result := bytes.NewBuffer(nil) + // execute the template + err = tmpl.Execute(result, struct { + ID int + Text string + }{ + ID: templateID, + Text: text, + }) + return result.String(), err +} diff --git a/pkg/templates/multimodal_test.go b/pkg/templates/multimodal_test.go new file mode 100644 index 00000000..d1a8bd5b --- /dev/null +++ b/pkg/templates/multimodal_test.go @@ -0,0 +1,19 @@ +package templates_test + +import ( + . "github.com/mudler/LocalAI/pkg/templates" // Update with your module path + + // Update with your module path + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("EvaluateTemplate", func() { + Context("templating simple strings for multimodal chat", func() { + It("should template messages correctly", func() { + result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("[img-1]bar")) + }) + }) +})