From 6b6c8cdd5f21780cd94f0e7ecc3a220ebad1a3a3 Mon Sep 17 00:00:00 2001 From: lenaxia Date: Sat, 18 May 2024 16:29:10 -0700 Subject: [PATCH] feat(functions): Enable true regex replacement for the regexReplacement option (#2341) * Adding regex capabilities to ParseFunctionCall replacement Signed-off-by: Lenaxia * Adding tests for the regex replace in ParseFunctionCall Signed-off-by: Lenaxia * Fixing tests and adding a test case to validate double quote replacement works Signed-off-by: Lenaxia * Make Regex replacement stable, drop lookaheads Signed-off-by: mudler --------- Signed-off-by: Lenaxia Signed-off-by: mudler Co-authored-by: Lenaxia Co-authored-by: mudler --- pkg/functions/parse.go | 10 +++-- pkg/functions/parse_test.go | 76 ++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 5327ee6d..ef81242b 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -3,10 +3,10 @@ package functions import ( "encoding/json" "regexp" - "strings" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog/log" + "gopkg.in/yaml.v2" ) // FunctionsConfig is the configuration for the tool/function call. @@ -44,7 +44,7 @@ type FunctionsConfig struct { GrammarPrefix string `yaml:"grammar_prefix"` // ReplaceResults allow to replace strings in the results before parsing them - ReplaceResults map[string]string `yaml:"replace_results"` + ReplaceResults yaml.MapSlice `yaml:"replace_results"` // FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } // instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }. @@ -60,9 +60,11 @@ type FuncCallResults struct { func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults { log.Debug().Msgf("LLM result: %s", llmresult) - for k, v := range functionConfig.ReplaceResults { + for _, item := range functionConfig.ReplaceResults { + k, v := item.Key.(string), item.Value.(string) log.Debug().Msgf("Replacing %s with %s", k, v) - llmresult = strings.ReplaceAll(llmresult, k, v) + re := regexp.MustCompile(k) + llmresult = re.ReplaceAllString(llmresult, v) } log.Debug().Msgf("LLM result(processed): %s", llmresult) diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go index 03a01239..14e27870 100644 --- a/pkg/functions/parse_test.go +++ b/pkg/functions/parse_test.go @@ -4,6 +4,7 @@ import ( . "github.com/go-skynet/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gopkg.in/yaml.v2" ) var _ = Describe("LocalAI function parse tests", func() { @@ -50,6 +51,7 @@ var _ = Describe("LocalAI function parse tests", func() { Expect(results).To(HaveLen(0)) }) }) + Context("when parallel calls are enabled", func() { It("should handle multiple function calls", func() { input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]` @@ -83,7 +85,7 @@ var _ = Describe("LocalAI function parse tests", func() { Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) }) - It("Should parse the result by matching the JSONRegexMatch", func() { + It("should parse the result by matching the JSONRegexMatch", func() { input := ` {"function": "add", "arguments": {"x": 5, "y": 3}} @@ -97,7 +99,7 @@ var _ = Describe("LocalAI function parse tests", func() { Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) }) - It("Should parse the result by matching the JSONRegexMatch", func() { + It("should parse the result by matching the JSONRegexMatch", func() { input := ` {"function": "add", "arguments": {"x": 5, "y": 3}} ` @@ -110,4 +112,74 @@ var _ = Describe("LocalAI function parse tests", func() { Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) }) }) + + Context("when using ReplaceResults to clean up input", func() { + It("should replace text before and after JSON blob", func() { + input := ` +Some text before the JSON +{"function": "add", "arguments": {"x": 5, "y": 3}} +Some text after the JSON +` + + functionConfig.ReplaceResults = yaml.MapSlice{ + {Key: `(?s)^[^{\[]*`, Value: ""}, + {Key: `(?s)[^}\]]*$`, Value: ""}, + } + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + + It("should replace text before and after array JSON blob", func() { + input := ` +Some text before the JSON +[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}] +Some text after the JSON +` + functionConfig.ReplaceResults = yaml.MapSlice{ + {Key: `(?s)^[^{\[]*`, Value: ""}, + {Key: `(?s)[^}\]]*$`, Value: ""}, + } + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(2)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + Expect(results[1].Name).To(Equal("subtract")) + Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`)) + }) + + It("should convert single-quoted key-value pairs to double-quoted and escape double quotes within values", func() { + input := ` +Some text before the JSON +{'function': '"add"', 'arguments': {'x': 5, 'z': '"v"', 'y': 'v"value"'}} +Some text after the JSON +` + // Regex to match non-JSON characters before the JSON structure + //reBefore := regexp.MustCompile(`(?s)^.*?(?=\{|\[)`) + // Regex to match non-JSON characters after the JSON structure + //reAfter := regexp.MustCompile(`(?s)(?<=\}|\]).*$`) + + functionConfig.ReplaceResults = yaml.MapSlice{ + {Key: `(?s)^[^{\[]*`, Value: ""}, + {Key: `(?s)[^}\]]*$`, Value: ""}, + // Regex pattern to match single quotes around keys and values + // Step 1: Replace single quotes around keys and values with double quotes + {Key: `'([^']*?)'`, Value: `_DQUOTE_${1}_DQUOTE_`}, + // Step 2: Replace double quotes inside values with placeholders + {Key: `\\"`, Value: `__TEMP_QUOTE__`}, + {Key: `"`, Value: `\"`}, + {Key: `\'`, Value: `'`}, + {Key: `_DQUOTE_`, Value: `"`}, + {Key: `__TEMP_QUOTE__`, Value: `"`}, + } + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("\"add\"")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":"v\"value\"","z":"\"v\""}`)) + }) + }) })