mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 02:24:59 +00:00
fix: be consistent in downloading files, check for scanner errors (#3108)
* fix(downloader): be consistent in downloading files This PR puts some order in the downloader such as functions are re-used across several places. This fixes an issue with having uri's inside the model YAML file, it would resolve to MD5 rather then using the filename Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(scanner): do raise error only if unsafeFiles are found Fixes: https://github.com/mudler/LocalAI/issues/3114 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
fc50a90f6a
commit
a36b721ca6
13 changed files with 173 additions and 171 deletions
|
@ -2,12 +2,10 @@ package downloader
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
@ -28,13 +26,16 @@ const (
|
|||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
LocalPrefix = "file://"
|
||||
)
|
||||
|
||||
func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []byte) error) error {
|
||||
url = ConvertURL(url)
|
||||
type URI string
|
||||
|
||||
if strings.HasPrefix(url, "file://") {
|
||||
rawURL := strings.TrimPrefix(url, "file://")
|
||||
func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error {
|
||||
url := uri.ResolveURL()
|
||||
|
||||
if strings.HasPrefix(url, LocalPrefix) {
|
||||
rawURL := strings.TrimPrefix(url, LocalPrefix)
|
||||
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
|
||||
resolvedFile, err := filepath.EvalSymlinks(rawURL)
|
||||
if err != nil {
|
||||
|
@ -78,24 +79,54 @@ func DownloadAndUnmarshal(url string, basePath string, f func(url string, i []by
|
|||
return f(url, body)
|
||||
}
|
||||
|
||||
func LooksLikeURL(s string) bool {
|
||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
||||
strings.HasPrefix(s, HuggingFacePrefix) ||
|
||||
strings.HasPrefix(s, GithubURI) ||
|
||||
strings.HasPrefix(s, OllamaPrefix) ||
|
||||
strings.HasPrefix(s, OCIPrefix) ||
|
||||
strings.HasPrefix(s, GithubURI2)
|
||||
func (u URI) FilenameFromUrl() (string, error) {
|
||||
f, err := filenameFromUrl(string(u))
|
||||
if err != nil || f == "" {
|
||||
f = utils.MD5(string(u))
|
||||
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") {
|
||||
f = f + ".yaml"
|
||||
}
|
||||
err = nil
|
||||
}
|
||||
|
||||
return f, err
|
||||
}
|
||||
|
||||
func LooksLikeOCI(s string) bool {
|
||||
return strings.HasPrefix(s, OCIPrefix) || strings.HasPrefix(s, OllamaPrefix)
|
||||
func filenameFromUrl(urlstr string) (string, error) {
|
||||
// strip anything after @
|
||||
if strings.Contains(urlstr, "@") {
|
||||
urlstr = strings.Split(urlstr, "@")[0]
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlstr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error due to parsing url: %w", err)
|
||||
}
|
||||
x, err := url.QueryUnescape(u.EscapedPath())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error due to escaping: %w", err)
|
||||
}
|
||||
return filepath.Base(x), nil
|
||||
}
|
||||
|
||||
func ConvertURL(s string) string {
|
||||
func (u URI) LooksLikeURL() bool {
|
||||
return strings.HasPrefix(string(u), HTTPPrefix) ||
|
||||
strings.HasPrefix(string(u), HTTPSPrefix) ||
|
||||
strings.HasPrefix(string(u), HuggingFacePrefix) ||
|
||||
strings.HasPrefix(string(u), GithubURI) ||
|
||||
strings.HasPrefix(string(u), OllamaPrefix) ||
|
||||
strings.HasPrefix(string(u), OCIPrefix) ||
|
||||
strings.HasPrefix(string(u), GithubURI2)
|
||||
}
|
||||
|
||||
func (s URI) LooksLikeOCI() bool {
|
||||
return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix)
|
||||
}
|
||||
|
||||
func (s URI) ResolveURL() string {
|
||||
switch {
|
||||
case strings.HasPrefix(s, GithubURI2):
|
||||
repository := strings.Replace(s, GithubURI2, "", 1)
|
||||
case strings.HasPrefix(string(s), GithubURI2):
|
||||
repository := strings.Replace(string(s), GithubURI2, "", 1)
|
||||
|
||||
repoParts := strings.Split(repository, "@")
|
||||
branch := "main"
|
||||
|
@ -110,8 +141,8 @@ func ConvertURL(s string) string {
|
|||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
case strings.HasPrefix(s, GithubURI):
|
||||
parts := strings.Split(s, ":")
|
||||
case strings.HasPrefix(string(s), GithubURI):
|
||||
parts := strings.Split(string(s), ":")
|
||||
repoParts := strings.Split(parts[1], "@")
|
||||
branch := "main"
|
||||
|
||||
|
@ -125,8 +156,8 @@ func ConvertURL(s string) string {
|
|||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
case strings.HasPrefix(s, HuggingFacePrefix):
|
||||
repository := strings.Replace(s, HuggingFacePrefix, "", 1)
|
||||
case strings.HasPrefix(string(s), HuggingFacePrefix):
|
||||
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1)
|
||||
// convert repository to a full URL.
|
||||
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
|
||||
owner := strings.Split(repository, "/")[0]
|
||||
|
@ -144,7 +175,7 @@ func ConvertURL(s string) string {
|
|||
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
|
||||
}
|
||||
|
||||
return s
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func removePartialFile(tmpFilePath string) error {
|
||||
|
@ -161,9 +192,9 @@ func removePartialFile(tmpFilePath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func DownloadFile(url string, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
url = ConvertURL(url)
|
||||
if LooksLikeOCI(url) {
|
||||
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
|
||||
url := uri.ResolveURL()
|
||||
if uri.LooksLikeOCI() {
|
||||
progressStatus := func(desc ocispec.Descriptor) io.Writer {
|
||||
return &progressWriter{
|
||||
fileName: filePath,
|
||||
|
@ -298,37 +329,6 @@ func DownloadFile(url string, filePath, sha string, fileN, total int, downloadSt
|
|||
return nil
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func GetBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
|
@ -356,42 +356,3 @@ func calculateSHA(filePath string) (string, error) {
|
|||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
||||
type HuggingFaceScanResult struct {
|
||||
RepositoryId string `json:"repositoryId"`
|
||||
Revision string `json:"revision"`
|
||||
HasUnsafeFiles bool `json:"hasUnsafeFile"`
|
||||
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"`
|
||||
DangerousPickles []string `json:"dangerousPickles"`
|
||||
ScansDone bool `json:"scansDone"`
|
||||
}
|
||||
|
||||
var ErrNonHuggingFaceFile = errors.New("not a huggingface repo")
|
||||
var ErrUnsafeFilesFound = errors.New("unsafe files found")
|
||||
|
||||
func HuggingFaceScan(uri string) (*HuggingFaceScanResult, error) {
|
||||
cleanParts := strings.Split(ConvertURL(uri), "/")
|
||||
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" {
|
||||
return nil, ErrNonHuggingFaceFile
|
||||
}
|
||||
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if results.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode)
|
||||
}
|
||||
scanResult := &HuggingFaceScanResult{}
|
||||
bodyBytes, err := io.ReadAll(results.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(bodyBytes, scanResult)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if scanResult.HasUnsafeFiles {
|
||||
return scanResult, ErrUnsafeFilesFound
|
||||
}
|
||||
return scanResult, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue