refactor: backend/service split, channel-based llm flow (#1963)

Refactor: channel based llm flow and services split

---------

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-04-13 03:45:34 -04:00 committed by GitHub
parent 1981154f49
commit eed5706994
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 3064 additions and 2279 deletions

View file

@ -15,22 +15,22 @@ import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitor struct {
type BackendMonitorService struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
return BackendMonitor{
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
return &BackendMonitorService{
configLoader: configLoader,
modelLoader: modelLoader,
options: appConfig,
}
}
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetBackendConfig(modelName)
func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bms.configLoader.GetBackendConfig(modelName)
var backendId string
if exists {
backendId = config.Model
@ -46,8 +46,8 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string
return backendId, nil
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetBackendConfig(model)
func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bms.configLoader.GetBackendConfig(model)
var backend string
if exists {
backend = config.Model
@ -60,7 +60,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.modelLoader.GetGRPCPID(backend)
pid, err := bms.modelLoader.GetGRPCPID(backend)
if err != nil {
log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid")
@ -101,12 +101,12 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
}, nil
}
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bms.getModelLoaderIDFromModelName(modelName)
if err != nil {
return nil, err
}
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}
@ -114,7 +114,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
val, slbErr := bms.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
@ -131,10 +131,10 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
return status, nil
}
func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
func (bms BackendMonitorService) ShutdownModel(modelName string) error {
backendId, err := bms.getModelLoaderIDFromModelName(modelName)
if err != nil {
return err
}
return bm.modelLoader.ShutdownModel(backendId)
return bms.modelLoader.ShutdownModel(backendId)
}

View file

@ -3,14 +3,18 @@ package services
import (
"context"
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/embedded"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/startup"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
@ -29,18 +33,6 @@ func NewGalleryService(modelPath string) *GalleryService {
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
g.Lock()
defer g.Unlock()
@ -92,10 +84,10 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else if op.ConfigURL != "" {
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
err = cl.Preload(g.modelPath)
} else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
err = prepareModel(g.modelPath, op.Req, progressCallback)
}
if err != nil {
@ -127,13 +119,12 @@ type galleryModel struct {
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
@ -158,7 +149,7 @@ func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, g
return err
}
return processRequests(modelPath, s, cl, galleries, requests)
return processRequests(modelPath, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
@ -168,5 +159,90 @@ func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader,
return err
}
return processRequests(modelPath, s, cl, galleries, requests)
return processRequests(modelPath, galleries, requests)
}
// PreloadModelsConfigurations will preload models from the given list of URLs
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
for _, url := range models {
// As a best effort, try to resolve the model from the remote library
// if it's not resolved we try with the other method below
if modelLibraryURL != "" {
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
if err == nil {
if lib[url] != "" {
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
url = lib[url]
}
}
}
url = embedded.ModelShortURL(url)
switch {
case embedded.ExistsInModelsLibrary(url):
modelYAML, err := embedded.ResolveContent(url)
// If we resolve something, just save it to disk and continue
if err != nil {
log.Error().Err(err).Msg("error resolving model content")
continue
}
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
md5Name := utils.MD5(url)
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
}
case downloader.LooksLikeURL(url):
log.Debug().Msgf("[startup] resolved model to download: %s", url)
// md5 of model name
md5Name := utils.MD5(url)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if err != nil {
log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
}
}
default:
if _, err := os.Stat(url); err == nil {
log.Debug().Msgf("[startup] resolved local model: %s", url)
// copy to modelPath
md5Name := utils.MD5(url)
modelYAML, err := os.ReadFile(url)
if err != nil {
log.Error().Err(err).Str("filepath", url).Msg("error reading model definition")
continue
}
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
}
} else {
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
}
}
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}

View file

@ -0,0 +1,72 @@
package services
import (
"regexp"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
)
type ListModelsService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
}
func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService {
return &ListModelsService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
}
}
func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := lms.ml.ListModels()
if err != nil {
return nil, err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// Start with the known configurations
for _, c := range lms.bcl.GetAllBackendConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return dataModels, nil
}

View file

@ -0,0 +1,83 @@
package services_test
import (
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/go-skynet/LocalAI/core/services"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Preload test", func() {
Context("Preloading from strings", func() {
It("loads from remote url", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")
resultFile := filepath.Join(tmpdir, fileName)
content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: phi-2"))
})
It("loads from embedded full-urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations("", tmpdir, url)
resultFile := filepath.Join(tmpdir, fileName)
content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: phi-2"))
})
It("loads from embedded short-urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "phi-2"
PreloadModelsConfigurations("", tmpdir, url)
entry, err := os.ReadDir(tmpdir)
Expect(err).ToNot(HaveOccurred())
Expect(entry).To(HaveLen(1))
resultFile := entry[0].Name()
content, err := os.ReadFile(filepath.Join(tmpdir, resultFile))
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: phi-2"))
})
It("loads from embedded models", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "mistral-openorca"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
PreloadModelsConfigurations("", tmpdir, url)
resultFile := filepath.Join(tmpdir, fileName)
content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
})
})
})

805
core/services/openai.go Normal file
View file

@ -0,0 +1,805 @@
package services
import (
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/google/uuid"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
)
type endpointGenerationConfigurationFn func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration
type endpointConfiguration struct {
SchemaObject string
TemplatePath string
TemplateData model.PromptTemplateData
ResultMappingFn func(resp *backend.LLMResponse, index int) schema.Choice
CompletionMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse]
TokenMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse]
}
// TODO: This is used for completion and edit. I am pretty sure I forgot parts, but fix it later.
func simpleMapper(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
if resp.Error != nil || resp.Value == nil {
return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error}
}
return concurrency.ErrorOr[*schema.OpenAIResponse]{
Value: &schema.OpenAIResponse{
Choices: []schema.Choice{
{
Text: resp.Value.Response,
},
},
Usage: schema.OpenAIUsage{
PromptTokens: resp.Value.Usage.Prompt,
CompletionTokens: resp.Value.Usage.Completion,
TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion,
},
},
}
}
// TODO: Consider alternative names for this.
// The purpose of this struct is to hold a reference to the OpenAI request context information
// This keeps things simple within core/services/openai.go and allows consumers to "see" this information if they need it
type OpenAIRequestTraceID struct {
ID string
Created int
}
// This type split out from core/backend/llm.go - I'm still not _totally_ sure about this, but it seems to make sense to keep the generic LLM code from the OpenAI specific higher level functionality
type OpenAIService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
llmbs *backend.LLMBackendService
}
func NewOpenAIService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, llmbs *backend.LLMBackendService) *OpenAIService {
return &OpenAIService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
llmbs: llmbs,
}
}
// Keeping in place as a reminder to POTENTIALLY ADD MORE VALIDATION HERE???
func (oais *OpenAIService) getConfig(request *schema.OpenAIRequest) (*config.BackendConfig, *schema.OpenAIRequest, error) {
return oais.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, oais.appConfig)
}
// TODO: It would be a lot less messy to make a return struct that had references to each of these channels
// INTENTIONALLY not doing that quite yet - I believe we need to let the references to unused channels die for the GC to automatically collect -- can we manually free()?
// finalResultsChannel is the primary async return path: one result for the entire request.
// promptResultsChannels is DUBIOUS. It's expected to be raw fan-out used within the function itself, but I am exposing for testing? One bundle of LLMResponseBundle per PromptString? Gets all N completions for a single prompt.
// completionsChannel is a channel that emits one *LLMResponse per generated completion, be that different prompts or N. Seems the most useful other than "entire request" Request is available to attempt tracing???
// tokensChannel is a channel that emits one *LLMResponse per generated token. Let's see what happens!
func (oais *OpenAIService) Completion(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration {
return endpointConfiguration{
SchemaObject: "text_completion",
TemplatePath: bc.TemplateConfig.Completion,
TemplateData: model.PromptTemplateData{
SystemPrompt: bc.SystemPrompt,
},
ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice {
return schema.Choice{
Index: promptIndex,
FinishReason: "stop",
Text: resp.Response,
}
},
CompletionMappingFn: simpleMapper,
TokenMappingFn: simpleMapper,
}
}, notifyOnPromptResult, notifyOnToken, nil)
}
func (oais *OpenAIService) Edit(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration {
return endpointConfiguration{
SchemaObject: "edit",
TemplatePath: bc.TemplateConfig.Edit,
TemplateData: model.PromptTemplateData{
SystemPrompt: bc.SystemPrompt,
Instruction: request.Instruction,
},
ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice {
return schema.Choice{
Index: promptIndex,
FinishReason: "stop",
Text: resp.Response,
}
},
CompletionMappingFn: simpleMapper,
TokenMappingFn: simpleMapper,
}
}, notifyOnPromptResult, notifyOnToken, nil)
}
func (oais *OpenAIService) Chat(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse],
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
return oais.GenerateFromMultipleMessagesChatRequest(request, notifyOnPromptResult, notifyOnToken, nil)
}
func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest, endpointConfigFn endpointGenerationConfigurationFn, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) (
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
if initialTraceID == nil {
traceID = &OpenAIRequestTraceID{
ID: uuid.New().String(),
Created: int(time.Now().Unix()),
}
} else {
traceID = initialTraceID
}
bc, request, err := oais.getConfig(request)
if err != nil {
log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err)
return
}
if request.ResponseFormat.Type == "json_object" {
request.Grammar = grammar.JSONBNF
}
bc.Grammar = request.Grammar
if request.Stream && len(bc.PromptStrings) > 1 {
log.Warn().Msg("potentially cannot handle more than 1 `PromptStrings` when Streaming?")
}
rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
finalResultChannel = rawFinalResultChannel
promptResultsChannels = []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle]{}
var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
if notifyOnPromptResult {
rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
}
if notifyOnToken {
rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
}
promptResultsChannelLock := sync.Mutex{}
endpointConfig := endpointConfigFn(bc, request)
if len(endpointConfig.TemplatePath) == 0 {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) {
endpointConfig.TemplatePath = bc.Model
} else {
log.Warn().Msgf("failed to find any template for %+v", request)
}
}
setupWG := sync.WaitGroup{}
var prompts []string
if lPS := len(bc.PromptStrings); lPS > 0 {
setupWG.Add(lPS)
prompts = bc.PromptStrings
} else {
setupWG.Add(len(bc.InputStrings))
prompts = bc.InputStrings
}
var setupError error = nil
for pI, p := range prompts {
go func(promptIndex int, prompt string) {
if endpointConfig.TemplatePath != "" {
promptTemplateData := model.PromptTemplateData{
Input: prompt,
}
err := mergo.Merge(promptTemplateData, endpointConfig.TemplateData, mergo.WithOverride)
if err == nil {
templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, endpointConfig.TemplatePath, promptTemplateData)
if err == nil {
prompt = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", prompt)
}
}
}
log.Debug().Msgf("[OAIS GenerateTextFromRequest] Prompt: %q", prompt)
promptResultsChannel, completionChannels, tokenChannels, err := oais.llmbs.GenerateText(prompt, request, bc,
func(r *backend.LLMResponse) schema.Choice {
return endpointConfig.ResultMappingFn(r, promptIndex)
}, notifyOnPromptResult, notifyOnToken)
if err != nil {
log.Error().Msgf("Unable to generate text prompt: %q\nerr: %q", prompt, err)
promptResultsChannelLock.Lock()
setupError = errors.Join(setupError, err)
promptResultsChannelLock.Unlock()
setupWG.Done()
return
}
if notifyOnPromptResult {
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(completionChannels, endpointConfig.CompletionMappingFn), rawCompletionsChannel, true)
}
if notifyOnToken {
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, endpointConfig.TokenMappingFn), rawTokenChannel, true)
}
promptResultsChannelLock.Lock()
promptResultsChannels = append(promptResultsChannels, promptResultsChannel)
promptResultsChannelLock.Unlock()
setupWG.Done()
}(pI, p)
}
setupWG.Wait()
// If any of the setup goroutines experienced an error, quit early here.
if setupError != nil {
go func() {
log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError)
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
close(rawFinalResultChannel)
}()
return
}
initialResponse := &schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model,
Object: endpointConfig.SchemaObject,
Usage: schema.OpenAIUsage{},
}
// utils.SliceOfChannelsRawMerger[[]schema.Choice](promptResultsChannels, rawFinalResultChannel, func(results []schema.Choice) (*schema.OpenAIResponse, error) {
concurrency.SliceOfChannelsReducer(
promptResultsChannels, rawFinalResultChannel,
func(iv concurrency.ErrorOr[*backend.LLMResponseBundle], result concurrency.ErrorOr[*schema.OpenAIResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
if iv.Error != nil {
result.Error = iv.Error
return result
}
result.Value.Usage.PromptTokens += iv.Value.Usage.Prompt
result.Value.Usage.CompletionTokens += iv.Value.Usage.Completion
result.Value.Usage.TotalTokens = result.Value.Usage.PromptTokens + result.Value.Usage.CompletionTokens
result.Value.Choices = append(result.Value.Choices, iv.Value.Response...)
return result
}, concurrency.ErrorOr[*schema.OpenAIResponse]{Value: initialResponse}, true)
completionsChannel = rawCompletionsChannel
tokenChannel = rawTokenChannel
return
}
// TODO: For porting sanity, this is distinct from GenerateTextFromRequest and is _currently_ specific to Chat purposes
// this is not a final decision -- just a reality of moving a lot of parts at once
// / This has _become_ Chat which wasn't the goal... More cleanup in the future once it's stable?
func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) (
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse],
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
if initialTraceID == nil {
traceID = &OpenAIRequestTraceID{
ID: uuid.New().String(),
Created: int(time.Now().Unix()),
}
} else {
traceID = initialTraceID
}
bc, request, err := oais.getConfig(request)
if err != nil {
return
}
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if bc.FunctionsConfig.NoActionFunctionName != "" {
noActionName = bc.FunctionsConfig.NoActionFunctionName
}
if bc.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = bc.FunctionsConfig.NoActionDescriptionName
}
if request.ResponseFormat.Type == "json_object" {
request.Grammar = grammar.JSONBNF
}
bc.Grammar = request.Grammar
processFunctions := false
funcs := grammar.Functions{}
// process functions if we have any defined or if we have a function call string
if len(request.Functions) > 0 && bc.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, request.Functions...)
if !bc.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if bc.FunctionToCall() != "" {
funcs = funcs.Select(bc.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
bc.Grammar = jsStruct.Grammar("", bc.FunctionsConfig.ParallelCalls)
} else if request.JSONFunctionGrammarObject != nil {
bc.Grammar = request.JSONFunctionGrammarObject.Grammar("", bc.FunctionsConfig.ParallelCalls)
}
if request.Stream && processFunctions {
log.Warn().Msg("Streaming + Functions is highly experimental in this version")
}
var predInput string
if !bc.TemplateConfig.UseTokenizerTemplate || processFunctions {
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range request.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := bc.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := bc.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if bc.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: bc.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(request.Messages) - 1),
Function: bc.Grammar != "" && (messageIndex == (len(request.Messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := oais.ml.EvaluateTemplateForChatMessage(bc.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, bc.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", bc.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) {
templateFile = bc.Model
}
if bc.TemplateConfig.Chat != "" && !processFunctions {
templateFile = bc.TemplateConfig.Chat
}
if bc.TemplateConfig.Functions != "" && processFunctions {
templateFile = bc.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: bc.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", bc.Grammar)
}
rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
if notifyOnPromptResult {
rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
}
if notifyOnToken {
rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
}
rawResultChannel, individualCompletionChannels, tokenChannels, err := oais.llmbs.GenerateText(predInput, request, bc, func(resp *backend.LLMResponse) schema.Choice {
return schema.Choice{
Index: 0, // ???
FinishReason: "stop",
Message: &schema.Message{
Role: "assistant",
Content: resp.Response,
},
}
}, notifyOnPromptResult, notifyOnToken)
chatSimpleMappingFn := func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
if resp.Error != nil || resp.Value == nil {
return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error}
}
return concurrency.ErrorOr[*schema.OpenAIResponse]{
Value: &schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Delta: &schema.Message{
Role: "assistant",
Content: resp.Value.Response,
},
Index: 0,
},
},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: resp.Value.Usage.Prompt,
CompletionTokens: resp.Value.Usage.Completion,
TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion,
},
},
}
}
if notifyOnPromptResult {
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(individualCompletionChannels, chatSimpleMappingFn), rawCompletionsChannel, true)
}
if notifyOnToken {
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, chatSimpleMappingFn), rawTokenChannel, true)
}
go func() {
rawResult := <-rawResultChannel
if rawResult.Error != nil {
log.Warn().Msgf("OpenAIService::processTools GenerateText error [DEBUG THIS?] %q", rawResult.Error)
return
}
llmResponseChoices := rawResult.Value.Response
if processFunctions && len(llmResponseChoices) > 1 {
log.Warn().Msgf("chat functions response with %d choices in response, debug this?", len(llmResponseChoices))
log.Debug().Msgf("%+v", llmResponseChoices)
}
for _, result := range rawResult.Value.Response {
// If no functions, just return the raw result.
if !processFunctions {
resp := schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{result},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: rawResult.Value.Usage.Prompt,
CompletionTokens: rawResult.Value.Usage.Completion,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
},
}
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp}
continue
}
// At this point, things are function specific!
// Oh no this can't be the right way to do this... but it works. Save us, mudler!
fString := fmt.Sprintf("%s", result.Message.Content)
results := parseFunctionCall(fString, bc.FunctionsConfig.ParallelCalls)
noActionToRun := (len(results) > 0 && results[0].name == noActionName)
if noActionToRun {
log.Debug().Msg("-- noActionToRun branch --")
initialMessage := schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: ""}}},
Object: "stop",
}
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage}
result, err := oais.handleQuestion(bc, request, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
resp := schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: rawResult.Value.Usage.Prompt,
CompletionTokens: rawResult.Value.Usage.Completion,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
},
}
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp}
} else {
log.Debug().Msgf("[GenerateFromMultipleMessagesChatRequest] fnResultsBranch: %+v", results)
for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: traceID.ID,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage}
}
}
}
close(rawFinalResultChannel)
}()
finalResultChannel = rawFinalResultChannel
completionsChannel = rawCompletionsChannel
tokenChannel = rawTokenChannel
return
}
func (oais *OpenAIService) handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, args, prompt string) (string, error) {
log.Debug().Msgf("[handleQuestion called] nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = oais.llmbs.Finetune(*config, prompt, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
return message, nil
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU/GPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
resultChannel, _, err := oais.llmbs.Inference(input.Context, &backend.LLMRequest{
Text: prompt,
Images: images,
RawMessages: input.Messages, // Experimental
}, config, false)
if err != nil {
log.Error().Msgf("inference setup error: %s", err.Error())
return "", err
}
raw := <-resultChannel
if raw.Error != nil {
log.Error().Msgf("inference error: %q", raw.Error.Error())
return "", err
}
if raw.Value == nil {
log.Warn().Msgf("nil inference response")
return "", nil
}
return oais.llmbs.Finetune(*config, prompt, raw.Value.Response), nil
}
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
// s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(llmresult), &ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
log.Debug().Msg("ss[function] is not OK!")
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
log.Debug().Msg("ss[arguments] is not OK!")
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
log.Debug().Msgf("unexpected func_name: %+v", func_name)
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
return results
}