mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-27 22:15:00 +00:00
feat(template): read jinja templates from gguf files (#4332)
* Read jinja templates as fallback Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move templating out of model loader Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Test TemplateMessages Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Set role and content from transformers Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Tests: be more flexible Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * More jinja Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small refactoring and adaptations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
f5e1527a5a
commit
cea5a0ea42
23 changed files with 971 additions and 785 deletions
|
@ -11,59 +11,41 @@ import (
|
|||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
|
||||
"github.com/nikolalohinski/gonja/v2"
|
||||
"github.com/nikolalohinski/gonja/v2/exec"
|
||||
)
|
||||
|
||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||
type TemplateType int
|
||||
|
||||
type TemplateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
type templateCache struct {
|
||||
mu sync.Mutex
|
||||
templatesPath string
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||
}
|
||||
|
||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
||||
tc := &TemplateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
func newTemplateCache(templatesPath string) *templateCache {
|
||||
tc := &templateCache{
|
||||
templatesPath: templatesPath,
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.templates[tt]; !ok {
|
||||
tc.templates[tt] = make(map[string]*template.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateName]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateName)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||
return utils.ExistsInPath(tc.templatesPath, s)
|
||||
}
|
||||
|
||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||
|
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if tc.existsInModelPath(modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dat = string(d)
|
||||
} else {
|
||||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
// skip any error here - we run anyway if a template does not exist
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||
|
||||
dat := ""
|
||||
file := filepath.Join(tc.templatesPath, modelTemplateFile)
|
||||
|
||||
// Security check
|
||||
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
|
||||
return fmt.Errorf("template file outside path: %s", file)
|
||||
}
|
||||
|
||||
// can either be a file in the system or a string with the template
|
||||
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
||||
d, err := os.ReadFile(file)
|
||||
|
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
|||
dat = templateName
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||
tmpl, err := gonja.FromString(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tc.templates[templateType][templateName] = tmpl
|
||||
tc.jinjaTemplates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeJinjaTemplateMapKey(templateType)
|
||||
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
data := exec.NewContext(in)
|
||||
|
||||
if err := m.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.initializeTemplateMapKey(templateType)
|
||||
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue