feat: add grammar and functions call support

This commit is contained in:
mudler 2023-07-02 11:13:51 +02:00
parent a6839fd238
commit f09ddd2983
7 changed files with 571 additions and 9 deletions

50
pkg/grammar/functions.go Normal file
View file

@ -0,0 +1,50 @@
package grammar
import (
"encoding/json"
)
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
type Functions []Function
func (f Functions) ToJSONStructure() JSONStructure {
js := JSONStructure{}
for _, function := range f {
// t := function.Parameters["type"]
//tt := t.(string)
properties := function.Parameters["properties"]
dat, _ := json.Marshal(properties)
prop := map[string]interface{}{}
json.Unmarshal(dat, &prop)
js.OneOf = append(js.OneOf, Item{
Type: "object",
Properties: Properties{
Function: FunctionName{Const: function.Name},
Arguments: Argument{
Type: "object",
Properties: prop,
},
},
})
}
return js
}
// Select returns a list of functions containing the function with the given name
func (f Functions) Select(name string) Functions {
var funcs Functions
for _, f := range f {
if f.Name == name {
funcs = []Function{f}
break
}
}
return funcs
}

View file

@ -0,0 +1,13 @@
package grammar
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestGrammar(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Grammar test suite")
}

222
pkg/grammar/json_schema.go Normal file
View file

@ -0,0 +1,222 @@
package grammar
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
var (
SPACE_RULE = `" "?`
PRIMITIVE_RULES = map[string]string{
"boolean": `("true" | "false") space`,
"number": `[0-9]+ space`, // TODO complete
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
"null": `"null" space`,
}
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`)
GRAMMAR_LITERAL_ESCAPES = map[string]string{
"\r": `\r`,
"\n": `\n`,
`"`: `\"`,
}
)
type JSONSchemaConverter struct {
propOrder map[string]int
rules map[string]string
}
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
propOrderSlice := strings.Split(propOrder, ",")
propOrderMap := make(map[string]int)
for idx, name := range propOrderSlice {
propOrderMap[name] = idx
}
rules := make(map[string]string)
rules["space"] = SPACE_RULE
return &JSONSchemaConverter{
propOrder: propOrderMap,
rules: rules,
}
}
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string {
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string {
return GRAMMAR_LITERAL_ESCAPES[match]
})
return fmt.Sprintf(`"%s"`, escaped)
}
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
key := escName
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
i := 0
for {
key = fmt.Sprintf("%s%d", escName, i)
if _, ok := sc.rules[key]; !ok {
break
}
i++
}
}
sc.rules[key] = rule
return key
}
func (sc *JSONSchemaConverter) formatGrammar() string {
var lines []string
for name, rule := range sc.rules {
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}
return strings.Join(lines, "\n")
}
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string {
st, existType := schema["type"]
var schemaType string
if existType {
schemaType = st.(string)
}
ruleName := name
if name == "" {
ruleName = "root"
}
_, oneOfExists := schema["oneOf"]
_, anyOfExists := schema["anyOf"]
if oneOfExists || anyOfExists {
var alternatives []string
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
if oneOfExists {
for i, altSchema := range oneOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
alternatives = append(alternatives, alternative)
}
} else if anyOfExists {
for i, altSchema := range anyOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
alternatives = append(alternatives, alternative)
}
}
rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule)
} else if constVal, exists := schema["const"]; exists {
return sc.addRule(ruleName, sc.formatLiteral(constVal))
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
var enumRules []string
for _, enumVal := range enumVals {
enumRule := sc.formatLiteral(enumVal)
enumRules = append(enumRules, enumRule)
}
rule := strings.Join(enumRules, " | ")
return sc.addRule(ruleName, rule)
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
propOrder := sc.propOrder
var propPairs []struct {
propName string
propSchema map[string]interface{}
}
for propName, propSchema := range properties {
propPairs = append(propPairs, struct {
propName string
propSchema map[string]interface{}
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
}
sort.Slice(propPairs, func(i, j int) bool {
iOrder := propOrder[propPairs[i].propName]
jOrder := propOrder[propPairs[j].propName]
if iOrder != 0 && jOrder != 0 {
return iOrder < jOrder
}
return propPairs[i].propName < propPairs[j].propName
})
var rule strings.Builder
rule.WriteString(`"{" space`)
for i, propPair := range propPairs {
propName := propPair.propName
propSchema := propPair.propSchema
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName))
if i > 0 {
rule.WriteString(` "," space`)
}
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName))
}
rule.WriteString(` "}" space`)
return sc.addRule(ruleName, rule.String())
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName))
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
return sc.addRule(ruleName, rule)
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
panic(fmt.Sprintf("Unrecognized schema: %v", schema))
}
return sc.addRule(schemaType, primitiveRule)
}
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
sc.visit(schema, "")
return sc.formatGrammar()
}
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
return sc.Grammar(schema)
}
func jsonString(v interface{}) string {
b, _ := json.Marshal(v)
return string(b)
}
type FunctionName struct {
Const string `json:"const"`
}
type Properties struct {
Function FunctionName `json:"function"`
Arguments Argument `json:"arguments"`
}
type Argument struct {
Type string `json:"type"`
Properties map[string]interface{} `json:"properties"`
}
type Item struct {
Type string `json:"type"`
Properties Properties `json:"properties"`
}
type JSONStructure struct {
OneOf []Item `json:"oneOf,omitempty"`
AnyOf []Item `json:"anyOf,omitempty"`
}
func (j JSONStructure) Grammar(propOrder string) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
}

View file

@ -0,0 +1,113 @@
package grammar_test
import (
"strings"
. "github.com/go-skynet/LocalAI/pkg/grammar"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
testInput1 = `
{
"oneOf": [
{
"type": "object",
"properties": {
"function": {"const": "create_event"},
"arguments": {
"type": "object",
"properties": {
"title": {"type": "string"},
"date": {"type": "string"},
"time": {"type": "string"}
}
}
}
},
{
"type": "object",
"properties": {
"function": {"const": "search"},
"arguments": {
"type": "object",
"properties": {
"query": {"type": "string"}
}
}
}
}
]
}`
inputResult1 = `root-0-function ::= "\"create_event\""
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
root ::= root-0 | root-1
space ::= " "?
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
string ::= "\"" [ \t!#-\[\]-~]* "\"" space
root-1-function ::= "\"search\""`
)
var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
It("generates a valid grammar from JSON Objects", func() {
structuredGrammar := JSONStructure{
OneOf: []Item{
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "create_event",
},
Arguments: Argument{ // this is OpenAI's parameter
Type: "object",
Properties: map[string]interface{}{
"title": map[string]string{"type": "string"},
"date": map[string]string{"type": "string"},
"time": map[string]string{"type": "string"},
},
},
},
},
{
Type: "object",
Properties: Properties{
Function: FunctionName{
Const: "search",
},
Arguments: Argument{
Type: "object",
Properties: map[string]interface{}{
"query": map[string]string{"type": "string"},
},
},
},
},
}}
grammar := structuredGrammar.Grammar("")
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expect(grammar).To(ContainSubstring(r))
}
}
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
})
})