diff --git a/api/api.go b/api/api.go index 99d52435..e3409eff 100644 --- a/api/api.go +++ b/api/api.go @@ -47,6 +47,10 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, } } + if err := cl.Preload(options.Loader.ModelPath); err != nil { + log.Error().Msgf("error downloading models: %s", err.Error()) + } + if options.Debug { for _, v := range cl.ListConfigs() { cfg, _ := cl.GetConfig(v) diff --git a/api/api_test.go b/api/api_test.go index 6329df34..a71b450a 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -294,7 +294,7 @@ var _ = Describe("API test", func() { Expect(content["backend"]).To(Equal("bert-embeddings")) }) - It("runs openllama", Label("llama"), func() { + It("runs openllama(llama-ggml backend)", Label("llama"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } @@ -362,9 +362,10 @@ var _ = Describe("API test", func() { Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) + }) - It("runs openllama gguf", Label("llama-gguf"), func() { + It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } diff --git a/api/config/config.go b/api/config/config.go index 84c1f784..7ed7061a 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -8,6 +8,8 @@ import ( "strings" "sync" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -264,6 +266,36 @@ func (cm *ConfigLoader) ListConfigs() []string { return res } +func (cm *ConfigLoader) Preload(modelPath string) error { + cm.Lock() + defer cm.Unlock() + + for i, config := range cm.configs { + modelURL := config.PredictionOptions.Model + modelURL = utils.ConvertURL(modelURL) + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + // md5 of model name + md5Name := utils.MD5(modelURL) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist { + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { + log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent) + }) + if err != nil { + return err + } + } + + cc := cm.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + cm.configs[i] = *c + } + } + return nil +} + func (cm *ConfigLoader) LoadConfigs(path string) error { cm.Lock() defer cm.Unlock() diff --git a/api/openai/chat.go b/api/openai/chat.go index cd0b82dd..02bf6149 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -219,7 +219,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) c.Set("Transfer-Encoding", "chunked") } - templateFile := config.Model + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } if config.TemplateConfig.Chat != "" && !processFunctions { templateFile = config.TemplateConfig.Chat @@ -229,18 +234,19 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) templateFile = config.TemplateConfig.Functions } - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } } log.Debug().Msgf("Prompt (after templating): %s", predInput) diff --git a/api/openai/completion.go b/api/openai/completion.go index da28d63c..c0607632 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -81,7 +81,12 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe c.Set("Transfer-Encoding", "chunked") } - templateFile := config.Model + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } if config.TemplateConfig.Completion != "" { templateFile = config.TemplateConfig.Completion @@ -94,13 +99,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe predInput := config.PromptStrings[0] - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } } responses := make(chan schema.OpenAIResponse) @@ -145,14 +151,16 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe totalTokenUsage := backend.TokenUsage{} for k, i := range config.PromptStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) + if templateFile != "" { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } } r, tokenUsage, err := ComputeChoices( diff --git a/api/openai/edit.go b/api/openai/edit.go index 088f0035..888b9db7 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -30,7 +30,12 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Parameter Config: %+v", config) - templateFile := config.Model + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } if config.TemplateConfig.Edit != "" { templateFile = config.TemplateConfig.Edit @@ -40,15 +45,16 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) totalTokenUsage := backend.TokenUsage{} for _, i := range config.InputStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ - Input: i, - Instruction: input.Instruction, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) + if templateFile != "" { + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } } r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index db68f525..9a169798 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -5,7 +5,6 @@ import ( "fmt" "hash" "io" - "net/http" "os" "path/filepath" "strconv" @@ -115,89 +114,8 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides // Create file path filePath := filepath.Join(basePath, file.Filename) - // Check if the file already exists - _, err := os.Stat(filePath) - if err == nil { - // File exists, check SHA - if file.SHA256 != "" { - // Verify SHA - calculatedSHA, err := calculateSHA(filePath) - if err != nil { - return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) - } - if calculatedSHA == file.SHA256 { - // SHA matches, skip downloading - log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) - continue - } - // SHA doesn't match, delete the file and download again - err = os.Remove(filePath) - if err != nil { - return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) - } - log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) - - } else { - // SHA is missing, skip downloading - log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) - continue - } - } else if !os.IsNotExist(err) { - // Error occurred while checking file existence - return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) - } - - log.Debug().Msgf("Downloading %q", file.URI) - - // Download file - resp, err := http.Get(file.URI) - if err != nil { - return fmt.Errorf("failed to download file %q: %v", file.Filename, err) - } - defer resp.Body.Close() - - // Create parent directory - err = os.MkdirAll(filepath.Dir(filePath), 0755) - if err != nil { - return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) - } - - // Create and write file content - outFile, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %q: %v", file.Filename, err) - } - defer outFile.Close() - - progress := &progressWriter{ - fileName: file.Filename, - total: resp.ContentLength, - hash: sha256.New(), - downloadStatus: downloadStatus, - } - _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } - - if file.SHA256 != "" { - // Verify SHA - calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) - if calculatedSHA != file.SHA256 { - log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) - return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) - } - } else { - log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) - } - - log.Debug().Msgf("File %q downloaded and verified", file.Filename) - if utils.IsArchive(filePath) { - log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath) - if err := utils.ExtractArchive(filePath, basePath); err != nil { - log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error()) - return err - } + if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil { + return err } } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 493ee083..d02f9e84 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -247,17 +247,19 @@ func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateN // skip any error here - we run anyway if a template does not exist modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) - if !ml.ExistsInModelPath(modelTemplateFile) { - return nil - } - - dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) - if err != nil { - return err + dat := "" + if ml.ExistsInModelPath(modelTemplateFile) { + d, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) + if err != nil { + return err + } + dat = string(d) + } else { + dat = templateName } // Parse the template - tmpl, err := template.New("prompt").Parse(string(dat)) + tmpl, err := template.New("prompt").Parse(dat) if err != nil { return err } diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index 16f4dbf8..8046b89f 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -1,12 +1,18 @@ package utils import ( + "crypto/md5" + "crypto/sha256" "fmt" + "hash" "io" "net/http" "os" "path/filepath" + "strconv" "strings" + + "github.com/rs/zerolog/log" ) const ( @@ -64,3 +70,173 @@ func GetURI(url string, f func(url string, i []byte) error) error { // Unmarshal YAML data into a struct return f(url, body) } + +func ConvertURL(s string) string { + switch { + case strings.HasPrefix(s, "huggingface://"): + repository := strings.Replace(s, "huggingface://", "", 1) + // convert repository to a full URL. + // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf + owner := strings.Split(repository, "/")[0] + repo := strings.Split(repository, "/")[1] + branch := "main" + if strings.Contains(repo, "@") { + branch = strings.Split(repository, "@")[1] + } + filepath := strings.Split(repository, "/")[2] + if strings.Contains(filepath, "@") { + filepath = strings.Split(filepath, "@")[0] + } + + return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) + } + + return s +} + +func DownloadFile(url string, filePath, sha string, downloadStatus func(string, string, string, float64)) error { + url = ConvertURL(url) + // Check if the file already exists + _, err := os.Stat(filePath) + if err == nil { + // File exists, check SHA + if sha != "" { + // Verify SHA + calculatedSHA, err := calculateSHA(filePath) + if err != nil { + return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) + } + if calculatedSHA == sha { + // SHA matches, skip downloading + log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath) + return nil + } + // SHA doesn't match, delete the file and download again + err = os.Remove(filePath) + if err != nil { + return fmt.Errorf("failed to remove existing file %q: %v", filePath, err) + } + log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) + + } else { + // SHA is missing, skip downloading + log.Debug().Msgf("File %q already exists. Skipping download", filePath) + return nil + } + } else if !os.IsNotExist(err) { + // Error occurred while checking file existence + return fmt.Errorf("failed to check file %q existence: %v", filePath, err) + } + + log.Info().Msgf("Downloading %q", url) + + // Download file + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download file %q: %v", filePath, err) + } + defer resp.Body.Close() + + // Create parent directory + err = os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) + } + + // Create and write file content + outFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %q: %v", filePath, err) + } + defer outFile.Close() + + progress := &progressWriter{ + fileName: filePath, + total: resp.ContentLength, + hash: sha256.New(), + downloadStatus: downloadStatus, + } + _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", filePath, err) + } + + if sha != "" { + // Verify SHA + calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) + if calculatedSHA != sha { + log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) + return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) + } + } else { + log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) + } + + log.Info().Msgf("File %q downloaded and verified", filePath) + if IsArchive(filePath) { + basePath := filepath.Dir(filePath) + log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath) + if err := ExtractArchive(filePath, basePath); err != nil { + log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error()) + return err + } + } + + return nil +} + +type progressWriter struct { + fileName string + total int64 + written int64 + downloadStatus func(string, string, string, float64) + hash hash.Hash +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + n, err = pw.hash.Write(p) + pw.written += int64(n) + + if pw.total > 0 { + percentage := float64(pw.written) / float64(pw.total) * 100 + //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + } else { + pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) + } + + return +} + +// MD5 of a string +func MD5(s string) string { + return fmt.Sprintf("%x", md5.Sum([]byte(s))) +} + +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return strconv.FormatInt(bytes, 10) + " B" + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +func calculateSHA(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hash.Sum(nil)), nil +}