mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 06:25:00 +00:00
refactor: backend/service split, channel-based llm flow (#1963)
Refactor: channel based llm flow and services split --------- Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
1981154f49
commit
eed5706994
52 changed files with 3064 additions and 2279 deletions
|
@ -15,22 +15,22 @@ import (
|
|||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
type BackendMonitor struct {
|
||||
type BackendMonitorService struct {
|
||||
configLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||
}
|
||||
|
||||
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
|
||||
return BackendMonitor{
|
||||
func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
|
||||
return &BackendMonitorService{
|
||||
configLoader: configLoader,
|
||||
modelLoader: modelLoader,
|
||||
options: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||
config, exists := bm.configLoader.GetBackendConfig(modelName)
|
||||
func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||
config, exists := bms.configLoader.GetBackendConfig(modelName)
|
||||
var backendId string
|
||||
if exists {
|
||||
backendId = config.Model
|
||||
|
@ -46,8 +46,8 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string
|
|||
return backendId, nil
|
||||
}
|
||||
|
||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||
config, exists := bm.configLoader.GetBackendConfig(model)
|
||||
func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||
config, exists := bms.configLoader.GetBackendConfig(model)
|
||||
var backend string
|
||||
if exists {
|
||||
backend = config.Model
|
||||
|
@ -60,7 +60,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
|
|||
backend = fmt.Sprintf("%s.bin", backend)
|
||||
}
|
||||
|
||||
pid, err := bm.modelLoader.GetGRPCPID(backend)
|
||||
pid, err := bms.modelLoader.GetGRPCPID(backend)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid")
|
||||
|
@ -101,12 +101,12 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||
func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||
backendId, err := bms.getModelLoaderIDFromModelName(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
|
||||
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
|
||||
if modelAddr == "" {
|
||||
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
|
|||
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
|
||||
if rpcErr != nil {
|
||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||
val, slbErr := bms.SampleLocalBackendProcess(backendId)
|
||||
if slbErr != nil {
|
||||
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||
}
|
||||
|
@ -131,10 +131,10 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
|
|||
return status, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) ShutdownModel(modelName string) error {
|
||||
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||
func (bms BackendMonitorService) ShutdownModel(modelName string) error {
|
||||
backendId, err := bms.getModelLoaderIDFromModelName(modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bm.modelLoader.ShutdownModel(backendId)
|
||||
return bms.modelLoader.ShutdownModel(backendId)
|
||||
}
|
||||
|
|
|
@ -3,14 +3,18 @@ package services
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/embedded"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/startup"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
|
@ -29,18 +33,6 @@ func NewGalleryService(modelPath string) *GalleryService {
|
|||
}
|
||||
}
|
||||
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||
}
|
||||
|
||||
func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
@ -92,10 +84,10 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
|
|||
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||
}
|
||||
} else if op.ConfigURL != "" {
|
||||
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
|
||||
PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
|
||||
err = cl.Preload(g.modelPath)
|
||||
} else {
|
||||
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
|
||||
err = prepareModel(g.modelPath, op.Req, progressCallback)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -127,13 +119,12 @@ type galleryModel struct {
|
|||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
|
||||
err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
if strings.Contains(r.ID, "@") {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
|
@ -158,7 +149,7 @@ func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, g
|
|||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cl, galleries, requests)
|
||||
return processRequests(modelPath, galleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
|
||||
|
@ -168,5 +159,90 @@ func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader,
|
|||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cl, galleries, requests)
|
||||
return processRequests(modelPath, galleries, requests)
|
||||
}
|
||||
|
||||
// PreloadModelsConfigurations will preload models from the given list of URLs
|
||||
// It will download the model if it is not already present in the model path
|
||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
|
||||
for _, url := range models {
|
||||
|
||||
// As a best effort, try to resolve the model from the remote library
|
||||
// if it's not resolved we try with the other method below
|
||||
if modelLibraryURL != "" {
|
||||
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
|
||||
if err == nil {
|
||||
if lib[url] != "" {
|
||||
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
|
||||
url = lib[url]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
url = embedded.ModelShortURL(url)
|
||||
switch {
|
||||
case embedded.ExistsInModelsLibrary(url):
|
||||
modelYAML, err := embedded.ResolveContent(url)
|
||||
// If we resolve something, just save it to disk and continue
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error resolving model content")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug().Msgf("[startup] resolved embedded model: %s", url)
|
||||
md5Name := utils.MD5(url)
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
|
||||
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition")
|
||||
}
|
||||
case downloader.LooksLikeURL(url):
|
||||
log.Debug().Msgf("[startup] resolved model to download: %s", url)
|
||||
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if _, err := os.Stat(url); err == nil {
|
||||
log.Debug().Msgf("[startup] resolved local model: %s", url)
|
||||
// copy to modelPath
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
modelYAML, err := os.ReadFile(url)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("filepath", url).Msg("error reading model definition")
|
||||
continue
|
||||
}
|
||||
|
||||
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
|
||||
if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil {
|
||||
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s")
|
||||
}
|
||||
} else {
|
||||
log.Warn().Msgf("[startup] failed resolving model '%s'", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||
}
|
||||
|
|
72
core/services/list_models.go
Normal file
72
core/services/list_models.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
type ListModelsService struct {
|
||||
bcl *config.BackendConfigLoader
|
||||
ml *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService {
|
||||
return &ListModelsService{
|
||||
bcl: bcl,
|
||||
ml: ml,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
|
||||
|
||||
models, err := lms.ml.ListModels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
|
||||
var filterFn func(name string) bool
|
||||
|
||||
// If filter is not specified, do not filter the list by model name
|
||||
if filter == "" {
|
||||
filterFn = func(_ string) bool { return true }
|
||||
} else {
|
||||
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filterFn = func(name string) bool {
|
||||
return rxp.MatchString(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Start with the known configurations
|
||||
for _, c := range lms.bcl.GetAllBackendConfigs() {
|
||||
if excludeConfigured {
|
||||
mm[c.Model] = nil
|
||||
}
|
||||
|
||||
if filterFn(c.Name) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
// Then iterate through the loose files:
|
||||
for _, m := range models {
|
||||
// And only adds them if they shouldn't be skipped.
|
||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
return dataModels, nil
|
||||
}
|
83
core/services/model_preload_test.go
Normal file
83
core/services/model_preload_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package services_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/core/services"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Preload test", func() {
|
||||
|
||||
Context("Preloading from strings", func() {
|
||||
It("loads from remote url", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
|
||||
|
||||
PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
|
||||
It("loads from embedded full-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded short-urls", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "phi-2"
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
entry, err := os.ReadDir(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).To(HaveLen(1))
|
||||
resultFile := entry[0].Name()
|
||||
|
||||
content, err := os.ReadFile(filepath.Join(tmpdir, resultFile))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: phi-2"))
|
||||
})
|
||||
It("loads from embedded models", func() {
|
||||
tmpdir, err := os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
url := "mistral-openorca"
|
||||
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
|
||||
|
||||
PreloadModelsConfigurations("", tmpdir, url)
|
||||
|
||||
resultFile := filepath.Join(tmpdir, fileName)
|
||||
|
||||
content, err := os.ReadFile(resultFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
|
||||
})
|
||||
})
|
||||
})
|
805
core/services/openai.go
Normal file
805
core/services/openai.go
Normal file
|
@ -0,0 +1,805 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type endpointGenerationConfigurationFn func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration
|
||||
|
||||
type endpointConfiguration struct {
|
||||
SchemaObject string
|
||||
TemplatePath string
|
||||
TemplateData model.PromptTemplateData
|
||||
ResultMappingFn func(resp *backend.LLMResponse, index int) schema.Choice
|
||||
CompletionMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
TokenMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
}
|
||||
|
||||
// TODO: This is used for completion and edit. I am pretty sure I forgot parts, but fix it later.
|
||||
func simpleMapper(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
if resp.Error != nil || resp.Value == nil {
|
||||
return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error}
|
||||
}
|
||||
return concurrency.ErrorOr[*schema.OpenAIResponse]{
|
||||
Value: &schema.OpenAIResponse{
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Text: resp.Value.Response,
|
||||
},
|
||||
},
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: resp.Value.Usage.Prompt,
|
||||
CompletionTokens: resp.Value.Usage.Completion,
|
||||
TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Consider alternative names for this.
|
||||
// The purpose of this struct is to hold a reference to the OpenAI request context information
|
||||
// This keeps things simple within core/services/openai.go and allows consumers to "see" this information if they need it
|
||||
type OpenAIRequestTraceID struct {
|
||||
ID string
|
||||
Created int
|
||||
}
|
||||
|
||||
// This type split out from core/backend/llm.go - I'm still not _totally_ sure about this, but it seems to make sense to keep the generic LLM code from the OpenAI specific higher level functionality
|
||||
type OpenAIService struct {
|
||||
bcl *config.BackendConfigLoader
|
||||
ml *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
llmbs *backend.LLMBackendService
|
||||
}
|
||||
|
||||
func NewOpenAIService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, llmbs *backend.LLMBackendService) *OpenAIService {
|
||||
return &OpenAIService{
|
||||
bcl: bcl,
|
||||
ml: ml,
|
||||
appConfig: appConfig,
|
||||
llmbs: llmbs,
|
||||
}
|
||||
}
|
||||
|
||||
// Keeping in place as a reminder to POTENTIALLY ADD MORE VALIDATION HERE???
|
||||
func (oais *OpenAIService) getConfig(request *schema.OpenAIRequest) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||
return oais.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, oais.appConfig)
|
||||
}
|
||||
|
||||
// TODO: It would be a lot less messy to make a return struct that had references to each of these channels
|
||||
// INTENTIONALLY not doing that quite yet - I believe we need to let the references to unused channels die for the GC to automatically collect -- can we manually free()?
|
||||
// finalResultsChannel is the primary async return path: one result for the entire request.
|
||||
// promptResultsChannels is DUBIOUS. It's expected to be raw fan-out used within the function itself, but I am exposing for testing? One bundle of LLMResponseBundle per PromptString? Gets all N completions for a single prompt.
|
||||
// completionsChannel is a channel that emits one *LLMResponse per generated completion, be that different prompts or N. Seems the most useful other than "entire request" Request is available to attempt tracing???
|
||||
// tokensChannel is a channel that emits one *LLMResponse per generated token. Let's see what happens!
|
||||
func (oais *OpenAIService) Completion(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
|
||||
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
|
||||
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
|
||||
|
||||
return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration {
|
||||
return endpointConfiguration{
|
||||
SchemaObject: "text_completion",
|
||||
TemplatePath: bc.TemplateConfig.Completion,
|
||||
TemplateData: model.PromptTemplateData{
|
||||
SystemPrompt: bc.SystemPrompt,
|
||||
},
|
||||
ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice {
|
||||
return schema.Choice{
|
||||
Index: promptIndex,
|
||||
FinishReason: "stop",
|
||||
Text: resp.Response,
|
||||
}
|
||||
},
|
||||
CompletionMappingFn: simpleMapper,
|
||||
TokenMappingFn: simpleMapper,
|
||||
}
|
||||
}, notifyOnPromptResult, notifyOnToken, nil)
|
||||
}
|
||||
|
||||
func (oais *OpenAIService) Edit(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
|
||||
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
|
||||
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
|
||||
|
||||
return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration {
|
||||
|
||||
return endpointConfiguration{
|
||||
SchemaObject: "edit",
|
||||
TemplatePath: bc.TemplateConfig.Edit,
|
||||
TemplateData: model.PromptTemplateData{
|
||||
SystemPrompt: bc.SystemPrompt,
|
||||
Instruction: request.Instruction,
|
||||
},
|
||||
ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice {
|
||||
return schema.Choice{
|
||||
Index: promptIndex,
|
||||
FinishReason: "stop",
|
||||
Text: resp.Response,
|
||||
}
|
||||
},
|
||||
CompletionMappingFn: simpleMapper,
|
||||
TokenMappingFn: simpleMapper,
|
||||
}
|
||||
}, notifyOnPromptResult, notifyOnToken, nil)
|
||||
}
|
||||
|
||||
func (oais *OpenAIService) Chat(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) (
|
||||
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse],
|
||||
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
|
||||
|
||||
return oais.GenerateFromMultipleMessagesChatRequest(request, notifyOnPromptResult, notifyOnToken, nil)
|
||||
}
|
||||
|
||||
func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest, endpointConfigFn endpointGenerationConfigurationFn, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) (
|
||||
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle],
|
||||
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
|
||||
|
||||
if initialTraceID == nil {
|
||||
traceID = &OpenAIRequestTraceID{
|
||||
ID: uuid.New().String(),
|
||||
Created: int(time.Now().Unix()),
|
||||
}
|
||||
} else {
|
||||
traceID = initialTraceID
|
||||
}
|
||||
|
||||
bc, request, err := oais.getConfig(request)
|
||||
if err != nil {
|
||||
log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err)
|
||||
return
|
||||
}
|
||||
|
||||
if request.ResponseFormat.Type == "json_object" {
|
||||
request.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
bc.Grammar = request.Grammar
|
||||
|
||||
if request.Stream && len(bc.PromptStrings) > 1 {
|
||||
log.Warn().Msg("potentially cannot handle more than 1 `PromptStrings` when Streaming?")
|
||||
}
|
||||
|
||||
rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
finalResultChannel = rawFinalResultChannel
|
||||
promptResultsChannels = []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle]{}
|
||||
var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
if notifyOnPromptResult {
|
||||
rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
}
|
||||
if notifyOnToken {
|
||||
rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
}
|
||||
|
||||
promptResultsChannelLock := sync.Mutex{}
|
||||
|
||||
endpointConfig := endpointConfigFn(bc, request)
|
||||
|
||||
if len(endpointConfig.TemplatePath) == 0 {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) {
|
||||
endpointConfig.TemplatePath = bc.Model
|
||||
} else {
|
||||
log.Warn().Msgf("failed to find any template for %+v", request)
|
||||
}
|
||||
}
|
||||
|
||||
setupWG := sync.WaitGroup{}
|
||||
var prompts []string
|
||||
if lPS := len(bc.PromptStrings); lPS > 0 {
|
||||
setupWG.Add(lPS)
|
||||
prompts = bc.PromptStrings
|
||||
} else {
|
||||
setupWG.Add(len(bc.InputStrings))
|
||||
prompts = bc.InputStrings
|
||||
}
|
||||
|
||||
var setupError error = nil
|
||||
|
||||
for pI, p := range prompts {
|
||||
|
||||
go func(promptIndex int, prompt string) {
|
||||
if endpointConfig.TemplatePath != "" {
|
||||
promptTemplateData := model.PromptTemplateData{
|
||||
Input: prompt,
|
||||
}
|
||||
err := mergo.Merge(promptTemplateData, endpointConfig.TemplateData, mergo.WithOverride)
|
||||
if err == nil {
|
||||
templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, endpointConfig.TemplatePath, promptTemplateData)
|
||||
if err == nil {
|
||||
prompt = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("[OAIS GenerateTextFromRequest] Prompt: %q", prompt)
|
||||
promptResultsChannel, completionChannels, tokenChannels, err := oais.llmbs.GenerateText(prompt, request, bc,
|
||||
func(r *backend.LLMResponse) schema.Choice {
|
||||
return endpointConfig.ResultMappingFn(r, promptIndex)
|
||||
}, notifyOnPromptResult, notifyOnToken)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Unable to generate text prompt: %q\nerr: %q", prompt, err)
|
||||
promptResultsChannelLock.Lock()
|
||||
setupError = errors.Join(setupError, err)
|
||||
promptResultsChannelLock.Unlock()
|
||||
setupWG.Done()
|
||||
return
|
||||
}
|
||||
if notifyOnPromptResult {
|
||||
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(completionChannels, endpointConfig.CompletionMappingFn), rawCompletionsChannel, true)
|
||||
}
|
||||
if notifyOnToken {
|
||||
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, endpointConfig.TokenMappingFn), rawTokenChannel, true)
|
||||
}
|
||||
promptResultsChannelLock.Lock()
|
||||
promptResultsChannels = append(promptResultsChannels, promptResultsChannel)
|
||||
promptResultsChannelLock.Unlock()
|
||||
setupWG.Done()
|
||||
}(pI, p)
|
||||
|
||||
}
|
||||
setupWG.Wait()
|
||||
|
||||
// If any of the setup goroutines experienced an error, quit early here.
|
||||
if setupError != nil {
|
||||
go func() {
|
||||
log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError)
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
|
||||
close(rawFinalResultChannel)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
initialResponse := &schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model,
|
||||
Object: endpointConfig.SchemaObject,
|
||||
Usage: schema.OpenAIUsage{},
|
||||
}
|
||||
|
||||
// utils.SliceOfChannelsRawMerger[[]schema.Choice](promptResultsChannels, rawFinalResultChannel, func(results []schema.Choice) (*schema.OpenAIResponse, error) {
|
||||
concurrency.SliceOfChannelsReducer(
|
||||
promptResultsChannels, rawFinalResultChannel,
|
||||
func(iv concurrency.ErrorOr[*backend.LLMResponseBundle], result concurrency.ErrorOr[*schema.OpenAIResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
|
||||
if iv.Error != nil {
|
||||
result.Error = iv.Error
|
||||
return result
|
||||
}
|
||||
result.Value.Usage.PromptTokens += iv.Value.Usage.Prompt
|
||||
result.Value.Usage.CompletionTokens += iv.Value.Usage.Completion
|
||||
result.Value.Usage.TotalTokens = result.Value.Usage.PromptTokens + result.Value.Usage.CompletionTokens
|
||||
|
||||
result.Value.Choices = append(result.Value.Choices, iv.Value.Response...)
|
||||
|
||||
return result
|
||||
}, concurrency.ErrorOr[*schema.OpenAIResponse]{Value: initialResponse}, true)
|
||||
|
||||
completionsChannel = rawCompletionsChannel
|
||||
tokenChannel = rawTokenChannel
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: For porting sanity, this is distinct from GenerateTextFromRequest and is _currently_ specific to Chat purposes
|
||||
// this is not a final decision -- just a reality of moving a lot of parts at once
|
||||
// / This has _become_ Chat which wasn't the goal... More cleanup in the future once it's stable?
|
||||
func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) (
|
||||
traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse],
|
||||
completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) {
|
||||
|
||||
if initialTraceID == nil {
|
||||
traceID = &OpenAIRequestTraceID{
|
||||
ID: uuid.New().String(),
|
||||
Created: int(time.Now().Unix()),
|
||||
}
|
||||
} else {
|
||||
traceID = initialTraceID
|
||||
}
|
||||
|
||||
bc, request, err := oais.getConfig(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
noActionDescription := "use this action to answer without performing any action"
|
||||
|
||||
if bc.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = bc.FunctionsConfig.NoActionFunctionName
|
||||
}
|
||||
if bc.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = bc.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
if request.ResponseFormat.Type == "json_object" {
|
||||
request.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
bc.Grammar = request.Grammar
|
||||
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(request.Functions) > 0 && bc.ShouldUseFunctions() {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, request.Functions...)
|
||||
if !bc.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if bc.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(bc.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
bc.Grammar = jsStruct.Grammar("", bc.FunctionsConfig.ParallelCalls)
|
||||
} else if request.JSONFunctionGrammarObject != nil {
|
||||
bc.Grammar = request.JSONFunctionGrammarObject.Grammar("", bc.FunctionsConfig.ParallelCalls)
|
||||
}
|
||||
|
||||
if request.Stream && processFunctions {
|
||||
log.Warn().Msg("Streaming + Functions is highly experimental in this version")
|
||||
}
|
||||
|
||||
var predInput string
|
||||
|
||||
if !bc.TemplateConfig.UseTokenizerTemplate || processFunctions {
|
||||
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range request.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := bc.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := bc.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
|
||||
fcall := i.FunctionCall
|
||||
if len(i.ToolCalls) > 0 {
|
||||
fcall = i.ToolCalls
|
||||
}
|
||||
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if bc.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: bc.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
FunctionCall: fcall,
|
||||
FunctionName: i.Name,
|
||||
LastMessage: messageIndex == (len(request.Messages) - 1),
|
||||
Function: bc.Grammar != "" && (messageIndex == (len(request.Messages) - 1)),
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := oais.ml.EvaluateTemplateForChatMessage(bc.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, bc.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", bc.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
marshalAnyRole := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
marshalAny := func(f any) {
|
||||
j, err := json.Marshal(f)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
|
||||
if i.FunctionCall != nil {
|
||||
marshalAnyRole(i.FunctionCall)
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
|
||||
if i.FunctionCall != nil {
|
||||
marshalAny(i.FunctionCall)
|
||||
}
|
||||
|
||||
if i.ToolCalls != nil {
|
||||
marshalAny(i.ToolCalls)
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
templateFile := ""
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) {
|
||||
templateFile = bc.Model
|
||||
}
|
||||
|
||||
if bc.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = bc.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if bc.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = bc.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: bc.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", bc.Grammar)
|
||||
}
|
||||
|
||||
rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse]
|
||||
if notifyOnPromptResult {
|
||||
rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
}
|
||||
if notifyOnToken {
|
||||
rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
|
||||
}
|
||||
|
||||
rawResultChannel, individualCompletionChannels, tokenChannels, err := oais.llmbs.GenerateText(predInput, request, bc, func(resp *backend.LLMResponse) schema.Choice {
|
||||
return schema.Choice{
|
||||
Index: 0, // ???
|
||||
FinishReason: "stop",
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: resp.Response,
|
||||
},
|
||||
}
|
||||
}, notifyOnPromptResult, notifyOnToken)
|
||||
|
||||
chatSimpleMappingFn := func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] {
|
||||
if resp.Error != nil || resp.Value == nil {
|
||||
return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error}
|
||||
}
|
||||
return concurrency.ErrorOr[*schema.OpenAIResponse]{
|
||||
Value: &schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Delta: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: resp.Value.Response,
|
||||
},
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: resp.Value.Usage.Prompt,
|
||||
CompletionTokens: resp.Value.Usage.Completion,
|
||||
TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if notifyOnPromptResult {
|
||||
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(individualCompletionChannels, chatSimpleMappingFn), rawCompletionsChannel, true)
|
||||
}
|
||||
if notifyOnToken {
|
||||
concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, chatSimpleMappingFn), rawTokenChannel, true)
|
||||
}
|
||||
|
||||
go func() {
|
||||
rawResult := <-rawResultChannel
|
||||
if rawResult.Error != nil {
|
||||
log.Warn().Msgf("OpenAIService::processTools GenerateText error [DEBUG THIS?] %q", rawResult.Error)
|
||||
return
|
||||
}
|
||||
llmResponseChoices := rawResult.Value.Response
|
||||
|
||||
if processFunctions && len(llmResponseChoices) > 1 {
|
||||
log.Warn().Msgf("chat functions response with %d choices in response, debug this?", len(llmResponseChoices))
|
||||
log.Debug().Msgf("%+v", llmResponseChoices)
|
||||
}
|
||||
|
||||
for _, result := range rawResult.Value.Response {
|
||||
// If no functions, just return the raw result.
|
||||
if !processFunctions {
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{result},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
||||
},
|
||||
}
|
||||
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp}
|
||||
|
||||
continue
|
||||
}
|
||||
// At this point, things are function specific!
|
||||
|
||||
// Oh no this can't be the right way to do this... but it works. Save us, mudler!
|
||||
fString := fmt.Sprintf("%s", result.Message.Content)
|
||||
results := parseFunctionCall(fString, bc.FunctionsConfig.ParallelCalls)
|
||||
noActionToRun := (len(results) > 0 && results[0].name == noActionName)
|
||||
|
||||
if noActionToRun {
|
||||
log.Debug().Msg("-- noActionToRun branch --")
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: ""}}},
|
||||
Object: "stop",
|
||||
}
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage}
|
||||
|
||||
result, err := oais.handleQuestion(bc, request, results[0].arguments, predInput)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error handling question: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: rawResult.Value.Usage.Prompt,
|
||||
CompletionTokens: rawResult.Value.Usage.Completion,
|
||||
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
|
||||
},
|
||||
}
|
||||
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp}
|
||||
|
||||
} else {
|
||||
log.Debug().Msgf("[GenerateFromMultipleMessagesChatRequest] fnResultsBranch: %+v", results)
|
||||
for i, ss := range results {
|
||||
name, args := ss.name, ss.arguments
|
||||
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: traceID.ID,
|
||||
Created: traceID.Created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: traceID.ID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(rawFinalResultChannel)
|
||||
}()
|
||||
|
||||
finalResultChannel = rawFinalResultChannel
|
||||
completionsChannel = rawCompletionsChannel
|
||||
tokenChannel = rawTokenChannel
|
||||
return
|
||||
}
|
||||
|
||||
func (oais *OpenAIService) handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, args, prompt string) (string, error) {
|
||||
log.Debug().Msgf("[handleQuestion called] nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(args), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = oais.llmbs.Finetune(*config, prompt, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
return message, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU/GPU) another computation
|
||||
config.Grammar = ""
|
||||
images := []string{}
|
||||
for _, m := range input.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
resultChannel, _, err := oais.llmbs.Inference(input.Context, &backend.LLMRequest{
|
||||
Text: prompt,
|
||||
Images: images,
|
||||
RawMessages: input.Messages, // Experimental
|
||||
}, config, false)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference setup error: %s", err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw := <-resultChannel
|
||||
if raw.Error != nil {
|
||||
log.Error().Msgf("inference error: %q", raw.Error.Error())
|
||||
return "", err
|
||||
}
|
||||
if raw.Value == nil {
|
||||
log.Warn().Msgf("nil inference response")
|
||||
return "", nil
|
||||
}
|
||||
return oais.llmbs.Finetune(*config, prompt, raw.Value.Response), nil
|
||||
}
|
||||
|
||||
type funcCallResults struct {
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
|
||||
|
||||
results := []funcCallResults{}
|
||||
|
||||
// TODO: use generics to avoid this code duplication
|
||||
if multipleResults {
|
||||
ss := []map[string]interface{}{}
|
||||
s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
|
||||
for _, s := range ss {
|
||||
func_name, ok := s["function"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
args, ok := s["arguments"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
} else {
|
||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
// s := utils.EscapeNewLines(llmresult)
|
||||
json.Unmarshal([]byte(llmresult), &ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name, ok := ss["function"]
|
||||
if !ok {
|
||||
log.Debug().Msg("ss[function] is not OK!")
|
||||
return results
|
||||
}
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
if !ok {
|
||||
log.Debug().Msg("ss[arguments] is not OK!")
|
||||
return results
|
||||
}
|
||||
d, _ := json.Marshal(args)
|
||||
funcName, ok := func_name.(string)
|
||||
if !ok {
|
||||
log.Debug().Msgf("unexpected func_name: %+v", func_name)
|
||||
return results
|
||||
}
|
||||
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||
}
|
||||
return results
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue