refactor: move remaining api packages to core (#1731)

* core 1

* api/openai/files fix

* core 2 - core/config

* move over core api.go and tests to the start of core/http

* move over localai specific endpoints to core/http, begin the service/endpoint split there

* refactor big chunk on the plane

* refactor chunk 2 on plane, next step: port and modify changes to request.go

* easy fixes for request.go, major changes not done yet

* lintfix

* json tag lintfix?

* gitignore and .keep files

* strange fix attempt: rename the config dir?
This commit is contained in:
Dave 2024-03-01 10:19:53 -05:00 committed by GitHub
parent 316de82f51
commit 1c312685aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 1440 additions and 1206 deletions

View file

@ -1,43 +0,0 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := loader.ListModels()
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View file

@ -1,162 +0,0 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
return bm.options.Loader.ShutdownModel(backendId)
}
}

View file

@ -1,326 +0,0 @@
package localai
import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func NewGalleryService(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cm.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoint Service
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View file

@ -1,53 +0,0 @@
package localai
import (
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
)
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View file

@ -1,610 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
result, err := handleQuestion(config, req, o, results[0].arguments, prompt)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
responses <- resp
default:
for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
}
}
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionName: i.Name,
MessageIndex: messageIndex,
}
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
switch {
case toStream:
responses := make(chan schema.OpenAIResponse)
if !processFunctions {
go process(predInput, input, config, o.Loader, responses)
} else {
go processTools(noActionName, predInput, input, config, o.Loader, responses)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
finishReason := "stop"
if toolsCalled {
finishReason = "tool_calls"
} else if toolsCalled && len(input.Tools) == 0 {
finishReason = "function_call"
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if !processFunctions {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, o, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
}
}
func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, prompt, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
return message, nil
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU/GPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
return results
}

View file

@ -1,199 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
log.Debug().Msgf("Parameter Config: %+v", config)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
}
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,94 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,78 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,218 +0,0 @@
package openai
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
var uploadedFiles []File
const uploadedFilesFile = "uploadedFiles.json"
// File represents the structure of a file object from the OpenAI API.
type File struct {
ID string `json:"id"` // Unique identifier for the file
Object string `json:"object"` // Type of the object (e.g., "file")
Bytes int `json:"bytes"` // Size of the file in bytes
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
Filename string `json:"filename"` // The name of the file
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
}
func saveUploadConfig(uploadDir string) {
file, err := json.MarshalIndent(uploadedFiles, "", " ")
if err != nil {
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
}
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
if err != nil {
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
}
}
func LoadUploadConfig(uploadPath string) {
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
_, err := os.Stat(uploadFilePath)
if os.IsNotExist(err) {
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
return
}
file, err := os.ReadFile(uploadFilePath)
if err != nil {
log.Error().Msgf("Failed to read file: %s", err)
} else {
err = json.Unmarshal(file, &uploadedFiles)
if err != nil {
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
}
}
}
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := c.FormFile("file")
if err != nil {
return err
}
// Check the file size
if file.Size > int64(o.UploadLimitMB*1024*1024) {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB))
}
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
if purpose == "" {
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
}
// Sanitize the filename to prevent directory traversal
filename := utils.SanitizeFileName(file.Filename)
savePath := filepath.Join(o.UploadDir, filename)
// Check if file already exists
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
}
err = c.SaveFile(file, savePath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
}
f := File{
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
Object: "file",
Bytes: int(file.Size),
CreatedAt: time.Now(),
Filename: file.Filename,
Purpose: purpose,
}
uploadedFiles = append(uploadedFiles, f)
saveUploadConfig(o.UploadDir)
return c.Status(fiber.StatusOK).JSON(f)
}
}
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
type ListFiles struct {
Data []File
Object string
}
return func(c *fiber.Ctx) error {
var listFiles ListFiles
purpose := c.Query("purpose")
if purpose == "" {
listFiles.Data = uploadedFiles
} else {
for _, f := range uploadedFiles {
if purpose == f.Purpose {
listFiles.Data = append(listFiles.Data, f)
}
}
}
listFiles.Object = "list"
return c.Status(fiber.StatusOK).JSON(listFiles)
}
}
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
id := c.Params("file_id")
if id == "" {
return nil, fmt.Errorf("file_id parameter is required")
}
for _, f := range uploadedFiles {
if id == f.ID {
return &f, nil
}
}
return nil, fmt.Errorf("unable to find file id %s", id)
}
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(file)
}
}
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
type DeleteStatus struct {
Id string
Object string
Deleted bool
}
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
err = os.Remove(filepath.Join(o.UploadDir, file.Filename))
if err != nil {
// If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
}
}
// Remove upload from list
for i, f := range uploadedFiles {
if f.ID == file.ID {
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
break
}
}
saveUploadConfig(o.UploadDir)
return c.JSON(DeleteStatus{
Id: file.ID,
Object: "file",
Deleted: true,
})
}
}
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename))
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.Send(fileContents)
}
}

View file

@ -1,287 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"testing"
)
type ListFiles struct {
Data []File
Object string
}
func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) {
// Preparing the mocked objects
loader = &config.ConfigLoader{}
option = &options.Option{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app = fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
return
}
func TestUploadFileExceedSizeLimit(t *testing.T) {
// Preparing the mocked objects
loader := &config.ConfigLoader{}
option := &options.Option{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app := fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
})
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
})
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
fmt.Println(f1)
fmt.Printf("ERror: %v", err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "File already exists")
})
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
// Check if file exists in the disk
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
_, err := os.Stat(filePath)
assert.False(t, os.IsNotExist(err))
assert.Equal(t, file.Bytes, 5242880)
assert.NotEmpty(t, file.CreatedAt)
assert.Equal(t, file.Filename, "test.txt")
assert.Equal(t, file.Purpose, "fine-tune")
})
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != len(uploadedFiles) {
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
assert.NoError(t, err)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 1 {
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
req := httptest.NewRequest("GET", "/files", nil)
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
var listFiles ListFiles
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
t.Errorf("Failed to decode response: %v", err)
return
}
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
}
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
var target string
if purpose != "" {
target = fmt.Sprintf("/files?purpose=%s", purpose)
} else {
target = "/files"
}
req := httptest.NewRequest("GET", target, nil)
return app.Test(req)
}
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
return app.Test(request)
}
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, o)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
return app.Test(req)
}
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, o)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
resp, err := app.Test(req)
assert.NoError(t, err)
f := responseToFile(t, resp)
id := f.ID
t.Cleanup(func() {
_, err := CallFilesDeleteEndpoint(t, app, id)
assert.NoError(t, err)
})
return f
}
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
target := fmt.Sprintf("/files/%s", fileId)
req := httptest.NewRequest(http.MethodDelete, target, nil)
return app.Test(req)
}
// Helper to create multi-part file
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
body := new(strings.Builder)
writer := multipart.NewWriter(body)
file, _ := os.Open(filePath)
defer file.Close()
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
io.Copy(part, file)
if purpose != "" {
_ = writer.WriteField("purpose", purpose)
}
writer.Close()
return strings.NewReader(body.String()), writer
}
// Helper to create test files
func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File {
err := os.MkdirAll(option.UploadDir, 0755)
if err != nil {
t.Fatalf("Error MKDIR: %v", err)
}
file, _ := os.Create(name)
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
t.Cleanup(func() {
os.Remove(name)
os.RemoveAll(option.UploadDir)
})
return file
}
func bodyToString(resp *http.Response, t *testing.T) string {
return string(bodyToByteArray(resp, t))
}
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
return bodyBytes
}
func responseToFile(t *testing.T, resp *http.Response) File {
var file File
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
if err != nil {
t.Errorf("Failed to decode response: %s", err)
}
return file
}
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
var listFiles ListFiles
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
if err != nil {
fmt.Printf("Failed to decode response: %s", err)
}
return listFiles
}

View file

@ -1,239 +0,0 @@
package openai
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "image")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}
// https://platform.openai.com/docs/api-reference/images/create
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if m == "" {
m = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
src := ""
if input.File != "" {
fileData := []byte{}
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
out, err := downloadFile(input.File)
if err != nil {
return fmt.Errorf("failed downloading file:%w", err)
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
return fmt.Errorf("failed reading file:%w", err)
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(input.File)
if err != nil {
return err
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
if err != nil {
return err
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
return err
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return fmt.Errorf("Invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = o.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return err
}
baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
if err != nil {
return err
}
if err := fn(); err != nil {
return err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
}
result = append(result, *item)
}
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,55 +0,0 @@
package openai
import (
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.Config,
o *options.Option,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}

View file

@ -1,69 +0,0 @@
package openai
import (
"regexp"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
return err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
filter := c.Query("filter")
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations
for _, c := range cm.GetAllConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return c.JSON(struct {
Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
}
}

View file

@ -1,280 +0,0 @@
package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
config "github.com/go-skynet/LocalAI/core/config"
options "github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel)
return modelFile, input, err
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice grammar.Tool
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
return cfg, input, err
}

View file

@ -1,71 +0,0 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer f.Close()
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return err
}
defer os.RemoveAll(dir)
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
if err != nil {
return err
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
}
}