refactor: backend/service split, channel-based llm flow (#1963)

Refactor: channel based llm flow and services split

---------

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-04-13 03:45:34 -04:00 committed by GitHub
parent 1981154f49
commit eed5706994
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 3064 additions and 2279 deletions

View file

@ -11,17 +11,22 @@ import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
type LLMRequest struct {
Id int // TODO Remove if not used.
Text string
Images []string
RawMessages []schema.Message
// TODO: Other Modalities?
}
type TokenUsage struct {
@ -29,57 +34,94 @@ type TokenUsage struct {
Completion int
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
type LLMResponse struct {
Request *LLMRequest
Response string // should this be []byte?
Usage TokenUsage
}
// TODO: Does this belong here or in core/services/openai.go?
type LLMResponseBundle struct {
Request *schema.OpenAIRequest
Response []schema.Choice
Usage TokenUsage
}
type LLMBackendService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
ftMutex sync.Mutex
cutstrings map[string]*regexp.Regexp
}
func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService {
return &LLMBackendService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
ftMutex: sync.Mutex{},
cutstrings: make(map[string]*regexp.Regexp),
}
grpcOpts := gRPCModelOpts(c)
}
// TODO: Should ctx param be removed and replaced with hardcoded req.Context?
func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) (
resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) {
threads := bc.Threads
if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 {
threads = &llmbs.appConfig.Threads
}
grpcOpts := gRPCModelOpts(bc)
var inferenceModel grpc.Backend
var err error
opts := modelOpts(c, o, []model.Option{
opts := modelOpts(bc, llmbs.appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(llmbs.appConfig.AssetsDestination),
model.WithModel(bc.Model),
model.WithContext(llmbs.appConfig.Context),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
if bc.Backend != "" {
opts = append(opts, model.WithBackendString(bc.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
// Check if bc.Model exists, if it doesn't try to load it from the gallery
if llmbs.appConfig.AutoloadGalleries { // experimental
if _, err := os.Stat(bc.Model); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
return nil, nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
if bc.Backend == "" {
log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model)
inferenceModel, err = llmbs.ml.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
inferenceModel, err = llmbs.ml.BackendLoader(opts...)
}
if err != nil {
return nil, err
log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend")
return
}
var protoMessages []*proto.Message
// if we are using the tokenizer template, we need to convert the messages to proto messages
// unless the prompt has already been tokenized (non-chat endpoints + functions)
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
protoMessages = make([]*proto.Message, len(messages), len(messages))
for i, message := range messages {
grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath)
grpcPredOpts.Prompt = req.Text
grpcPredOpts.Images = req.Images
if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" {
grpcPredOpts.UseTokenizerTemplate = true
protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages))
for i, message := range req.RawMessages {
protoMessages[i] = &proto.Message{
Role: message.Role,
}
@ -87,47 +129,32 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
case string:
protoMessages[i].Content = ct
default:
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
return
}
}
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
opts.Images = images
tokenUsage := TokenUsage{}
tokenUsage := TokenUsage{}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse])
// TODO this next line is the biggest argument for taking named return values _back_ out!!!
var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse]
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
if enableTokenChannel {
rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse])
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
// TODO Needs better name
ss := ""
go func() {
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
@ -137,48 +164,120 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
break
}
tokenCallback(string(r), tokenUsage)
tokenUsage.Completion++
rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: string(r),
Usage: tokenUsage,
}}
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
close(rawTokenChannel)
if err != nil {
return LLMResponse{}, err
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
} else {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: ss,
Usage: tokenUsage,
}}
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
close(rawResultChannel)
}()
} else {
go func() {
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
if err != nil {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
close(rawResultChannel)
} else {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}}
close(rawResultChannel)
}
}()
}
return fn, nil
resultChannel = rawResultChannel
tokenChannel = rawTokenChannel
return
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request??
func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig,
mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) (
// Returns:
resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) {
func Finetune(config config.BackendConfig, input, prediction string) string {
rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle])
resultChannel = rawChannel
if request.N == 0 { // number of completions to return
request.N = 1
}
images := []string{}
for _, m := range request.Messages {
images = append(images, m.StringImages...)
}
for i := 0; i < request.N; i++ {
individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{
Text: predInput,
Images: images,
RawMessages: request.Messages,
}, bc, enableTokenChannels)
if infErr != nil {
err = infErr // Avoids complaints about redeclaring err but looks dumb
return
}
completionChannels = append(completionChannels, individualResultChannel)
tokenChannels = append(tokenChannels, tokenChannel)
}
go func() {
initialBundle := LLMResponseBundle{
Request: request,
Response: []schema.Choice{},
Usage: TokenUsage{},
}
wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] {
if iv.Error != nil {
ov.Error = iv.Error
// TODO: Decide if we should wipe partials or not?
return ov
}
ov.Value.Usage.Prompt += iv.Value.Usage.Prompt
ov.Value.Usage.Completion += iv.Value.Usage.Completion
ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value))
return ov
}, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true)
wg.Wait()
}()
return
}
func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
llmbs.ftMutex.Lock()
reg, ok := llmbs.cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
llmbs.cutstrings[c] = regexp.MustCompile(c)
reg = llmbs.cutstrings[c]
}
mu.Unlock()
llmbs.ftMutex.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}