mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-25 21:15:00 +00:00
feat: add dolly/redpajama/bloomz models support (#214)
This commit is contained in:
parent
f02202e1e1
commit
11675932ac
7 changed files with 235 additions and 13 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
|
@ -198,6 +199,50 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||
|
||||
return response, nil
|
||||
}
|
||||
case *gpt2.RedPajama:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt2.PredictOption{
|
||||
gpt2.SetTemperature(c.Temperature),
|
||||
gpt2.SetTopP(c.TopP),
|
||||
gpt2.SetTopK(c.TopK),
|
||||
gpt2.SetTokens(c.Maxtokens),
|
||||
gpt2.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *bloomz.Bloomz:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{
|
||||
bloomz.SetTemperature(c.Temperature),
|
||||
bloomz.SetTopP(c.TopP),
|
||||
bloomz.SetTopK(c.TopK),
|
||||
bloomz.SetTokens(c.Maxtokens),
|
||||
bloomz.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *gpt2.StableLM:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
@ -222,6 +267,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *gpt2.Dolly:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt2.PredictOption{
|
||||
gpt2.SetTemperature(c.Temperature),
|
||||
gpt2.SetTopP(c.TopP),
|
||||
gpt2.SetTopK(c.TopK),
|
||||
gpt2.SetTokens(c.Maxtokens),
|
||||
gpt2.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *gpt2.GPT2:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue