feat: add /edit endpoint (#119)

This commit is contained in:
Ettore Di Giacinto 2023-04-29 09:22:09 +02:00 committed by GitHub
parent d0ceebc5d7
commit 52f4d993c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 296 additions and 199 deletions

View file

@ -2,6 +2,8 @@ package api
import (
"fmt"
"regexp"
"strings"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -186,3 +188,59 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
return fn()
}, nil
}
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) {
result := []Choice{}
n := input.N
if input.N == 0 {
n = 1
}
// get the model function to call for the result
predFunc, err := ModelInference(predInput, loader, *config)
if err != nil {
return result, err
}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, err
}
prediction = Finetune(*config, predInput, prediction)
cb(prediction, &result)
//result = append(result, Choice{Text: prediction})
}
return result, err
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config Config, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
return prediction
}