diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index be277bfa..bac46095 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -133,6 +133,15 @@ json parse_options(bool streaming, const backend::PredictOptions* predict) }); } + // for each audio in the request, add the audio data + for (int i = 0; i < predict->audios_size(); i++) { + data["audio_data"].push_back(json + { + {"id", i}, + {"data", predict->audios(i)}, + }); + } + data["stop"] = predict->stopprompts(); // data["n_probs"] = predict->nprobs(); //TODO: images, @@ -406,6 +415,16 @@ public: } } + const auto &audio_data = data.find("audio_data"); + if (audio_data != data.end() && audio_data->is_array()) + { + for (const auto &audio : *audio_data) + { + auto decoded_data = base64_decode(audio["data"].get()); + files.push_back(decoded_data); + } + } + // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; @@ -416,10 +435,10 @@ public: for (auto & file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); if (!bmp.ptr) { - throw std::runtime_error("Failed to load image"); + throw std::runtime_error("Failed to load image/audio"); } // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); bmp.set_id(hash.c_str()); bitmaps.entries.push_back(std::move(bmp)); } @@ -588,6 +607,16 @@ public: } } + const auto &audio_data = data.find("audio_data"); + if (audio_data != data.end() && audio_data->is_array()) + { + for (const auto &audio : *audio_data) + { + auto decoded_data = base64_decode(audio["data"].get()); + files.push_back(decoded_data); + } + } + // process files mtmd::bitmaps bitmaps; const bool has_mtmd = ctx_server.mctx != nullptr; @@ -598,10 +627,10 @@ public: for (auto & file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); if (!bmp.ptr) { - throw std::runtime_error("Failed to load image"); + throw std::runtime_error("Failed to load image/audio"); } // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); bmp.set_id(hash.c_str()); bitmaps.entries.push_back(std::move(bmp)); } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index b6934a82..09f6b6ee 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -308,7 +308,7 @@ func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *sch input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff vidIndex++ nrOfVideosInMessage++ - case "audio_url", "audio": + case "audio_url", "audio", "input_audio": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) if err != nil { diff --git a/pkg/templates/multimodal.go b/pkg/templates/multimodal.go index 1436b85d..bc8bad7e 100644 --- a/pkg/templates/multimodal.go +++ b/pkg/templates/multimodal.go @@ -22,7 +22,8 @@ type MultimodalContent struct { } // https://github.com/ggml-org/llama.cpp/blob/be1d4a13db26750fac702ceb3af88ae4f39dc9f4/tools/mtmd/mtmd.h#L42 -const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}<__image__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}" +// from <__image__> to <__media__> https://github.com/ggml-org/llama.cpp/blob/79c137f77677b3c8ee3c60a7da033721b938399a/tools/mtmd/mtmd.cpp#L83 +const DefaultMultiModalTemplate = "{{ range .Audio }}<__media__>{{end}}{{ range .Images }}<__media__>{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}" func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) { if templateString == "" { diff --git a/pkg/templates/multimodal_test.go b/pkg/templates/multimodal_test.go index d0918697..adea6b52 100644 --- a/pkg/templates/multimodal_test.go +++ b/pkg/templates/multimodal_test.go @@ -20,7 +20,7 @@ var _ = Describe("EvaluateTemplate", func() { VideosInMessage: 0, }, "bar") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("<__image__>bar")) + Expect(result).To(Equal("<__media__>bar")) }) It("should handle messages with more images correctly", func() { @@ -33,7 +33,7 @@ var _ = Describe("EvaluateTemplate", func() { VideosInMessage: 0, }, "bar") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("<__image__><__image__>bar")) + Expect(result).To(Equal("<__media__><__media__>bar")) }) It("should handle messages with more images correctly", func() { result, err := TemplateMultiModal("", MultiModalOptions{ @@ -45,7 +45,7 @@ var _ = Describe("EvaluateTemplate", func() { VideosInMessage: 0, }, "bar") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("[audio-0]<__image__><__image__>bar")) + Expect(result).To(Equal("<__media__><__media__><__media__>bar")) }) It("should handle messages with more images correctly", func() { result, err := TemplateMultiModal("", MultiModalOptions{ @@ -57,7 +57,7 @@ var _ = Describe("EvaluateTemplate", func() { VideosInMessage: 0, }, "bar") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("[audio-0]<__image__>bar")) + Expect(result).To(Equal("<__media__><__media__>bar")) }) It("should handle messages with more images correctly", func() { result, err := TemplateMultiModal("", MultiModalOptions{