feat: allow to specify a reply prefix (#4931)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-03-02 16:07:32 +01:00 committed by GitHub
parent ff85f01459
commit a7b4001b75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 26 deletions

View file

@ -116,6 +116,11 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
} }
if tokenCallback != nil { if tokenCallback != nil {
if c.TemplateConfig.ReplyPrefix != "" {
tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage)
}
ss := "" ss := ""
var partialRune []byte var partialRune []byte
@ -165,8 +170,13 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
response := string(reply.Message)
if c.TemplateConfig.ReplyPrefix != "" {
response = c.TemplateConfig.ReplyPrefix + response
}
return LLMResponse{ return LLMResponse{
Response: string(reply.Message), Response: response,
Usage: tokenUsage, Usage: tokenUsage,
}, err }, err
} }

View file

@ -130,28 +130,28 @@ type LLMConfig struct {
TrimSpace []string `yaml:"trimspace"` TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"` TrimSuffix []string `yaml:"trimsuffix"`
ContextSize *int `yaml:"context_size"` ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"` NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"` LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"` LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"` LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"` LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"` LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"` NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"` DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"` NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"` Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"` LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
DType string `yaml:"dtype"` // vLLM DType string `yaml:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj"` MMProj string `yaml:"mmproj"`
FlashAttention bool `yaml:"flash_attention"` FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"` NoKVOffloading bool `yaml:"no_kv_offloading"`
@ -171,9 +171,9 @@ type LLMConfig struct {
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM // LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct { type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image"` LimitImagePerPrompt int `yaml:"image"`
LimitVideoPerPrompt int `yaml:"video"` LimitVideoPerPrompt int `yaml:"video"`
LimitAudioPerPrompt int `yaml:"audio"` LimitAudioPerPrompt int `yaml:"audio"`
} }
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend // AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
@ -213,6 +213,8 @@ type TemplateConfig struct {
Multimodal string `yaml:"multimodal"` Multimodal string `yaml:"multimodal"`
JinjaTemplate bool `yaml:"jinja_template"` JinjaTemplate bool `yaml:"jinja_template"`
ReplyPrefix string `yaml:"reply_prefix"`
} }
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {