From 39c9e1f9554237a25ccfd321d228594445048f32 Mon Sep 17 00:00:00 2001 From: krishnaduttPanchagnula Date: Tue, 30 May 2023 12:59:44 +0530 Subject: [PATCH] updated code to include streaming --- api/openai.go | 96 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 31 deletions(-) diff --git a/api/openai.go b/api/openai.go index 73e4fd81..d6073bff 100644 --- a/api/openai.go +++ b/api/openai.go @@ -157,12 +157,77 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { log.Debug().Msgf("Parameter Config: %+v", config) + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + templateFile := config.Model if config.TemplateConfig.Completion != "" { templateFile = config.TemplateConfig.Completion } + if input.Stream { + responses := make(chan OpenAIResponse) + + go func() { + defer close(responses) + for _, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct{ Input string }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input, config, o.loader, nil, nil) + if err != nil { + log.Error().Err(err).Msg("Error computing choices") + continue + } + + resp := OpenAIResponse{ + Model: input.Model, + Choices: r, + Object: "text_completion", + } + responses <- resp + + } + }() + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{FinishReason: "stop"}}, + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + + })) + + return nil + } + var result []Choice for _, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix @@ -184,37 +249,6 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { result = append(result, r...) } - if input.Stream { - responses := make(chan OpenAIResponse) - - go func() { - defer close(responses) - for _, r := range result { - responses <- OpenAIResponse{ - Model: input.Model, - Choices: []Choice{r}, - Object: "text_completion", - } - } - }() - - c.Context().SetContentType("text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - })) - } resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result,