mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 06:25:00 +00:00
refactor: consolidate usage of GetURI (#674)
Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
parent
d18f85df46
commit
78f3c3da48
10 changed files with 110 additions and 137 deletions
|
@ -23,9 +23,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||
}
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
var config Config
|
||||
|
||||
err := model.Get(&config)
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -79,7 +77,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
|
|||
func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) {
|
||||
var models []*GalleryModel = []*GalleryModel{}
|
||||
|
||||
err := utils.GetURI(gallery.URL, func(d []byte) error {
|
||||
err := utils.GetURI(gallery.URL, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &models)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -65,6 +65,17 @@ type PromptTemplate struct {
|
|||
Content string `yaml:"content"`
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
var config Config
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile(filePath string) (*Config, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
|
|
|
@ -1,14 +1,5 @@
|
|||
package gallery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// GalleryModel is the struct used to represent a model in the gallery returned by the endpoint.
|
||||
// It is used to install the model by resolving the URL and downloading the files.
|
||||
// The other fields are used to override the configuration of the model.
|
||||
|
@ -34,52 +25,3 @@ type GalleryModel struct {
|
|||
const (
|
||||
githubURI = "github:"
|
||||
)
|
||||
|
||||
func (request GalleryModel) DecodeURL() (string, error) {
|
||||
input := request.URL
|
||||
var rawURL string
|
||||
|
||||
if strings.HasPrefix(input, githubURI) {
|
||||
parts := strings.Split(input, ":")
|
||||
repoParts := strings.Split(parts[1], "@")
|
||||
branch := "main"
|
||||
|
||||
if len(repoParts) > 1 {
|
||||
branch = repoParts[1]
|
||||
}
|
||||
|
||||
repoPath := strings.Split(repoParts[0], "/")
|
||||
org := repoPath[0]
|
||||
project := repoPath[1]
|
||||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
||||
// Handle regular URLs
|
||||
u, err := url.Parse(input)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
rawURL = u.String()
|
||||
// check if it's a file path
|
||||
} else if strings.HasPrefix(input, "file://") {
|
||||
return input, nil
|
||||
} else {
|
||||
|
||||
return "", fmt.Errorf("invalid URL format: %s", input)
|
||||
}
|
||||
|
||||
return rawURL, nil
|
||||
}
|
||||
|
||||
// Get fetches a model from a URL and unmarshals it into a struct
|
||||
func (request GalleryModel) Get(i interface{}) error {
|
||||
url, err := request.DecodeURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return utils.GetURI(url, func(d []byte) error {
|
||||
return yaml.Unmarshal(d, i)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,37 +6,13 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type example struct {
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
|
||||
var _ = Describe("Gallery API tests", func() {
|
||||
|
||||
Context("requests", func() {
|
||||
It("parses github with a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
var e example
|
||||
err := req.Get(&e)
|
||||
e, err := GetGalleryConfigFromURL(req.URL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
It("parses URLS", func() {
|
||||
req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
|
||||
str, err := req.DecodeURL()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,12 +1,34 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetURI(url string, f func(i []byte) error) error {
|
||||
const (
|
||||
githubURI = "github:"
|
||||
)
|
||||
|
||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
if strings.HasPrefix(url, githubURI) {
|
||||
parts := strings.Split(url, ":")
|
||||
repoParts := strings.Split(parts[1], "@")
|
||||
branch := "main"
|
||||
|
||||
if len(repoParts) > 1 {
|
||||
branch = repoParts[1]
|
||||
}
|
||||
|
||||
repoPath := strings.Split(repoParts[0], "/")
|
||||
org := repoPath[0]
|
||||
project := repoPath[1]
|
||||
projectPath := strings.Join(repoPath[2:], "/")
|
||||
|
||||
url = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(url, "file://") {
|
||||
rawURL := strings.TrimPrefix(url, "file://")
|
||||
// Read the response body
|
||||
|
@ -16,7 +38,7 @@ func GetURI(url string, f func(i []byte) error) error {
|
|||
}
|
||||
|
||||
// Unmarshal YAML data into a struct
|
||||
return f(body)
|
||||
return f(url, body)
|
||||
}
|
||||
|
||||
// Send a GET request to the URL
|
||||
|
@ -33,5 +55,5 @@ func GetURI(url string, f func(i []byte) error) error {
|
|||
}
|
||||
|
||||
// Unmarshal YAML data into a struct
|
||||
return f(body)
|
||||
return f(url, body)
|
||||
}
|
||||
|
|
36
pkg/utils/uri_test.go
Normal file
36
pkg/utils/uri_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package utils_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/utils"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Gallery API tests", func() {
|
||||
Context("URI", func() {
|
||||
It("parses github with a branch", func() {
|
||||
Expect(
|
||||
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
It("parses github without a branch", func() {
|
||||
Expect(
|
||||
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
It("parses github with urls", func() {
|
||||
Expect(
|
||||
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
|
||||
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
|
||||
return nil
|
||||
}),
|
||||
).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
13
pkg/utils/utils_suite_test.go
Normal file
13
pkg/utils/utils_suite_test.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestUtils(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Utils test suite")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue