diff --git a/api/api.go b/api/api.go index e164d442..38708218 100644 --- a/api/api.go +++ b/api/api.go @@ -61,14 +61,14 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 app.Use(cors.New()) // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(cm, ChatEndpoint, debug, loader, threads, ctxSize, f16)) - app.Post("/chat/completions", openAIEndpoint(cm, ChatEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/v1/edits", openAIEndpoint(cm, EditEndpoint, debug, loader, threads, ctxSize, f16)) - app.Post("/edits", openAIEndpoint(cm, EditEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/v1/completions", openAIEndpoint(cm, CompletionEndpoint, debug, loader, threads, ctxSize, f16)) - app.Post("/completions", openAIEndpoint(cm, CompletionEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Get("/v1/models", listModels(loader, cm)) app.Get("/models", listModels(loader, cm)) diff --git a/api/openai.go b/api/openai.go index 4dda0edb..b2b37287 100644 --- a/api/openai.go +++ b/api/openai.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path/filepath" - "regexp" "strings" model "github.com/go-skynet/LocalAI/pkg/model" @@ -146,19 +145,11 @@ func updateConfig(config *Config, input *OpenAIRequest) { } } -type EndpointType uint8 - -const ( - ChatEndpoint EndpointType = iota - CompletionEndpoint EndpointType = iota - EditEndpoint EndpointType = iota -) - func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { input := new(OpenAIRequest) // Get input data from the request body if err := c.BodyParser(input); err != nil { - return err + return nil, nil, err } modelFile := input.Model @@ -178,7 +169,7 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug log.Debug().Msgf("No model specified, using: %s", modelFile) } else { log.Debug().Msgf("No model specified, returning error") - return nil, fmt.Errorf("no model specified") + return nil, nil, fmt.Errorf("no model specified") } } @@ -192,7 +183,7 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") if _, err := os.Stat(modelConfig); err == nil { if err := cm.LoadConfig(modelConfig); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) } } @@ -226,6 +217,7 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug return config, input, nil } +// https://platform.openai.com/docs/api-reference/completions func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) @@ -251,32 +243,13 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Template found, input modified to: %s", predInput) } - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config) + result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }) if err != nil { return err } - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return err - } - - prediction = Finetune(*config, predInput, prediction) - - result = append(result, Choice{Text: prediction}) - - } - resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, @@ -291,18 +264,29 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } } -// https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(cm ConfigMerger, endpointType EndpointType, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - chat := endpointType == ChatEndpoint - completion := endpointType == CompletionEndpoint - edit := endpointType == EditEndpoint + log.Debug().Msgf("Parameter Config: %+v", config) + + predInput := input.Prompt + + mess := []string{} + for _, i := range input.Messages { + r := config.Roles[i.Role] + if r == "" { + r = i.Role + } + + content := fmt.Sprint(r, " ", i.Content) + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") if input.Stream { log.Debug().Msgf("Stream request received") @@ -313,133 +297,42 @@ func openAIEndpoint(cm ConfigMerger, endpointType EndpointType, debug bool, load c.Set("Transfer-Encoding", "chunked") } - log.Debug().Msgf("Parameter Config: %+v", config) - - predInput := input.Prompt - if chat { - mess := []string{} - for _, i := range input.Messages { - r := config.Roles[i.Role] - if r == "" { - r = i.Role - } - - content := fmt.Sprint(r, " ", i.Content) - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - } - templateFile := config.Model - switch { - case config.TemplateConfig.Chat != "" && chat: + if config.TemplateConfig.Chat != "" { templateFile = config.TemplateConfig.Chat - case config.TemplateConfig.Completion != "" && completion: - templateFile = config.TemplateConfig.Completion - case config.TemplateConfig.Edit != "" && edit: - templateFile = config.TemplateConfig.Edit } - if edit { - e := "" - if config.TemplateConfig.Edit == "" { - e = ".edit" - } - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile+e, struct { - Input string - Instruction string - }{Input: input.Input, Instruction: input.Instruction}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - } else { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) } - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config) + result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { + if input.Stream { + *c = append(*c, Choice{Delta: &Message{Role: "assistant", Content: s}}) + } else { + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) + } + }) if err != nil { return err } - finetunePrediction := func(prediction string) string { - if config.Echo { - prediction = predInput + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - return prediction - } - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return err - } - - prediction = finetunePrediction(prediction) - - if chat { - if input.Stream { - result = append(result, Choice{Delta: &Message{Role: "assistant", Content: prediction}}) - } else { - result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}}) - } - } else { - result = append(result, Choice{Text: prediction}) - } - } - resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, + Object: "chat.completion", } - if input.Stream && chat { - resp.Object = "chat.completion.chunk" - } else if chat { - resp.Object = "chat.completion" - } else if edit { - resp.Object = "edit" - } else { - resp.Object = "text_completion" - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) if input.Stream { + resp.Object = "chat.completion.chunk" + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Handling stream request") c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "event: data\n") @@ -464,10 +357,57 @@ func openAIEndpoint(cm ConfigMerger, endpointType EndpointType, debug bool, load // w.Flush() })) return nil - } else { - // Return the prediction in the response body - return c.JSON(resp) } + + // Return the prediction in the response body + return c.JSON(resp) + } +} + +func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + predInput := input.Input + templateFile := config.Model + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + Instruction string + }{Input: predInput, Instruction: input.Instruction}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }) + if err != nil { + return err + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) } } diff --git a/api/prediction.go b/api/prediction.go index 7d94d6cb..65cfce95 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -189,6 +189,36 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri }, nil } +func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) { + result := []Choice{} + + n := input.N + + if input.N == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := ModelInference(predInput, loader, *config) + if err != nil { + return result, err + } + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, err + } + + prediction = Finetune(*config, predInput, prediction) + cb(prediction, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, err +} + var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var mu sync.Mutex = sync.Mutex{}