mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
Update llama-go, allow to set context-size and enable alpaca model by default
This commit is contained in:
parent
973042bb4c
commit
9ba30c9c44
4 changed files with 161 additions and 45 deletions
126
main.go
126
main.go
|
@ -31,6 +31,15 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
|
|||
### Response:
|
||||
`
|
||||
|
||||
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
if ctx.Bool("alpaca") {
|
||||
opts = append(opts, llama.EnableAlpaca)
|
||||
}
|
||||
|
||||
return llama.New(ctx.String("model"), opts...)
|
||||
}
|
||||
|
||||
func templateString(t string, in interface{}) (string, error) {
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(t)
|
||||
|
@ -46,12 +55,54 @@ func templateString(t string, in interface{}) (string, error) {
|
|||
return buf.String(), nil
|
||||
}
|
||||
|
||||
var modelFlags = []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "model",
|
||||
EnvVars: []string{"MODEL_PATH"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "tokens",
|
||||
EnvVars: []string{"TOKENS"},
|
||||
Value: 128,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
EnvVars: []string{"CONTEXT_SIZE"},
|
||||
Value: 512,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "threads",
|
||||
EnvVars: []string{"THREADS"},
|
||||
Value: runtime.NumCPU(),
|
||||
},
|
||||
&cli.Float64Flag{
|
||||
Name: "temperature",
|
||||
EnvVars: []string{"TEMPERATURE"},
|
||||
Value: 0.95,
|
||||
},
|
||||
&cli.Float64Flag{
|
||||
Name: "topp",
|
||||
EnvVars: []string{"TOP_P"},
|
||||
Value: 0.85,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "topk",
|
||||
EnvVars: []string{"TOP_K"},
|
||||
Value: 20,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "alpaca",
|
||||
EnvVars: []string{"ALPACA"},
|
||||
Value: true,
|
||||
},
|
||||
}
|
||||
|
||||
func main() {
|
||||
app := &cli.App{
|
||||
Name: "llama-cli",
|
||||
Version: "0.1",
|
||||
Usage: "llama-cli --model ... --instruction 'What is an alpaca?'",
|
||||
Flags: []cli.Flag{
|
||||
Flags: append(modelFlags,
|
||||
&cli.StringFlag{
|
||||
Name: "template",
|
||||
EnvVars: []string{"TEMPLATE"},
|
||||
|
@ -63,37 +114,7 @@ func main() {
|
|||
&cli.StringFlag{
|
||||
Name: "input",
|
||||
EnvVars: []string{"INPUT"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "model",
|
||||
EnvVars: []string{"MODEL_PATH"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "tokens",
|
||||
EnvVars: []string{"TOKENS"},
|
||||
Value: 128,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "threads",
|
||||
EnvVars: []string{"THREADS"},
|
||||
Value: runtime.NumCPU(),
|
||||
},
|
||||
&cli.Float64Flag{
|
||||
Name: "temperature",
|
||||
EnvVars: []string{"TEMPERATURE"},
|
||||
Value: 0.95,
|
||||
},
|
||||
&cli.Float64Flag{
|
||||
Name: "topp",
|
||||
EnvVars: []string{"TOP_P"},
|
||||
Value: 0.85,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "topk",
|
||||
EnvVars: []string{"TOP_K"},
|
||||
Value: 20,
|
||||
},
|
||||
},
|
||||
}),
|
||||
Description: `Run llama.cpp inference`,
|
||||
UsageText: `
|
||||
llama-cli --model ~/ggml-alpaca-7b-q4.bin --instruction "What's an alpaca?"
|
||||
|
@ -107,6 +128,25 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||
Copyright: "go-skynet authors",
|
||||
Commands: []*cli.Command{
|
||||
{
|
||||
Flags: modelFlags,
|
||||
Name: "interactive",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
|
||||
l, err := llamaFromOptions(ctx)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return startInteractive(l, llama.SetTemperature(ctx.Float64("temperature")),
|
||||
llama.SetTopP(ctx.Float64("topp")),
|
||||
llama.SetTopK(ctx.Int("topk")),
|
||||
llama.SetTokens(ctx.Int("tokens")),
|
||||
llama.SetThreads(ctx.Int("threads")))
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
Name: "api",
|
||||
Flags: []cli.Flag{
|
||||
&cli.IntFlag{
|
||||
|
@ -123,9 +163,25 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||
EnvVars: []string{"ADDRESS"},
|
||||
Value: ":8080",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "alpaca",
|
||||
EnvVars: []string{"ALPACA"},
|
||||
Value: true,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
EnvVars: []string{"CONTEXT_SIZE"},
|
||||
Value: 512,
|
||||
},
|
||||
},
|
||||
Action: func(ctx *cli.Context) error {
|
||||
return api(ctx.String("model"), ctx.String("address"), ctx.Int("threads"))
|
||||
l, err := llamaFromOptions(ctx)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return api(l, ctx.String("address"), ctx.Int("threads"))
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -179,11 +235,13 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||
fmt.Println("Templating the input failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
l, err := llama.New(ctx.String("model"))
|
||||
|
||||
l, err := llamaFromOptions(ctx)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
res, err := l.Predict(
|
||||
str,
|
||||
llama.SetTemperature(ctx.Float64("temperature")),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue