diff --git a/api/api.go b/api/api.go index 85cbef21..e164d442 100644 --- a/api/api.go +++ b/api/api.go @@ -61,11 +61,14 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 app.Use(cors.New()) // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16)) - app.Post("/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/chat/completions", openAIEndpoint(cm, ChatEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/chat/completions", openAIEndpoint(cm, ChatEndpoint, debug, loader, threads, ctxSize, f16)) - app.Post("/v1/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16)) - app.Post("/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/edits", openAIEndpoint(cm, EditEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/edits", openAIEndpoint(cm, EditEndpoint, debug, loader, threads, ctxSize, f16)) + + app.Post("/v1/completions", openAIEndpoint(cm, CompletionEndpoint, debug, loader, threads, ctxSize, f16)) + app.Post("/completions", openAIEndpoint(cm, CompletionEndpoint, debug, loader, threads, ctxSize, f16)) app.Get("/v1/models", listModels(loader, cm)) app.Get("/models", listModels(loader, cm)) diff --git a/api/config.go b/api/config.go index 848f25ca..ea4b335b 100644 --- a/api/config.go +++ b/api/config.go @@ -27,6 +27,7 @@ type Config struct { type TemplateConfig struct { Completion string `yaml:"completion"` Chat string `yaml:"chat"` + Edit string `yaml:"edit"` } type ConfigMerger map[string]Config diff --git a/api/openai.go b/api/openai.go index 3cb9b599..4dda0edb 100644 --- a/api/openai.go +++ b/api/openai.go @@ -8,7 +8,6 @@ import ( "path/filepath" "regexp" "strings" - "sync" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -60,6 +59,10 @@ type OpenAIRequest struct { // Prompt is read only by completion API calls Prompt string `json:"prompt" yaml:"prompt"` + // Edit endpoint + Instruction string `json:"instruction" yaml:"instruction"` + Input string `json:"input" yaml:"input"` + Stop string `json:"stop" yaml:"stop"` // Messages is read only by chat/completion API calls @@ -143,19 +146,164 @@ func updateConfig(config *Config, input *OpenAIRequest) { } } -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} +type EndpointType uint8 -// https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +const ( + ChatEndpoint EndpointType = iota + CompletionEndpoint EndpointType = iota + EditEndpoint EndpointType = iota +) + +func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + modelFile := input.Model + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + } + + var config *Config + cfg, exists := cm[modelFile] + if !exists { + config = &Config{ + OpenAIRequest: defaultRequest(modelFile), + } + } else { + config = &cfg + } + + // Set the parameters for the language model prediction + updateConfig(config, input) + + if threads != 0 { + config.Threads = threads + } + if ctx != 0 { + config.ContextSize = ctx + } + if f16 { + config.F16 = true + } + + if debug { + config.Debug = true + } + + return config, input, nil +} + +func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { + log.Debug().Msgf("Parameter Config: %+v", config) + + predInput := input.Prompt + templateFile := config.Model + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + result := []Choice{} + + n := input.N + + if input.N == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := ModelInference(predInput, loader, *config) + if err != nil { return err } + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return err + } + + prediction = Finetune(*config, predInput, prediction) + + result = append(result, Choice{Text: prediction}) + + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} + +// https://platform.openai.com/docs/api-reference/completions +func openAIEndpoint(cm ConfigMerger, endpointType EndpointType, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + chat := endpointType == ChatEndpoint + completion := endpointType == CompletionEndpoint + edit := endpointType == EditEndpoint + if input.Stream { log.Debug().Msgf("Stream request received") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) @@ -165,68 +313,6 @@ func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader c.Set("Transfer-Encoding", "chunked") } - modelFile := input.Model - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - } - - var config *Config - cfg, exists := cm[modelFile] - if !exists { - config = &Config{ - OpenAIRequest: defaultRequest(modelFile), - } - } else { - config = &cfg - } - - // Set the parameters for the language model prediction - updateConfig(config, input) - - if threads != 0 { - config.Threads = threads - } - if ctx != 0 { - config.ContextSize = ctx - } - if f16 { - config.F16 = true - } - - if debug { - config.Debug = true - } - log.Debug().Msgf("Parameter Config: %+v", config) predInput := input.Prompt @@ -246,21 +332,39 @@ func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader } templateFile := config.Model - if config.TemplateConfig.Chat != "" && chat { + + switch { + case config.TemplateConfig.Chat != "" && chat: templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Completion != "" && !chat { + case config.TemplateConfig.Completion != "" && completion: templateFile = config.TemplateConfig.Completion + case config.TemplateConfig.Edit != "" && edit: + templateFile = config.TemplateConfig.Edit } - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) + if edit { + e := "" + if config.TemplateConfig.Edit == "" { + e = ".edit" + } + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile+e, struct { + Input string + Instruction string + }{Input: input.Input, Instruction: input.Instruction}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + } else { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } } result := []Choice{} @@ -326,6 +430,8 @@ func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader resp.Object = "chat.completion.chunk" } else if chat { resp.Object = "chat.completion" + } else if edit { + resp.Object = "edit" } else { resp.Object = "text_completion" } diff --git a/api/prediction.go b/api/prediction.go index dfa8b603..7d94d6cb 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -2,6 +2,8 @@ package api import ( "fmt" + "regexp" + "strings" "sync" model "github.com/go-skynet/LocalAI/pkg/model" @@ -186,3 +188,29 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri return fn() }, nil } + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + +}