From 1d6afbd65d24b46c74f71f4b593f359efb54bae3 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 2 Feb 2025 13:25:03 +0100 Subject: [PATCH] feat(llama.cpp): Add support to grammar triggers (#4733) Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 7 +++++++ backend/cpp/llama/grpc-server.cpp | 20 ++++++++++++++++++++ core/backend/options.go | 10 ++++++++++ pkg/functions/parse.go | 10 +++++++++- 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/backend/backend.proto b/backend/backend.proto index fea4214f..bd75adc5 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -163,6 +163,11 @@ message Reply { double timing_token_generation = 5; } +message GrammarTrigger { + string word = 1; + bool at_start = 2; +} + message ModelOptions { string Model = 1; int32 ContextSize = 2; @@ -247,6 +252,8 @@ message ModelOptions { string CacheTypeKey = 63; string CacheTypeValue = 64; + + repeated GrammarTrigger GrammarTriggers = 65; } message Result { diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 9aeb34db..1e9a3551 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -468,6 +468,9 @@ struct llama_server_context bool add_bos_token = true; bool has_eos_token = true; + bool grammar_lazy = false; + std::vector grammar_trigger_words; + int32_t n_ctx; // total context for all clients / slots // system prompt @@ -706,6 +709,8 @@ struct llama_server_context slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot->sparams.grammar_trigger_words = grammar_trigger_words; + slot->sparams.grammar_lazy = grammar_lazy; if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { // Might be better to reject the request with a 400 ? @@ -2374,6 +2379,21 @@ static void params_parse(const backend::ModelOptions* request, if ( request->ropefreqscale() != 0.0f ) { params.rope_freq_scale = request->ropefreqscale(); } + + if (request->grammartriggers_size() > 0) { + LOG_INFO("configuring grammar triggers", {}); + llama.grammar_lazy = true; + for (int i = 0; i < request->grammartriggers_size(); i++) { + common_grammar_trigger trigger; + trigger.word = request->grammartriggers(i).word(); + trigger.at_start = request->grammartriggers(i).at_start(); + llama.grammar_trigger_words.push_back(trigger); + LOG_INFO("grammar trigger", { + { "word", trigger.word }, + { "at_start", trigger.at_start } + }); + } + } } diff --git a/core/backend/options.go b/core/backend/options.go index 92a42893..3201142d 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -118,9 +118,19 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions { nGPULayers = *c.NGPULayers } + triggers := make([]*pb.GrammarTrigger, 0) + for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers { + triggers = append(triggers, &pb.GrammarTrigger{ + Word: t.Word, + AtStart: t.AtStart, + }) + + } + return &pb.ModelOptions{ CUDA: c.CUDA || c.Diffusers.CUDA, SchedulerType: c.Diffusers.SchedulerType, + GrammarTriggers: triggers, PipelineType: c.Diffusers.PipelineType, CFGScale: c.CFGScale, LoraAdapter: c.LoraAdapter, diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 50cbb27b..30338ffd 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -47,6 +47,14 @@ type GrammarConfig struct { // SchemaType can be configured to use a specific schema type to force the grammar // available : json, llama3.1 SchemaType string `yaml:"schema_type"` + + GrammarTriggers []GrammarTrigger `yaml:"triggers"` +} + +type GrammarTrigger struct { + // Trigger is the string that triggers the grammar + Word string `yaml:"word"` + AtStart bool `yaml:"at_start"` } // FunctionsConfig is the configuration for the tool/function call. @@ -361,6 +369,6 @@ func ParseFunctionCallArgs(functionArguments string, functionConfig FunctionsCon } jsonBytes, _ := json.Marshal(args) - + return string(jsonBytes) }