feat(llama2): add template for chat messages (#782)

Co-authored-by: Aman Karmani <aman@tmm1.net>

Lays some of the groundwork for LLAMA2 compatibility as well as other future models with complex prompting schemes.

Started small refactoring in pkg/model/loader.go regarding template loading. Currently still a part of ModelLoader, but should be easy to add template loading for situations other than overall prompt templates and the new chat-specific per-message templates
Adds support for new chat-endpoint-specific, per-message templates as an alternative to the existing Role: XYZ sprintf method.
Includes a temporary prompt template as an example, since I have a few questions before we merge in the model-gallery side changes (see )
Minor debug logging changes.
This commit is contained in:
Dave 2023-07-22 11:31:39 -04:00 committed by GitHub
parent 5ee186b8e5
commit c6bf67f446
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 237 additions and 123 deletions

View file

@ -43,12 +43,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
model, input, err := readInput(c, o.Loader, true)
modelFile, input, err := readInput(c, o.Loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -110,9 +110,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
var predInput string
mess := []string{}
for _, i := range input.Messages {
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
@ -124,31 +125,55 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
r := config.Roles[role]
contentExists := i.Content != nil && *i.Content != ""
if r != "" {
if contentExists {
content = fmt.Sprint(r, " ", *i.Content)
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: *i.Content,
MessageIndex: messageIndex,
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, " ", *i.Content)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
}
} else {
if contentExists {
content = fmt.Sprint(*i.Content)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
} else {
if contentExists {
content = fmt.Sprint(*i.Content)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
@ -181,10 +206,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
Input string
Functions []grammar.Function
}{
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
Functions: funcs,
})