feat(functions): simplify parsing, read functions as list (#2340)

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-05-18 09:35:28 +02:00 committed by GitHub
parent 9ab8f8f5e0
commit 02f1b477df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 132 deletions

View file

@ -2,7 +2,6 @@ package functions
import ( import (
"encoding/json" "encoding/json"
"fmt"
"regexp" "regexp"
"strings" "strings"
@ -68,9 +67,6 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
log.Debug().Msgf("LLM result(processed): %s", llmresult) log.Debug().Msgf("LLM result(processed): %s", llmresult)
multipleResults := functionConfig.ParallelCalls
useGrammars := !functionConfig.NoGrammar
functionNameKey := "function" functionNameKey := "function"
if functionConfig.FunctionName { if functionConfig.FunctionName {
functionNameKey = "name" functionNameKey = "name"
@ -78,38 +74,51 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
results := []FuncCallResults{} results := []FuncCallResults{}
returnResult := func(s string) (name, arguments string, e error) { returnResult := func(s string) (result []FuncCallResults, e error) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?) // As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{} var ss []map[string]interface{}
// This prevent newlines to break JSON parsing for clients result = make([]FuncCallResults, 0)
s = utils.EscapeNewLines(s) s = utils.EscapeNewLines(s)
err := json.Unmarshal([]byte(s), &ss) err := json.Unmarshal([]byte(s), &ss)
if err != nil {
// If the LLM result is a single object, try unmarshaling it into a single map
var singleObj map[string]interface{}
err = json.Unmarshal([]byte(s), &singleObj)
if err != nil { if err != nil {
log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result") log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
} else {
ss = []map[string]interface{}{singleObj}
} }
}
log.Debug().Msgf("Function return: %s %+v", s, ss) log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
// The grammar defines the function name as "function", while OpenAI returns "name" // The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss[functionNameKey] func_name, ok := s[functionNameKey]
if !ok { if !ok {
return "", "", fmt.Errorf("unable to find function name in result") continue
//return result, fmt.Errorf("unable to find function name in result")
} }
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok { if !ok {
return "", "", fmt.Errorf("unable to find arguments in result") continue
//return result, fmt.Errorf("unable to find arguments in result")
} }
d, _ := json.Marshal(args) d, _ := json.Marshal(args)
funcName, ok := func_name.(string) funcName, ok := func_name.(string)
if !ok { if !ok {
return "", "", fmt.Errorf("unable to cast function name to string") continue
//return result, fmt.Errorf("unable to cast function name to string")
} }
return funcName, string(d), nil result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)})
}
return result, nil
} }
// if no grammar is used, we have to extract function and arguments from the result
if !useGrammars {
// the response is a string that we have to parse // the response is a string that we have to parse
result := make(map[string]string) result := make(map[string]string)
@ -131,9 +140,8 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
if functionName == "" { if functionName == "" {
return results return results
} }
results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
} else if functionConfig.JSONRegexMatch != "" { } else if functionConfig.JSONRegexMatch != "" {
//re := regexp.MustCompile(`(?s)<tool_call>(.*?)</tool_call>`)
//m:= re.FindStringSubmatch(`<tool_call>{ foo barr }</tool_call>`)
// We use a regex to extract the JSON object from the response // We use a regex to extract the JSON object from the response
var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch) var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
@ -142,60 +150,9 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
return results return results
} }
funcName, args, err := returnResult(match[1]) results, _ = returnResult(match[1])
if err != nil {
return results
}
return append(results, FuncCallResults{Name: funcName, Arguments: args})
} else { } else {
results, _ = returnResult(llmresult)
funcName, args, err := returnResult(llmresult)
if err != nil {
return results
}
return append(results, FuncCallResults{Name: funcName, Arguments: args})
}
return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
}
// with grammars
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
err := json.Unmarshal([]byte(s), &ss)
if err != nil {
log.Warn().Err(err).Str("escapedLLMResult", s).Msg("multiple results: unable to unmarshal llm result")
}
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s[functionNameKey]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
}
} else {
funcName, args, err := returnResult(llmresult)
if err != nil {
return results
}
results = append(results, FuncCallResults{Name: funcName, Arguments: args})
} }
return results return results

View file

@ -11,18 +11,12 @@ var _ = Describe("LocalAI function parse tests", func() {
BeforeEach(func() { BeforeEach(func() {
// Default configuration setup // Default configuration setup
functionConfig = FunctionsConfig{ functionConfig = FunctionsConfig{}
ParallelCalls: false,
NoGrammar: false,
ResponseRegex: `(?P<function>\w+)\s*\((?P<arguments>.*)\)`,
}
}) })
Context("when using grammars and single result expected", func() { Context("when using grammars and single result expected", func() {
It("should parse the function name and arguments correctly", func() { It("should parse the function name and arguments correctly", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))
@ -34,7 +28,7 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when not using grammars and regex is needed", func() { Context("when not using grammars and regex is needed", func() {
It("should extract function name and arguments from the regex", func() { It("should extract function name and arguments from the regex", func() {
input := `add({"x":5,"y":3})` input := `add({"x":5,"y":3})`
functionConfig.NoGrammar = true functionConfig.ResponseRegex = `(?P<function>\w+)\s*\((?P<arguments>.*)\)`
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))
@ -46,33 +40,19 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when having invalid input", func() { Context("when having invalid input", func() {
It("returns no results when there is no input", func() { It("returns no results when there is no input", func() {
input := "" input := ""
functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0)) Expect(results).To(HaveLen(0))
functionConfig.NoGrammar = false
results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
}) })
It("returns no results when is invalid", func() { It("returns no results when is invalid", func() {
input := "invalid input" input := "invalid input"
functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0)) Expect(results).To(HaveLen(0))
functionConfig.NoGrammar = false
results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
}) })
}) })
Context("when parallel calls are enabled", func() { Context("when parallel calls are enabled", func() {
It("should handle multiple function calls", func() { It("should handle multiple function calls", func() {
input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]` input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]`
functionConfig.ParallelCalls = true
functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(2)) Expect(results).To(HaveLen(2))
@ -86,9 +66,6 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("without grammars and without regex", func() { Context("without grammars and without regex", func() {
It("should parse the function name and arguments correctly with the name key", func() { It("should parse the function name and arguments correctly with the name key", func() {
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}` input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = true functionConfig.FunctionName = true
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
@ -99,10 +76,6 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the function key", func() { It("should parse the function name and arguments correctly with the function key", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))
@ -115,11 +88,8 @@ var _ = Describe("LocalAI function parse tests", func() {
<tool_call> <tool_call>
{"function": "add", "arguments": {"x": 5, "y": 3}} {"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>` </tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.JSONRegexMatch = `(?s)<tool_call>(.*?)</tool_call>` functionConfig.JSONRegexMatch = `(?s)<tool_call>(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))
@ -131,11 +101,8 @@ var _ = Describe("LocalAI function parse tests", func() {
input := ` input := `
{"function": "add", "arguments": {"x": 5, "y": 3}} {"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>` </tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.JSONRegexMatch = `(?s)(.*?)</tool_call>` functionConfig.JSONRegexMatch = `(?s)(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false
results := ParseFunctionCall(input, functionConfig) results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1)) Expect(results).To(HaveLen(1))