diff --git a/.gitignore b/.gitignore index 21f7e298..9b8baad0 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ apiv2/localai.gen.go # Generated during build backend-assets/ +/ggml-metal.metal diff --git a/Makefile b/Makefile index eafb8b80..899117a5 100644 --- a/Makefile +++ b/Makefile @@ -3,14 +3,14 @@ GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=351aa714672fb09aa84396868d08934e8e477f25 +GOLLAMA_VERSION?=53d9b5735740f37eec8ed10a50268da9442dfe5e GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all GPT4ALL_VERSION?=f7498c9 -GOGGMLTRANSFORMERS_VERSION?=bd765bb6f3b38a63f915f3725e488aad492eedd4 +GOGGMLTRANSFORMERS_VERSION?=01b8436f44294d0e1267430f9eda4460458cec54 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=1e18b2490e7e32f6b00e16f6a9ec0dd3a3d09266 WHISPER_CPP_VERSION?=57543c169e27312e7546d07ed0d8c6eb806ebc36 -BERT_VERSION?=0548994371f7081e45fcf8d472f3941a12f179aa +BERT_VERSION?=6069103f54b9969c02e789d0fb12a23bd614285f BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f OPENAI_OPENAPI_REPO?=https://github.com/openai/openai-openapi.git OPENAI_OPENAPI_VERSION?= @@ -244,6 +244,9 @@ build: prepare ## Build the project $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ +ifeq ($(BUILD_TYPE),metal) + cp go-llama/build/bin/ggml-metal.metal . +endif dist: build mkdir -p release diff --git a/api/gallery.go b/api/gallery.go index b5b74b0d..a9a87220 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -10,10 +10,12 @@ import ( "os" "strings" "sync" + "time" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -23,9 +25,12 @@ type galleryOp struct { } type galleryOpStatus struct { - Error error `json:"error"` - Processed bool `json:"processed"` - Message string `json:"message"` + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` + Progress float64 `json:"progress"` + TotalFileSize string `json:"file_size"` + DownloadedFileSize string `json:"downloaded_size"` } type galleryApplier struct { @@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } } -func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error { +func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { url, err := req.DecodeURL() if err != nil { return err @@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg config.Files = append(config.Files, req.AdditionalFiles...) - if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil { + if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil { return err } @@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { case <-c.Done(): return case op := <-g.C: - g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) updateError := func(e error) { g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) } - if err := applyGallery(g.modelPath, op.req, cm); err != nil { + if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + displayDownload(fileName, current, total, percentage) + }); err != nil { updateError(err) continue } - g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) + g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) } } }() } +var lastProgress time.Time = time.Now() +var startTime time.Time = time.Now() + +func displayDownload(fileName string, current string, total string, percentage float64) { + currentTime := time.Now() + + if currentTime.Sub(lastProgress) >= 5*time.Second { + + lastProgress = currentTime + + // calculate ETA based on percentage and elapsed time + var eta time.Duration + if percentage > 0 { + elapsed := currentTime.Sub(startTime) + eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) + } + + if total != "" { + log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) + } else { + log.Debug().Msgf("Downloading: %s", current) + } + } +} + func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { dat, err := os.ReadFile(s) if err != nil { @@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm); err != nil { + if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { return err } } return nil } + func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { var requests []ApplyGalleryModelRequest err := json.Unmarshal([]byte(s), &requests) @@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm); err != nil { + if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { return err } } diff --git a/examples/README.md b/examples/README.md index 2285ed19..d86cd1f7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -106,6 +106,16 @@ Shows how to integrate with `Langchain` and `Chroma` to enable question answerin [Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-chroma/) +### Telegram bot + +_by [@mudler](https://github.com/mudler) + +![Screenshot from 2023-06-09 00-36-26](https://github.com/go-skynet/LocalAI/assets/2420543/e98b4305-fa2d-41cf-9d2f-1bb2d75ca902) + +Use LocalAI to power a Telegram bot assistant, with Image generation and audio support! + +[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/telegram-bot/) + ### Template for Runpod.io _by [@fHachenberg](https://github.com/fHachenberg)_ diff --git a/examples/telegram-bot/README.md b/examples/telegram-bot/README.md new file mode 100644 index 00000000..d0ab0dfd --- /dev/null +++ b/examples/telegram-bot/README.md @@ -0,0 +1,30 @@ +## Telegram bot + +![Screenshot from 2023-06-09 00-36-26](https://github.com/go-skynet/LocalAI/assets/2420543/e98b4305-fa2d-41cf-9d2f-1bb2d75ca902) + +This example uses a fork of [chatgpt-telegram-bot](https://github.com/karfly/chatgpt_telegram_bot) to deploy a telegram bot with LocalAI instead of OpenAI. + +```bash +# Clone LocalAI +git clone https://github.com/go-skynet/LocalAI + +cd LocalAI/examples/telegram-bot + +git clone https://github.com/mudler/chatgpt_telegram_bot + +cp -rf docker-compose.yml chatgpt_telegram_bot + +cd chatgpt_telegram_bot + +mv config/config.example.yml config/config.yml +mv config/config.example.env config/config.env + +# Edit config/config.yml to set the telegram bot token +vim config/config.yml + +# run the bot +docker-compose --env-file config/config.env up --build +``` + +Note: LocalAI is configured to download `gpt4all-j` in place of `gpt-3.5-turbo` and `stablediffusion` for image generation at the first start. Download size is >6GB, if your network connection is slow, adapt the `docker-compose.yml` file healthcheck section accordingly (replace `20m`, for instance with `1h`, etc.). +To configure models manually, comment the `PRELOAD_MODELS` environment variable in the `docker-compose.yml` file and see for instance the [chatbot-ui-manual example](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui-manual) `model` directory. \ No newline at end of file diff --git a/examples/telegram-bot/docker-compose.yml b/examples/telegram-bot/docker-compose.yml new file mode 100644 index 00000000..3aea6ebe --- /dev/null +++ b/examples/telegram-bot/docker-compose.yml @@ -0,0 +1,38 @@ +version: "3" + +services: + api: + image: quay.io/go-skynet/local-ai:v1.18.0-ffmpeg + # As initially LocalAI will download the models defined in PRELOAD_MODELS + # you might need to tweak the healthcheck values here according to your network connection. + # Here we give a timespan of 20m to download all the required files. + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/readyz"] + interval: 1m + timeout: 20m + retries: 20 + ports: + - 8080:8080 + environment: + - DEBUG=true + - MODELS_PATH=/models + - IMAGE_PATH=/tmp + # You can preload different models here as well. + # See: https://github.com/go-skynet/model-gallery + - 'PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}, {"url": "github:go-skynet/model-gallery/stablediffusion.yaml"}, {"url": "github:go-skynet/model-gallery/whisper-base.yaml", "name": "whisper-1"}]' + volumes: + - ./models:/models:cached + command: ["/usr/bin/local-ai" ] + chatgpt_telegram_bot: + container_name: chatgpt_telegram_bot + command: python3 bot/bot.py + restart: always + environment: + - OPENAI_API_KEY=sk---anystringhere + - OPENAI_API_BASE=http://api:8080/v1 + build: + context: "." + dockerfile: Dockerfile + depends_on: + api: + condition: service_healthy diff --git a/go.mod b/go.mod index 66a4012f..d9488013 100644 --- a/go.mod +++ b/go.mod @@ -4,27 +4,26 @@ go 1.19 require ( github.com/deepmap/oapi-codegen v1.12.4 - github.com/donomii/go-rwkv.cpp v0.0.0-20230606181754-d5f48f6d607a + github.com/donomii/go-rwkv.cpp v0.0.0-20230609132458-d2b25a4bb148 github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 github.com/go-audio/wav v1.1.0 - github.com/go-chi/chi/v5 v5.0.8 github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa - github.com/go-skynet/go-bert.cpp v0.0.0-20230531070950-0548994371f7 - github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230606131358-bd765bb6f3b3 - github.com/go-skynet/go-llama.cpp v0.0.0-20230607123950-351aa714672f + github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9 + github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230607102637-dabd6cd7b789 + github.com/go-skynet/go-llama.cpp v0.0.0-20230609233637-a12ce511c063 github.com/gofiber/fiber/v2 v2.46.0 github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 github.com/imdario/mergo v0.3.16 github.com/mitchellh/mapstructure v1.5.0 github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af - github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230605194130-266f13aee9d8 + github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230610141538-a9c2f473032f github.com/onsi/ginkgo/v2 v2.10.0 github.com/onsi/gomega v1.27.8 github.com/otiai10/openaigo v1.1.0 github.com/rs/zerolog v1.29.1 - github.com/sashabaranov/go-openai v1.10.0 - github.com/tmc/langchaingo v0.0.0-20230605114752-4afed6d7be4a + github.com/sashabaranov/go-openai v1.10.1 + github.com/tmc/langchaingo v0.0.0-20230610024316-06cb7b57ea80 github.com/urfave/cli/v2 v2.25.5 github.com/valyala/fasthttp v1.47.0 github.com/vmware-tanzu/carvel-ytt v0.45.2 diff --git a/go.sum b/go.sum index 36d3169e..1b13cbc3 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,6 @@ github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= -github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= -github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= @@ -147,8 +145,7 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= -github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= +github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= @@ -174,8 +171,8 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.10.0 h1:uUD3EOKDdGa6geMVbe2Trj9/ckF9sCV5jpQM19f7GM8= -github.com/sashabaranov/go-openai v1.10.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ= +github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8= github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= @@ -201,8 +198,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= -github.com/tmc/langchaingo v0.0.0-20230605114752-4afed6d7be4a h1:YtKJTKbM3qu60+ZxLtyeCl0RvdG7LKbyF8TT7nzV6Gg= -github.com/tmc/langchaingo v0.0.0-20230605114752-4afed6d7be4a/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI= +github.com/tmc/langchaingo v0.0.0-20230610024316-06cb7b57ea80 h1:Y+a76dNVbdWduw3gznOr2O2OSZkdwDRYPKTDpG/vM9I= +github.com/tmc/langchaingo v0.0.0-20230610024316-06cb7b57ea80/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index f4f86ae7..14a7d6ac 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -3,10 +3,12 @@ package gallery import ( "crypto/sha256" "fmt" + "hash" "io" "net/http" "os" "path/filepath" + "strconv" "github.com/imdario/mergo" "github.com/rs/zerolog/log" @@ -93,7 +95,7 @@ func verifyPath(path, basePath string) error { return inTrustedRoot(c, basePath) } -func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error { +func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { @@ -168,27 +170,25 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st } defer outFile.Close() + progress := &progressWriter{ + fileName: file.Filename, + total: resp.ContentLength, + hash: sha256.New(), + downloadStatus: downloadStatus, + } + _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + if file.SHA256 != "" { - log.Debug().Msgf("Download and verifying %q", file.Filename) - - // Write file content and calculate SHA - hash := sha256.New() - _, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } - // Verify SHA - calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) + calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) if calculatedSHA != file.SHA256 { return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) } } else { log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) - _, err = io.Copy(outFile, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } } log.Debug().Msgf("File %q downloaded and verified", file.Filename) @@ -255,6 +255,42 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st return nil } +type progressWriter struct { + fileName string + total int64 + written int64 + downloadStatus func(string, string, string, float64) + hash hash.Hash +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + n, err = pw.hash.Write(p) + pw.written += int64(n) + + if pw.total > 0 { + percentage := float64(pw.written) / float64(pw.total) * 100 + //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + } else { + pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) + } + + return +} + +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return strconv.FormatInt(bytes, 10) + " B" + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + func calculateSHA(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index f0e580e9..343bf6ab 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -19,7 +19,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "", c, map[string]interface{}{}) + err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -45,7 +45,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{}) + err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -61,7 +61,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}) + err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -87,7 +87,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}) + err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).To(HaveOccurred()) }) })