feat: add starcoder (#236)

This commit is contained in:
Ettore Di Giacinto 2023-05-11 20:20:07 +02:00 committed by GitHub
parent f359e1c6c4
commit 4413defca5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 2 deletions

View file

@ -199,6 +199,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
return response, nil
}
case *gpt2.Starcoder:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.RedPajama:
fn = func() (string, error) {
// Generate the prediction using the language model