Update llama-go, allow to set context-size and enable alpaca model by default

This commit is contained in:
mudler 2023-03-21 19:20:23 +01:00
parent 973042bb4c
commit 9ba30c9c44
4 changed files with 161 additions and 45 deletions

126
main.go
View file

@ -31,6 +31,15 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
### Response:
`
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
if ctx.Bool("alpaca") {
opts = append(opts, llama.EnableAlpaca)
}
return llama.New(ctx.String("model"), opts...)
}
func templateString(t string, in interface{}) (string, error) {
// Parse the template
tmpl, err := template.New("prompt").Parse(t)
@ -46,12 +55,54 @@ func templateString(t string, in interface{}) (string, error) {
return buf.String(), nil
}
var modelFlags = []cli.Flag{
&cli.StringFlag{
Name: "model",
EnvVars: []string{"MODEL_PATH"},
},
&cli.IntFlag{
Name: "tokens",
EnvVars: []string{"TOKENS"},
Value: 128,
},
&cli.IntFlag{
Name: "context-size",
EnvVars: []string{"CONTEXT_SIZE"},
Value: 512,
},
&cli.IntFlag{
Name: "threads",
EnvVars: []string{"THREADS"},
Value: runtime.NumCPU(),
},
&cli.Float64Flag{
Name: "temperature",
EnvVars: []string{"TEMPERATURE"},
Value: 0.95,
},
&cli.Float64Flag{
Name: "topp",
EnvVars: []string{"TOP_P"},
Value: 0.85,
},
&cli.IntFlag{
Name: "topk",
EnvVars: []string{"TOP_K"},
Value: 20,
},
&cli.BoolFlag{
Name: "alpaca",
EnvVars: []string{"ALPACA"},
Value: true,
},
}
func main() {
app := &cli.App{
Name: "llama-cli",
Version: "0.1",
Usage: "llama-cli --model ... --instruction 'What is an alpaca?'",
Flags: []cli.Flag{
Flags: append(modelFlags,
&cli.StringFlag{
Name: "template",
EnvVars: []string{"TEMPLATE"},
@ -63,37 +114,7 @@ func main() {
&cli.StringFlag{
Name: "input",
EnvVars: []string{"INPUT"},
},
&cli.StringFlag{
Name: "model",
EnvVars: []string{"MODEL_PATH"},
},
&cli.IntFlag{
Name: "tokens",
EnvVars: []string{"TOKENS"},
Value: 128,
},
&cli.IntFlag{
Name: "threads",
EnvVars: []string{"THREADS"},
Value: runtime.NumCPU(),
},
&cli.Float64Flag{
Name: "temperature",
EnvVars: []string{"TEMPERATURE"},
Value: 0.95,
},
&cli.Float64Flag{
Name: "topp",
EnvVars: []string{"TOP_P"},
Value: 0.85,
},
&cli.IntFlag{
Name: "topk",
EnvVars: []string{"TOP_K"},
Value: 20,
},
},
}),
Description: `Run llama.cpp inference`,
UsageText: `
llama-cli --model ~/ggml-alpaca-7b-q4.bin --instruction "What's an alpaca?"
@ -107,6 +128,25 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
Copyright: "go-skynet authors",
Commands: []*cli.Command{
{
Flags: modelFlags,
Name: "interactive",
Action: func(ctx *cli.Context) error {
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
return startInteractive(l, llama.SetTemperature(ctx.Float64("temperature")),
llama.SetTopP(ctx.Float64("topp")),
llama.SetTopK(ctx.Int("topk")),
llama.SetTokens(ctx.Int("tokens")),
llama.SetThreads(ctx.Int("threads")))
},
},
{
Name: "api",
Flags: []cli.Flag{
&cli.IntFlag{
@ -123,9 +163,25 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
EnvVars: []string{"ADDRESS"},
Value: ":8080",
},
&cli.BoolFlag{
Name: "alpaca",
EnvVars: []string{"ALPACA"},
Value: true,
},
&cli.IntFlag{
Name: "context-size",
EnvVars: []string{"CONTEXT_SIZE"},
Value: 512,
},
},
Action: func(ctx *cli.Context) error {
return api(ctx.String("model"), ctx.String("address"), ctx.Int("threads"))
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
return api(l, ctx.String("address"), ctx.Int("threads"))
},
},
},
@ -179,11 +235,13 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
fmt.Println("Templating the input failed:", err.Error())
os.Exit(1)
}
l, err := llama.New(ctx.String("model"))
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
res, err := l.Predict(
str,
llama.SetTemperature(ctx.Float64("temperature")),