package gallery import ( "crypto/sha256" "fmt" "hash" "io" "io/ioutil" "net/http" "os" "path/filepath" "strconv" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) /* description: | foo license: "" urls: - - name: "bar" config_file: | # Note, name will be injected. or generated by the alias wanted by the user threads: 14 files: - filename: "" sha: "" uri: "" prompt_templates: - name: "" content: "" */ type Config struct { Description string `yaml:"description"` License string `yaml:"license"` URLs []string `yaml:"urls"` Name string `yaml:"name"` ConfigFile string `yaml:"config_file"` Files []File `yaml:"files"` PromptTemplates []PromptTemplate `yaml:"prompt_templates"` } type File struct { Filename string `yaml:"filename" json:"filename"` SHA256 string `yaml:"sha256" json:"sha256"` URI string `yaml:"uri" json:"uri"` } type PromptTemplate struct { Name string `yaml:"name"` Content string `yaml:"content"` } func ReadConfigFile(filePath string) (*Config, error) { // Read the YAML file yamlFile, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read YAML file: %v", err) } // Unmarshal YAML data into a Config struct var config Config err = yaml.Unmarshal(yamlFile, &config) if err != nil { return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) } return &config, nil } func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { return fmt.Errorf("failed to create base path: %v", err) } if len(configOverrides) > 0 { log.Debug().Msgf("Config overrides %+v", configOverrides) } // Download files and verify their SHA for _, file := range config.Files { log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) if err := utils.VerifyPath(file.Filename, basePath); err != nil { return err } // Create file path filePath := filepath.Join(basePath, file.Filename) // Check if the file already exists _, err := os.Stat(filePath) if err == nil { // File exists, check SHA if file.SHA256 != "" { // Verify SHA calculatedSHA, err := calculateSHA(filePath) if err != nil { return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) } if calculatedSHA == file.SHA256 { // SHA matches, skip downloading log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) continue } // SHA doesn't match, delete the file and download again err = os.Remove(filePath) if err != nil { return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) } log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) } else { // SHA is missing, skip downloading log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) continue } } else if !os.IsNotExist(err) { // Error occurred while checking file existence return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) } log.Debug().Msgf("Downloading %q", file.URI) // Download file resp, err := http.Get(file.URI) if err != nil { return fmt.Errorf("failed to download file %q: %v", file.Filename, err) } defer resp.Body.Close() // Create parent directory err = os.MkdirAll(filepath.Dir(filePath), 0755) if err != nil { return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) } // Create and write file content outFile, err := os.Create(filePath) if err != nil { return fmt.Errorf("failed to create file %q: %v", file.Filename, err) } defer outFile.Close() progress := &progressWriter{ fileName: file.Filename, total: resp.ContentLength, hash: sha256.New(), downloadStatus: downloadStatus, } _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) if err != nil { return fmt.Errorf("failed to write file %q: %v", file.Filename, err) } if file.SHA256 != "" { // Verify SHA calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) if calculatedSHA != file.SHA256 { return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) } } else { log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) } log.Debug().Msgf("File %q downloaded and verified", file.Filename) } // Write prompt template contents to separate files for _, template := range config.PromptTemplates { if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil { return err } // Create file path filePath := filepath.Join(basePath, template.Name+".tmpl") // Create parent directory err := os.MkdirAll(filepath.Dir(filePath), 0755) if err != nil { return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) } // Create and write file content err = os.WriteFile(filePath, []byte(template.Content), 0644) if err != nil { return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) } log.Debug().Msgf("Prompt template %q written", template.Name) } name := config.Name if nameOverride != "" { name = nameOverride } if err := utils.VerifyPath(name+".yaml", basePath); err != nil { return err } configFilePath := filepath.Join(basePath, name+".yaml") // Read and update config file as map[string]interface{} configMap := make(map[string]interface{}) err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) if err != nil { return fmt.Errorf("failed to unmarshal config YAML: %v", err) } configMap["name"] = name if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { return err } // Write updated config file updatedConfigYAML, err := yaml.Marshal(configMap) if err != nil { return fmt.Errorf("failed to marshal updated config YAML: %v", err) } err = os.WriteFile(configFilePath, updatedConfigYAML, 0644) if err != nil { return fmt.Errorf("failed to write updated config file: %v", err) } log.Debug().Msgf("Written config file %s", configFilePath) return nil } type progressWriter struct { fileName string total int64 written int64 downloadStatus func(string, string, string, float64) hash hash.Hash } func (pw *progressWriter) Write(p []byte) (n int, err error) { n, err = pw.hash.Write(p) pw.written += int64(n) if pw.total > 0 { percentage := float64(pw.written) / float64(pw.total) * 100 //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) } else { pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) } return } func formatBytes(bytes int64) string { const unit = 1024 if bytes < unit { return strconv.FormatInt(bytes, 10) + " B" } div, exp := int64(unit), 0 for n := bytes / unit; n >= unit; n /= unit { div *= unit exp++ } return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) } func calculateSHA(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", err } defer file.Close() hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { return "", err } return fmt.Sprintf("%x", hash.Sum(nil)), nil } type Gallery struct { URL string `json:"url"` Name string `json:"name"` } // Installs a model from the gallery (galleryname@modelname) func ApplyModelFromGallery(galleries []*Gallery, name string, basePath, nameOverride string, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { models, err := AvailableModels(galleries) if err != nil { return err } applyModel := func(model *GalleryModel) error { url, err := model.DecodeURL() if err != nil { return err } // Send a GET request to the URL response, err := http.Get(url) if err != nil { return err } defer response.Body.Close() // Read the response body body, err := ioutil.ReadAll(response.Body) if err != nil { return err } // Unmarshal YAML data into a Config struct var config Config err = yaml.Unmarshal(body, &config) if err != nil { return err } if nameOverride != "" { model.Name = nameOverride } // TODO model.Overrides could be merged with user overrides (not defined yet) if err := mergo.Merge(&model.Overrides, configOverrides, mergo.WithOverride); err != nil { return err } if err := Apply(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { return err } return nil } for _, model := range models { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { return applyModel(model) } } return fmt.Errorf("no model found with name %q", name) } // List available models // Models galleries are a list of json files that are hosted on a remote server (for example github). // Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. func AvailableModels(galleries []*Gallery) ([]*GalleryModel, error) { var models []*GalleryModel // Get models from galleries for _, gallery := range galleries { galleryModels, err := getModels(gallery) if err != nil { return nil, err } models = append(models, galleryModels...) } return models, nil } func getModels(gallery *Gallery) ([]*GalleryModel, error) { var models []*GalleryModel // Get list of models resp, err := http.Get(gallery.URL) if err != nil { return nil, fmt.Errorf("failed to get models: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get models: %s", resp.Status) } err = yaml.NewDecoder(resp.Body).Decode(&models) if err != nil { return nil, fmt.Errorf("failed to decode models: %v", err) } // Add gallery to models for _, model := range models { model.Gallery = gallery } return models, nil }