mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-22 03:24:59 +00:00
feat: add stream events (#152)
This commit is contained in:
parent
0a00a4b58e
commit
220d6fd59b
3 changed files with 67 additions and 46 deletions
|
@ -16,12 +16,12 @@ import (
|
|||
var mutexMap sync.Mutex
|
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) {
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
var model *llama.LLama
|
||||
var gptModel *gptj.GPTJ
|
||||
var gpt2Model *gpt2.GPT2
|
||||
var stableLMModel *gpt2.StableLM
|
||||
|
||||
supportStreams := false
|
||||
modelFile := c.Model
|
||||
|
||||
// Try to load the model
|
||||
|
@ -125,7 +125,13 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
|
|||
)
|
||||
}
|
||||
case model != nil:
|
||||
supportStreams = true
|
||||
fn = func() (string, error) {
|
||||
|
||||
if tokenCallback != nil {
|
||||
model.SetTokenCallback(tokenCallback)
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(c.Temperature),
|
||||
|
@ -185,11 +191,15 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
|
|||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return fn()
|
||||
res, err := fn()
|
||||
if tokenCallback != nil && !supportStreams {
|
||||
tokenCallback(res)
|
||||
}
|
||||
return res, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) {
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
||||
result := []Choice{}
|
||||
|
||||
n := input.N
|
||||
|
@ -199,7 +209,7 @@ func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, load
|
|||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(predInput, loader, *config)
|
||||
predFunc, err := ModelInference(predInput, loader, *config, tokenCallback)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue