mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-22 03:24:59 +00:00
feat: add starcoder (#236)
This commit is contained in:
parent
f359e1c6c4
commit
4413defca5
5 changed files with 56 additions and 2 deletions
|
@ -199,6 +199,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||
|
||||
return response, nil
|
||||
}
|
||||
case *gpt2.Starcoder:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt2.PredictOption{
|
||||
gpt2.SetTemperature(c.Temperature),
|
||||
gpt2.SetTopP(c.TopP),
|
||||
gpt2.SetTopK(c.TopK),
|
||||
gpt2.SetTokens(c.Maxtokens),
|
||||
gpt2.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *gpt2.RedPajama:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue