mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat: add /models/apply endpoint to prepare models (#286)
This commit is contained in:
parent
5617e50ebc
commit
cc9aa9eb3f
23 changed files with 556 additions and 33 deletions
13
pkg/gallery/gallery_suite_test.go
Normal file
13
pkg/gallery/gallery_suite_test.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package gallery_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestGallery(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Gallery test suite")
|
||||
}
|
237
pkg/gallery/models.go
Normal file
237
pkg/gallery/models.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
package gallery
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
description: |
|
||||
foo
|
||||
license: ""
|
||||
|
||||
urls:
|
||||
-
|
||||
-
|
||||
|
||||
name: "bar"
|
||||
|
||||
config_file: |
|
||||
# Note, name will be injected. or generated by the alias wanted by the user
|
||||
threads: 14
|
||||
|
||||
files:
|
||||
- filename: ""
|
||||
sha: ""
|
||||
uri: ""
|
||||
|
||||
prompt_templates:
|
||||
- name: ""
|
||||
content: ""
|
||||
|
||||
*/
|
||||
|
||||
type Config struct {
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license"`
|
||||
URLs []string `yaml:"urls"`
|
||||
Name string `yaml:"name"`
|
||||
ConfigFile string `yaml:"config_file"`
|
||||
Files []File `yaml:"files"`
|
||||
PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Filename string `yaml:"filename"`
|
||||
SHA256 string `yaml:"sha256"`
|
||||
URI string `yaml:"uri"`
|
||||
}
|
||||
|
||||
type PromptTemplate struct {
|
||||
Name string `yaml:"name"`
|
||||
Content string `yaml:"content"`
|
||||
}
|
||||
|
||||
func ReadConfigFile(filePath string) (*Config, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read YAML file: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal YAML data into a Config struct
|
||||
var config Config
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func Apply(basePath, nameOverride string, config *Config) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create base path: %v", err)
|
||||
}
|
||||
|
||||
// Download files and verify their SHA
|
||||
for _, file := range config.Files {
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
// Create file path
|
||||
filePath := filepath.Join(basePath, file.Filename)
|
||||
|
||||
// Check if the file already exists
|
||||
_, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
// File exists, check SHA
|
||||
if file.SHA256 != "" {
|
||||
// Verify SHA
|
||||
calculatedSHA, err := calculateSHA(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
|
||||
}
|
||||
if calculatedSHA == file.SHA256 {
|
||||
// SHA matches, skip downloading
|
||||
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
|
||||
continue
|
||||
}
|
||||
// SHA doesn't match, delete the file and download again
|
||||
err = os.Remove(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
|
||||
}
|
||||
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
|
||||
|
||||
} else {
|
||||
// SHA is missing, skip downloading
|
||||
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
|
||||
continue
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
// Error occurred while checking file existence
|
||||
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Downloading %q", file.URI)
|
||||
|
||||
// Download file
|
||||
resp, err := http.Get(file.URI)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create parent directory
|
||||
err = os.MkdirAll(filepath.Dir(filePath), 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
|
||||
}
|
||||
|
||||
// Create and write file content
|
||||
outFile, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
if file.SHA256 != "" {
|
||||
log.Debug().Msgf("Download and verifying %q", file.Filename)
|
||||
|
||||
// Write file content and calculate SHA
|
||||
hash := sha256.New()
|
||||
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
||||
}
|
||||
|
||||
// Verify SHA
|
||||
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if calculatedSHA != file.SHA256 {
|
||||
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
|
||||
_, err = io.Copy(outFile, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
|
||||
}
|
||||
|
||||
// Write prompt template contents to separate files
|
||||
for _, template := range config.PromptTemplates {
|
||||
// Create file path
|
||||
filePath := filepath.Join(basePath, template.Name+".tmpl")
|
||||
|
||||
// Create parent directory
|
||||
err := os.MkdirAll(filepath.Dir(filePath), 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
|
||||
}
|
||||
// Create and write file content
|
||||
err = os.WriteFile(filePath, []byte(template.Content), 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt template %q written", template.Name)
|
||||
}
|
||||
|
||||
name := config.Name
|
||||
if nameOverride != "" {
|
||||
name = nameOverride
|
||||
}
|
||||
|
||||
configFilePath := filepath.Join(basePath, name+".yaml")
|
||||
|
||||
// Read and update config file as map[string]interface{}
|
||||
configMap := make(map[string]interface{})
|
||||
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
||||
}
|
||||
|
||||
configMap["name"] = name
|
||||
|
||||
// Write updated config file
|
||||
updatedConfigYAML, err := yaml.Marshal(configMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write updated config file: %v", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Written config file %s", configFilePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateSHA(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||
}
|
30
pkg/gallery/models_test.go
Normal file
30
pkg/gallery/models_test.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package gallery_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Model test", func() {
|
||||
Context("Downloading", func() {
|
||||
It("applies model correctly", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = Apply(tempdir, "", c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
_, err = os.Stat(filepath.Join(tempdir, f))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
|
@ -164,11 +164,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
|||
}
|
||||
|
||||
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) {
|
||||
log.Debug().Msgf("Loading models greedly")
|
||||
log.Debug().Msgf("Loading model '%s' greedly", modelFile)
|
||||
|
||||
ml.mu.Lock()
|
||||
m, exists := ml.models[modelFile]
|
||||
if exists {
|
||||
log.Debug().Msgf("Model '%s' already loaded", modelFile)
|
||||
ml.mu.Unlock()
|
||||
return m, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue