feat(templates): use a single template for multimodals messages (#3892)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-22 09:34:05 +02:00 committed by GitHub
parent a1d6cc93a8
commit ccc7cb0287
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 140 additions and 29 deletions

View file

@ -7,20 +7,60 @@ import (
"github.com/Masterminds/sprig/v3"
)
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
type MultiModalOptions struct {
TotalImages int
TotalAudios int
TotalVideos int
ImagesInMessage int
AudiosInMessage int
VideosInMessage int
}
type MultimodalContent struct {
ID int
}
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
if templateString == "" {
templateString = DefaultMultiModalTemplate
}
// compile the template
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
if err != nil {
return "", err
}
videos := []MultimodalContent{}
for i := 0; i < opts.VideosInMessage; i++ {
videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)})
}
audios := []MultimodalContent{}
for i := 0; i < opts.AudiosInMessage; i++ {
audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)})
}
images := []MultimodalContent{}
for i := 0; i < opts.ImagesInMessage; i++ {
images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)})
}
result := bytes.NewBuffer(nil)
// execute the template
err = tmpl.Execute(result, struct {
ID int
Text string
Audio []MultimodalContent
Images []MultimodalContent
Video []MultimodalContent
Text string
}{
ID: templateID,
Text: text,
Audio: audios,
Images: images,
Video: videos,
Text: text,
})
return result.String(), err
}

View file

@ -11,7 +11,77 @@ import (
var _ = Describe("EvaluateTemplate", func() {
Context("templating simple strings for multimodal chat", func() {
It("should template messages correctly", func() {
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 1,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-0]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 2,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 2,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-0][img-1]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 4,
TotalAudios: 1,
TotalVideos: 0,
ImagesInMessage: 2,
AudiosInMessage: 1,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0][img-2][img-3]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 3,
TotalAudios: 1,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 1,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0][img-2]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 0,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 0,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("bar"))
})
})
Context("templating with custom defaults", func() {
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("{{ range .Audio }}[audio-{{ add1 .ID}}]{{end}}{{ range .Images }}[img-{{ add1 .ID}}]{{end}}{{ range .Video }}[vid-{{ add1 .ID}}]{{end}}{{.Text}}", MultiModalOptions{
TotalImages: 1,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-1]bar"))
})