feat(oci): support OCI images and Ollama models (#2628)

* Support specifying oci:// and ollama:// for model URLs

Fixes: https://github.com/mudler/LocalAI/issues/2527
Fixes: https://github.com/mudler/LocalAI/issues/1028

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Lower watcher warnings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Allow to install ollama models from CLI

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixup tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Do not keep file ownership

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Skip test on darwin

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-06-22 08:17:41 +02:00 committed by GitHub
parent e265a618d9
commit f569237a50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 638 additions and 97 deletions

52
pkg/oci/blob.go Normal file
View file

@ -0,0 +1,52 @@
package oci
import (
"context"
"fmt"
"io"
"os"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
oras "oras.land/oras-go/v2"
"oras.land/oras-go/v2/registry/remote"
)
func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
// 0. Create a file store for the output
fs, err := os.Create(dst)
if err != nil {
return err
}
defer fs.Close()
// 1. Connect to a remote repository
ctx := context.Background()
repo, err := remote.NewRepository(r)
if err != nil {
return fmt.Errorf("failed to create repository: %v", err)
}
repo.SkipReferrersGC = true
// https://github.com/oras-project/oras/blob/main/cmd/oras/internal/option/remote.go#L364
// https://github.com/oras-project/oras/blob/main/cmd/oras/root/blob/fetch.go#L136
desc, reader, err := oras.Fetch(ctx, repo.Blobs(), reference, oras.DefaultFetchOptions)
if err != nil {
return fmt.Errorf("failed to fetch image: %v", err)
}
if statusReader != nil {
// 3. Write the file to the file store
_, err = io.Copy(io.MultiWriter(fs, statusReader(desc)), reader)
if err != nil {
return err
}
} else {
_, err = io.Copy(fs, reader)
if err != nil {
return err
}
}
return nil
}

21
pkg/oci/blob_test.go Normal file
View file

@ -0,0 +1,21 @@
package oci_test
import (
"os"
. "github.com/go-skynet/LocalAI/pkg/oci" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("OCI", func() {
Context("pulling images", func() {
It("should fetch blobs correctly", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
err = FetchImageBlob("registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})
})

153
pkg/oci/image.go Normal file
View file

@ -0,0 +1,153 @@
package oci
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"syscall"
"time"
"github.com/containerd/containerd/archive"
registrytypes "github.com/docker/docker/api/types/registry"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/logs"
"github.com/google/go-containerregistry/pkg/name"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/google/go-containerregistry/pkg/v1/mutate"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
)
// ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117
type staticAuth struct {
auth *registrytypes.AuthConfig
}
func (s staticAuth) Authorization() (*authn.AuthConfig, error) {
if s.auth == nil {
return nil, nil
}
return &authn.AuthConfig{
Username: s.auth.Username,
Password: s.auth.Password,
Auth: s.auth.Auth,
IdentityToken: s.auth.IdentityToken,
RegistryToken: s.auth.RegistryToken,
}, nil
}
var defaultRetryBackoff = remote.Backoff{
Duration: 1.0 * time.Second,
Factor: 3.0,
Jitter: 0.1,
Steps: 3,
}
var defaultRetryPredicate = func(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || strings.Contains(err.Error(), "connection refused") {
logs.Warn.Printf("retrying %v", err)
return true
}
return false
}
// ExtractOCIImage will extract a given targetImage into a given targetDestination
func ExtractOCIImage(img v1.Image, targetDestination string) error {
reader := mutate.Extract(img)
_, err := archive.Apply(context.Background(), targetDestination, reader, archive.WithNoSameOwner())
return err
}
func ParseImageParts(image string) (tag, repository, dstimage string) {
tag = "latest"
repository = "library"
if strings.Contains(image, ":") {
parts := strings.Split(image, ":")
image = parts[0]
tag = parts[1]
}
if strings.Contains("/", image) {
parts := strings.Split(image, "/")
repository = parts[0]
image = parts[1]
}
dstimage = image
return tag, repository, image
}
// GetImage if returns the proper image to pull with transport and auth
// tries local daemon first and then fallbacks into remote
// if auth is nil, it will try to use the default keychain https://github.com/google/go-containerregistry/tree/main/pkg/authn#tldr-for-consumers-of-this-package
func GetImage(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (v1.Image, error) {
var platform *v1.Platform
var image v1.Image
var err error
if targetPlatform != "" {
platform, err = v1.ParsePlatform(targetPlatform)
if err != nil {
return image, err
}
} else {
platform, err = v1.ParsePlatform(fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH))
if err != nil {
return image, err
}
}
ref, err := name.ParseReference(targetImage)
if err != nil {
return image, err
}
if t == nil {
t = http.DefaultTransport
}
tr := transport.NewRetry(t,
transport.WithRetryBackoff(defaultRetryBackoff),
transport.WithRetryPredicate(defaultRetryPredicate),
)
opts := []remote.Option{
remote.WithTransport(tr),
remote.WithPlatform(*platform),
}
if auth != nil {
opts = append(opts, remote.WithAuth(staticAuth{auth}))
} else {
opts = append(opts, remote.WithAuthFromKeychain(authn.DefaultKeychain))
}
image, err = remote.Image(ref, opts...)
return image, err
}
func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.AuthConfig, t http.RoundTripper) (int64, error) {
var size int64
var img v1.Image
var err error
img, err = GetImage(targetImage, targetPlatform, auth, t)
if err != nil {
return size, err
}
layers, _ := img.Layers()
for _, layer := range layers {
s, _ := layer.Size()
size += s
}
return size, nil
}

37
pkg/oci/image_test.go Normal file
View file

@ -0,0 +1,37 @@
package oci_test
import (
"os"
"runtime"
. "github.com/go-skynet/LocalAI/pkg/oci" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("OCI", func() {
Context("when template is loaded successfully", func() {
It("should evaluate the template correctly", func() {
if runtime.GOOS == "darwin" {
Skip("Skipping test on darwin")
}
imageName := "alpine"
img, err := GetImage(imageName, "", nil, nil)
Expect(err).NotTo(HaveOccurred())
size, err := GetOCIImageSize(imageName, "", nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(size).ToNot(Equal(int64(0)))
// Create tempdir
dir, err := os.MkdirTemp("", "example")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(dir)
err = ExtractOCIImage(img, dir)
Expect(err).NotTo(HaveOccurred())
})
})
})

13
pkg/oci/oci_suite_test.go Normal file
View file

@ -0,0 +1,13 @@
package oci_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestOCI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "OCI test suite")
}

88
pkg/oci/ollama.go Normal file
View file

@ -0,0 +1,88 @@
package oci
import (
"encoding/json"
"fmt"
"io"
"net/http"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
// Define the main struct for the JSON data
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Config `json:"config"`
Layers []LayerDetail `json:"layers"`
}
// Define the struct for the "config" section
type Config struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int `json:"size"`
}
// Define the struct for each item in the "layers" array
type LayerDetail struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int `json:"size"`
}
func OllamaModelManifest(image string) (*Manifest, error) {
// parse the repository and tag from `image`. `image` should be for e.g. gemma:2b, or foobar/gemma:2b
// if there is a : in the image, then split it
// if there is no : in the image, then assume it is the latest tag
tag, repository, image := ParseImageParts(image)
// get e.g. https://registry.ollama.ai/v2/library/llama3/manifests/latest
req, err := http.NewRequest("GET", "https://registry.ollama.ai/v2/"+repository+"/"+image+"/manifests/"+tag, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
// parse the JSON response
var manifest Manifest
err = json.NewDecoder(resp.Body).Decode(&manifest)
if err != nil {
return nil, err
}
return &manifest, nil
}
func OllamaModelBlob(image string) (string, error) {
manifest, err := OllamaModelManifest(image)
if err != nil {
return "", err
}
// find a application/vnd.ollama.image.model in the mediaType
for _, layer := range manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
return layer.Digest, nil
}
}
return "", nil
}
func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
_, repository, imageNoTag := ParseImageParts(image)
blobID, err := OllamaModelBlob(image)
if err != nil {
return err
}
return FetchImageBlob(fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
}

21
pkg/oci/ollama_test.go Normal file
View file

@ -0,0 +1,21 @@
package oci_test
import (
"os"
. "github.com/go-skynet/LocalAI/pkg/oci" // Update with your module path
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("OCI", func() {
Context("ollama", func() {
It("pulls model files", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
err = OllamaFetchModel("gemma:2b", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})
})