fix(completionEndpoint): don't remove existing functionality

This commit is contained in:
samm81 2023-06-02 13:52:18 -04:00
parent cd6a0b1ff8
commit 493828b9d7
No known key found for this signature in database

View file

@ -147,7 +147,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Text: s}},
Object: "text_completion",
}
@ -178,7 +178,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
@ -190,22 +190,22 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
templateFile = config.TemplateConfig.Completion
}
predInput := config.PromptStrings[0]
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
if input.Stream {
if (len(config.PromptStrings) > 1) {
return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing")
}
predInput := config.PromptStrings[0]
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
responses := make(chan OpenAIResponse)
go process(predInput, input, config, o.loader, responses)
@ -223,7 +223,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
}
resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{FinishReason: "stop"}},
}
respData, _ := json.Marshal(resp)
@ -235,11 +235,25 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return nil
}
result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) {
var result []Choice
for _, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: i})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
return err
}, nil)
if err != nil {
return err
}
result = append(result, r...)
}
resp := &OpenAIResponse{