feat(functions): support models with no grammar, add tests (#2068)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-04-18 22:43:12 +02:00 committed by GitHub
parent 13012cfa70
commit bbea62b907
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 255 additions and 119 deletions

View file

@ -11,9 +11,8 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
@ -68,8 +67,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
return true
})
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
switch {
case noActionToRun:
@ -82,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
responses <- initialMessage
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
@ -105,7 +104,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
default:
for i, ss := range results {
name, args := ss.name, ss.arguments
name, args := ss.Name, ss.Arguments
initialMessage := schema.OpenAIResponse{
ID: id,
@ -156,8 +155,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
@ -169,6 +166,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
log.Debug().Msgf("Configuration read: %+v", config)
funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
@ -182,18 +182,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
input.Grammar = functions.JSONBNF
}
config.Grammar = input.Grammar
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
if shouldUseFn {
log.Debug().Msgf("Response needs to process functions")
}
processFunctions = true
noActionGrammar := grammar.Function{
switch {
case !config.FunctionsConfig.NoGrammar && shouldUseFn:
noActionGrammar := functions.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
@ -206,7 +206,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
@ -219,10 +218,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
}
// process functions if we have any defined or if we have a function call string
// functions are not supported in stream mode (yet?)
toStream := input.Stream
@ -232,8 +238,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || processFunctions {
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
@ -346,11 +351,11 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
if config.TemplateConfig.Chat != "" && !shouldUseFn {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
if config.TemplateConfig.Functions != "" && shouldUseFn {
templateFile = config.TemplateConfig.Functions
}
@ -370,7 +375,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
if shouldUseFn && config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}
@ -388,7 +393,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
responses := make(chan schema.OpenAIResponse)
if !processFunctions {
if !shouldUseFn {
go process(predInput, input, config, ml, responses)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
@ -446,18 +451,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !processFunctions {
if !shouldUseFn {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
@ -476,7 +481,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
for _, ss := range results {
name, args := ss.name, ss.arguments
name, args := ss.Name, ss.Arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
@ -534,16 +539,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
// Return the prediction in the response body
return c.JSON(resp)
}
}
}
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply")
arg := ""
if len(funcResults) > 0 {
arg = funcResults[0].Arguments
}
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
if err := json.Unmarshal([]byte(arg), &arguments); err != nil {
log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object")
}
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
@ -580,63 +589,3 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
return results
}

View file

@ -12,7 +12,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
@ -70,7 +70,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
input.Grammar = functions.JSONBNF
}
config.Grammar = input.Grammar

View file

@ -12,7 +12,7 @@ import (
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@ -145,7 +145,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
}
if input.ToolsChoice != nil {
var toolChoice grammar.Tool
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string: