Revert "[Refactor]: Core/API Split" (#1550)

Revert "[Refactor]: Core/API Split (#1506)"

This reverts commit ab7b4d5ee9.
This commit is contained in:
Ettore Di Giacinto 2024-01-05 12:04:46 -05:00 committed by GitHub
parent ab7b4d5ee9
commit db926896bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
77 changed files with 3132 additions and 3456 deletions

View file

@ -1,144 +0,0 @@
package backend
import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel interface{}
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(c.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
var fn func() ([]float32, error)
switch model := inferenceModel.(type) {
case *grpc.Client:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
for _, t := range tokens {
embeds = append(embeds, int32(t))
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(o.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
predictOptions.Embeddings = s
res, err := model.Embeddings(o.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
default:
fn = func() ([]float32, error) {
return nil, fmt.Errorf("embeddings not supported by the backend")
}
}
return func() ([]float32, error) {
embeds, err := fn()
if err != nil {
return embeds, err
}
// Remove trailing 0s
for i := len(embeds) - 1; i >= 0; i-- {
if embeds[i] == 0.0 {
embeds = embeds[:i]
} else {
break
}
}
return embeds, nil
}, nil
}
func EmbeddingOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions)
if err != nil {
return nil, err
}
embeddings, err := embedFn()
if err != nil {
return nil, err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions)
if err != nil {
return nil, err
}
embeddings, err := embedFn()
if err != nil {
return nil, err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}, nil
}

View file

@ -1,210 +0,0 @@
package backend
import (
"encoding/base64"
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}
func ImageGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
id := uuid.New().String()
created := int(time.Now().Unix())
if modelName == "" {
modelName = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", modelName)
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request: %w", err)
}
src := ""
if input.File != "" {
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src")
if err != nil {
return nil, fmt.Errorf("failed downloading file:%w", err)
}
} else {
src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src")
if err != nil {
return nil, fmt.Errorf("error creating temporary image source file: %w", err)
}
}
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return nil, fmt.Errorf("invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return nil, fmt.Errorf("invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return nil, fmt.Errorf("invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = startupOptions.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return nil, err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return nil, err
}
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions)
if err != nil {
return nil, err
}
if err := fn(); err != nil {
return nil, err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return nil, err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = path.Join(startupOptions.ImageDir, base)
}
result = append(result, *item)
}
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}, nil
}

View file

@ -1,861 +0,0 @@
package backend
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
////////// TYPES //////////////
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
// TODO: Test removing this and using the variant in pkg/schema someday?
type TokenUsage struct {
Prompt int
Completion int
}
type TemplateConfigBindingFn func(*schema.Config) *string
// type LLMStreamProcessor func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse)
/////// CONSTS ///////////
const DEFAULT_NO_ACTION_NAME = "answer"
const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action"
////// INFERENCE /////////
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel *grpc.Client
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Images = images
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
if r == utf8.RuneError {
// incomplete rune, wait for more bytes
break
}
tokenCallback(string(r), tokenUsage)
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return LLMResponse{}, err
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config schema.Config, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
for _, c := range config.TrimSuffix {
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
}
return prediction
}
////// CONFIG AND REQUEST HANDLING ///////////////
func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *schema.OpenAIRequest, cm *services.ConfigLoader, startupOptions *schema.StartupOptions) (*schema.Config, *schema.OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml")
var cfg *schema.Config
defaults := func() {
cfg = schema.DefaultConfig(modelFile)
cfg.ContextSize = startupOptions.ContextSize
cfg.Threads = startupOptions.Threads
cfg.F16 = startupOptions.F16
cfg.Debug = startupOptions.Debug
}
cfgExisting, exists := cm.GetConfig(modelFile)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelFile)
if exists {
cfg = &cfgExisting
} else {
defaults()
}
} else {
defaults()
}
} else {
cfg = &cfgExisting
}
// Set the parameters for the language model prediction
schema.UpdateConfigFromOpenAIRequest(cfg, input)
// Don't allow 0 as setting
if cfg.Threads == 0 {
if startupOptions.Threads != 0 {
cfg.Threads = startupOptions.Threads
} else {
cfg.Threads = 4
}
}
// Enforce debug flag if passed from CLI
if startupOptions.Debug {
cfg.Debug = true
}
return cfg, input, nil
}
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *schema.Config,
o *schema.StartupOptions,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, TokenUsage) bool) ([]schema.Choice, TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
if err != nil {
return result, TokenUsage{}, err
}
tokenUsage := TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}
// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below?
func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
log.Debug().Msgf("Parameter Config: %+v", config)
configTemplate := bindingFn(config)
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) {
*configTemplate = config.Model
}
if *configTemplate == "" {
return nil, fmt.Errorf(("failed to find templateConfig"))
}
return config, nil
}
////////// SPECIFIC REQUESTS //////////////
// TODO: For round one of the refactor, give each of the three primary text endpoints their own function?
// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split
// Can cleanup into a common form later if possible easier if they are all here for now
// If they remain different, extract each of these named segments to a seperate file
func prepareChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, string, bool, error) {
// IMPORTANT DEFS
funcs := grammar.Functions{}
// The Basic Begining
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Special Input/Config Handling
// Allow the user to set custom actions via config file
// to be "embedded" in each model - but if they are missing, use defaults.
if config.FunctionsConfig.NoActionFunctionName == "" {
config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME
}
if config.FunctionsConfig.NoActionDescriptionName == "" {
config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions()
if processFunctions {
log.Debug().Msgf("Response needs to process functions")
noActionGrammar := grammar.Function{
Name: config.FunctionsConfig.NoActionFunctionName,
Description: config.FunctionsConfig.NoActionDescriptionName,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
}
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
return config, predInput, processFunctions, nil
}
func EditGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Edit
}
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
var result []schema.Choice
totalTokenUsage := TokenUsage{}
for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return nil, err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}, nil
}
func ChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
// DEFS
id := uuid.New().String()
created := int(time.Now().Unix())
// Prepare
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
ss["arguments"] = string(d)
ss["name"] = func_name
// if do nothing, reply with a message
if func_name == config.FunctionsConfig.NoActionFunctionName {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
fineTunedResponse := Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return nil, err
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}, nil
}
func CompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
// Prepare
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Completion
}
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
var result []schema.Choice
totalTokenUsage := TokenUsage{}
for k, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(
input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return nil, err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}, nil
}
func StreamingChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
// DEFS
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
// Prepare
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
if processFunctions {
// TODO: unused variable means I did something wrong. investigate once stable
log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name)
}
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel")
responses := make(chan schema.OpenAIResponse)
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine")
go processor(predInput, input, config, ml, responses)
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!")
return responses, nil
}
func StreamingCompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
// DEFS
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Completion
}
// Prepare
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
if len(config.PromptStrings) > 1 {
return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
//A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel")
responses := make(chan schema.OpenAIResponse)
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine")
go processor(predInput, input, config, ml, responses)
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!")
return responses, nil
}

View file

@ -1,125 +0,0 @@
package backend
import (
"os"
"path/filepath"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
)
func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) []model.Option {
if o.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
if o.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
return opts
}
func gRPCModelOpts(c schema.Config) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
return &pb.ModelOptions{
ContextSize: int32(c.ContextSize),
Seed: int32(c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
CUDA: c.CUDA, // diffusers, transformers
DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath,
Quantization: c.Quantization,
MMProj: c.MMProj,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
YarnBetaFast: c.YarnBetaFast,
YarnBetaSlow: c.YarnBetaSlow,
LoraAdapter: c.LoraAdapter,
LoraBase: c.LoraBase,
LoraScale: c.LoraScale,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
F16Memory: c.F16,
MLock: c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: c.Embeddings,
LowVRAM: c.LowVRAM,
NGPULayers: int32(c.NGPULayers),
MMap: c.MMap,
MainGPU: c.MainGPU,
Threads: int32(c.Threads),
TensorSplit: c.TensorSplit,
// AutoGPTQ
ModelBaseName: c.AutoGPTQ.ModelBaseName,
Device: c.AutoGPTQ.Device,
UseTriton: c.AutoGPTQ.Triton,
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
// RWKV
Tokenizer: c.Tokenizer,
}
}
func gRPCPredictOpts(c schema.Config, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}
return &pb.PredictOptions{
Temperature: float32(c.Temperature),
TopP: float32(c.TopP),
NDraft: c.NDraft,
TopK: int32(c.TopK),
Tokens: int32(c.Maxtokens),
Threads: int32(c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: c.F16,
DebugMode: c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(c.LLMConfig.Mirostat),
MirostatETA: float32(c.LLMConfig.MirostatETA),
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
Debug: c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(c.Seed),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: c.MMlock,
MMap: c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(c.TFZ),
TypicalP: float32(c.TypicalP),
}
}

View file

@ -1,52 +0,0 @@
package backend
import (
"context"
"fmt"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (*schema.WhisperResult, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
whisperModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}
func TranscriptionOpenAIRequest(modelName string, input *schema.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.WhisperResult, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions)
if err != nil {
return nil, err
}
return tr, nil
}

View file

@ -1,79 +0,0 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *schema.StartupOptions) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
opts := modelOpts(schema.Config{}, o, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
piperModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil {
return "", nil, err
}
} else {
modelPath = modelFile
}
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Dst: filePath,
})
return filePath, res, err
}

View file

@ -1,169 +0,0 @@
package http
import (
"errors"
"strings"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
)
func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if options.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
app.Use(recover.New())
if options.Metrics != nil {
app.Use(localai.MetricsAPIMiddleware(options.Metrics))
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 {
return c.Next()
}
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := services.NewGalleryApplier(options.ModelPath)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.Get("/metrics", localai.MetricsHandler())
backendMonitor := services.NewBackendMonitor(cl, ml, options)
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// model listing
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
return app, nil
}

View file

@ -1,867 +0,0 @@
package http_test
import (
"bytes"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
server "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"`
Name string `json:"name"`
Overrides map[string]interface{} `json:"overrides"`
}
func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
fmt.Println("Error creating request:", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
func getModels(url string) (response []gallery.GalleryModel) {
utils.GetURI(url, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
//url := "http://localhost:AI/models/apply"
// Create the request payload
payload, err := json.Marshal(request)
if err != nil {
fmt.Println("Error marshaling JSON:", err)
return
}
// Create the HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
fmt.Println("Error creating request:", err)
return
}
req.Header.Set("Content-Type", "application/json")
// Make the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error making request:", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
//go:embed backend-assets/*
var backendAssets embed.FS
var _ = Describe("API test", func() {
var app *fiber.App
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
var cancel context.CancelFunc
var tmpdir string
commonOpts := []schema.AppOption{
schema.WithDebug(true),
schema.WithDisableMessage(true),
}
Context("API with ephemeral models", func() {
BeforeEach(func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{
{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
},
{
Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
},
}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = os.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
},
}
metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
cl, ml, options, err := startup.Startup(
append(commonOpts,
schema.WithMetrics(metricsService),
schema.WithContext(c),
schema.WithGalleries(galleries),
schema.WithModelPath(tmpdir),
schema.WithBackendAssets(backendAssets),
schema.WithBackendAssetsOutput(tmpdir))...)
Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
Context("Applying models", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
resp := map[string]interface{}{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
models = getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
for _, m := range models {
if m.Name == "bert2" {
Expect(m.Installed).To(BeTrue())
} else {
Expect(m.Installed).To(BeFalse())
}
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert",
Overrides: map[string]interface{}{
"backend": "llama",
},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("llama"))
})
It("apply models without overrides", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert",
Overrides: map[string]interface{}{},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
})
It("runs openllama(llama-ggml backend)", Label("llama"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
Name: "openllama_3b",
Overrides: map[string]interface{}{"backend": "llama-ggml", "mmap": true, "f16": true, "context_size": 128},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
By("testing completion")
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
By("testing functions")
resp2, err := client.CreateChatCompletion(
context.TODO(),
openai.ChatCompletionRequest{
Model: "openllama_3b",
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "What is the weather like in San Francisco (celsius)?",
},
},
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition{
Name: "get_current_weather",
Description: "Get the current weather",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celcius", "fahrenheit"},
},
},
Required: []string{"location"},
},
},
},
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
})
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
modelName := "codellama"
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml",
Name: modelName,
Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true, "context_size": 128},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
By("testing chat")
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "How much is 2+2?",
},
}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")))
By("testing functions")
resp2, err := client.CreateChatCompletion(
context.TODO(),
openai.ChatCompletionRequest{
Model: modelName,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "What is the weather like in San Francisco (celsius)?",
},
},
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition{
Name: "get_current_weather",
Description: "Get the current weather",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celcius", "fahrenheit"},
},
},
Required: []string{"location"},
},
},
},
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
})
It("runs gpt4all", Label("gpt4all"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
Name: "gpt4all-j",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "960s", "10s").Should(Equal(true))
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
})
})
})
Context("Model gallery", func() {
BeforeEach(func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{
{
Name: "model-gallery",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",
},
}
metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
cl, ml, options, err := startup.Startup(
append(commonOpts,
schema.WithContext(c),
schema.WithMetrics(metricsService),
schema.WithAudioDir(tmpdir),
schema.WithImageDir(tmpdir),
schema.WithGalleries(galleries),
schema.WithModelPath(tmpdir),
schema.WithBackendAssets(backendAssets),
schema.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
It("installs and is capable to run tts", Label("tts"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@voice-en-us-kathleen-low",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
// An HTTP Post to the /tts endpoint should return a wav audio file
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`)))
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
})
It("installs and is capable to generate images", Label("stablediffusion"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@stablediffusion",
Overrides: map[string]interface{}{
"parameters": map[string]interface{}{"model": "stablediffusion_assets"},
},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
resp, err := http.Post(
"http://127.0.0.1:9090/v1/images/generations",
"application/json",
bytes.NewBuffer([]byte(`{
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
"mode": 2, "seed":9000,
"size": "256x256", "n":2}`)))
// The response should contain an URL
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), string(dat))
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
})
})
Context("API query", func() {
BeforeEach(func() {
c, cancel = context.WithCancel(context.Background())
metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
cl, ml, options, err := startup.Startup(
append(commonOpts,
schema.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
schema.WithContext(c),
schema.WithModelPath(os.Getenv("MODELS_PATH")),
schema.WithMetrics(metricsService),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("returns the models list", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8?
})
It("can generate completions", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
It("can generate chat completions ", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate completions from model configs", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: "abcdedfghikl"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
It("can generate chat completions from model configs", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("returns errors", func() {
backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
})
It("transcribes audio", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateTranscription(
context.Background(),
openai.AudioRequest{
Model: openai.Whisper1,
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
})
It("calculate embeddings", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred(), err)
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
})
Context("External gRPC calls", func() {
It("calculate embeddings with sentencetransformers", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
})
})
Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})
})
Context("Config file", func() {
BeforeEach(func() {
c, cancel = context.WithCancel(context.Background())
metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
cl, ml, options, err := startup.Startup(
append(commonOpts,
schema.WithContext(c),
schema.WithMetrics(metricsService),
schema.WithModelPath(os.Getenv("MODELS_PATH")),
schema.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("can generate chat completions from config file (list1)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate chat completions from config file (list2)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate edit completions from config file", func() {
request := openaigo.EditCreateRequestBody{
Model: "list2",
Instruction: "foo",
Input: "bar",
}
resp, err := client2.CreateEdit(context.Background(), request)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
})
})

View file

@ -1,13 +0,0 @@
package http_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI test suite")
}

View file

@ -1,34 +0,0 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
)
func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View file

@ -1,148 +0,0 @@
package localai
import (
"encoding/json"
"fmt"
"slices"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
/// Endpoint Service
type ModelGalleryEndpointService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *services.GalleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- gallery.GalleryOp{
Req: input.GalleryModel,
Id: uuid.String(),
GalleryName: input.ID,
Galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View file

@ -1,42 +0,0 @@
package localai
import (
"time"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func MetricsHandler() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metrics *schema.LocalAIMetrics
}
func MetricsAPIMiddleware(metrics *schema.LocalAIMetrics) fiber.Handler {
cfg := apiMiddlewareConfig{
metrics: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metrics.ObserveAPICall(method, path, elapsed)
return err
}
}

View file

@ -1,25 +0,0 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
)
func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View file

@ -1,97 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) func(c *fiber.Ctx) error {
emptyMessage := ""
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, startupOptions, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// The scary comment I feel like I forgot about along the way:
//
// functions are not supported in stream mode (yet?)
//
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return fmt.Errorf("failed establishing streaming chat request :%w", err)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
id := ""
created := 0
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
id = ev.ID
created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
//////////////////////////////////////////
resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return fmt.Errorf("error generating chat request: +%w", err)
}
respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this?
log.Debug().Msgf("Response: %s", respData)
return c.JSON(resp)
}
}

View file

@ -1,91 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("failed establishing streaming completion request :%w", err)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
///////////
resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating completion request: +%w", err)
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,34 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so)
if err != nil {
return err
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,35 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so)
if err != nil {
return err
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,48 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/images/create
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating image request: +%w", err)
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -1,69 +0,0 @@
package openai
import (
"regexp"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
)
func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := ml.ListModels()
if err != nil {
return err
}
var mm map[string]interface{} = map[string]interface{}{}
openAIModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
filter := c.Query("filter")
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations
for _, c := range cl.GetAllConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
openAIModels = append(openAIModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
openAIModels = append(openAIModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return c.JSON(struct {
Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"`
}{
Object: "list",
Data: openAIModels,
})
}
}

View file

@ -1,57 +0,0 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readInput(c *fiber.Ctx, o *schema.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
modelFile := input.Model
if c.Params("model") != "" {
modelFile = c.Params("model")
}
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
// Set model from bearer token, if available
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && ml.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" && !bearerExists && randomModel {
models, _ := ml.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", nil, fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
return modelFile, input, nil
}

View file

@ -1,49 +0,0 @@
package openai
import (
"fmt"
"net/http"
"os"
"path"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper
if err != nil {
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
defer os.RemoveAll(path.Dir(dst))
tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating transcription request: +%w", err)
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
}
}

View file

@ -1,24 +0,0 @@
package mqtt
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
)
// PLACEHOLDER DURING PART 1 OF THE REFACTOR
type MQTTManager struct {
configLoader *services.ConfigLoader
modelLoader *model.ModelLoader
startupOptions *schema.StartupOptions
}
func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*MQTTManager, error) {
return &MQTTManager{
configLoader: cl,
modelLoader: ml,
startupOptions: options,
}, nil
}

View file

@ -1,138 +0,0 @@
package services
import (
"context"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitor struct {
configLoader *ConfigLoader
modelLoader *model.ModelLoader
options *schema.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *schema.StartupOptions) *BackendMonitor {
return &BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.modelLoader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &schema.BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetConfig(modelName)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = modelName
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return nil, err
}
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return &proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
}, nil
}
return status, nil
}
func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return err
}
return bm.modelLoader.ShutdownModel(backendId)
}

View file

@ -1,157 +0,0 @@
package services
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type ConfigLoader struct {
configs map[string]schema.Config
sync.Mutex
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]schema.Config),
}
}
// TODO: check this is correct post-merge
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := schema.ReadSingleConfigFile(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm.configs[c.Name] = *c
return nil
}
func (cm *ConfigLoader) GetConfig(m string) (schema.Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm *ConfigLoader) GetAllConfigs() []schema.Config {
cm.Lock()
defer cm.Unlock()
var res []schema.Config
for _, v := range cm.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
func (cm *ConfigLoader) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := schema.ReadSingleConfigFile(filepath.Join(path, file.Name()))
if err == nil {
cm.configs[c.Name] = *c
}
}
return nil
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
for _, config := range cm.configs {
// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}
modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)
if utils.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
}
}
return nil
}
func (cl *ConfigLoader) LoadConfigFile(file string) error {
cl.Lock()
defer cl.Unlock()
c, err := schema.ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cl.configs[cc.Name] = *cc
}
return nil
}

View file

@ -1,160 +0,0 @@
package services
import (
"context"
"encoding/json"
"os"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)
type GalleryApplier struct {
modelPath string
sync.Mutex
C chan gallery.GalleryOp
statuses map[string]*gallery.GalleryOpStatus
}
func NewGalleryApplier(modelPath string) *GalleryApplier {
return &GalleryApplier{
modelPath: modelPath,
C: make(chan gallery.GalleryOp),
statuses: make(map[string]*gallery.GalleryOpStatus),
}
}
func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else {
err = PrepareModel(g.modelPath, op.Req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetInstallableModelFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}

View file

@ -1,29 +0,0 @@
package services
import (
"github.com/go-skynet/LocalAI/pkg/schema"
"go.opentelemetry.io/otel/exporters/prometheus"
api "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
)
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
// If it does not return an error, make sure to call shutdown for proper cleanup.
func SetupMetrics() (*schema.LocalAIMetrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
meter := provider.Meter("github.com/go-skynet/LocalAI")
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
if err != nil {
return nil, err
}
return &schema.LocalAIMetrics{
Meter: meter,
ApiTimeMetric: apiTimeMetric,
}, nil
}

View file

@ -1,100 +0,0 @@
package startup
import (
"encoding/json"
"fmt"
"os"
"path"
"github.com/fsnotify/fsnotify"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
)
type WatchConfigDirectoryCloser func() error
func ReadApiKeysJson(configDir string, options *schema.StartupOptions) error {
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err == nil {
options.ApiKeys = append(options.ApiKeys, fileKeys...)
return nil
}
return err
}
return err
}
func ReadExternalBackendsJson(configDir string, options *schema.StartupOptions) error {
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
if err != nil {
return err
}
// Parse JSON content from the file
var fileBackends map[string]string
err = json.Unmarshal(fileContent, &fileBackends)
if err != nil {
return err
}
err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends)
if err != nil {
return err
}
return nil
}
var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *schema.StartupOptions) error{
"api_keys.json": ReadApiKeysJson,
"external_backends.json": ReadExternalBackendsJson,
}
func WatchConfigDirectory(configDir string, options *schema.StartupOptions) (WatchConfigDirectoryCloser, error) {
if len(configDir) == 0 {
return nil, fmt.Errorf("configDir blank")
}
configWatcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
}
ret := func() error {
configWatcher.Close()
return nil
}
// Start listening for events.
go func() {
for {
select {
case event, ok := <-configWatcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
for targetName, watchFn := range CONFIG_FILE_UPDATES {
if event.Name == targetName {
err := watchFn(configDir, options)
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
}
}
}
case _, ok := <-configWatcher.Errors:
if !ok {
return
}
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
}
}
}()
// Add a path.
err = configWatcher.Add(configDir)
if err != nil {
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
}
return ret, nil
}

View file

@ -1,93 +0,0 @@
package startup
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...schema.AppOption) (*services.ConfigLoader, *model.ModelLoader, *schema.StartupOptions, error) {
options := schema.NewStartupOptions(opts...)
ml := model.NewModelLoader(options.ModelPath)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cl := services.NewConfigLoader()
if err := cl.LoadConfigs(options.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
ml.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
ml,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
ml.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return cl, ml, options, nil
}