mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
refactor: backend/service split, channel-based llm flow (#1963)
Refactor: channel based llm flow and services split --------- Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
1981154f49
commit
eed5706994
52 changed files with 3064 additions and 2279 deletions
|
@ -1,18 +1,252 @@
|
|||
package backend
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
type ImageGenerationBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
BaseUrlForGeneratedImages string
|
||||
}
|
||||
|
||||
func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService {
|
||||
return &ImageGenerationBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
go func(request *schema.OpenAIRequest) {
|
||||
bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
src := ""
|
||||
if request.File != "" {
|
||||
|
||||
var fileData []byte
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") {
|
||||
out, err := downloadFile(request.File)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(request.File)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64")
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", bc)
|
||||
|
||||
switch bc.Backend {
|
||||
case "stablediffusion":
|
||||
bc.Backend = model.StableDiffusionBackend
|
||||
case "tinydream":
|
||||
bc.Backend = model.TinyDreamBackend
|
||||
case "":
|
||||
bc.Backend = model.StableDiffusionBackend
|
||||
if bc.Model == "" {
|
||||
bc.Model = "stablediffusion_assets" // TODO: check?
|
||||
}
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(request.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if request.ResponseFormat.Type == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
// src and clip_skip
|
||||
var result []schema.Item
|
||||
for _, i := range bc.PromptStrings {
|
||||
n := request.N
|
||||
if request.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := bc.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if request.Mode != 0 {
|
||||
mode = request.Mode
|
||||
}
|
||||
|
||||
if request.Step != 0 {
|
||||
step = request.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = igbs.appConfig.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Seed == nil {
|
||||
zVal := 0 // Idiomatic way to do this? Actually needed?
|
||||
request.Seed = &zVal
|
||||
}
|
||||
|
||||
fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
item := &schema.Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = igbs.BaseUrlForGeneratedImages + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: result,
|
||||
}
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
|
||||
close(resultChannel)
|
||||
}(request)
|
||||
return resultChannel
|
||||
}
|
||||
|
||||
func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
|
||||
threads := backendConfig.Threads
|
||||
if *threads == 0 && appConfig.Threads != 0 {
|
||||
threads = &appConfig.Threads
|
||||
}
|
||||
|
||||
gRPCOpts := gRPCModelOpts(backendConfig)
|
||||
|
||||
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||
model.WithBackendString(backendConfig.Backend),
|
||||
model.WithAssetDir(appConfig.AssetsDestination),
|
||||
|
@ -50,3 +284,24 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change.
|
||||
func downloadFile(url string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp("", "image")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue