mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: cancel stream generation if client disappears (#792)
This commit is contained in:
parent
72e3e236de
commit
12fe0932c4
12 changed files with 37 additions and 21 deletions
|
@ -1,6 +1,7 @@
|
|||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -14,7 +15,7 @@ import (
|
|||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
@ -66,13 +67,13 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
|||
opts.Prompt = s
|
||||
if tokenCallback != nil {
|
||||
ss := ""
|
||||
err := inferenceModel.PredictStream(o.Context, opts, func(s string) {
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(s string) {
|
||||
tokenCallback(s)
|
||||
ss += s
|
||||
})
|
||||
return ss, err
|
||||
} else {
|
||||
reply, err := inferenceModel.Predict(o.Context, opts)
|
||||
reply, err := inferenceModel.Predict(ctx, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue