mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
refactor: backend/service split, channel-based llm flow (#1963)
Refactor: channel based llm flow and services split --------- Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
1981154f49
commit
eed5706994
52 changed files with 3064 additions and 2279 deletions
|
@ -2,14 +2,100 @@ package backend
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
type EmbeddingsBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService {
|
||||
return &EmbeddingsBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
|
||||
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
go func(request *schema.OpenAIRequest) {
|
||||
if request.Model == "" {
|
||||
request.Model = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range bc.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range bc.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
|
||||
close(resultChannel)
|
||||
}(request)
|
||||
return resultChannel
|
||||
}
|
||||
|
||||
func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
modelFile := backendConfig.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(backendConfig)
|
||||
|
|
|
@ -1,18 +1,252 @@
|
|||
package backend
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
type ImageGenerationBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
BaseUrlForGeneratedImages string
|
||||
}
|
||||
|
||||
func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService {
|
||||
return &ImageGenerationBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
go func(request *schema.OpenAIRequest) {
|
||||
bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
src := ""
|
||||
if request.File != "" {
|
||||
|
||||
var fileData []byte
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") {
|
||||
out, err := downloadFile(request.File)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(request.File)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64")
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", bc)
|
||||
|
||||
switch bc.Backend {
|
||||
case "stablediffusion":
|
||||
bc.Backend = model.StableDiffusionBackend
|
||||
case "tinydream":
|
||||
bc.Backend = model.TinyDreamBackend
|
||||
case "":
|
||||
bc.Backend = model.StableDiffusionBackend
|
||||
if bc.Model == "" {
|
||||
bc.Model = "stablediffusion_assets" // TODO: check?
|
||||
}
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(request.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if request.ResponseFormat.Type == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
// src and clip_skip
|
||||
var result []schema.Item
|
||||
for _, i := range bc.PromptStrings {
|
||||
n := request.N
|
||||
if request.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := bc.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if request.Mode != 0 {
|
||||
mode = request.Mode
|
||||
}
|
||||
|
||||
if request.Step != 0 {
|
||||
step = request.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = igbs.appConfig.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Seed == nil {
|
||||
zVal := 0 // Idiomatic way to do this? Actually needed?
|
||||
request.Seed = &zVal
|
||||
}
|
||||
|
||||
fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
|
||||
item := &schema.Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
|
||||
close(resultChannel)
|
||||
return
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = igbs.BaseUrlForGeneratedImages + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: result,
|
||||
}
|
||||
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
|
||||
close(resultChannel)
|
||||
}(request)
|
||||
return resultChannel
|
||||
}
|
||||
|
||||
func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
|
||||
threads := backendConfig.Threads
|
||||
if *threads == 0 && appConfig.Threads != 0 {
|
||||
threads = &appConfig.Threads
|
||||
}
|
||||
|
||||
gRPCOpts := gRPCModelOpts(backendConfig)
|
||||
|
||||
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||
model.WithBackendString(backendConfig.Backend),
|
||||
model.WithAssetDir(appConfig.AssetsDestination),
|
||||
|
@ -50,3 +284,24 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change.
|
||||
func downloadFile(url string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp("", "image")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
||||
|
|
|
@ -11,17 +11,22 @@ import (
|
|||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
type LLMRequest struct {
|
||||
Id int // TODO Remove if not used.
|
||||
Text string
|
||||
Images []string
|
||||
RawMessages []schema.Message
|
||||
// TODO: Other Modalities?
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
|
@ -29,57 +34,94 @@ type TokenUsage struct {
|
|||
Completion int
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
threads := c.Threads
|
||||
if *threads == 0 && o.Threads != 0 {
|
||||
threads = &o.Threads
|
||||
type LLMResponse struct {
|
||||
Request *LLMRequest
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
// TODO: Does this belong here or in core/services/openai.go?
|
||||
type LLMResponseBundle struct {
|
||||
Request *schema.OpenAIRequest
|
||||
Response []schema.Choice
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
type LLMBackendService struct {
|
||||
bcl *config.BackendConfigLoader
|
||||
ml *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
ftMutex sync.Mutex
|
||||
cutstrings map[string]*regexp.Regexp
|
||||
}
|
||||
|
||||
func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService {
|
||||
return &LLMBackendService{
|
||||
bcl: bcl,
|
||||
ml: ml,
|
||||
appConfig: appConfig,
|
||||
ftMutex: sync.Mutex{},
|
||||
cutstrings: make(map[string]*regexp.Regexp),
|
||||
}
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
}
|
||||
|
||||
// TODO: Should ctx param be removed and replaced with hardcoded req.Context?
|
||||
func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) (
|
||||
resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) {
|
||||
|
||||
threads := bc.Threads
|
||||
if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 {
|
||||
threads = &llmbs.appConfig.Threads
|
||||
}
|
||||
|
||||
grpcOpts := gRPCModelOpts(bc)
|
||||
|
||||
var inferenceModel grpc.Backend
|
||||
var err error
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
opts := modelOpts(bc, llmbs.appConfig, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(llmbs.appConfig.AssetsDestination),
|
||||
model.WithModel(bc.Model),
|
||||
model.WithContext(llmbs.appConfig.Context),
|
||||
})
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
if bc.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(bc.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
// Check if bc.Model exists, if it doesn't try to load it from the gallery
|
||||
if llmbs.appConfig.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(bc.Model); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
if bc.Backend == "" {
|
||||
log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model)
|
||||
inferenceModel, err = llmbs.ml.GreedyLoader(opts...)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
inferenceModel, err = llmbs.ml.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend")
|
||||
return
|
||||
}
|
||||
|
||||
var protoMessages []*proto.Message
|
||||
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
||||
// unless the prompt has already been tokenized (non-chat endpoints + functions)
|
||||
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
|
||||
protoMessages = make([]*proto.Message, len(messages), len(messages))
|
||||
for i, message := range messages {
|
||||
grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath)
|
||||
grpcPredOpts.Prompt = req.Text
|
||||
grpcPredOpts.Images = req.Images
|
||||
|
||||
if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" {
|
||||
grpcPredOpts.UseTokenizerTemplate = true
|
||||
protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages))
|
||||
for i, message := range req.RawMessages {
|
||||
protoMessages[i] = &proto.Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
|
@ -87,47 +129,32 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||
case string:
|
||||
protoMessages[i].Content = ct
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
|
||||
err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (LLMResponse, error) {
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||
opts.Prompt = s
|
||||
opts.Messages = protoMessages
|
||||
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
|
||||
opts.Images = images
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts)
|
||||
if pErr == nil && promptInfo.Length > 0 {
|
||||
tokenUsage.Prompt = int(promptInfo.Length)
|
||||
}
|
||||
|
||||
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
||||
// Defaults to off as for now it is still experimental
|
||||
if c.FeatureFlag.Enabled("usage") {
|
||||
userTokenCallback := tokenCallback
|
||||
if userTokenCallback == nil {
|
||||
userTokenCallback = func(token string, usage TokenUsage) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse])
|
||||
// TODO this next line is the biggest argument for taking named return values _back_ out!!!
|
||||
var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse]
|
||||
|
||||
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
||||
if pErr == nil && promptInfo.Length > 0 {
|
||||
tokenUsage.Prompt = int(promptInfo.Length)
|
||||
}
|
||||
if enableTokenChannel {
|
||||
rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse])
|
||||
|
||||
tokenCallback = func(token string, usage TokenUsage) bool {
|
||||
tokenUsage.Completion++
|
||||
return userTokenCallback(token, tokenUsage)
|
||||
}
|
||||
}
|
||||
|
||||
if tokenCallback != nil {
|
||||
ss := ""
|
||||
// TODO Needs better name
|
||||
ss := ""
|
||||
|
||||
go func() {
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
||||
err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) {
|
||||
partialRune = append(partialRune, chars...)
|
||||
|
||||
for len(partialRune) > 0 {
|
||||
|
@ -137,48 +164,120 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||
break
|
||||
}
|
||||
|
||||
tokenCallback(string(r), tokenUsage)
|
||||
tokenUsage.Completion++
|
||||
rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
|
||||
Response: string(r),
|
||||
Usage: tokenUsage,
|
||||
}}
|
||||
|
||||
ss += string(r)
|
||||
|
||||
partialRune = partialRune[size:]
|
||||
}
|
||||
})
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
reply, err := inferenceModel.Predict(ctx, opts)
|
||||
close(rawTokenChannel)
|
||||
if err != nil {
|
||||
return LLMResponse{}, err
|
||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
|
||||
} else {
|
||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
}}
|
||||
}
|
||||
return LLMResponse{
|
||||
Response: string(reply.Message),
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
}
|
||||
close(rawResultChannel)
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
|
||||
if err != nil {
|
||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
|
||||
close(rawResultChannel)
|
||||
} else {
|
||||
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
|
||||
Response: string(reply.Message),
|
||||
Usage: tokenUsage,
|
||||
}}
|
||||
close(rawResultChannel)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
resultChannel = rawResultChannel
|
||||
tokenChannel = rawTokenChannel
|
||||
return
|
||||
}
|
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request??
|
||||
func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig,
|
||||
mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) (
|
||||
// Returns:
|
||||
resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) {
|
||||
|
||||
func Finetune(config config.BackendConfig, input, prediction string) string {
|
||||
rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle])
|
||||
resultChannel = rawChannel
|
||||
|
||||
if request.N == 0 { // number of completions to return
|
||||
request.N = 1
|
||||
}
|
||||
images := []string{}
|
||||
for _, m := range request.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
for i := 0; i < request.N; i++ {
|
||||
|
||||
individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{
|
||||
Text: predInput,
|
||||
Images: images,
|
||||
RawMessages: request.Messages,
|
||||
}, bc, enableTokenChannels)
|
||||
if infErr != nil {
|
||||
err = infErr // Avoids complaints about redeclaring err but looks dumb
|
||||
return
|
||||
}
|
||||
completionChannels = append(completionChannels, individualResultChannel)
|
||||
tokenChannels = append(tokenChannels, tokenChannel)
|
||||
}
|
||||
|
||||
go func() {
|
||||
initialBundle := LLMResponseBundle{
|
||||
Request: request,
|
||||
Response: []schema.Choice{},
|
||||
Usage: TokenUsage{},
|
||||
}
|
||||
|
||||
wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] {
|
||||
if iv.Error != nil {
|
||||
ov.Error = iv.Error
|
||||
// TODO: Decide if we should wipe partials or not?
|
||||
return ov
|
||||
}
|
||||
ov.Value.Usage.Prompt += iv.Value.Usage.Prompt
|
||||
ov.Value.Usage.Completion += iv.Value.Usage.Completion
|
||||
|
||||
ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value))
|
||||
return ov
|
||||
}, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true)
|
||||
wg.Wait()
|
||||
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
for _, c := range config.Cutstrings {
|
||||
mu.Lock()
|
||||
reg, ok := cutstrings[c]
|
||||
llmbs.ftMutex.Lock()
|
||||
reg, ok := llmbs.cutstrings[c]
|
||||
if !ok {
|
||||
cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = cutstrings[c]
|
||||
llmbs.cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = llmbs.cutstrings[c]
|
||||
}
|
||||
mu.Unlock()
|
||||
llmbs.ftMutex.Unlock()
|
||||
prediction = reg.ReplaceAllString(prediction, "")
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
|
||||
func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
|
||||
if so.SingleBackend {
|
||||
opts = append(opts, model.WithSingleActiveBackend())
|
||||
}
|
||||
|
@ -19,12 +19,12 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
|
|||
opts = append(opts, model.EnableParallelRequests)
|
||||
}
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
if bc.GRPC.Attempts != 0 {
|
||||
opts = append(opts, model.WithGRPCAttempts(bc.GRPC.Attempts))
|
||||
}
|
||||
|
||||
if c.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||
if bc.GRPC.AttemptsSleepTime != 0 {
|
||||
opts = append(opts, model.WithGRPCAttemptsDelay(bc.GRPC.AttemptsSleepTime))
|
||||
}
|
||||
|
||||
for k, v := range so.ExternalGRPCBackends {
|
||||
|
@ -34,7 +34,7 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
|
|||
return opts
|
||||
}
|
||||
|
||||
func getSeed(c config.BackendConfig) int32 {
|
||||
func getSeed(c *config.BackendConfig) int32 {
|
||||
seed := int32(*c.Seed)
|
||||
if seed == config.RAND_SEED {
|
||||
seed = rand.Int31()
|
||||
|
@ -43,7 +43,7 @@ func getSeed(c config.BackendConfig) int32 {
|
|||
return seed
|
||||
}
|
||||
|
||||
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||
func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
b = c.Batch
|
||||
|
@ -104,47 +104,47 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
|||
}
|
||||
}
|
||||
|
||||
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
|
||||
func gRPCPredictOpts(bc *config.BackendConfig, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath := ""
|
||||
if c.PromptCachePath != "" {
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
if bc.PromptCachePath != "" {
|
||||
p := filepath.Join(modelPath, bc.PromptCachePath)
|
||||
os.MkdirAll(filepath.Dir(p), 0755)
|
||||
promptCachePath = p
|
||||
}
|
||||
|
||||
return &pb.PredictOptions{
|
||||
Temperature: float32(*c.Temperature),
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
TopK: int32(*c.TopK),
|
||||
Tokens: int32(*c.Maxtokens),
|
||||
Threads: int32(*c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
Temperature: float32(*bc.Temperature),
|
||||
TopP: float32(*bc.TopP),
|
||||
NDraft: bc.NDraft,
|
||||
TopK: int32(*bc.TopK),
|
||||
Tokens: int32(*bc.Maxtokens),
|
||||
Threads: int32(*bc.Threads),
|
||||
PromptCacheAll: bc.PromptCacheAll,
|
||||
PromptCacheRO: bc.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: *c.F16,
|
||||
DebugMode: *c.Debug,
|
||||
Grammar: c.Grammar,
|
||||
NegativePromptScale: c.NegativePromptScale,
|
||||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NegativePrompt: c.NegativePrompt,
|
||||
Mirostat: int32(*c.LLMConfig.Mirostat),
|
||||
MirostatETA: float32(*c.LLMConfig.MirostatETA),
|
||||
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
|
||||
Debug: *c.Debug,
|
||||
StopPrompts: c.StopWords,
|
||||
Repeat: int32(c.RepeatPenalty),
|
||||
NKeep: int32(c.Keep),
|
||||
Batch: int32(c.Batch),
|
||||
IgnoreEOS: c.IgnoreEOS,
|
||||
Seed: getSeed(c),
|
||||
FrequencyPenalty: float32(c.FrequencyPenalty),
|
||||
MLock: *c.MMlock,
|
||||
MMap: *c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
TensorSplit: c.TensorSplit,
|
||||
TailFreeSamplingZ: float32(*c.TFZ),
|
||||
TypicalP: float32(*c.TypicalP),
|
||||
F16KV: *bc.F16,
|
||||
DebugMode: *bc.Debug,
|
||||
Grammar: bc.Grammar,
|
||||
NegativePromptScale: bc.NegativePromptScale,
|
||||
RopeFreqBase: bc.RopeFreqBase,
|
||||
RopeFreqScale: bc.RopeFreqScale,
|
||||
NegativePrompt: bc.NegativePrompt,
|
||||
Mirostat: int32(*bc.LLMConfig.Mirostat),
|
||||
MirostatETA: float32(*bc.LLMConfig.MirostatETA),
|
||||
MirostatTAU: float32(*bc.LLMConfig.MirostatTAU),
|
||||
Debug: *bc.Debug,
|
||||
StopPrompts: bc.StopWords,
|
||||
Repeat: int32(bc.RepeatPenalty),
|
||||
NKeep: int32(bc.Keep),
|
||||
Batch: int32(bc.Batch),
|
||||
IgnoreEOS: bc.IgnoreEOS,
|
||||
Seed: getSeed(bc),
|
||||
FrequencyPenalty: float32(bc.FrequencyPenalty),
|
||||
MLock: *bc.MMlock,
|
||||
MMap: *bc.MMap,
|
||||
MainGPU: bc.MainGPU,
|
||||
TensorSplit: bc.TensorSplit,
|
||||
TailFreeSamplingZ: float32(*bc.TFZ),
|
||||
TypicalP: float32(*bc.TypicalP),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,11 +7,48 @@ import (
|
|||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
|
||||
type TranscriptionBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewTranscriptionBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TranscriptionBackendService {
|
||||
return &TranscriptionBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (tbs *TranscriptionBackendService) Transcribe(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.TranscriptionResult] {
|
||||
responseChannel := make(chan concurrency.ErrorOr[*schema.TranscriptionResult])
|
||||
go func(request *schema.OpenAIRequest) {
|
||||
bc, request, err := tbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, tbs.appConfig)
|
||||
if err != nil {
|
||||
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: fmt.Errorf("failed reading parameters from request:%w", err)}
|
||||
close(responseChannel)
|
||||
return
|
||||
}
|
||||
|
||||
tr, err := modelTranscription(request.File, request.Language, tbs.ml, bc, tbs.appConfig)
|
||||
if err != nil {
|
||||
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: err}
|
||||
close(responseChannel)
|
||||
return
|
||||
}
|
||||
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Value: tr}
|
||||
close(responseChannel)
|
||||
}(request)
|
||||
return responseChannel
|
||||
}
|
||||
|
||||
func modelTranscription(audio, language string, ml *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
|
||||
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
|
|
|
@ -7,29 +7,60 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
type TextToSpeechBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService {
|
||||
return &TextToSpeechBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
|
||||
func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) <-chan concurrency.ErrorOr[*string] {
|
||||
responseChannel := make(chan concurrency.ErrorOr[*string])
|
||||
go func(request *schema.TTSRequest) {
|
||||
cfg, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath,
|
||||
config.LoadOptionDebug(ttsbs.appConfig.Debug),
|
||||
config.LoadOptionThreads(ttsbs.appConfig.Threads),
|
||||
config.LoadOptionContextSize(ttsbs.appConfig.ContextSize),
|
||||
config.LoadOptionF16(ttsbs.appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
responseChannel <- concurrency.ErrorOr[*string]{Error: err}
|
||||
close(responseChannel)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Backend != "" {
|
||||
cfg.Backend = request.Backend
|
||||
}
|
||||
|
||||
outFile, _, err := modelTTS(cfg.Backend, request.Input, cfg.Model, request.Voice, ttsbs.ml, ttsbs.appConfig, cfg)
|
||||
if err != nil {
|
||||
responseChannel <- concurrency.ErrorOr[*string]{Error: err}
|
||||
close(responseChannel)
|
||||
return
|
||||
}
|
||||
responseChannel <- concurrency.ErrorOr[*string]{Value: &outFile}
|
||||
close(responseChannel)
|
||||
}(request)
|
||||
return responseChannel
|
||||
}
|
||||
|
||||
func modelTTS(backend, text, modelFile string, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig *config.BackendConfig) (string, *proto.Result, error) {
|
||||
bb := backend
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
|
@ -37,7 +68,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
|
|||
|
||||
grpcOpts := gRPCModelOpts(backendConfig)
|
||||
|
||||
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||
opts := modelOpts(&config.BackendConfig{}, appConfig, []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(appConfig.Context),
|
||||
|
@ -87,3 +118,19 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
|
|||
|
||||
return filePath, res, err
|
||||
}
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue