MQTT Startup Refactoring Part 1: core/ packages part 1 (#1728)

This PR specifically introduces a `core` folder and moves the following packages over, without any other changes:

- `api/backend`
- `api/config`
- `api/options`
- `api/schema`

Once this is merged and we confirm there's no regressions, I can migrate over the remaining changes piece by piece to split up application startup, backend services, http, and mqtt as was the goal of the earlier PRs!
This commit is contained in:
Dave 2024-02-20 20:21:19 -05:00 committed by GitHub
parent 594eb468df
commit 255748bcba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 93 additions and 90 deletions

View file

@ -0,0 +1,92 @@
package backend
import (
"fmt"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel interface{}
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(c.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
var fn func() ([]float32, error)
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
for _, t := range tokens {
embeds = append(embeds, int32(t))
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(o.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
predictOptions.Embeddings = s
res, err := model.Embeddings(o.Context, predictOptions)
if err != nil {
return nil, err
}
return res.Embeddings, nil
}
default:
fn = func() ([]float32, error) {
return nil, fmt.Errorf("embeddings not supported by the backend")
}
}
return func() ([]float32, error) {
embeds, err := fn()
if err != nil {
return embeds, err
}
// Remove trailing 0s
for i := len(embeds) - 1; i >= 0; i-- {
if embeds[i] == 0.0 {
embeds = embeds[:i]
} else {
break
}
}
return embeds, nil
}, nil
}

61
core/backend/image.go Normal file
View file

@ -0,0 +1,61 @@
package backend
import (
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}

167
core/backend/llm.go Normal file
View file

@ -0,0 +1,167 @@
package backend
import (
"context"
"os"
"regexp"
"strings"
"sync"
"unicode/utf8"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
type TokenUsage struct {
Prompt int
Completion int
}
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel grpc.Backend
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Images = images
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
if r == utf8.RuneError {
// incomplete rune, wait for more bytes
break
}
tokenCallback(string(r), tokenUsage)
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return LLMResponse{}, err
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.Config, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
for _, c := range config.TrimSuffix {
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
}
return prediction
}

129
core/backend/options.go Normal file
View file

@ -0,0 +1,129 @@
package backend
import (
"os"
"path/filepath"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
)
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
if o.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
if o.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
return opts
}
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
return &pb.ModelOptions{
ContextSize: int32(c.ContextSize),
Seed: int32(c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
CUDA: c.CUDA, // diffusers, transformers
DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath,
Quantization: c.Quantization,
MMProj: c.MMProj,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
YarnBetaFast: c.YarnBetaFast,
YarnBetaSlow: c.YarnBetaSlow,
LoraAdapter: c.LoraAdapter,
LoraBase: c.LoraBase,
LoraScale: c.LoraScale,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
F16Memory: c.F16,
MLock: c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: c.Embeddings,
LowVRAM: c.LowVRAM,
NGPULayers: int32(c.NGPULayers),
MMap: c.MMap,
MainGPU: c.MainGPU,
Threads: int32(c.Threads),
TensorSplit: c.TensorSplit,
// AutoGPTQ
ModelBaseName: c.AutoGPTQ.ModelBaseName,
Device: c.AutoGPTQ.Device,
UseTriton: c.AutoGPTQ.Triton,
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
// RWKV
Tokenizer: c.Tokenizer,
}
}
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}
return &pb.PredictOptions{
Temperature: float32(c.Temperature),
TopP: float32(c.TopP),
NDraft: c.NDraft,
TopK: int32(c.TopK),
Tokens: int32(c.Maxtokens),
Threads: int32(c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: c.F16,
DebugMode: c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(c.LLMConfig.Mirostat),
MirostatETA: float32(c.LLMConfig.MirostatETA),
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
Debug: c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(c.Seed),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: c.MMlock,
MMap: c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(c.TFZ),
TypicalP: float32(c.TypicalP),
}
}

View file

@ -0,0 +1,39 @@
package backend
import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
})
whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}

83
core/backend/tts.go Normal file
View file

@ -0,0 +1,83 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(c)
opts := modelOpts(config.Config{}, o, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
piperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return "", nil, err
}
} else {
modelPath = modelFile
}
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Dst: filePath,
})
return filePath, res, err
}

429
core/config/config.go Normal file
View file

@ -0,0 +1,429 @@
package config
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type Config struct {
PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
FunctionsConfig Functions `yaml:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline"`
// AutoGPTQ specifics
AutoGPTQ AutoGPTQ `yaml:"autogptq"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers"`
Step int `yaml:"step"`
// GRPC Options
GRPC GRPC `yaml:"grpc"`
// Vall-e-x
VallE VallE `yaml:"vall-e"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`
DownloadFiles []File `yaml:"download_files"`
Description string `yaml:"description"`
Usage string `yaml:"usage"`
}
type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}
type VallE struct {
AudioPath string `yaml:"audio_path"`
}
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
v, exist := ff[s]
return exist && v != nil && *v
}
type GRPC struct {
Attempts int `yaml:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
}
type Diffusers struct {
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
}
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt"`
TensorSplit string `yaml:"tensor_split"`
MainGPU string `yaml:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
MMap bool `yaml:"mmap"`
MMlock bool `yaml:"mmlock"`
LowVRAM bool `yaml:"low_vram"`
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
ContextSize int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
MMProj string `yaml:"mmproj"`
RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
}
type AutoGPTQ struct {
ModelBaseName string `yaml:"model_base_name"`
Device string `yaml:"device"`
Triton bool `yaml:"triton"`
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
}
type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}
type TemplateConfig struct {
Chat string `yaml:"chat"`
ChatMessage string `yaml:"chat_message"`
Completion string `yaml:"completion"`
Edit string `yaml:"edit"`
Functions string `yaml:"function"`
}
type ConfigLoader struct {
configs map[string]Config
sync.Mutex
}
func (c *Config) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *Config) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *Config) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *Config) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
func (c *Config) FunctionToCall() string {
return c.functionCallNameString
}
// Load a config file for a model
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml")
var cfg *Config
defaults := func() {
cfg = DefaultConfig(modelName)
cfg.ContextSize = ctx
cfg.Threads = threads
cfg.F16 = f16
cfg.Debug = debug
}
cfgExisting, exists := cm.GetConfig(modelName)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
defaults()
}
} else {
defaults()
}
} else {
cfg = &cfgExisting
}
// Set the parameters for the language model prediction
//updateConfig(cfg, input)
// Don't allow 0 as setting
if cfg.Threads == 0 {
if threads != 0 {
cfg.Threads = threads
} else {
cfg.Threads = 4
}
}
// Enforce debug flag if passed from CLI
if debug {
cfg.Debug = true
}
return cfg, nil
}
func defaultPredictOptions(modelFile string) PredictionOptions {
return PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func DefaultConfig(modelFile string) *Config {
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return *c, nil
}
func ReadConfig(file string) (*Config, error) {
c := &Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return c, nil
}
func (cm *ConfigLoader) LoadConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm.configs[cc.Name] = *cc
}
return nil
}
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm.configs[c.Name] = *c
return nil
}
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm *ConfigLoader) GetAllConfigs() []Config {
cm.Lock()
defer cm.Unlock()
var res []Config
for _, v := range cm.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
for i, config := range cm.configs {
// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}
modelURL := config.PredictionOptions.Model
modelURL = downloader.ConvertURL(modelURL)
if downloader.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
cc := cm.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
cm.configs[i] = *c
}
if cm.configs[i].Name != "" {
log.Info().Msgf("Model name: %s", cm.configs[i].Name)
}
if cm.configs[i].Description != "" {
log.Info().Msgf("Model description: %s", cm.configs[i].Description)
}
if cm.configs[i].Usage != "" {
log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage)
}
}
return nil
}
func (cm *ConfigLoader) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
if err == nil {
cm.configs[c.Name] = *c
}
}
return nil
}

View file

@ -0,0 +1,56 @@
package config_test
import (
"os"
. "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Test cases for config related functions", func() {
var (
configFile string
)
Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() {
config, err := ReadConfigFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
Expect(config[0].Name).To(Equal("list1"))
Expect(config[1].Name).To(Equal("list2"))
})
It("Test LoadConfigs", func() {
cm := NewConfigLoader()
opts := options.NewOptions()
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
options.WithModelLoader(modelLoader)(opts)
err := cm.LoadConfigs(opts.Loader.ModelPath)
Expect(err).To(BeNil())
Expect(cm.ListConfigs()).ToNot(BeNil())
// config should includes gpt4all models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
// config should includes gpt2 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
// config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
// config should includes rwkv_test models's api.config
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
// config should includes whisper-1 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
})
})
})

50
core/config/prediction.go Normal file
View file

@ -0,0 +1,50 @@
package config
type PredictionOptions struct {
// Also part of the OpenAI official spec
Model string `json:"model" yaml:"model"`
// Also part of the OpenAI official spec
Language string `json:"language"`
// Also part of the OpenAI official spec. use it for returning multiple results
N int `json:"n"`
// Common options between all the API calls, part of the OpenAI spec
TopP float64 `json:"top_p" yaml:"top_p"`
TopK int `json:"top_k" yaml:"top_k"`
Temperature float64 `json:"temperature" yaml:"temperature"`
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
Echo bool `json:"echo"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch" yaml:"batch"`
F16 bool `json:"f16" yaml:"f16"`
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
Keep int `json:"n_keep" yaml:"n_keep"`
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
Mirostat int `json:"mirostat" yaml:"mirostat"`
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
TFZ float64 `json:"tfz" yaml:"tfz"`
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
Seed int `json:"seed" yaml:"seed"`
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
RopeFreqScale float32 `json:"rope_freq_scale" yaml:"rope_freq_scale"`
NegativePromptScale float32 `json:"negative_prompt_scale" yaml:"negative_prompt_scale"`
// AutoGPTQ
UseFastTokenizer bool `json:"use_fast_tokenizer" yaml:"use_fast_tokenizer"`
// Diffusers
ClipSkip int `json:"clip_skip" yaml:"clip_skip"`
// RWKV (?)
Tokenizer string `json:"tokenizer" yaml:"tokenizer"`
}

308
core/http/api.go Normal file
View file

@ -0,0 +1,308 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"github.com/go-skynet/LocalAI/api/localai"
"github.com/go-skynet/LocalAI/api/openai"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/startup"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if options.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
if !options.Debug {
app.Use(recover.New())
}
if options.Metrics != nil {
app.Use(metrics.APIMiddleware(options.Metrics))
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 {
return c.Next()
}
// Check for api_keys.json file
fileContent, err := os.ReadFile("api_keys.json")
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
}
// Add file keys to options.ApiKeys
options.ApiKeys = append(options.ApiKeys, fileKeys...)
}
if len(options.ApiKeys) == 0 {
return c.Next()
}
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
// Make sure directories exists
os.MkdirAll(options.ImageDir, 0755)
os.MkdirAll(options.AudioDir, 0755)
os.MkdirAll(options.UploadDir, 0755)
os.MkdirAll(options.Loader.ModelPath, 0755)
// Load upload json
openai.LoadUploadConfig(options.UploadDir)
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/metrics", metrics.MetricsHandler())
return app, nil
}

870
core/http/api_test.go Normal file
View file

@ -0,0 +1,870 @@
package http_test
import (
"bytes"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
. "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can.
### User:
Can you help rephrasing sentences?
### Response:`
type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"`
Name string `json:"name"`
Overrides map[string]interface{} `json:"overrides"`
}
func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
fmt.Println("Error creating request:", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
func getModels(url string) (response []gallery.GalleryModel) {
downloader.GetURI(url, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
//url := "http://localhost:AI/models/apply"
// Create the request payload
payload, err := json.Marshal(request)
if err != nil {
fmt.Println("Error marshaling JSON:", err)
return
}
// Create the HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
fmt.Println("Error creating request:", err)
return
}
req.Header.Set("Content-Type", "application/json")
// Make the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error making request:", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}
// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}
//go:embed backend-assets/*
var backendAssets embed.FS
var _ = Describe("API test", func() {
var app *fiber.App
var modelLoader *model.ModelLoader
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
var cancel context.CancelFunc
var tmpdir string
commonOpts := []options.AppOption{
options.WithDebug(true),
options.WithDisableMessage(true),
}
Context("API with ephemeral models", func() {
BeforeEach(func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{
{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
},
{
Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
},
}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = os.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
},
}
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts,
options.WithMetrics(metricsService),
options.WithContext(c),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
Context("Applying models", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
resp := map[string]interface{}{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
models = getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
for _, m := range models {
if m.Name == "bert2" {
Expect(m.Installed).To(BeTrue())
} else {
Expect(m.Installed).To(BeFalse())
}
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert",
Overrides: map[string]interface{}{
"backend": "llama",
},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("llama"))
})
It("apply models without overrides", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert",
Overrides: map[string]interface{}{},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
})
It("runs openllama(llama-ggml backend)", Label("llama"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
Name: "openllama_3b",
Overrides: map[string]interface{}{"backend": "llama-ggml", "mmap": true, "f16": true, "context_size": 128},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
By("testing completion")
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
By("testing functions")
resp2, err := client.CreateChatCompletion(
context.TODO(),
openai.ChatCompletionRequest{
Model: "openllama_3b",
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "What is the weather like in San Francisco (celsius)?",
},
},
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition{
Name: "get_current_weather",
Description: "Get the current weather",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celcius", "fahrenheit"},
},
},
Required: []string{"location"},
},
},
},
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
})
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
modelName := "codellama"
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml",
Name: modelName,
Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true, "context_size": 128},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
By("testing chat")
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "How much is 2+2?",
},
}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")))
By("testing functions")
resp2, err := client.CreateChatCompletion(
context.TODO(),
openai.ChatCompletionRequest{
Model: modelName,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "What is the weather like in San Francisco (celsius)?",
},
},
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition{
Name: "get_current_weather",
Description: "Get the current weather",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celcius", "fahrenheit"},
},
},
Required: []string{"location"},
},
},
},
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
})
It("runs gpt4all", Label("gpt4all"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
Name: "gpt4all-j",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "960s", "10s").Should(Equal(true))
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
})
})
})
Context("Model gallery", func() {
BeforeEach(func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{
{
Name: "model-gallery",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",
},
}
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts,
options.WithContext(c),
options.WithMetrics(metricsService),
options.WithAudioDir(tmpdir),
options.WithImageDir(tmpdir),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader),
options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
It("installs and is capable to run tts", Label("tts"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@voice-en-us-kathleen-low",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
// An HTTP Post to the /tts endpoint should return a wav audio file
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`)))
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
})
It("installs and is capable to generate images", Label("stablediffusion"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@stablediffusion",
Overrides: map[string]interface{}{
"parameters": map[string]interface{}{"model": "stablediffusion_assets"},
},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
resp, err := http.Post(
"http://127.0.0.1:9090/v1/images/generations",
"application/json",
bytes.NewBuffer([]byte(`{
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
"mode": 2, "seed":9000,
"size": "256x256", "n":2}`)))
// The response should contain an URL
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), string(dat))
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
})
})
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts,
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c),
options.WithModelLoader(modelLoader),
options.WithMetrics(metricsService),
)...)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("returns the models list", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8?
})
It("can generate completions", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: testPrompt})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
It("can generate chat completions ", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate completions from model configs", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: testPrompt})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
It("can generate chat completions from model configs", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("returns errors", func() {
backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
})
It("transcribes audio", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateTranscription(
context.Background(),
openai.AudioRequest{
Model: openai.Whisper1,
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
})
It("calculate embeddings", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred(), err)
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
})
Context("External gRPC calls", func() {
It("calculate embeddings with sentencetransformers", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
})
})
Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})
})
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts,
options.WithContext(c),
options.WithMetrics(metricsService),
options.WithModelLoader(modelLoader),
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("can generate chat completions from config file (list1)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate chat completions from config file (list2)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate edit completions from config file", func() {
request := openaigo.EditCreateRequestBody{
Model: "list2",
Instruction: "foo",
Input: "bar",
}
resp, err := client2.CreateEdit(context.Background(), request)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
})
})

View file

@ -0,0 +1,13 @@
package http_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI test suite")
}

269
core/options/options.go Normal file
View file

@ -0,0 +1,269 @@
package options
import (
"context"
"embed"
"encoding/json"
"time"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type Option struct {
Context context.Context
ConfigFile string
Loader *model.ModelLoader
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug, DisableMessage bool
ImageDir string
AudioDir string
UploadDir string
CORS bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
Metrics *metrics.Metrics
ModelLibraryURL string
Galleries []gallery.Gallery
BackendAssets embed.FS
AssetsDestination string
ExternalGRPCBackends map[string]string
AutoloadGalleries bool
SingleBackend bool
ParallelBackendRequests bool
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
}
type AppOption func(*Option)
func NewOptions(o ...AppOption) *Option {
opt := &Option{
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
ContextSize: 512,
Debug: true,
DisableMessage: true,
}
for _, oo := range o {
oo(opt)
}
return opt
}
func WithModelsURL(urls ...string) AppOption {
return func(o *Option) {
o.ModelsURL = urls
}
}
func WithCors(b bool) AppOption {
return func(o *Option) {
o.CORS = b
}
}
func WithModelLibraryURL(url string) AppOption {
return func(o *Option) {
o.ModelLibraryURL = url
}
}
var EnableWatchDog = func(o *Option) {
o.WatchDog = true
}
var EnableWatchDogIdleCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogIdle = true
}
var EnableWatchDogBusyCheck = func(o *Option) {
o.WatchDog = true
o.WatchDogBusy = true
}
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogBusyTimeout = t
}
}
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) {
o.WatchDogIdleTimeout = t
}
}
var EnableSingleBackend = func(o *Option) {
o.SingleBackend = true
}
var EnableParallelBackendRequests = func(o *Option) {
o.ParallelBackendRequests = true
}
var EnableGalleriesAutoload = func(o *Option) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
o.ExternalGRPCBackends[name] = uri
}
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) {
o.CORSAllowOrigins = b
}
}
func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) {
o.AssetsDestination = out
}
}
func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) {
o.BackendAssets = f
}
}
func WithStringGalleries(galls string) AppOption {
return func(o *Option) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{}
return
}
var galleries []gallery.Gallery
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
log.Error().Msgf("failed loading galleries: %s", err.Error())
}
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
o.Context = ctx
}
}
func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *Option) {
o.PreloadModelsFromPath = configFile
}
}
func WithJSONStringPreload(configFile string) AppOption {
return func(o *Option) {
o.PreloadJSONModels = configFile
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *Option) {
o.ConfigFile = configFile
}
}
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.Loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) {
o.UploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *Option) {
o.Threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *Option) {
o.ContextSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *Option) {
o.F16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *Option) {
o.Debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) {
o.DisableMessage = disableMessage
}
}
func WithAudioDir(audioDir string) AppOption {
return func(o *Option) {
o.AudioDir = audioDir
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *Option) {
o.ImageDir = imageDir
}
}
func WithUploadDir(uploadDir string) AppOption {
return func(o *Option) {
o.UploadDir = uploadDir
}
}
func WithApiKeys(apiKeys []string) AppOption {
return func(o *Option) {
o.ApiKeys = apiKeys
}
}
func WithMetrics(meter *metrics.Metrics) AppOption {
return func(o *Option) {
o.Metrics = meter
}
}

156
core/schema/openai.go Normal file
View file

@ -0,0 +1,156 @@
package schema
import (
"context"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grammar"
)
// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
}
type ErrorResponse struct {
Error *APIError `json:"error,omitempty"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Item struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object,omitempty"`
// Images
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
Data []Item `json:"data,omitempty"`
Usage OpenAIUsage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
}
type Content struct {
Type string `json:"type" yaml:"type"`
Text string `json:"text" yaml:"text"`
ImageURL ContentURL `json:"image_url" yaml:"image_url"`
}
type ContentURL struct {
URL string `json:"url" yaml:"url"`
}
type Message struct {
// The message role
Role string `json:"role,omitempty" yaml:"role"`
// The message name (used for tools calls)
Name string `json:"name,omitempty" yaml:"name"`
// The message content
Content interface{} `json:"content" yaml:"content"`
StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"`
StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"`
// A result of a function call
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"`
}
type ToolCall struct {
Index int `json:"index"`
ID string `json:"id"`
Type string `json:"type"`
FunctionCall FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
type ChatCompletionResponseFormatType string
type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
}
type OpenAIRequest struct {
config.PredictionOptions
Context context.Context
Cancel context.CancelFunc
// whisper
File string `json:"file" validate:"required"`
//whisper/image
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
// image
Size string `json:"size"`
// Prompt is read only by completion/image API calls
Prompt interface{} `json:"prompt" yaml:"prompt"`
// Edit endpoint
Instruction string `json:"instruction" yaml:"instruction"`
Input interface{} `json:"input" yaml:"input"`
Stop interface{} `json:"stop" yaml:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages" yaml:"messages"`
// A list of available functions to call
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"`
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`
Stream bool `json:"stream"`
// Image (not supported by OpenAI)
Mode int `json:"mode"`
Step int `json:"step"`
// A grammar to constrain the LLM output
Grammar string `json:"grammar" yaml:"grammar"`
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
Backend string `json:"backend" yaml:"backend"`
// AutoGPTQ
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
}

16
core/schema/whisper.go Normal file
View file

@ -0,0 +1,16 @@
package schema
import "time"
type Segment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type Result struct {
Segments []Segment `json:"segments"`
Text string `json:"text"`
}