refactor(application): introduce application global state (#2072)

* start breaking up the giant channel refactor now that it's better understood - easier to merge bites

Signed-off-by: Dave Lee <dave@gray101.com>

* add concurrency and base64 back in, along with new base64 tests.

Signed-off-by: Dave Lee <dave@gray101.com>

* Automatic rename of whisper.go's Result to TranscriptResult

Signed-off-by: Dave Lee <dave@gray101.com>

* remove pkg/concurrency - significant changes coming in split 2

Signed-off-by: Dave Lee <dave@gray101.com>

* fix comments

Signed-off-by: Dave Lee <dave@gray101.com>

* add list_model service as another low-risk service to get it out of the way

Signed-off-by: Dave Lee <dave@gray101.com>

* split backend config loader into seperate file from the actual config struct. No changes yet, just reduce cognative load with smaller files of logical blocks

Signed-off-by: Dave Lee <dave@gray101.com>

* rename state.go ==> application.go

Signed-off-by: Dave Lee <dave@gray101.com>

* fix lost import?

Signed-off-by: Dave Lee <dave@gray101.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-04-29 13:42:37 -04:00 committed by GitHub
parent 147440b39b
commit c4f958e11b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 590 additions and 422 deletions

View file

@ -41,7 +41,7 @@ type Backend interface {
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error)

View file

@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
return schema.Result{}, fmt.Errorf("unimplemented")
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return schema.TranscriptionResult{}, fmt.Errorf("unimplemented")
}
func (llm *Base) TTS(*pb.TTSRequest) error {

View file

@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...)
}
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
if err != nil {
return nil, err
}
tresult := &schema.Result{}
tresult := &schema.TranscriptionResult{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {

View file

@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
return e.s.TTS(ctx, in)
}
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
r, err := e.s.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
tr := &schema.Result{}
tr := &schema.TranscriptionResult{}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {

View file

@ -15,7 +15,7 @@ type LLM interface {
Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)

50
pkg/utils/base64.go Normal file
View file

@ -0,0 +1,50 @@
package utils
import (
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"
)
var base64DownloadClient http.Client = http.Client{
Timeout: 30 * time.Second,
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
// This may look weird down in pkg/utils while it is currently only used in core/config
//
// but I believe it may be useful for MQTT as well in the near future, so I'm
// extracting it while I'm thinking of it.
func GetImageURLAsBase64(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := base64DownloadClient.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}

31
pkg/utils/base64_test.go Normal file
View file

@ -0,0 +1,31 @@
package utils_test
import (
. "github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("utils/base64 tests", func() {
It("GetImageURLAsBase64 can strip data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := "data:image/jpeg;base64,FOO"
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("FOO"))
})
It("GetImageURLAsBase64 returns an error for bogus data", func() {
input := "FOO"
b64, err := GetImageURLAsBase64(input)
Expect(b64).To(Equal(""))
Expect(err).ToNot(BeNil())
Expect(err).To(MatchError("not valid string"))
})
It("GetImageURLAsBase64 can actually download images and calculates something", func() {
// This test doesn't actually _check_ the results at this time, which is bad, but there wasn't a test at all before...
input := "https://upload.wikimedia.org/wikipedia/en/2/29/Wargames.jpg"
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).ToNot(BeNil())
})
})