mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(functions): support models with no grammar, add tests (#2068)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
13012cfa70
commit
bbea62b907
13 changed files with 255 additions and 119 deletions
|
@ -11,9 +11,8 @@ import (
|
|||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -68,8 +67,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
return true
|
||||
})
|
||||
|
||||
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
|
||||
noActionToRun := len(results) > 0 && results[0].name == noAction
|
||||
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionToRun:
|
||||
|
@ -82,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
responses <- initialMessage
|
||||
|
||||
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
|
||||
result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
|
@ -105,7 +104,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
|
||||
default:
|
||||
for i, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
name, args := ss.Name, ss.Arguments
|
||||
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
|
@ -156,8 +155,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
|
@ -169,6 +166,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
funcs := input.Functions
|
||||
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
|
@ -182,18 +182,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
input.Grammar = functions.JSONBNF
|
||||
}
|
||||
|
||||
config.Grammar = input.Grammar
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
if shouldUseFn {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
}
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
switch {
|
||||
case !config.FunctionsConfig.NoGrammar && shouldUseFn:
|
||||
noActionGrammar := functions.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
|
@ -206,7 +206,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
@ -219,10 +218,17 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
case input.JSONFunctionGrammarObject != nil:
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||
default:
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream
|
||||
|
||||
|
@ -232,8 +238,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
|
||||
// If we are using the tokenizer template, we don't need to process the messages
|
||||
// unless we are processing functions
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || processFunctions {
|
||||
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
|
@ -346,11 +351,11 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
if config.TemplateConfig.Chat != "" && !shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
|
@ -370,7 +375,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
if shouldUseFn && config.Grammar != "" {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
}
|
||||
|
@ -388,7 +393,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
if !processFunctions {
|
||||
if !shouldUseFn {
|
||||
go process(predInput, input, config, ml, responses)
|
||||
} else {
|
||||
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||
|
@ -446,18 +451,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
// no streaming mode
|
||||
default:
|
||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !processFunctions {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
|
||||
noActionsToRun := len(results) > 0 && results[0].name == noActionName
|
||||
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
|
||||
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
|
@ -476,7 +481,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
name, args := ss.Name, ss.Arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
|
@ -534,16 +539,20 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
|
||||
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
arg := ""
|
||||
if len(funcResults) > 0 {
|
||||
arg = funcResults[0].Arguments
|
||||
}
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(args), &arguments)
|
||||
if err := json.Unmarshal([]byte(arg), &arguments); err != nil {
|
||||
log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object")
|
||||
}
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
|
@ -580,63 +589,3 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
|
|||
}
|
||||
return backend.Finetune(*config, prompt, prediction.Response), nil
|
||||
}
|
||||
|
||||
type funcCallResults struct {
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
|
||||
results := []funcCallResults{}
|
||||
|
||||
// TODO: use generics to avoid this code duplication
|
||||
if multipleResults {
|
||||
ss := []map[string]interface{}{}
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
for _, s := range ss {
|
||||
func_name, ok := s["function"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args, ok := s["arguments"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
} else {
|
||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name, ok := ss["function"]
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue