feat(functions): support models with no grammar, add tests (#2068)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-04-18 22:43:12 +02:00 committed by GitHub
parent 13012cfa70
commit bbea62b907
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 255 additions and 119 deletions

View file

@ -1,4 +1,4 @@
package grammar
package functions
import (
"encoding/json"

View file

@ -1,4 +1,4 @@
package grammar
package functions
import (
"testing"

View file

@ -1,7 +1,7 @@
package grammar_test
package functions_test
import (
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/go-skynet/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

View file

@ -1,4 +1,4 @@
package grammar
package functions
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887

View file

@ -1,9 +1,9 @@
package grammar_test
package functions_test
import (
"strings"
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/go-skynet/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

108
pkg/functions/parse.go Normal file
View file

@ -0,0 +1,108 @@
package functions
import (
"encoding/json"
"regexp"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type FunctionsConfig struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"`
}
type FuncCallResults struct {
Name string
Arguments string
}
func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults {
multipleResults := functionConfig.ParallelCalls
useGrammars := !functionConfig.NoGrammar
results := []FuncCallResults{}
// if no grammar is used, we have to extract function and arguments from the result
if !useGrammars {
// the response is a string that we have to parse
// We use named regexes here to extract the function name and arguments
// obviously, this expects the LLM to be stable and return correctly formatted JSON
// TODO: optimize this and pre-compile it
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
match := respRegex.FindStringSubmatch(llmresult)
result := make(map[string]string)
for i, name := range respRegex.SubexpNames() {
if i != 0 && name != "" && len(match) > i {
result[name] = match[i]
}
}
// TODO: open point about multiple results and/or mixed with chat messages
// This is not handled as for now, we only expect one function call per response
functionName := result["function"]
if functionName == "" {
return results
}
return append(results, FuncCallResults{Name: result["function"], Arguments: result["arguments"]})
}
// with grammars
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
}
return results
}

View file

@ -0,0 +1,85 @@
package functions_test
import (
. "github.com/go-skynet/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LocalAI function parse tests", func() {
var functionConfig FunctionsConfig
BeforeEach(func() {
// Default configuration setup
functionConfig = FunctionsConfig{
ParallelCalls: false,
NoGrammar: false,
ResponseRegex: `(?P<function>\w+)\s*\((?P<arguments>.*)\)`,
}
})
Context("when using grammars and single result expected", func() {
It("should parse the function name and arguments correctly", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("add"))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
})
})
Context("when not using grammars and regex is needed", func() {
It("should extract function name and arguments from the regex", func() {
input := `add({"x":5,"y":3})`
functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("add"))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
})
})
Context("when having invalid input", func() {
It("returns no results when there is no input", func() {
input := ""
functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
functionConfig.NoGrammar = false
results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
})
It("returns no results when is invalid", func() {
input := "invalid input"
functionConfig.NoGrammar = true
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
functionConfig.NoGrammar = false
results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
})
})
Context("when parallel calls are enabled", func() {
It("should handle multiple function calls", func() {
input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]`
functionConfig.ParallelCalls = true
functionConfig.NoGrammar = false
results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(2))
Expect(results[0].Name).To(Equal("add"))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
Expect(results[1].Name).To(Equal("subtract"))
Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`))
})
})
})

View file

@ -11,7 +11,7 @@ import (
"text/template"
"github.com/Masterminds/sprig/v3"
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/functions"
"github.com/go-skynet/LocalAI/pkg/grpc"
process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log"
@ -25,7 +25,7 @@ type PromptTemplateData struct {
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
Input string
Instruction string
Functions []grammar.Function
Functions []functions.Function
MessageIndex int
}