feat(llama.cpp): add support for audio input (#5466)

* feat(llama.cpp): add support for audio input

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-05-26 16:06:03 +02:00 committed by GitHub
parent 9650d490d4
commit 88de2ea01a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 40 additions and 10 deletions

View file

@ -133,6 +133,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
});
}
// for each audio in the request, add the audio data
for (int i = 0; i < predict->audios_size(); i++) {
data["audio_data"].push_back(json
{
{"id", i},
{"data", predict->audios(i)},
});
}
data["stop"] = predict->stopprompts();
// data["n_probs"] = predict->nprobs();
//TODO: images,
@ -406,6 +415,16 @@ public:
}
}
const auto &audio_data = data.find("audio_data");
if (audio_data != data.end() && audio_data->is_array())
{
for (const auto &audio : *audio_data)
{
auto decoded_data = base64_decode(audio["data"].get<std::string>());
files.push_back(decoded_data);
}
}
// process files
mtmd::bitmaps bitmaps;
const bool has_mtmd = ctx_server.mctx != nullptr;
@ -416,10 +435,10 @@ public:
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
throw std::runtime_error("Failed to load image/audio");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}
@ -588,6 +607,16 @@ public:
}
}
const auto &audio_data = data.find("audio_data");
if (audio_data != data.end() && audio_data->is_array())
{
for (const auto &audio : *audio_data)
{
auto decoded_data = base64_decode(audio["data"].get<std::string>());
files.push_back(decoded_data);
}
}
// process files
mtmd::bitmaps bitmaps;
const bool has_mtmd = ctx_server.mctx != nullptr;
@ -598,10 +627,10 @@ public:
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
throw std::runtime_error("Failed to load image/audio");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}

View file

@ -308,7 +308,7 @@ func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *sch
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
case "audio_url", "audio", "input_audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {

View file

@ -22,7 +22,8 @@ type MultimodalContent struct {
}
// https://github.com/ggml-org/llama.cpp/blob/be1d4a13db26750fac702ceb3af88ae4f39dc9f4/tools/mtmd/mtmd.h#L42
const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}<__image__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
// from <__image__> to <__media__> https://github.com/ggml-org/llama.cpp/blob/79c137f77677b3c8ee3c60a7da033721b938399a/tools/mtmd/mtmd.cpp#L83
const DefaultMultiModalTemplate = "{{ range .Audio }}<__media__>{{end}}{{ range .Images }}<__media__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"
func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
if templateString == "" {

View file

@ -20,7 +20,7 @@ var _ = Describe("EvaluateTemplate", func() {
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("<__image__>bar"))
Expect(result).To(Equal("<__media__>bar"))
})
It("should handle messages with more images correctly", func() {
@ -33,7 +33,7 @@ var _ = Describe("EvaluateTemplate", func() {
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("<__image__><__image__>bar"))
Expect(result).To(Equal("<__media__><__media__>bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
@ -45,7 +45,7 @@ var _ = Describe("EvaluateTemplate", func() {
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0]<__image__><__image__>bar"))
Expect(result).To(Equal("<__media__><__media__><__media__>bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
@ -57,7 +57,7 @@ var _ = Describe("EvaluateTemplate", func() {
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0]<__image__>bar"))
Expect(result).To(Equal("<__media__><__media__>bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{