mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-30 23:44:59 +00:00
* Revert "fix(fncall): fix regression introduced in #1963 (#2048)" This reverts commit6b06d4e0af
. * Revert "fix: action-tmate back to upstream, dead code removal (#2038)" This reverts commitfdec8a9d00
. * Revert "feat(grpc): return consumed token count and update response accordingly (#2035)" This reverts commite843d7df0e
. * Revert "refactor: backend/service split, channel-based llm flow (#1963)" This reverts commiteed5706994
. * feat(grpc): return consumed token count and update response accordingly Fixes: #1920 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
af8c705ecd
commit
af9e5a2d05
52 changed files with 2295 additions and 3065 deletions
|
@ -1,135 +0,0 @@
|
|||
package concurrency
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TODO: closeWhenDone bool parameter ::
|
||||
// It currently is experimental, and therefore exists.
|
||||
// Is there ever a situation to use false?
|
||||
|
||||
// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of a second type.
|
||||
// mappingFn allows the caller to convert from the input type to the output type
|
||||
// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use.
|
||||
// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes.
|
||||
func SliceOfChannelsRawMerger[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan IndividualResultType, outputChannel chan<- OutputResultType, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(individualResultChannels))
|
||||
mergingFn := func(c <-chan IndividualResultType) {
|
||||
for r := range c {
|
||||
mr, err := mappingFn(r)
|
||||
if err == nil {
|
||||
outputChannel <- mr
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
for _, irc := range individualResultChannels {
|
||||
go mergingFn(irc)
|
||||
}
|
||||
if closeWhenDone {
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(outputChannel)
|
||||
}()
|
||||
}
|
||||
|
||||
return &wg
|
||||
}
|
||||
|
||||
// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of THE SAME TYPE.
|
||||
// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use.
|
||||
// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes.
|
||||
func SliceOfChannelsRawMergerWithoutMapping[ResultType any](individualResultsChannels []<-chan ResultType, outputChannel chan<- ResultType, closeWhenDone bool) *sync.WaitGroup {
|
||||
return SliceOfChannelsRawMerger(individualResultsChannels, outputChannel, func(v ResultType) (ResultType, error) { return v, nil }, closeWhenDone)
|
||||
}
|
||||
|
||||
// This function is used to merge the results of a slice of channels of a specific result type down to a single succcess result channel of a second type, and an error channel
|
||||
// mappingFn allows the caller to convert from the input type to the output type
|
||||
// This variant is designed to be aware of concurrency.ErrorOr[T], splitting successes from failures.
|
||||
// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use.
|
||||
// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes.
|
||||
func SliceOfChannelsMergerWithErrors[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan ErrorOr[IndividualResultType], successChannel chan<- OutputResultType, errorChannel chan<- error, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(individualResultChannels))
|
||||
mergingFn := func(c <-chan ErrorOr[IndividualResultType]) {
|
||||
for r := range c {
|
||||
if r.Error != nil {
|
||||
errorChannel <- r.Error
|
||||
} else {
|
||||
mv, err := mappingFn(r.Value)
|
||||
if err != nil {
|
||||
errorChannel <- err
|
||||
} else {
|
||||
successChannel <- mv
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
for _, irc := range individualResultChannels {
|
||||
go mergingFn(irc)
|
||||
}
|
||||
if closeWhenDone {
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(successChannel)
|
||||
close(errorChannel)
|
||||
}()
|
||||
}
|
||||
return &wg
|
||||
}
|
||||
|
||||
// This function is used to reduce down the results of a slice of channels of a specific result type down to a single result value of a second type.
|
||||
// reducerFn allows the caller to convert from the input type to the output type
|
||||
// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use.
|
||||
// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes.
|
||||
func SliceOfChannelsReducer[InputResultType any, OutputResultType any](individualResultsChannels []<-chan InputResultType, outputChannel chan<- OutputResultType,
|
||||
reducerFn func(iv InputResultType, ov OutputResultType) OutputResultType, initialValue OutputResultType, closeWhenDone bool) (wg *sync.WaitGroup) {
|
||||
wg = &sync.WaitGroup{}
|
||||
wg.Add(len(individualResultsChannels))
|
||||
reduceLock := sync.Mutex{}
|
||||
reducingFn := func(c <-chan InputResultType) {
|
||||
for iv := range c {
|
||||
reduceLock.Lock()
|
||||
initialValue = reducerFn(iv, initialValue)
|
||||
reduceLock.Unlock()
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
for _, irc := range individualResultsChannels {
|
||||
go reducingFn(irc)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
outputChannel <- initialValue
|
||||
if closeWhenDone {
|
||||
close(outputChannel)
|
||||
}
|
||||
}()
|
||||
return wg
|
||||
}
|
||||
|
||||
// This function is primarily designed to be used in combination with the above utility functions.
|
||||
// A slice of input result channels of a specific type is provided, along with a function to map those values to another type
|
||||
// A slice of output result channels is returned, where each value is mapped as it comes in.
|
||||
// The order of the slice will be retained.
|
||||
func SliceOfChannelsTransformer[InputResultType any, OutputResultType any](inputChanels []<-chan InputResultType, mappingFn func(v InputResultType) OutputResultType) (outputChannels []<-chan OutputResultType) {
|
||||
rawOutputChannels := make([]<-chan OutputResultType, len(inputChanels))
|
||||
|
||||
transformingFn := func(ic <-chan InputResultType, oc chan OutputResultType) {
|
||||
for iv := range ic {
|
||||
oc <- mappingFn(iv)
|
||||
}
|
||||
close(oc)
|
||||
}
|
||||
|
||||
for ci, c := range inputChanels {
|
||||
roc := make(chan OutputResultType)
|
||||
go transformingFn(c, roc)
|
||||
rawOutputChannels[ci] = roc
|
||||
}
|
||||
|
||||
outputChannels = rawOutputChannels
|
||||
return
|
||||
}
|
|
@ -1,101 +0,0 @@
|
|||
package concurrency_test
|
||||
|
||||
// TODO: noramlly, these go in utils_tests, right? Why does this cause problems only in pkg/utils?
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("utils/concurrency tests", func() {
|
||||
It("SliceOfChannelsReducer works", func() {
|
||||
individualResultsChannels := []<-chan int{}
|
||||
initialValue := 0
|
||||
for i := 0; i < 3; i++ {
|
||||
c := make(chan int)
|
||||
go func(i int, c chan int) {
|
||||
for ii := 1; ii < 4; ii++ {
|
||||
c <- (i * ii)
|
||||
}
|
||||
close(c)
|
||||
}(i, c)
|
||||
individualResultsChannels = append(individualResultsChannels, c)
|
||||
}
|
||||
Expect(len(individualResultsChannels)).To(Equal(3))
|
||||
finalResultChannel := make(chan int)
|
||||
wg := SliceOfChannelsReducer[int, int](individualResultsChannels, finalResultChannel, func(input int, val int) int {
|
||||
return val + input
|
||||
}, initialValue, true)
|
||||
|
||||
Expect(wg).ToNot(BeNil())
|
||||
|
||||
result := <-finalResultChannel
|
||||
|
||||
Expect(result).ToNot(Equal(0))
|
||||
Expect(result).To(Equal(18))
|
||||
})
|
||||
|
||||
It("SliceOfChannelsRawMergerWithoutMapping works", func() {
|
||||
individualResultsChannels := []<-chan int{}
|
||||
for i := 0; i < 3; i++ {
|
||||
c := make(chan int)
|
||||
go func(i int, c chan int) {
|
||||
for ii := 1; ii < 4; ii++ {
|
||||
c <- (i * ii)
|
||||
}
|
||||
close(c)
|
||||
}(i, c)
|
||||
individualResultsChannels = append(individualResultsChannels, c)
|
||||
}
|
||||
Expect(len(individualResultsChannels)).To(Equal(3))
|
||||
outputChannel := make(chan int)
|
||||
wg := SliceOfChannelsRawMergerWithoutMapping(individualResultsChannels, outputChannel, true)
|
||||
Expect(wg).ToNot(BeNil())
|
||||
outputSlice := []int{}
|
||||
for v := range outputChannel {
|
||||
outputSlice = append(outputSlice, v)
|
||||
}
|
||||
Expect(len(outputSlice)).To(Equal(9))
|
||||
slices.Sort(outputSlice)
|
||||
Expect(outputSlice[0]).To(BeZero())
|
||||
Expect(outputSlice[3]).To(Equal(1))
|
||||
Expect(outputSlice[8]).To(Equal(6))
|
||||
})
|
||||
|
||||
It("SliceOfChannelsTransformer works", func() {
|
||||
individualResultsChannels := []<-chan int{}
|
||||
for i := 0; i < 3; i++ {
|
||||
c := make(chan int)
|
||||
go func(i int, c chan int) {
|
||||
for ii := 1; ii < 4; ii++ {
|
||||
c <- (i * ii)
|
||||
}
|
||||
close(c)
|
||||
}(i, c)
|
||||
individualResultsChannels = append(individualResultsChannels, c)
|
||||
}
|
||||
Expect(len(individualResultsChannels)).To(Equal(3))
|
||||
mappingFn := func(i int) string {
|
||||
return fmt.Sprintf("$%d", i)
|
||||
}
|
||||
|
||||
outputChannels := SliceOfChannelsTransformer(individualResultsChannels, mappingFn)
|
||||
Expect(len(outputChannels)).To(Equal(3))
|
||||
rSlice := []string{}
|
||||
for ii := 1; ii < 4; ii++ {
|
||||
for i := 0; i < 3; i++ {
|
||||
res := <-outputChannels[i]
|
||||
rSlice = append(rSlice, res)
|
||||
}
|
||||
}
|
||||
slices.Sort(rSlice)
|
||||
Expect(rSlice[0]).To(Equal("$0"))
|
||||
Expect(rSlice[3]).To(Equal("$1"))
|
||||
Expect(rSlice[8]).To(Equal("$6"))
|
||||
})
|
||||
})
|
|
@ -1,6 +0,0 @@
|
|||
package concurrency
|
||||
|
||||
type ErrorOr[T any] struct {
|
||||
Value T
|
||||
Error error
|
||||
}
|
|
@ -41,7 +41,7 @@ type Backend interface {
|
|||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
|
||||
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||
|
||||
|
|
|
@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
|||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) {
|
||||
return schema.TranscriptionResult{}, fmt.Errorf("unimplemented")
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
|
||||
return schema.Result{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
|
|
|
@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
|||
return client.TTS(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
|
@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tresult := &schema.TranscriptionResult{}
|
||||
tresult := &schema.Result{}
|
||||
for _, s := range res.Segments {
|
||||
tks := []int{}
|
||||
for _, t := range s.Tokens {
|
||||
|
|
|
@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
|
|||
return e.s.TTS(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
|
||||
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
||||
r, err := e.s.AudioTranscription(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := &schema.TranscriptionResult{}
|
||||
tr := &schema.Result{}
|
||||
for _, s := range r.Segments {
|
||||
var tks []int
|
||||
for _, t := range s.Tokens {
|
||||
|
|
|
@ -15,7 +15,7 @@ type LLM interface {
|
|||
Load(*pb.ModelOptions) error
|
||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||
GenerateImage(*pb.GenerateImageRequest) error
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
|
||||
TTS(*pb.TTSRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
|
|
@ -81,7 +81,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||
if _, err := os.Stat(uri); err == nil {
|
||||
serverAddress, err := getFreeAddress()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%s failed allocating free ports: %s", backend, err.Error())
|
||||
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
// Make sure the process is executable
|
||||
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
|
||||
|
@ -134,7 +134,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||
|
||||
if !ready {
|
||||
log.Debug().Msgf("GRPC Service NOT ready")
|
||||
return "", fmt.Errorf("%s grpc service not ready", backend)
|
||||
return "", fmt.Errorf("grpc service not ready")
|
||||
}
|
||||
|
||||
options := *o.gRPCOptions
|
||||
|
@ -145,10 +145,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||
|
||||
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("\"%s\" could not load model: %w", backend, err)
|
||||
return "", fmt.Errorf("could not load model: %w", err)
|
||||
}
|
||||
if !res.Success {
|
||||
return "", fmt.Errorf("\"%s\" could not load model (no success): %s", backend, res.Message)
|
||||
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
|
|
85
pkg/startup/model_preload.go
Normal file
85
pkg/startup/model_preload.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package startup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/embedded"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// PreloadModelsConfigurations will preload models from the given list of URLs
|
||||
// It will download the model if it is not already present in the model path
|
||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
|
||||
for _, url := range models {
|
||||
|
||||
// As a best effort, try to resolve the model from the remote library
|
||||
// if it's not resolved we try with the other method below
|
||||
if modelLibraryURL != "" {
|
||||
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
|
||||
if err == nil {
|
||||
if lib[url] != "" {
|
||||
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
||||
url = lib[url]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
url = embedded.ModelShortURL(url)
|
||||
switch {
|
||||
case embedded.ExistsInModelsLibrary(url):
|
||||
modelYAML, err := embedded.ResolveContent(url)
|
||||
// If we resolve something, just save it to disk and continue
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error resolving model content")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
|
||||
md5Name := utils.MD5(url)
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
|
||||
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
||||
}
|
||||
case downloader.LooksLikeURL(url):
|
||||
log.Debug().Msgf("[startup] resolved model to download: %s", url)
|
||||
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if _, err := os.Stat(url); err == nil {
|
||||
log.Debug().Msgf("[startup] resolved local model: %s", url)
|
||||
// copy to modelPath
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
modelYAML, err := os.ReadFile(url)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("filepath", url).Msg("error reading model definition")
|
||||
continue
|
||||
}
|
||||
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
|
||||
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
|
||||
}
|
||||
} else {
|
||||
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
82
pkg/startup/model_preload_test.go
Normal file
82
pkg/startup/model_preload_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package startup_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/startup"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Preload test", func() {
|
||||
|
||||
Context("Preloading from strings", func() {
|
||||
It("loads from remote url", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
|
||||
|
||||
PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
|
||||
It("loads from embedded full-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded short-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "phi-2"
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
entry, err := os.ReadDir(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).To(HaveLen(1))
|
||||
resultFile := entry[0].Name()
|
||||
|
||||
content, err := os.ReadFile(filepath.Join(tmpdir, resultFile))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded models", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "mistral-openorca"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,50 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var base64DownloadClient http.Client = http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
|
||||
// This may look weird down in pkg/utils while it is currently only used in core/config
|
||||
//
|
||||
// but I believe it may be useful for MQTT as well in the near future, so I'm
|
||||
// extracting it while I'm thinking of it.
|
||||
func GetImageURLAsBase64(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := base64DownloadClient.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue