refactor: move remaining api packages to core (#1731)

* core 1

* api/openai/files fix

* core 2 - core/config

* move over core api.go and tests to the start of core/http

* move over localai specific endpoints to core/http, begin the service/endpoint split there

* refactor big chunk on the plane

* refactor chunk 2 on plane, next step: port and modify changes to request.go

* easy fixes for request.go, major changes not done yet

* lintfix

* json tag lintfix?

* gitignore and .keep files

* strange fix attempt: rename the config dir?
This commit is contained in:
Dave 2024-03-01 10:19:53 -05:00 committed by GitHub
parent 316de82f51
commit 1c312685aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 1440 additions and 1206 deletions

View file

@ -0,0 +1,609 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
responses <- resp
default:
for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
}
}
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionName: i.Name,
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
switch {
case toStream:
responses := make(chan schema.OpenAIResponse)
if !processFunctions {
go process(predInput, input, config, ml, responses)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
finishReason := "stop"
if toolsCalled {
finishReason = "tool_calls"
} else if toolsCalled && len(input.Tools) == 0 {
finishReason = "function_call"
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !processFunctions {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
}
}
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, prompt, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
return message, nil
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU/GPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
return results
}

View file

@ -0,0 +1,199 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
log.Debug().Msgf("Parameter Config: %+v", config)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
}
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -0,0 +1,94 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -0,0 +1,79 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -0,0 +1,218 @@
package openai
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
var uploadedFiles []File
const uploadedFilesFile = "uploadedFiles.json"
// File represents the structure of a file object from the OpenAI API.
type File struct {
ID string `json:"id"` // Unique identifier for the file
Object string `json:"object"` // Type of the object (e.g., "file")
Bytes int `json:"bytes"` // Size of the file in bytes
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
Filename string `json:"filename"` // The name of the file
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
}
func saveUploadConfig(uploadDir string) {
file, err := json.MarshalIndent(uploadedFiles, "", " ")
if err != nil {
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
}
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
if err != nil {
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
}
}
func LoadUploadConfig(uploadPath string) {
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
_, err := os.Stat(uploadFilePath)
if os.IsNotExist(err) {
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
return
}
file, err := os.ReadFile(uploadFilePath)
if err != nil {
log.Error().Msgf("Failed to read file: %s", err)
} else {
err = json.Unmarshal(file, &uploadedFiles)
if err != nil {
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
}
}
}
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := c.FormFile("file")
if err != nil {
return err
}
// Check the file size
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
}
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
if purpose == "" {
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
}
// Sanitize the filename to prevent directory traversal
filename := utils.SanitizeFileName(file.Filename)
savePath := filepath.Join(appConfig.UploadDir, filename)
// Check if file already exists
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
}
err = c.SaveFile(file, savePath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
}
f := File{
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
Object: "file",
Bytes: int(file.Size),
CreatedAt: time.Now(),
Filename: file.Filename,
Purpose: purpose,
}
uploadedFiles = append(uploadedFiles, f)
saveUploadConfig(appConfig.UploadDir)
return c.Status(fiber.StatusOK).JSON(f)
}
}
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type ListFiles struct {
Data []File
Object string
}
return func(c *fiber.Ctx) error {
var listFiles ListFiles
purpose := c.Query("purpose")
if purpose == "" {
listFiles.Data = uploadedFiles
} else {
for _, f := range uploadedFiles {
if purpose == f.Purpose {
listFiles.Data = append(listFiles.Data, f)
}
}
}
listFiles.Object = "list"
return c.Status(fiber.StatusOK).JSON(listFiles)
}
}
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
id := c.Params("file_id")
if id == "" {
return nil, fmt.Errorf("file_id parameter is required")
}
for _, f := range uploadedFiles {
if id == f.ID {
return &f, nil
}
}
return nil, fmt.Errorf("unable to find file id %s", id)
}
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(file)
}
}
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type DeleteStatus struct {
Id string
Object string
Deleted bool
}
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
// If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
}
}
// Remove upload from list
for i, f := range uploadedFiles {
if f.ID == file.ID {
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
break
}
}
saveUploadConfig(appConfig.UploadDir)
return c.JSON(DeleteStatus{
Id: file.ID,
Object: "file",
Deleted: true,
})
}
}
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.Send(fileContents)
}
}

View file

@ -0,0 +1,287 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"github.com/go-skynet/LocalAI/core/config"
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"testing"
)
type ListFiles struct {
Data []File
Object string
}
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
// Preparing the mocked objects
loader = &config.BackendConfigLoader{}
option = &config.ApplicationConfig{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app = fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
return
}
func TestUploadFileExceedSizeLimit(t *testing.T) {
// Preparing the mocked objects
loader := &config.BackendConfigLoader{}
option := &config.ApplicationConfig{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app := fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
})
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
})
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
fmt.Println(f1)
fmt.Printf("ERror: %v", err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "File already exists")
})
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
// Check if file exists in the disk
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
_, err := os.Stat(filePath)
assert.False(t, os.IsNotExist(err))
assert.Equal(t, file.Bytes, 5242880)
assert.NotEmpty(t, file.CreatedAt)
assert.Equal(t, file.Filename, "test.txt")
assert.Equal(t, file.Purpose, "fine-tune")
})
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != len(uploadedFiles) {
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
assert.NoError(t, err)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 1 {
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
req := httptest.NewRequest("GET", "/files", nil)
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
var listFiles ListFiles
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
t.Errorf("Failed to decode response: %v", err)
return
}
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
}
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
var target string
if purpose != "" {
target = fmt.Sprintf("/files?purpose=%s", purpose)
} else {
target = "/files"
}
req := httptest.NewRequest("GET", target, nil)
return app.Test(req)
}
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
return app.Test(request)
}
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
return app.Test(req)
}
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
resp, err := app.Test(req)
assert.NoError(t, err)
f := responseToFile(t, resp)
id := f.ID
t.Cleanup(func() {
_, err := CallFilesDeleteEndpoint(t, app, id)
assert.NoError(t, err)
})
return f
}
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
target := fmt.Sprintf("/files/%s", fileId)
req := httptest.NewRequest(http.MethodDelete, target, nil)
return app.Test(req)
}
// Helper to create multi-part file
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
body := new(strings.Builder)
writer := multipart.NewWriter(body)
file, _ := os.Open(filePath)
defer file.Close()
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
io.Copy(part, file)
if purpose != "" {
_ = writer.WriteField("purpose", purpose)
}
writer.Close()
return strings.NewReader(body.String()), writer
}
// Helper to create test files
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
err := os.MkdirAll(option.UploadDir, 0755)
if err != nil {
t.Fatalf("Error MKDIR: %v", err)
}
file, _ := os.Create(name)
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
t.Cleanup(func() {
os.Remove(name)
os.RemoveAll(option.UploadDir)
})
return file
}
func bodyToString(resp *http.Response, t *testing.T) string {
return string(bodyToByteArray(resp, t))
}
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
return bodyBytes
}
func responseToFile(t *testing.T, resp *http.Response) File {
var file File
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
if err != nil {
t.Errorf("Failed to decode response: %s", err)
}
return file
}
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
var listFiles ListFiles
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
if err != nil {
fmt.Printf("Failed to decode response: %s", err)
}
return listFiles
}

View file

@ -0,0 +1,239 @@
package openai
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/backend"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "image")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}
// https://platform.openai.com/docs/api-reference/images/create
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if m == "" {
m = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
src := ""
if input.File != "" {
fileData := []byte{}
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
out, err := downloadFile(input.File)
if err != nil {
return fmt.Errorf("failed downloading file:%w", err)
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
return fmt.Errorf("failed reading file:%w", err)
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(input.File)
if err != nil {
return err
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
if err != nil {
return err
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
return err
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return fmt.Errorf("invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return fmt.Errorf("invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return fmt.Errorf("invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = appConfig.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return err
}
baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
if err != nil {
return err
}
if err := fn(); err != nil {
return err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
}
result = append(result, *item)
}
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -0,0 +1,55 @@
package openai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.BackendConfig,
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}

View file

@ -0,0 +1,69 @@
package openai
import (
"regexp"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := ml.ListModels()
if err != nil {
return err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
filter := c.Query("filter")
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations
for _, c := range cl.GetAllBackendConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return c.JSON(struct {
Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
}
}

View file

@ -0,0 +1,281 @@
package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
return modelFile, input, err
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice grammar.Tool
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
return cfg, input, err
}

View file

@ -0,0 +1,71 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer f.Close()
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return err
}
defer os.RemoveAll(dir)
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
if err != nil {
return err
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
}
}