feat: allow to override model config (#323)

This commit is contained in:
Ettore Di Giacinto 2023-05-20 17:03:53 +02:00 committed by GitHub
parent 7bc08797f9
commit 05a3d569b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 269 additions and 64 deletions

View file

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
@ -92,13 +93,17 @@ func verifyPath(path, basePath string) error {
return inTrustedRoot(c, basePath)
}
func Apply(basePath, nameOverride string, config *Config) error {
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
if err != nil {
return fmt.Errorf("failed to create base path: %v", err)
}
if len(configOverrides) > 0 {
log.Debug().Msgf("Config overrides %+v", configOverrides)
}
// Download files and verify their SHA
for _, file := range config.Files {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
@ -231,6 +236,10 @@ func Apply(basePath, nameOverride string, config *Config) error {
configMap["name"] = name
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
return err
}
// Write updated config file
updatedConfigYAML, err := yaml.Marshal(configMap)
if err != nil {

View file

@ -7,6 +7,7 @@ import (
. "github.com/go-skynet/LocalAI/pkg/gallery"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v3"
)
var _ = Describe("Model test", func() {
@ -18,14 +19,25 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "", c)
err = Apply(tempdir, "", c, map[string]interface{}{})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
_, err = os.Stat(filepath.Join(tempdir, f))
Expect(err).ToNot(HaveOccurred())
}
content := map[string]interface{}{}
dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
Expect(err).ToNot(HaveOccurred())
err = yaml.Unmarshal(dat, content)
Expect(err).ToNot(HaveOccurred())
Expect(content["context_size"]).To(Equal(1024))
})
It("renames model correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
@ -33,7 +45,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c)
err = Apply(tempdir, "foo", c, map[string]interface{}{})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -41,6 +53,33 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred())
}
})
It("overrides parameters", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"})
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
_, err = os.Stat(filepath.Join(tempdir, f))
Expect(err).ToNot(HaveOccurred())
}
content := map[string]interface{}{}
dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
err = yaml.Unmarshal(dat, content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("foo"))
})
It("catches path traversals", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
@ -48,7 +87,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "../../../foo", c)
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{})
Expect(err).To(HaveOccurred())
})
})