mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: allow to specify a reply prefix (#4931)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
ff85f01459
commit
a7b4001b75
2 changed files with 38 additions and 26 deletions
|
@ -116,6 +116,11 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenCallback != nil {
|
if tokenCallback != nil {
|
||||||
|
|
||||||
|
if c.TemplateConfig.ReplyPrefix != "" {
|
||||||
|
tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage)
|
||||||
|
}
|
||||||
|
|
||||||
ss := ""
|
ss := ""
|
||||||
|
|
||||||
var partialRune []byte
|
var partialRune []byte
|
||||||
|
@ -165,8 +170,13 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||||
|
|
||||||
|
response := string(reply.Message)
|
||||||
|
if c.TemplateConfig.ReplyPrefix != "" {
|
||||||
|
response = c.TemplateConfig.ReplyPrefix + response
|
||||||
|
}
|
||||||
|
|
||||||
return LLMResponse{
|
return LLMResponse{
|
||||||
Response: string(reply.Message),
|
Response: response,
|
||||||
Usage: tokenUsage,
|
Usage: tokenUsage,
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,28 +130,28 @@ type LLMConfig struct {
|
||||||
TrimSpace []string `yaml:"trimspace"`
|
TrimSpace []string `yaml:"trimspace"`
|
||||||
TrimSuffix []string `yaml:"trimsuffix"`
|
TrimSuffix []string `yaml:"trimsuffix"`
|
||||||
|
|
||||||
ContextSize *int `yaml:"context_size"`
|
ContextSize *int `yaml:"context_size"`
|
||||||
NUMA bool `yaml:"numa"`
|
NUMA bool `yaml:"numa"`
|
||||||
LoraAdapter string `yaml:"lora_adapter"`
|
LoraAdapter string `yaml:"lora_adapter"`
|
||||||
LoraBase string `yaml:"lora_base"`
|
LoraBase string `yaml:"lora_base"`
|
||||||
LoraAdapters []string `yaml:"lora_adapters"`
|
LoraAdapters []string `yaml:"lora_adapters"`
|
||||||
LoraScales []float32 `yaml:"lora_scales"`
|
LoraScales []float32 `yaml:"lora_scales"`
|
||||||
LoraScale float32 `yaml:"lora_scale"`
|
LoraScale float32 `yaml:"lora_scale"`
|
||||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||||
DraftModel string `yaml:"draft_model"`
|
DraftModel string `yaml:"draft_model"`
|
||||||
NDraft int32 `yaml:"n_draft"`
|
NDraft int32 `yaml:"n_draft"`
|
||||||
Quantization string `yaml:"quantization"`
|
Quantization string `yaml:"quantization"`
|
||||||
LoadFormat string `yaml:"load_format"`
|
LoadFormat string `yaml:"load_format"`
|
||||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||||
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
|
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
|
||||||
DType string `yaml:"dtype"` // vLLM
|
DType string `yaml:"dtype"` // vLLM
|
||||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
|
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
|
||||||
MMProj string `yaml:"mmproj"`
|
MMProj string `yaml:"mmproj"`
|
||||||
|
|
||||||
FlashAttention bool `yaml:"flash_attention"`
|
FlashAttention bool `yaml:"flash_attention"`
|
||||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||||
|
@ -171,9 +171,9 @@ type LLMConfig struct {
|
||||||
|
|
||||||
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||||
type LimitMMPerPrompt struct {
|
type LimitMMPerPrompt struct {
|
||||||
LimitImagePerPrompt int `yaml:"image"`
|
LimitImagePerPrompt int `yaml:"image"`
|
||||||
LimitVideoPerPrompt int `yaml:"video"`
|
LimitVideoPerPrompt int `yaml:"video"`
|
||||||
LimitAudioPerPrompt int `yaml:"audio"`
|
LimitAudioPerPrompt int `yaml:"audio"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
||||||
|
@ -213,6 +213,8 @@ type TemplateConfig struct {
|
||||||
Multimodal string `yaml:"multimodal"`
|
Multimodal string `yaml:"multimodal"`
|
||||||
|
|
||||||
JinjaTemplate bool `yaml:"jinja_template"`
|
JinjaTemplate bool `yaml:"jinja_template"`
|
||||||
|
|
||||||
|
ReplyPrefix string `yaml:"reply_prefix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue