mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 06:25:00 +00:00
feat(grammar): add llama3.1 schema (#3015)
* wip Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * get rid of panics Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose it properly from the config Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Simplify Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * forgot to commit Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Remove focus on test Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
fee52942eb
commit
2169c3497d
14 changed files with 609 additions and 148 deletions
58
pkg/functions/grammars/bnf_rules.go
Normal file
58
pkg/functions/grammars/bnf_rules.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package grammars
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
PRIMITIVE_RULES = map[string]string{
|
||||
"boolean": `("true" | "false") space`,
|
||||
"number": `("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space`,
|
||||
"integer": `("-"? ([0-9] | [1-9] [0-9]*)) space`,
|
||||
"string": `"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space`,
|
||||
// TODO: we shouldn't forbid \" and \\ or all unicode and have this branch here,
|
||||
// however, if we don't have it, the grammar will be ambiguous and
|
||||
// empirically results are way worse.
|
||||
"freestring": `(
|
||||
[^\x00] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space`,
|
||||
"null": `"null" space`,
|
||||
}
|
||||
|
||||
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`)
|
||||
GRAMMAR_LITERAL_ESCAPES = map[string]string{
|
||||
"\r": `\r`,
|
||||
"\n": `\n`,
|
||||
`"`: `\"`,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
SPACE_RULE = `" "?`
|
||||
|
||||
arrayNewLines = `arr ::=
|
||||
"[\n" (
|
||||
realvalue
|
||||
(",\n" realvalue)*
|
||||
)? "]"`
|
||||
|
||||
array = `arr ::=
|
||||
"[" (
|
||||
realvalue
|
||||
("," realvalue)*
|
||||
)? "]"`
|
||||
)
|
||||
|
||||
func jsonString(v interface{}) (string, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
25
pkg/functions/grammars/grammars_suite_test.go
Normal file
25
pkg/functions/grammars/grammars_suite_test.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package grammars_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/functions"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestGrammar(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Grammar test suite")
|
||||
}
|
||||
|
||||
func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} {
|
||||
property := map[string]interface{}{}
|
||||
property[field1] = FunctionName{Const: name}
|
||||
property[field2] = Argument{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
}
|
||||
return property
|
||||
}
|
220
pkg/functions/grammars/json_schema.go
Normal file
220
pkg/functions/grammars/json_schema.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package grammars
|
||||
|
||||
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JSONSchemaConverter struct {
|
||||
propOrder map[string]int
|
||||
rules Rules
|
||||
}
|
||||
|
||||
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
|
||||
propOrderSlice := strings.Split(propOrder, ",")
|
||||
propOrderMap := make(map[string]int)
|
||||
for idx, name := range propOrderSlice {
|
||||
propOrderMap[name] = idx
|
||||
}
|
||||
|
||||
rules := make(map[string]string)
|
||||
rules["space"] = SPACE_RULE
|
||||
|
||||
return &JSONSchemaConverter{
|
||||
propOrder: propOrderMap,
|
||||
rules: rules,
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) (string, error) {
|
||||
jLiteral, err := jsonString(literal)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
|
||||
return GRAMMAR_LITERAL_ESCAPES[match]
|
||||
})
|
||||
return fmt.Sprintf(`"%s"`, escaped), nil
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
|
||||
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
|
||||
key := escName
|
||||
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
|
||||
i := 0
|
||||
for {
|
||||
key = fmt.Sprintf("%s%d", escName, i)
|
||||
if _, ok := sc.rules[key]; !ok {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
sc.rules[key] = rule
|
||||
return key
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
|
||||
st, existType := schema["type"]
|
||||
var schemaType string
|
||||
if existType {
|
||||
schemaType = st.(string)
|
||||
}
|
||||
ruleName := name
|
||||
if name == "" {
|
||||
ruleName = "root"
|
||||
}
|
||||
_, oneOfExists := schema["oneOf"]
|
||||
_, anyOfExists := schema["anyOf"]
|
||||
if oneOfExists || anyOfExists {
|
||||
var alternatives []string
|
||||
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
|
||||
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
|
||||
|
||||
if oneOfExists {
|
||||
for i, altSchema := range oneOfSchemas {
|
||||
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
} else if anyOfExists {
|
||||
for i, altSchema := range anyOfSchemas {
|
||||
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
}
|
||||
|
||||
rule := strings.Join(alternatives, " | ")
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else if ref, exists := schema["$ref"].(string); exists {
|
||||
referencedSchema, err := sc.resolveReference(ref, rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.visit(referencedSchema, name, rootSchema)
|
||||
} else if constVal, exists := schema["const"]; exists {
|
||||
literal, err := sc.formatLiteral((constVal))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.addRule(ruleName, literal), nil
|
||||
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||||
var enumRules []string
|
||||
for _, enumVal := range enumVals {
|
||||
enumRule, err := sc.formatLiteral(enumVal)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
enumRules = append(enumRules, enumRule)
|
||||
}
|
||||
rule := strings.Join(enumRules, " | ")
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
|
||||
propOrder := sc.propOrder
|
||||
var propPairs []struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}
|
||||
|
||||
for propName, propSchema := range properties {
|
||||
propPairs = append(propPairs, struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
|
||||
}
|
||||
|
||||
sort.Slice(propPairs, func(i, j int) bool {
|
||||
iOrder := propOrder[propPairs[i].propName]
|
||||
jOrder := propOrder[propPairs[j].propName]
|
||||
if iOrder != 0 && jOrder != 0 {
|
||||
return iOrder < jOrder
|
||||
}
|
||||
return propPairs[i].propName < propPairs[j].propName
|
||||
})
|
||||
|
||||
var rule strings.Builder
|
||||
rule.WriteString(`"{" space`)
|
||||
|
||||
for i, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
propSchema := propPair.propSchema
|
||||
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
lPropName, err := sc.formatLiteral(propName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if i > 0 {
|
||||
rule.WriteString(` "," space`)
|
||||
}
|
||||
|
||||
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
|
||||
}
|
||||
|
||||
rule.WriteString(` "}" space`)
|
||||
return sc.addRule(ruleName, rule.String()), nil
|
||||
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||||
itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else {
|
||||
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("unrecognized schema: %v", schema)
|
||||
}
|
||||
if ruleName == "root" {
|
||||
schemaType = "root"
|
||||
}
|
||||
return sc.addRule(schemaType, primitiveRule), nil
|
||||
}
|
||||
}
|
||||
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
|
||||
if !strings.HasPrefix(ref, "#/$defs/") {
|
||||
return nil, fmt.Errorf("invalid reference format: %s", ref)
|
||||
}
|
||||
|
||||
defKey := strings.TrimPrefix(ref, "#/$defs/")
|
||||
definitions, exists := rootSchema["$defs"].(map[string]interface{})
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
|
||||
}
|
||||
|
||||
def, exists := definitions[defKey].(map[string]interface{})
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
|
||||
}
|
||||
|
||||
return def, nil
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
|
||||
sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
|
||||
_, err := sc.visit(schema, "", schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.rules.ToGrammar(options...), nil
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
|
||||
var schema map[string]interface{}
|
||||
err := json.Unmarshal(b, &schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.Grammar(schema, options...)
|
||||
}
|
446
pkg/functions/grammars/json_schema_test.go
Normal file
446
pkg/functions/grammars/json_schema_test.go
Normal file
|
@ -0,0 +1,446 @@
|
|||
package grammars_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/mudler/LocalAI/pkg/functions/grammars"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var testFunctions = []Item{
|
||||
{
|
||||
Type: "object",
|
||||
Properties: createFunction(
|
||||
"function",
|
||||
"arguments",
|
||||
"create_event",
|
||||
map[string]interface{}{
|
||||
"title": map[string]string{"type": "string"},
|
||||
"date": map[string]string{"type": "string"},
|
||||
"time": map[string]string{"type": "string"},
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
Type: "object",
|
||||
Properties: createFunction(
|
||||
"function",
|
||||
"arguments",
|
||||
"search",
|
||||
map[string]interface{}{
|
||||
"query": map[string]string{"type": "string"},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
var testFunctionsName = []Item{
|
||||
{
|
||||
Type: "object",
|
||||
Properties: createFunction(
|
||||
"name",
|
||||
"arguments",
|
||||
"create_event",
|
||||
map[string]interface{}{
|
||||
"title": map[string]string{"type": "string"},
|
||||
"date": map[string]string{"type": "string"},
|
||||
"time": map[string]string{"type": "string"},
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
Type: "object",
|
||||
Properties: createFunction(
|
||||
"name",
|
||||
"arguments",
|
||||
"search",
|
||||
map[string]interface{}{
|
||||
"query": map[string]string{"type": "string"},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
func rootResult(s string) string {
|
||||
return `root-0-name ::= "\"create_event\""
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"name\"" space ":" space root-0-name "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
realvalue ::= root-0 | root-1
|
||||
root ::= ` + s + `
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"name\"" space ":" space root-1-name "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
arr ::=
|
||||
"[\n" (
|
||||
realvalue
|
||||
(",\n" realvalue)*
|
||||
)? "]"
|
||||
root-1-name ::= "\"search\""`
|
||||
}
|
||||
|
||||
const (
|
||||
testInput1 = `
|
||||
{
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "create_event"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"time": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "search"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
inputResult1 = `root-0-function ::= "\"create_event\""
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
root ::= root-0 | root-1
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
root-1-function ::= "\"search\""`
|
||||
|
||||
inputResult2 = `root-0-function ::= "\"create_event\""
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
realvalue ::= root-0 | root-1
|
||||
root ::= arr | realvalue
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
arr ::=
|
||||
"[\n" (
|
||||
realvalue
|
||||
(",\n" realvalue)*
|
||||
)? "]"
|
||||
root-1-function ::= "\"search\""`
|
||||
|
||||
testInput2 = `
|
||||
{
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"const": "create_event"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"time": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"const": "search"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
inputResult3 = `root-0-name ::= "\"create_event\""
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"name\"" space ":" space root-0-name "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
root ::= root-0 | root-1
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"name\"" space ":" space root-1-name "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
root-1-name ::= "\"search\""`
|
||||
|
||||
inputResult4 = `root-0-name ::= "\"create_event\""
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"name\"" space ":" space root-0-name "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
realvalue ::= root-0 | root-1
|
||||
root ::= arr | realvalue
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"name\"" space ":" space root-1-name "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
arr ::=
|
||||
"[\n" (
|
||||
realvalue
|
||||
(",\n" realvalue)*
|
||||
)? "]"
|
||||
root-1-name ::= "\"search\""`
|
||||
)
|
||||
|
||||
var _ = Describe("JSON schema grammar tests", func() {
|
||||
Context("JSON", func() {
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2))
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(inputResult3, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
It("generates a valid grammar from JSON Objects", func() {
|
||||
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctions}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar()
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
|
||||
It("generates a valid grammar from JSON Objects for multiple function return", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctions}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(EnableMaybeArray)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
inputResult2,
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
|
||||
It("generates a valid grammar from JSON Objects for multiple function return", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(EnableMaybeArray)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
inputResult4,
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
|
||||
It("generates a valid grammar from JSON Objects for multiple function return with a suffix and array", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(
|
||||
SetPrefix("suffix"),
|
||||
EnableMaybeArray,
|
||||
)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`"suffix" arr | realvalue`),
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
It("generates a valid grammar from JSON Objects with a suffix", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"))
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`"suffix" realvalue`),
|
||||
"mixedstring ::= freestring | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
It("generates a valid grammar from JSON Objects with a suffix and could return string", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`( "suffix" realvalue | mixedstring )`),
|
||||
"mixedstring ::= freestring | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
It("generates a valid grammar from JSON Objects with a suffix that could return text or an array of tools", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString, EnableMaybeArray)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`( "suffix" (arr | realvalue) | mixedstring )`),
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
|
||||
It("generates a valid grammar from JSON Objects without a suffix that could return text or an array of tools or just string", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`mixedstring | arr | realvalue`),
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
|
||||
It("generates a valid grammar from JSON Objects without a suffix that could return text or an array of tools or just string. Disables mixedstring", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
|
||||
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, NoMixedFreeString)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(
|
||||
strings.Join([]string{
|
||||
rootResult(`freestring | arr | realvalue`),
|
||||
"mixedstring ::= freestring | freestring arr | freestring realvalue"}, "\n"),
|
||||
"\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar)
|
||||
})
|
||||
|
||||
It("generates parallel tools without newlines in JSON", func() {
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: testFunctionsName}
|
||||
content := `arr ::=
|
||||
"[" (
|
||||
realvalue
|
||||
("," realvalue)*
|
||||
)? "]"`
|
||||
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, DisableParallelNewLines)
|
||||
Expect(err).To(BeNil())
|
||||
results := strings.Split(content, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
281
pkg/functions/grammars/llama31_schema.go
Normal file
281
pkg/functions/grammars/llama31_schema.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
package grammars
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LLama31SchemaConverter struct {
|
||||
fnName string
|
||||
rules Rules
|
||||
}
|
||||
|
||||
func NewLLama31SchemaConverter(fnName string) *LLama31SchemaConverter {
|
||||
rules := make(map[string]string)
|
||||
rules["space"] = SPACE_RULE
|
||||
if fnName == "" {
|
||||
fnName = "name"
|
||||
}
|
||||
|
||||
return &LLama31SchemaConverter{
|
||||
rules: rules,
|
||||
fnName: fnName,
|
||||
}
|
||||
}
|
||||
|
||||
var GRAMMAR_LITERAL_ESCAPESLlama = map[string]string{
|
||||
"\r": `\r`,
|
||||
"\n": `\n`,
|
||||
}
|
||||
|
||||
var GRAMMAR_LITERAL_ESCAPE_RELlama = regexp.MustCompile(`[\r\n]`)
|
||||
|
||||
func (sc *LLama31SchemaConverter) formatLiteral(literal interface{}) (string, error) {
|
||||
jLiteral, err := jsonString(literal)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
escaped := GRAMMAR_LITERAL_ESCAPE_RELlama.ReplaceAllStringFunc(jLiteral, func(match string) string {
|
||||
return GRAMMAR_LITERAL_ESCAPESLlama[match]
|
||||
})
|
||||
return escaped, nil
|
||||
}
|
||||
|
||||
func (sc *LLama31SchemaConverter) formatLiteralQuoted(literal interface{}) (string, error) {
|
||||
jLiteral, err := jsonString(literal)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
|
||||
return GRAMMAR_LITERAL_ESCAPES[match]
|
||||
})
|
||||
return fmt.Sprintf(`"%s"`, escaped), nil
|
||||
}
|
||||
|
||||
func (sc *LLama31SchemaConverter) addRule(name, rule string) string {
|
||||
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
|
||||
key := escName
|
||||
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
|
||||
i := 0
|
||||
for {
|
||||
key = fmt.Sprintf("%s%d", escName, i)
|
||||
if _, ok := sc.rules[key]; !ok {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
sc.rules[key] = rule
|
||||
return key
|
||||
}
|
||||
|
||||
func (sc *LLama31SchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
|
||||
st, existType := schema["type"]
|
||||
var schemaType string
|
||||
if existType {
|
||||
schemaType = st.(string)
|
||||
}
|
||||
ruleName := name
|
||||
if name == "" {
|
||||
ruleName = "root"
|
||||
}
|
||||
_, oneOfExists := schema["oneOf"]
|
||||
_, anyOfExists := schema["anyOf"]
|
||||
if oneOfExists || anyOfExists {
|
||||
var alternatives []string
|
||||
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
|
||||
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
|
||||
|
||||
if oneOfExists {
|
||||
for i, altSchema := range oneOfSchemas {
|
||||
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
} else if anyOfExists {
|
||||
for i, altSchema := range anyOfSchemas {
|
||||
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
}
|
||||
|
||||
rule := strings.Join(alternatives, " | ")
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else if ref, exists := schema["$ref"].(string); exists {
|
||||
referencedSchema, err := sc.resolveReference(ref, rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.visit(referencedSchema, name, rootSchema)
|
||||
} else if constVal, exists := schema["const"]; exists {
|
||||
|
||||
literal, err := sc.formatLiteral((constVal))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.addRule(ruleName, literal), nil
|
||||
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||||
var enumRules []string
|
||||
for _, enumVal := range enumVals {
|
||||
enumRule, err := sc.formatLiteralQuoted(enumVal)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
enumRules = append(enumRules, enumRule)
|
||||
}
|
||||
rule := strings.Join(enumRules, " | ")
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
|
||||
baseProperty := false
|
||||
depth := strings.Split(name, "-")
|
||||
if len(depth) == 2 {
|
||||
baseProperty = true
|
||||
}
|
||||
type propData []struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}
|
||||
var propPairs propData
|
||||
|
||||
for propName, propSchema := range properties {
|
||||
propPairs = append(propPairs, struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
|
||||
}
|
||||
|
||||
sort.Slice(propPairs, func(i, j int) bool {
|
||||
return propPairs[i].propName < propPairs[j].propName
|
||||
})
|
||||
|
||||
var rule strings.Builder
|
||||
if baseProperty {
|
||||
rule.WriteString(`"<function="`)
|
||||
} else {
|
||||
rule.WriteString(`"{" space`)
|
||||
}
|
||||
|
||||
if baseProperty {
|
||||
|
||||
namePair := propData{}
|
||||
for i, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
if propName == sc.fnName {
|
||||
namePair = append(namePair, propPair)
|
||||
// remove namePair from propPairs
|
||||
propPairs = append(propPairs[:i], propPairs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(namePair) == 0 {
|
||||
return "", fmt.Errorf("no function name found in the schema: %s", schema)
|
||||
}
|
||||
|
||||
propRuleName, err := sc.visit(namePair[0].propSchema, fmt.Sprintf("%s-%s", ruleName, sc.fnName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rule.WriteString(fmt.Sprintf(` %s ">{" `, propRuleName))
|
||||
|
||||
for _, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
propSchema := propPair.propSchema
|
||||
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rule.WriteString(propRuleName)
|
||||
}
|
||||
|
||||
rule.WriteString(` "}</function>"`)
|
||||
|
||||
} else {
|
||||
for i, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
propSchema := propPair.propSchema
|
||||
propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
lPropName, err := sc.formatLiteralQuoted(propName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if i > 0 {
|
||||
rule.WriteString(` "," space`)
|
||||
}
|
||||
|
||||
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if !baseProperty {
|
||||
rule.WriteString(` "}" space`)
|
||||
}
|
||||
|
||||
return sc.addRule(ruleName, rule.String()), nil
|
||||
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||||
itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||||
return sc.addRule(ruleName, rule), nil
|
||||
} else {
|
||||
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("unrecognized schema: %v", schema)
|
||||
}
|
||||
if ruleName == "root" {
|
||||
schemaType = "root"
|
||||
}
|
||||
return sc.addRule(schemaType, primitiveRule), nil
|
||||
}
|
||||
}
|
||||
func (sc *LLama31SchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
|
||||
if !strings.HasPrefix(ref, "#/$defs/") {
|
||||
return nil, fmt.Errorf("invalid reference format: %s", ref)
|
||||
}
|
||||
|
||||
defKey := strings.TrimPrefix(ref, "#/$defs/")
|
||||
definitions, exists := rootSchema["$defs"].(map[string]interface{})
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
|
||||
}
|
||||
|
||||
def, exists := definitions[defKey].(map[string]interface{})
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
|
||||
}
|
||||
|
||||
return def, nil
|
||||
}
|
||||
|
||||
func (sc *LLama31SchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
|
||||
sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
|
||||
_, err := sc.visit(schema, "", schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.rules.ToGrammar(options...), nil
|
||||
}
|
||||
|
||||
func (sc *LLama31SchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
|
||||
var schema map[string]interface{}
|
||||
err := json.Unmarshal(b, &schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sc.Grammar(schema, options...)
|
||||
}
|
76
pkg/functions/grammars/llama31_schema_test.go
Normal file
76
pkg/functions/grammars/llama31_schema_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package grammars_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/mudler/LocalAI/pkg/functions/grammars"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
testllama31Input1 = `
|
||||
{
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "create_event"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"time": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "search"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
// <function=example_function_name>{{"example_name": "example_value"}}</function>
|
||||
testllama31inputResult1 = `root-0-function ::= "create_event"
|
||||
freestring ::= (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* space
|
||||
root-0 ::= "<function=" root-0-function ">{" root-0-arguments "}</function>"
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
root ::= root-0 | root-1
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "<function=" root-1-function ">{" root-1-arguments "}</function>"
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
root-1-function ::= "search"`
|
||||
)
|
||||
|
||||
var _ = Describe("JSON schema grammar tests", func() {
|
||||
Context("JSON", func() {
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar, err := NewLLama31SchemaConverter("function").GrammarFromBytes([]byte(testllama31Input1))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
results := strings.Split(testllama31inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
})
|
||||
})
|
65
pkg/functions/grammars/options.go
Normal file
65
pkg/functions/grammars/options.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package grammars
|
||||
|
||||
type GrammarOption struct {
|
||||
PropOrder string
|
||||
Prefix string
|
||||
MaybeArray bool
|
||||
DisableParallelNewLines bool
|
||||
MaybeString bool
|
||||
NoMixedFreeString bool
|
||||
ExpectStringsAfterJSON bool
|
||||
|
||||
FunctionName string
|
||||
SchemaType SchemaConverterType
|
||||
}
|
||||
|
||||
func (o *GrammarOption) Apply(options ...func(*GrammarOption)) {
|
||||
for _, l := range options {
|
||||
l(o)
|
||||
}
|
||||
}
|
||||
|
||||
var EnableMaybeArray = func(o *GrammarOption) {
|
||||
o.MaybeArray = true
|
||||
}
|
||||
|
||||
var DisableParallelNewLines = func(o *GrammarOption) {
|
||||
o.DisableParallelNewLines = true
|
||||
}
|
||||
|
||||
var EnableMaybeString = func(o *GrammarOption) {
|
||||
o.MaybeString = true
|
||||
}
|
||||
|
||||
var NoMixedFreeString func(*GrammarOption) = func(o *GrammarOption) {
|
||||
o.NoMixedFreeString = true
|
||||
}
|
||||
|
||||
// ExpectStringsAfterJSON enables mixed string suffix
|
||||
var ExpectStringsAfterJSON func(*GrammarOption) = func(o *GrammarOption) {
|
||||
o.ExpectStringsAfterJSON = true
|
||||
}
|
||||
|
||||
func SetPrefix(suffix string) func(*GrammarOption) {
|
||||
return func(o *GrammarOption) {
|
||||
o.Prefix = suffix
|
||||
}
|
||||
}
|
||||
|
||||
func SetPropOrder(order string) func(*GrammarOption) {
|
||||
return func(o *GrammarOption) {
|
||||
o.PropOrder = order
|
||||
}
|
||||
}
|
||||
|
||||
func WithSchemaType(schemaType SchemaConverterType) func(*GrammarOption) {
|
||||
return func(o *GrammarOption) {
|
||||
o.SchemaType = schemaType
|
||||
}
|
||||
}
|
||||
|
||||
func WithFunctionName(name string) func(*GrammarOption) {
|
||||
return func(o *GrammarOption) {
|
||||
o.FunctionName = name
|
||||
}
|
||||
}
|
93
pkg/functions/grammars/rules.go
Normal file
93
pkg/functions/grammars/rules.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package grammars
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type Rules map[string]string
|
||||
|
||||
func (rules Rules) ToGrammar(options ...func(*GrammarOption)) string {
|
||||
grammarOpts := &GrammarOption{}
|
||||
grammarOpts.Apply(options...)
|
||||
|
||||
prefix := grammarOpts.Prefix
|
||||
maybeArray := grammarOpts.MaybeArray
|
||||
disableParallelNewLines := grammarOpts.DisableParallelNewLines
|
||||
maybeString := grammarOpts.MaybeString
|
||||
noMixedFreeString := grammarOpts.NoMixedFreeString
|
||||
|
||||
var lines []string
|
||||
|
||||
swapRoot := maybeArray || maybeString || prefix != ""
|
||||
|
||||
// write down the computed rules.
|
||||
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
|
||||
for name, rule := range rules {
|
||||
if swapRoot && name == "root" {
|
||||
name = "realvalue"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
|
||||
}
|
||||
|
||||
if !swapRoot {
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
newRoot := "realvalue"
|
||||
if maybeArray {
|
||||
newRoot = "arr | realvalue"
|
||||
}
|
||||
|
||||
freestringRule := "mixedstring"
|
||||
if noMixedFreeString {
|
||||
freestringRule = "freestring"
|
||||
}
|
||||
|
||||
if prefix != "" {
|
||||
// quote newlines in suffix
|
||||
prefix = utils.EscapeNewLines(prefix)
|
||||
|
||||
if maybeArray && maybeString {
|
||||
newRoot = "(" + newRoot + ")"
|
||||
}
|
||||
|
||||
if maybeString {
|
||||
//newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) "
|
||||
newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) "
|
||||
} else {
|
||||
newRoot = "\"" + prefix + "\" " + "" + newRoot + ""
|
||||
}
|
||||
} else if maybeString {
|
||||
if maybeArray {
|
||||
// newRoot = "(" + newRoot + ")"
|
||||
}
|
||||
|
||||
newRoot = freestringRule + " | " + newRoot
|
||||
}
|
||||
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot))
|
||||
if disableParallelNewLines {
|
||||
lines = append(lines, array)
|
||||
} else {
|
||||
lines = append(lines, arrayNewLines)
|
||||
}
|
||||
|
||||
if maybeArray {
|
||||
if grammarOpts.ExpectStringsAfterJSON {
|
||||
lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`)
|
||||
} else {
|
||||
lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`)
|
||||
}
|
||||
} else {
|
||||
if grammarOpts.ExpectStringsAfterJSON {
|
||||
lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`)
|
||||
} else {
|
||||
lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
33
pkg/functions/grammars/types.go
Normal file
33
pkg/functions/grammars/types.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package grammars
|
||||
|
||||
type SchemaConverterType int
|
||||
|
||||
const (
|
||||
JSONSchema SchemaConverterType = iota
|
||||
LLama31Schema
|
||||
)
|
||||
|
||||
const (
|
||||
LlamaType string = "llama3.1"
|
||||
JSONType string = "json"
|
||||
)
|
||||
|
||||
func (s SchemaConverterType) String() string {
|
||||
switch s {
|
||||
case JSONSchema:
|
||||
return JSONType
|
||||
case LLama31Schema:
|
||||
return LlamaType
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func NewType(t string) SchemaConverterType {
|
||||
switch t {
|
||||
case JSONType:
|
||||
return JSONSchema
|
||||
case LlamaType:
|
||||
return LLama31Schema
|
||||
}
|
||||
return JSONSchema
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue