From 493828b9d744b840bf0f2bfaa973d975e151f0ba Mon Sep 17 00:00:00 2001 From: samm81 Date: Fri, 2 Jun 2023 13:52:18 -0400 Subject: [PATCH] fix(completionEndpoint): don't remove existing functionality --- api/openai.go | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/api/openai.go b/api/openai.go index 73881e4c..cb935101 100644 --- a/api/openai.go +++ b/api/openai.go @@ -147,7 +147,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{Text: s}}, Object: "text_completion", } @@ -178,7 +178,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") + //c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") @@ -190,22 +190,22 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { templateFile = config.TemplateConfig.Completion } - predInput := config.PromptStrings[0] - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - if input.Stream { if (len(config.PromptStrings) > 1) { return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing") } + predInput := config.PromptStrings[0] + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + responses := make(chan OpenAIResponse) go process(predInput, input, config, o.loader, responses) @@ -223,7 +223,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{FinishReason: "stop"}}, } respData, _ := json.Marshal(resp) @@ -235,11 +235,25 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return nil } - result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) { + var result []Choice + for _, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err + }, nil) + if err != nil { + return err + } + + result = append(result, r...) } resp := &OpenAIResponse{