feat: use tokenizer.apply_chat_template() in vLLM (#1990)

Use tokenizer.apply_chat_template() in vLLM

Signed-off-by: Ludovic LEROUX <ludovic@inpher.io>
This commit is contained in:
Ludovic Leroux 2024-04-11 13:20:22 -04:00 committed by GitHub
parent cbda06fb96
commit 12c0d9443e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 3088 additions and 989 deletions

View file

@ -230,112 +230,154 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || processFunctions {
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(input.Messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(input.Messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
content = fmt.Sprint(r, " ", string(j))
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
} else {
content = string(j)
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
mess = append(mess, content)
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
switch {
case toStream:
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
@ -343,45 +385,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
switch {
case toStream:
responses := make(chan schema.OpenAIResponse)
if !processFunctions {
@ -563,7 +567,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err