mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
Revert "[Refactor]: Core/API Split" (#1550)
Revert "[Refactor]: Core/API Split (#1506)"
This reverts commit ab7b4d5ee9
.
This commit is contained in:
parent
ab7b4d5ee9
commit
db926896bd
77 changed files with 3132 additions and 3456 deletions
|
@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||
applyModel := func(model *GalleryModel) error {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
var config InstallableModel
|
||||
var config Config
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetInstallableModelFromURL(model.URL)
|
||||
config, err = GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config = InstallableModel{
|
||||
config = Config{
|
||||
ConfigFile: string(reYamlConfig),
|
||||
Description: model.Description,
|
||||
License: model.License,
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
package gallery
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/imdario/mergo"
|
||||
|
@ -37,9 +41,9 @@ prompt_templates:
|
|||
content: ""
|
||||
|
||||
*/
|
||||
// InstallableModel is the model configuration which contains all the model details
|
||||
// Config is the model configuration which contains all the model details
|
||||
// This configuration is read from the gallery endpoint and is used to download and install the model
|
||||
type InstallableModel struct {
|
||||
type Config struct {
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license"`
|
||||
URLs []string `yaml:"urls"`
|
||||
|
@ -60,8 +64,8 @@ type PromptTemplate struct {
|
|||
Content string `yaml:"content"`
|
||||
}
|
||||
|
||||
func GetInstallableModelFromURL(url string) (InstallableModel, error) {
|
||||
var config InstallableModel
|
||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
var config Config
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
|
@ -72,7 +76,7 @@ func GetInstallableModelFromURL(url string) (InstallableModel, error) {
|
|||
return config, nil
|
||||
}
|
||||
|
||||
func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
|
||||
func ReadConfigFile(filePath string) (*Config, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
|
@ -80,7 +84,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
|
|||
}
|
||||
|
||||
// Unmarshal YAML data into a Config struct
|
||||
var config InstallableModel
|
||||
var config Config
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
||||
|
@ -89,7 +93,7 @@ func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
|
|||
return &config, nil
|
||||
}
|
||||
|
||||
func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0755)
|
||||
if err != nil {
|
||||
|
@ -179,3 +183,54 @@ func InstallModel(basePath, nameOverride string, config *InstallableModel, confi
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = pw.hash.Write(p)
|
||||
pw.written += int64(n)
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
} else {
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return strconv.FormatInt(bytes, 10) + " B"
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func calculateSHA(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ var _ = Describe("Model test", func() {
|
|||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
|
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
|
|||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
|
@ -103,7 +103,7 @@ var _ = Describe("Model test", func() {
|
|||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
|
||||
|
@ -129,7 +129,7 @@ var _ = Describe("Model test", func() {
|
|||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
package gallery
|
||||
|
||||
type GalleryOp struct {
|
||||
Req GalleryModel
|
||||
Id string
|
||||
Galleries []Gallery
|
||||
GalleryName string
|
||||
}
|
||||
|
||||
type GalleryOpStatus struct {
|
||||
FileName string `json:"file_name"`
|
||||
Error error `json:"error"`
|
||||
Processed bool `json:"processed"`
|
||||
Message string `json:"message"`
|
||||
Progress float64 `json:"progress"`
|
||||
TotalFileSize string `json:"file_size"`
|
||||
DownloadedFileSize string `json:"downloaded_size"`
|
||||
}
|
|
@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
|
|||
Context("requests", func() {
|
||||
It("parses github with a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
e, err := GetInstallableModelFromURL(req.URL)
|
||||
e, err := GetGalleryConfigFromURL(req.URL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||
})
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
|
@ -53,9 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
|||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// TODO CHECK THIS
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) {
|
||||
return schema.WhisperResult{}, fmt.Errorf("unimplemented")
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
|
||||
return schema.Result{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
|||
return client.TTS(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) {
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
|
@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tresult := &schema.WhisperResult{}
|
||||
tresult := &schema.Result{}
|
||||
for _, s := range res.Segments {
|
||||
tks := []int{}
|
||||
for _, t := range s.Tokens {
|
||||
tks = append(tks, int(t))
|
||||
}
|
||||
tresult.Segments = append(tresult.Segments,
|
||||
schema.WhisperSegment{
|
||||
schema.Segment{
|
||||
Text: s.Text,
|
||||
Id: int(s.Id),
|
||||
Start: time.Duration(s.Start),
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
|
@ -15,7 +15,7 @@ type LLM interface {
|
|||
Load(*pb.ModelOptions) error
|
||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||
GenerateImage(*pb.GenerateImageRequest) error
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
|
||||
TTS(*pb.TTSRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v4.26.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.6.1
|
||||
// source: backend.proto
|
||||
|
||||
package proto
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.3.0
|
||||
// - protoc v4.26.0
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.6.1
|
||||
// source: backend.proto
|
||||
|
||||
package proto
|
||||
|
@ -18,19 +18,6 @@ import (
|
|||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
const (
|
||||
Backend_Health_FullMethodName = "/backend.Backend/Health"
|
||||
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
|
||||
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
|
||||
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
|
||||
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
|
||||
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
|
||||
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
|
||||
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
|
||||
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
|
||||
Backend_Status_FullMethodName = "/backend.Backend/Status"
|
||||
)
|
||||
|
||||
// BackendClient is the client API for Backend service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
|
@ -57,7 +44,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
|
|||
|
||||
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -66,7 +53,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
|
|||
|
||||
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -75,7 +62,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
|
|||
|
||||
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -83,7 +70,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
|
|||
}
|
||||
|
||||
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -116,7 +103,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
|
|||
|
||||
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
||||
out := new(EmbeddingResult)
|
||||
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -125,7 +112,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
|
|||
|
||||
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,7 +121,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
|
|||
|
||||
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
|
||||
out := new(TranscriptResult)
|
||||
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -143,7 +130,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
|
|||
|
||||
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -152,7 +139,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
|
|||
|
||||
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
|
||||
out := new(TokenizationResponse)
|
||||
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -161,7 +148,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
|
|||
|
||||
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
|
||||
out := new(StatusResponse)
|
||||
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -242,7 +229,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_Health_FullMethodName,
|
||||
FullMethod: "/backend.Backend/Health",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
|
||||
|
@ -260,7 +247,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_Predict_FullMethodName,
|
||||
FullMethod: "/backend.Backend/Predict",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
|
||||
|
@ -278,7 +265,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_LoadModel_FullMethodName,
|
||||
FullMethod: "/backend.Backend/LoadModel",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
|
||||
|
@ -317,7 +304,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_Embedding_FullMethodName,
|
||||
FullMethod: "/backend.Backend/Embedding",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
|
||||
|
@ -335,7 +322,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_GenerateImage_FullMethodName,
|
||||
FullMethod: "/backend.Backend/GenerateImage",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
|
||||
|
@ -353,7 +340,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_AudioTranscription_FullMethodName,
|
||||
FullMethod: "/backend.Backend/AudioTranscription",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
|
||||
|
@ -371,7 +358,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_TTS_FullMethodName,
|
||||
FullMethod: "/backend.Backend/TTS",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
|
||||
|
@ -389,7 +376,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_TokenizeString_FullMethodName,
|
||||
FullMethod: "/backend.Backend/TokenizeString",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
|
||||
|
@ -407,7 +394,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
|
|||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_Status_FullMethodName,
|
||||
FullMethod: "/backend.Backend/Status",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{
|
|||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) {
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
|
||||
return func(modelName, modelFile string) (ModelAddress, error) {
|
||||
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"sync"
|
||||
"text/template"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type ModelOptions struct {
|
||||
type Options struct {
|
||||
backendString string
|
||||
model string
|
||||
threads uint32
|
||||
|
@ -23,14 +23,14 @@ type ModelOptions struct {
|
|||
parallelRequests bool
|
||||
}
|
||||
|
||||
type Option func(*ModelOptions)
|
||||
type Option func(*Options)
|
||||
|
||||
var EnableParallelRequests = func(o *ModelOptions) {
|
||||
var EnableParallelRequests = func(o *Options) {
|
||||
o.parallelRequests = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = make(map[string]string)
|
||||
}
|
||||
|
@ -38,81 +38,62 @@ func WithExternalBackend(name string, uri string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates
|
||||
func WithExternalBackends(backends map[string]string, overwrite bool) Option {
|
||||
return func(o *ModelOptions) {
|
||||
if backends == nil {
|
||||
return
|
||||
}
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = backends
|
||||
return
|
||||
}
|
||||
for name, url := range backends {
|
||||
_, exists := o.externalBackends[name]
|
||||
if !exists || overwrite {
|
||||
o.externalBackends[name] = url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithGRPCAttempts(attempts int) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.grpcAttempts = attempts
|
||||
}
|
||||
}
|
||||
|
||||
func WithGRPCAttemptsDelay(delay int) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.grpcAttemptsDelay = delay
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendString(backend string) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.backendString = backend
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(modelFile string) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.model = modelFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.gRPCOptions = opts
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads uint32) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithAssetDir(assetDir string) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.assetDir = assetDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithSingleActiveBackend() Option {
|
||||
return func(o *ModelOptions) {
|
||||
return func(o *Options) {
|
||||
o.singleActiveBackend = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewOptions(opts ...Option) *ModelOptions {
|
||||
o := &ModelOptions{
|
||||
func NewOptions(opts ...Option) *Options {
|
||||
o := &Options{
|
||||
gRPCOptions: &pb.ModelOptions{},
|
||||
context: context.Background(),
|
||||
grpcAttempts: 20,
|
||||
|
|
|
@ -1,400 +0,0 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PredictionOptions `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
|
||||
F16 bool `yaml:"f16"`
|
||||
Threads int `yaml:"threads"`
|
||||
Debug bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-"`
|
||||
|
||||
FunctionsConfig Functions `yaml:"function"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline"`
|
||||
|
||||
// AutoGPTQ specifics
|
||||
AutoGPTQ AutoGPTQ `yaml:"autogptq"`
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers"`
|
||||
Step int `yaml:"step"`
|
||||
|
||||
// GRPC Options
|
||||
GRPC GRPC `yaml:"grpc"`
|
||||
|
||||
// Vall-e-x
|
||||
VallE VallE `yaml:"vall-e"`
|
||||
|
||||
// CUDA
|
||||
// Explicitly enable CUDA or not (some backends might need it)
|
||||
CUDA bool `yaml:"cuda"`
|
||||
|
||||
DownloadFiles []File `yaml:"download_files"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Filename string `yaml:"filename" json:"filename"`
|
||||
SHA256 string `yaml:"sha256" json:"sha256"`
|
||||
URI string `yaml:"uri" json:"uri"`
|
||||
}
|
||||
|
||||
type VallE struct {
|
||||
AudioPath string `yaml:"audio_path"`
|
||||
}
|
||||
|
||||
type FeatureFlag map[string]*bool
|
||||
|
||||
func (ff FeatureFlag) Enabled(s string) bool {
|
||||
v, exist := ff[s]
|
||||
return exist && v != nil && *v
|
||||
}
|
||||
|
||||
type GRPC struct {
|
||||
Attempts int `yaml:"attempts"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
|
||||
}
|
||||
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||
Mirostat int `yaml:"mirostat"`
|
||||
NGPULayers int `yaml:"gpu_layers"`
|
||||
MMap bool `yaml:"mmap"`
|
||||
MMlock bool `yaml:"mmlock"`
|
||||
LowVRAM bool `yaml:"low_vram"`
|
||||
Grammar string `yaml:"grammar"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix"`
|
||||
|
||||
ContextSize int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
MMProj string `yaml:"mmproj"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling"`
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
||||
}
|
||||
|
||||
type AutoGPTQ struct {
|
||||
ModelBaseName string `yaml:"model_base_name"`
|
||||
Device string `yaml:"device"`
|
||||
Triton bool `yaml:"triton"`
|
||||
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
|
||||
}
|
||||
|
||||
type Functions struct {
|
||||
DisableNoAction bool `yaml:"disable_no_action"`
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Chat string `yaml:"chat"`
|
||||
ChatMessage string `yaml:"chat_message"`
|
||||
Completion string `yaml:"completion"`
|
||||
Edit string `yaml:"edit"`
|
||||
Functions string `yaml:"function"`
|
||||
}
|
||||
|
||||
func (c *Config) SetFunctionCallString(s string) {
|
||||
c.functionCallString = s
|
||||
}
|
||||
|
||||
func (c *Config) SetFunctionCallNameString(s string) {
|
||||
c.functionCallNameString = s
|
||||
}
|
||||
|
||||
func (c *Config) ShouldUseFunctions() bool {
|
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||
}
|
||||
|
||||
func (c *Config) ShouldCallSpecificFunction() bool {
|
||||
return len(c.functionCallNameString) > 0
|
||||
}
|
||||
|
||||
func (c *Config) FunctionToCall() string {
|
||||
return c.functionCallNameString
|
||||
}
|
||||
|
||||
func defaultPredictOptions(modelFile string) PredictionOptions {
|
||||
return PredictionOptions{
|
||||
TopP: 0.7,
|
||||
TopK: 80,
|
||||
Maxtokens: 512,
|
||||
Temperature: 0.9,
|
||||
Model: modelFile,
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultConfig(modelFile string) *Config {
|
||||
return &Config{
|
||||
PredictionOptions: defaultPredictOptions(modelFile),
|
||||
}
|
||||
}
|
||||
|
||||
func ReadConfigFile(file string) ([]*Config, error) {
|
||||
c := &[]*Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadSingleConfigFile(file string) (*Config, error) {
|
||||
c := &Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.LLMConfig.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
fmt.Print("Failed encoding image", err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check that this was merged correctly? I _think_ it is?
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package schema_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
|
||||
var (
|
||||
configFile string
|
||||
)
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
configFile = os.Getenv("CONFIG_FILE")
|
||||
It("Test ReadConfigFile", func() {
|
||||
config, err := schema.ReadConfigFile(configFile)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
Expect(config[0].Name).To(Equal("list1"))
|
||||
Expect(config[1].Name).To(Equal("list2"))
|
||||
})
|
||||
|
||||
It("Test LoadConfigs", func() {
|
||||
cm := services.NewConfigLoader()
|
||||
err := cm.LoadConfigs(os.Getenv("MODELS_PATH"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
||||
|
||||
// config should includes gpt4all models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
|
||||
|
||||
// config should includes gpt2 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
|
||||
|
||||
// config should includes text-embedding-ada-002 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||
|
||||
// config should includes rwkv_test models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
|
||||
|
||||
// config should includes whisper-1 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,39 +0,0 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type BackendMonitorRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
}
|
||||
|
||||
type BackendMonitorResponse struct {
|
||||
MemoryInfo *gopsutil.MemoryInfoStat
|
||||
MemoryPercent float32
|
||||
CPUPercent float64
|
||||
}
|
||||
|
||||
type TTSRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Input string `json:"input" yaml:"input"`
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
type LocalAIMetrics struct {
|
||||
Meter metric.Meter
|
||||
ApiTimeMetric metric.Float64Histogram
|
||||
}
|
||||
|
||||
func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) {
|
||||
opts := metric.WithAttributes(
|
||||
attribute.String("method", method),
|
||||
attribute.String("path", path),
|
||||
)
|
||||
m.ApiTimeMetric.Record(context.Background(), duration, opts)
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
)
|
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct {
|
||||
Code any `json:"code,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Param *string `json:"param,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object,omitempty"`
|
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Created int `json:"created,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
Data []Item `json:"data,omitempty"`
|
||||
|
||||
Usage OpenAIUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Type string `json:"type" yaml:"type"`
|
||||
Text string `json:"text" yaml:"text"`
|
||||
ImageURL ContentURL `json:"image_url" yaml:"image_url"`
|
||||
}
|
||||
|
||||
type ContentURL struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"`
|
||||
// The message content
|
||||
Content interface{} `json:"content" yaml:"content"`
|
||||
|
||||
StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"`
|
||||
StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"`
|
||||
|
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponseFormatType string
|
||||
|
||||
type ChatCompletionResponseFormat struct {
|
||||
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
PredictionOptions
|
||||
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"`
|
||||
//whisper/image
|
||||
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
|
||||
// image
|
||||
Size string `json:"size"`
|
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"`
|
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"`
|
||||
Input interface{} `json:"input" yaml:"input"`
|
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"`
|
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"`
|
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"`
|
||||
Step int `json:"step"`
|
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"`
|
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
||||
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
|
||||
// AutoGPTQ
|
||||
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
package schema
|
||||
|
||||
type PredictionOptions struct {
|
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Model string `json:"model" yaml:"model"`
|
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Language string `json:"language"`
|
||||
|
||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||
N int `json:"n"`
|
||||
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
||||
TopK int `json:"top_k" yaml:"top_k"`
|
||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
||||
Echo bool `json:"echo"`
|
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"`
|
||||
F16 bool `json:"f16" yaml:"f16"`
|
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||
Seed int `json:"seed" yaml:"seed"`
|
||||
|
||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
||||
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
||||
RopeFreqScale float32 `json:"rope_freq_scale" yaml:"rope_freq_scale"`
|
||||
NegativePromptScale float32 `json:"negative_prompt_scale" yaml:"negative_prompt_scale"`
|
||||
// AutoGPTQ
|
||||
UseFastTokenizer bool `json:"use_fast_tokenizer" yaml:"use_fast_tokenizer"`
|
||||
|
||||
// Diffusers
|
||||
ClipSkip int `json:"clip_skip" yaml:"clip_skip"`
|
||||
|
||||
// RWKV (?)
|
||||
Tokenizer string `json:"tokenizer" yaml:"tokenizer"`
|
||||
}
|
|
@ -1,260 +0,0 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type StartupOptions struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
ModelPath string
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
Debug, DisableMessage bool
|
||||
ImageDir string
|
||||
AudioDir string
|
||||
CORS bool
|
||||
PreloadJSONModels string
|
||||
PreloadModelsFromPath string
|
||||
CORSAllowOrigins string
|
||||
ApiKeys []string
|
||||
Metrics *LocalAIMetrics
|
||||
|
||||
Galleries []gallery.Gallery
|
||||
|
||||
BackendAssets embed.FS
|
||||
AssetsDestination string
|
||||
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries bool
|
||||
|
||||
SingleBackend bool
|
||||
ParallelBackendRequests bool
|
||||
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
|
||||
ModelsURL []string
|
||||
|
||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||
|
||||
LocalAIConfigDir string
|
||||
}
|
||||
|
||||
type AppOption func(*StartupOptions)
|
||||
|
||||
func NewStartupOptions(o ...AppOption) *StartupOptions {
|
||||
opt := &StartupOptions{
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Threads: 1,
|
||||
ContextSize: 512,
|
||||
Debug: true,
|
||||
DisableMessage: true,
|
||||
}
|
||||
for _, oo := range o {
|
||||
oo(opt)
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func WithModelsURL(urls ...string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ModelsURL = urls
|
||||
}
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.CORS = b
|
||||
}
|
||||
}
|
||||
|
||||
var EnableWatchDog = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
}
|
||||
|
||||
var EnableWatchDogIdleCheck = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
o.WatchDogIdle = true
|
||||
}
|
||||
|
||||
var EnableWatchDogBusyCheck = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
o.WatchDogBusy = true
|
||||
}
|
||||
|
||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.WatchDogBusyTimeout = t
|
||||
}
|
||||
}
|
||||
|
||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.WatchDogIdleTimeout = t
|
||||
}
|
||||
}
|
||||
|
||||
var EnableSingleBackend = func(o *StartupOptions) {
|
||||
o.SingleBackend = true
|
||||
}
|
||||
|
||||
var EnableParallelBackendRequests = func(o *StartupOptions) {
|
||||
o.ParallelBackendRequests = true
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *StartupOptions) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
o.ExternalGRPCBackends = make(map[string]string)
|
||||
}
|
||||
o.ExternalGRPCBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.CORSAllowOrigins = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssetsOutput(out string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.AssetsDestination = out
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssets(f embed.FS) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.BackendAssets = f
|
||||
}
|
||||
}
|
||||
|
||||
func WithStringGalleries(galls string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
if galls == "" {
|
||||
log.Debug().Msgf("no galleries to load")
|
||||
o.Galleries = []gallery.Gallery{}
|
||||
return
|
||||
}
|
||||
var galleries []gallery.Gallery
|
||||
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
|
||||
log.Error().Msgf("failed loading galleries: %s", err.Error())
|
||||
}
|
||||
o.Galleries = append(o.Galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Galleries = append(o.Galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.PreloadModelsFromPath = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.PreloadJSONModels = configFile
|
||||
}
|
||||
}
|
||||
func WithConfigFile(configFile string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ConfigFile = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithModelPath(path string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ModelPath = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadLimitMB(limit int) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.UploadLimitMB = limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads int) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithContextSize(ctxSize int) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ContextSize = ctxSize
|
||||
}
|
||||
}
|
||||
|
||||
func WithF16(f16 bool) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.F16 = f16
|
||||
}
|
||||
}
|
||||
|
||||
func WithDebug(debug bool) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Debug = debug
|
||||
}
|
||||
}
|
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.DisableMessage = disableMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudioDir(audioDir string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.AudioDir = audioDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithImageDir(imageDir string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ImageDir = imageDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithApiKeys(apiKeys []string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ApiKeys = apiKeys
|
||||
}
|
||||
}
|
||||
|
||||
func WithMetrics(metrics *LocalAIMetrics) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Metrics = metrics
|
||||
}
|
||||
}
|
||||
|
||||
func WithLocalAIConfigDir(configDir string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.LocalAIConfigDir = configDir
|
||||
}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package schema
|
||||
|
||||
import "time"
|
||||
|
||||
type WhisperSegment struct {
|
||||
Id int `json:"id"`
|
||||
Start time.Duration `json:"start"`
|
||||
End time.Duration `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type WhisperResult struct {
|
||||
Segments []WhisperSegment `json:"segments"`
|
||||
Text string `json:"text"`
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) {
|
||||
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create a temporary file in the requested directory:
|
||||
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer outputFile.Close()
|
||||
|
||||
if _, err := io.Copy(outputFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return outputFile.Name(), nil
|
||||
}
|
||||
|
||||
func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) {
|
||||
if len(base64data) == 0 {
|
||||
return "", fmt.Errorf("base64data empty?")
|
||||
}
|
||||
//base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
decoded, err := base64.StdEncoding.DecodeString(base64data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Create a temporary file in the requested directory:
|
||||
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer outputFile.Close()
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(decoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return outputFile.Name(), nil
|
||||
}
|
||||
|
||||
func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
|
@ -3,38 +3,18 @@ package utils
|
|||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func getRecognizedURIPrefixes() []string {
|
||||
return []string{
|
||||
HuggingFacePrefix,
|
||||
HTTPPrefix,
|
||||
HTTPSPrefix,
|
||||
GithubURI,
|
||||
GithubURI2,
|
||||
}
|
||||
}
|
||||
|
||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
url = ConvertURL(url)
|
||||
|
||||
|
@ -72,8 +52,20 @@ func GetURI(url string, f func(url string, i []byte) error) error {
|
|||
return f(url, body)
|
||||
}
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func LooksLikeURL(s string) bool {
|
||||
return slices.Contains(getRecognizedURIPrefixes(), s)
|
||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
||||
strings.HasPrefix(s, HuggingFacePrefix) ||
|
||||
strings.HasPrefix(s, GithubURI) ||
|
||||
strings.HasPrefix(s, GithubURI2)
|
||||
}
|
||||
|
||||
func ConvertURL(s string) string {
|
||||
|
@ -249,37 +241,6 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
|
|||
return nil
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func GetBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue