Usage Features (#863)

This commit is contained in:
Dave 2023-08-18 15:23:14 -04:00 committed by GitHub
parent 2bacd0180d
commit 8cb1061c11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 1222 additions and 317 deletions

View file

@ -15,7 +15,17 @@ import (
"github.com/go-skynet/LocalAI/pkg/utils"
)
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
type TokenUsage struct {
Prompt int
Completion int
}
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
@ -70,40 +80,56 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (string, error) {
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost, but default to on.
if !c.FeatureFlag["usage"] {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(ctx, opts, func(s []byte) {
tokenCallback(string(s))
tokenCallback(string(s), tokenUsage)
ss += string(s)
})
return ss, err
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return "", err
return LLMResponse{}, err
}
return string(reply.Message), err
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
return fn()
}, nil
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)