Revert "[Refactor]: Core/API Split" (#1550)

Revert "[Refactor]: Core/API Split (#1506)"

This reverts commit ab7b4d5ee9.
This commit is contained in:
Ettore Di Giacinto 2024-01-05 12:04:46 -05:00 committed by GitHub
parent ab7b4d5ee9
commit db926896bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
77 changed files with 3132 additions and 3456 deletions

View file

@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
applyModel := func(model *GalleryModel) error {
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
var config InstallableModel
var config Config
if len(model.URL) > 0 {
var err error
config, err = GetInstallableModelFromURL(model.URL)
config, err = GetGalleryConfigFromURL(model.URL)
if err != nil {
return err
}
@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
if err != nil {
return err
}
config = InstallableModel{
config = Config{
ConfigFile: string(reYamlConfig),
Description: model.Description,
License: model.License,

View file

@ -1,9 +1,13 @@
package gallery
import (
"crypto/sha256"
"fmt"
"hash"
"io"
"os"
"path/filepath"
"strconv"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
@ -37,9 +41,9 @@ prompt_templates:
content: ""
*/
// InstallableModel is the model configuration which contains all the model details
// Config is the model configuration which contains all the model details
// This configuration is read from the gallery endpoint and is used to download and install the model
type InstallableModel struct {
type Config struct {
Description string `yaml:"description"`
License string `yaml:"license"`
URLs []string `yaml:"urls"`
@ -60,8 +64,8 @@ type PromptTemplate struct {
Content string `yaml:"content"`
}
func GetInstallableModelFromURL(url string) (InstallableModel, error) {
var config InstallableModel
func GetGalleryConfigFromURL(url string) (Config, error) {
var config Config
err := utils.GetURI(url, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
@ -72,7 +76,7 @@ func GetInstallableModelFromURL(url string) (InstallableModel, error) {
return config, nil
}
func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
func ReadConfigFile(filePath string) (*Config, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
if err != nil {
@ -80,7 +84,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
}
// Unmarshal YAML data into a Config struct
var config InstallableModel
var config Config
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
@ -89,7 +93,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
return &config, nil
}
func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
if err != nil {
@ -179,3 +183,54 @@ func InstallModel(basePath, nameOverride string, config *InstallableModel, confi
return nil
}
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

View file

@ -16,7 +16,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
@ -103,7 +103,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
@ -129,7 +129,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})

View file

@ -1,18 +0,0 @@
package gallery
type GalleryOp struct {
Req GalleryModel
Id string
Galleries []Gallery
GalleryName string
}
type GalleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}

View file

@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
e, err := GetInstallableModelFromURL(req.URL)
e, err := GetGalleryConfigFromURL(req.URL)
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})

View file

@ -6,8 +6,8 @@ import (
"fmt"
"os"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
@ -53,9 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented")
}
// TODO CHECK THIS
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) {
return schema.WhisperResult{}, fmt.Errorf("unimplemented")
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
return schema.Result{}, fmt.Errorf("unimplemented")
}
func (llm *Base) TTS(*pb.TTSRequest) error {

View file

@ -7,8 +7,8 @@ import (
"sync"
"time"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...)
}
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) {
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
if err != nil {
return nil, err
}
tresult := &schema.WhisperResult{}
tresult := &schema.Result{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tresult.Segments = append(tresult.Segments,
schema.WhisperSegment{
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),

View file

@ -1,8 +1,8 @@
package grpc
import (
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
)
type LLM interface {
@ -15,7 +15,7 @@ type LLM interface {
Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error)
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v4.26.0
// protoc-gen-go v1.28.1
// protoc v3.6.1
// source: backend.proto
package proto

View file

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v4.26.0
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.6.1
// source: backend.proto
package proto
@ -18,19 +18,6 @@ import (
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
)
// BackendClient is the client API for Backend service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -57,7 +44,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
if err != nil {
return nil, err
}
@ -66,7 +53,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
if err != nil {
return nil, err
}
@ -75,7 +62,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
if err != nil {
return nil, err
}
@ -83,7 +70,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
}
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
if err != nil {
return nil, err
}
@ -116,7 +103,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
if err != nil {
return nil, err
}
@ -125,7 +112,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
if err != nil {
return nil, err
}
@ -134,7 +121,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult)
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
if err != nil {
return nil, err
}
@ -143,7 +130,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
if err != nil {
return nil, err
}
@ -152,7 +139,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
if err != nil {
return nil, err
}
@ -161,7 +148,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse)
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
if err != nil {
return nil, err
}
@ -242,7 +229,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Health_FullMethodName,
FullMethod: "/backend.Backend/Health",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -260,7 +247,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Predict_FullMethodName,
FullMethod: "/backend.Backend/Predict",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -278,7 +265,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_LoadModel_FullMethodName,
FullMethod: "/backend.Backend/LoadModel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -317,7 +304,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Embedding_FullMethodName,
FullMethod: "/backend.Backend/Embedding",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -335,7 +322,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_GenerateImage_FullMethodName,
FullMethod: "/backend.Backend/GenerateImage",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -353,7 +340,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_AudioTranscription_FullMethodName,
FullMethod: "/backend.Backend/AudioTranscription",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -371,7 +358,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_TTS_FullMethodName,
FullMethod: "/backend.Backend/TTS",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -389,7 +376,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_TokenizeString_FullMethodName,
FullMethod: "/backend.Backend/TokenizeString",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
@ -407,7 +394,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_Status_FullMethodName,
FullMethod: "/backend.Backend/Status",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))

View file

@ -8,7 +8,7 @@ import (
"strings"
"time"
"github.com/go-skynet/LocalAI/pkg/grpc"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/hashicorp/go-multierror"
"github.com/phayes/freeport"
"github.com/rs/zerolog/log"
@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{
// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) {
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
return func(modelName, modelFile string) (ModelAddress, error) {
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)

View file

@ -10,7 +10,7 @@ import (
"sync"
"text/template"
"github.com/go-skynet/LocalAI/pkg/grammar"
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/grpc"
process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log"

View file

@ -6,7 +6,7 @@ import (
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
)
type ModelOptions struct {
type Options struct {
backendString string
model string
threads uint32
@ -23,14 +23,14 @@ type ModelOptions struct {
parallelRequests bool
}
type Option func(*ModelOptions)
type Option func(*Options)
var EnableParallelRequests = func(o *ModelOptions) {
var EnableParallelRequests = func(o *Options) {
o.parallelRequests = true
}
func WithExternalBackend(name string, uri string) Option {
return func(o *ModelOptions) {
return func(o *Options) {
if o.externalBackends == nil {
o.externalBackends = make(map[string]string)
}
@ -38,81 +38,62 @@ func WithExternalBackend(name string, uri string) Option {
}
}
// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates
func WithExternalBackends(backends map[string]string, overwrite bool) Option {
return func(o *ModelOptions) {
if backends == nil {
return
}
if o.externalBackends == nil {
o.externalBackends = backends
return
}
for name, url := range backends {
_, exists := o.externalBackends[name]
if !exists || overwrite {
o.externalBackends[name] = url
}
}
}
}
func WithGRPCAttempts(attempts int) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.grpcAttempts = attempts
}
}
func WithGRPCAttemptsDelay(delay int) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.grpcAttemptsDelay = delay
}
}
func WithBackendString(backend string) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.backendString = backend
}
}
func WithModel(modelFile string) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.model = modelFile
}
}
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.gRPCOptions = opts
}
}
func WithThreads(threads uint32) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.threads = threads
}
}
func WithAssetDir(assetDir string) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.assetDir = assetDir
}
}
func WithContext(ctx context.Context) Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.context = ctx
}
}
func WithSingleActiveBackend() Option {
return func(o *ModelOptions) {
return func(o *Options) {
o.singleActiveBackend = true
}
}
func NewOptions(opts ...Option) *ModelOptions {
o := &ModelOptions{
func NewOptions(opts ...Option) *Options {
o := &Options{
gRPCOptions: &pb.ModelOptions{},
context: context.Background(),
grpcAttempts: 20,

View file

@ -1,400 +0,0 @@
package schema
import (
"encoding/json"
"fmt"
"os"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v3"
)
type Config struct {
PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
FunctionsConfig Functions `yaml:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline"`
// AutoGPTQ specifics
AutoGPTQ AutoGPTQ `yaml:"autogptq"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers"`
Step int `yaml:"step"`
// GRPC Options
GRPC GRPC `yaml:"grpc"`
// Vall-e-x
VallE VallE `yaml:"vall-e"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`
DownloadFiles []File `yaml:"download_files"`
}
type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}
type VallE struct {
AudioPath string `yaml:"audio_path"`
}
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
v, exist := ff[s]
return exist && v != nil && *v
}
type GRPC struct {
Attempts int `yaml:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
}
type Diffusers struct {
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
}
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt"`
TensorSplit string `yaml:"tensor_split"`
MainGPU string `yaml:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
MMap bool `yaml:"mmap"`
MMlock bool `yaml:"mmlock"`
LowVRAM bool `yaml:"low_vram"`
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
ContextSize int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
MMProj string `yaml:"mmproj"`
RopeScaling string `yaml:"rope_scaling"`
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
}
type AutoGPTQ struct {
ModelBaseName string `yaml:"model_base_name"`
Device string `yaml:"device"`
Triton bool `yaml:"triton"`
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
}
type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
}
type TemplateConfig struct {
Chat string `yaml:"chat"`
ChatMessage string `yaml:"chat_message"`
Completion string `yaml:"completion"`
Edit string `yaml:"edit"`
Functions string `yaml:"function"`
}
func (c *Config) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *Config) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *Config) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *Config) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
func (c *Config) FunctionToCall() string {
return c.functionCallNameString
}
func defaultPredictOptions(modelFile string) PredictionOptions {
return PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func DefaultConfig(modelFile string) *Config {
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return *c, nil
}
func ReadSingleConfigFile(file string) (*Config, error) {
c := &Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return c, nil
}
func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
// TODO: check that this was merged correctly? I _think_ it is?
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}

View file

@ -1,51 +0,0 @@
package schema_test
import (
"os"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/schema"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Test cases for config related functions", func() {
var (
configFile string
)
Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() {
config, err := schema.ReadConfigFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
Expect(config[0].Name).To(Equal("list1"))
Expect(config[1].Name).To(Equal("list2"))
})
It("Test LoadConfigs", func() {
cm := services.NewConfigLoader()
err := cm.LoadConfigs(os.Getenv("MODELS_PATH"))
Expect(err).To(BeNil())
Expect(cm.ListConfigs()).ToNot(BeNil())
// config should includes gpt4all models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
// config should includes gpt2 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
// config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
// config should includes rwkv_test models's api.config
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
// config should includes whisper-1 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
})
})
})

View file

@ -1,39 +0,0 @@
package schema
import (
"context"
gopsutil "github.com/shirou/gopsutil/v3/process"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
type LocalAIMetrics struct {
Meter metric.Meter
ApiTimeMetric metric.Float64Histogram
}
func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) {
opts := metric.WithAttributes(
attribute.String("method", method),
attribute.String("path", path),
)
m.ApiTimeMetric.Record(context.Background(), duration, opts)
}

View file

@ -1,133 +0,0 @@
package schema
import (
"context"
"github.com/go-skynet/LocalAI/pkg/grammar"
)
// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
}
type ErrorResponse struct {
Error *APIError `json:"error,omitempty"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Item struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object,omitempty"`
// Images
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
Data []Item `json:"data,omitempty"`
Usage OpenAIUsage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
}
type Content struct {
Type string `json:"type" yaml:"type"`
Text string `json:"text" yaml:"text"`
ImageURL ContentURL `json:"image_url" yaml:"image_url"`
}
type ContentURL struct {
URL string `json:"url" yaml:"url"`
}
type Message struct {
// The message role
Role string `json:"role,omitempty" yaml:"role"`
// The message content
Content interface{} `json:"content" yaml:"content"`
StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"`
StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"`
// A result of a function call
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
type ChatCompletionResponseFormatType string
type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
}
type OpenAIRequest struct {
PredictionOptions
Context context.Context
Cancel context.CancelFunc
// whisper
File string `json:"file" validate:"required"`
//whisper/image
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
// image
Size string `json:"size"`
// Prompt is read only by completion/image API calls
Prompt interface{} `json:"prompt" yaml:"prompt"`
// Edit endpoint
Instruction string `json:"instruction" yaml:"instruction"`
Input interface{} `json:"input" yaml:"input"`
Stop interface{} `json:"stop" yaml:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages" yaml:"messages"`
// A list of available functions to call
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Stream bool `json:"stream"`
// Image (not supported by OpenAI)
Mode int `json:"mode"`
Step int `json:"step"`
// A grammar to constrain the LLM output
Grammar string `json:"grammar" yaml:"grammar"`
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
Backend string `json:"backend" yaml:"backend"`
// AutoGPTQ
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
}

View file

@ -1,50 +0,0 @@
package schema
type PredictionOptions struct {
// Also part of the OpenAI official spec
Model string `json:"model" yaml:"model"`
// Also part of the OpenAI official spec
Language string `json:"language"`
// Also part of the OpenAI official spec. use it for returning multiple results
N int `json:"n"`
// Common options between all the API calls, part of the OpenAI spec
TopP float64 `json:"top_p" yaml:"top_p"`
TopK int `json:"top_k" yaml:"top_k"`
Temperature float64 `json:"temperature" yaml:"temperature"`
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
Echo bool `json:"echo"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch" yaml:"batch"`
F16 bool `json:"f16" yaml:"f16"`
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
Keep int `json:"n_keep" yaml:"n_keep"`
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
Mirostat int `json:"mirostat" yaml:"mirostat"`
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
TFZ float64 `json:"tfz" yaml:"tfz"`
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
Seed int `json:"seed" yaml:"seed"`
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
RopeFreqScale float32 `json:"rope_freq_scale" yaml:"rope_freq_scale"`
NegativePromptScale float32 `json:"negative_prompt_scale" yaml:"negative_prompt_scale"`
// AutoGPTQ
UseFastTokenizer bool `json:"use_fast_tokenizer" yaml:"use_fast_tokenizer"`
// Diffusers
ClipSkip int `json:"clip_skip" yaml:"clip_skip"`
// RWKV (?)
Tokenizer string `json:"tokenizer" yaml:"tokenizer"`
}

View file

@ -1,260 +0,0 @@
package schema
import (
"context"
"embed"
"encoding/json"
"time"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/rs/zerolog/log"
)
type StartupOptions struct {
Context context.Context
ConfigFile string
ModelPath string
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug, DisableMessage bool
ImageDir string
AudioDir string
CORS bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
Metrics *LocalAIMetrics
Galleries []gallery.Gallery
BackendAssets embed.FS
AssetsDestination string
ExternalGRPCBackends map[string]string
AutoloadGalleries bool
SingleBackend bool
ParallelBackendRequests bool
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
LocalAIConfigDir string
}
type AppOption func(*StartupOptions)
func NewStartupOptions(o ...AppOption) *StartupOptions {
opt := &StartupOptions{
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
ContextSize: 512,
Debug: true,
DisableMessage: true,
}
for _, oo := range o {
oo(opt)
}
return opt
}
func WithModelsURL(urls ...string) AppOption {
return func(o *StartupOptions) {
o.ModelsURL = urls
}
}
func WithCors(b bool) AppOption {
return func(o *StartupOptions) {
o.CORS = b
}
}
var EnableWatchDog = func(o *StartupOptions) {
o.WatchDog = true
}
var EnableWatchDogIdleCheck = func(o *StartupOptions) {
o.WatchDog = true
o.WatchDogIdle = true
}
var EnableWatchDogBusyCheck = func(o *StartupOptions) {
o.WatchDog = true
o.WatchDogBusy = true
}
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *StartupOptions) {
o.WatchDogBusyTimeout = t
}
}
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *StartupOptions) {
o.WatchDogIdleTimeout = t
}
}
var EnableSingleBackend = func(o *StartupOptions) {
o.SingleBackend = true
}
var EnableParallelBackendRequests = func(o *StartupOptions) {
o.ParallelBackendRequests = true
}
var EnableGalleriesAutoload = func(o *StartupOptions) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *StartupOptions) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
o.ExternalGRPCBackends[name] = uri
}
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *StartupOptions) {
o.CORSAllowOrigins = b
}
}
func WithBackendAssetsOutput(out string) AppOption {
return func(o *StartupOptions) {
o.AssetsDestination = out
}
}
func WithBackendAssets(f embed.FS) AppOption {
return func(o *StartupOptions) {
o.BackendAssets = f
}
}
func WithStringGalleries(galls string) AppOption {
return func(o *StartupOptions) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{}
return
}
var galleries []gallery.Gallery
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
log.Error().Msgf("failed loading galleries: %s", err.Error())
}
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *StartupOptions) {
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *StartupOptions) {
o.Context = ctx
}
}
func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *StartupOptions) {
o.PreloadModelsFromPath = configFile
}
}
func WithJSONStringPreload(configFile string) AppOption {
return func(o *StartupOptions) {
o.PreloadJSONModels = configFile
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *StartupOptions) {
o.ConfigFile = configFile
}
}
func WithModelPath(path string) AppOption {
return func(o *StartupOptions) {
o.ModelPath = path
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *StartupOptions) {
o.UploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *StartupOptions) {
o.Threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *StartupOptions) {
o.ContextSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *StartupOptions) {
o.F16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *StartupOptions) {
o.Debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *StartupOptions) {
o.DisableMessage = disableMessage
}
}
func WithAudioDir(audioDir string) AppOption {
return func(o *StartupOptions) {
o.AudioDir = audioDir
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *StartupOptions) {
o.ImageDir = imageDir
}
}
func WithApiKeys(apiKeys []string) AppOption {
return func(o *StartupOptions) {
o.ApiKeys = apiKeys
}
}
func WithMetrics(metrics *LocalAIMetrics) AppOption {
return func(o *StartupOptions) {
o.Metrics = metrics
}
}
func WithLocalAIConfigDir(configDir string) AppOption {
return func(o *StartupOptions) {
o.LocalAIConfigDir = configDir
}
}

View file

@ -1,16 +0,0 @@
package schema
import "time"
type WhisperSegment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type WhisperResult struct {
Segments []WhisperSegment `json:"segments"`
Text string `json:"text"`
}

View file

@ -1,81 +0,0 @@
package utils
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"github.com/rs/zerolog/log"
)
func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) {
f, err := file.Open()
if err != nil {
return "", err
}
defer f.Close()
// Create a temporary file in the requested directory:
outputFile, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer outputFile.Close()
if _, err := io.Copy(outputFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err)
return "", err
}
return outputFile.Name(), nil
}
func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) {
if len(base64data) == 0 {
return "", fmt.Errorf("base64data empty?")
}
//base 64 decode the file and write it somewhere
// that we will cleanup
decoded, err := base64.StdEncoding.DecodeString(base64data)
if err != nil {
return "", err
}
// Create a temporary file in the requested directory:
outputFile, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer outputFile.Close()
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(decoded)
if err != nil {
return "", err
}
return outputFile.Name(), nil
}
func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}

View file

@ -3,38 +3,18 @@ package utils
import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/rs/zerolog/log"
)
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func getRecognizedURIPrefixes() []string {
return []string{
HuggingFacePrefix,
HTTPPrefix,
HTTPSPrefix,
GithubURI,
GithubURI2,
}
}
func GetURI(url string, f func(url string, i []byte) error) error {
url = ConvertURL(url)
@ -72,8 +52,20 @@ func GetURI(url string, f func(url string, i []byte) error) error {
return f(url, body)
}
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func LooksLikeURL(s string) bool {
return slices.Contains(getRecognizedURIPrefixes(), s)
return strings.HasPrefix(s, HTTPPrefix) ||
strings.HasPrefix(s, HTTPSPrefix) ||
strings.HasPrefix(s, HuggingFacePrefix) ||
strings.HasPrefix(s, GithubURI) ||
strings.HasPrefix(s, GithubURI2)
}
func ConvertURL(s string) string {
@ -249,37 +241,6 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return nil
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func GetBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
type progressWriter struct {
fileName string
total int64