feat: cancel stream generation if client disappears (#792)

This commit is contained in:
Aman Gupta Karmani 2023-07-24 14:10:54 -07:00 committed by GitHub
parent 72e3e236de
commit 12fe0932c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 37 additions and 21 deletions

View file

@ -28,7 +28,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- initialMessage
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
@ -43,7 +43,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readInput(c, o.Loader, true)
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -235,7 +235,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
@ -258,7 +263,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return nil
}
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
@ -300,7 +305,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return