feat(templates): use a single template for multimodals messages (#3892)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-22 09:34:05 +02:00 committed by GitHub
parent a1d6cc93a8
commit ccc7cb0287
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 140 additions and 29 deletions

View file

@ -7,20 +7,60 @@ import (
"github.com/Masterminds/sprig/v3"
)
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
type MultiModalOptions struct {
TotalImages int
TotalAudios int
TotalVideos int
ImagesInMessage int
AudiosInMessage int
VideosInMessage int
}
type MultimodalContent struct {
ID int
}
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
if templateString == "" {
templateString = DefaultMultiModalTemplate
}
// compile the template
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
if err != nil {
return "", err
}
videos := []MultimodalContent{}
for i := 0; i < opts.VideosInMessage; i++ {
videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)})
}
audios := []MultimodalContent{}
for i := 0; i < opts.AudiosInMessage; i++ {
audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)})
}
images := []MultimodalContent{}
for i := 0; i < opts.ImagesInMessage; i++ {
images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)})
}
result := bytes.NewBuffer(nil)
// execute the template
err = tmpl.Execute(result, struct {
ID int
Text string
Audio []MultimodalContent
Images []MultimodalContent
Video []MultimodalContent
Text string
}{
ID: templateID,
Text: text,
Audio: audios,
Images: images,
Video: videos,
Text: text,
})
return result.String(), err
}