diff --git a/.env b/.env
index b05dac68..73e3174d 100644
--- a/.env
+++ b/.env
@@ -1,5 +1,30 @@
+## Set number of threads.
+## Note: prefer the number of physical cores. Overbooking the CPU degrades performance notably.
# THREADS=14
+
+## Specify a different bind address (defaults to ":8080")
+# ADDRESS=127.0.0.1:8080
+
+## Default models context size
# CONTEXT_SIZE=512
+
+## Default path for models
MODELS_PATH=/models
+
+## Enable debug mode
# DEBUG=true
-# BUILD_TYPE=generic
+
+## Specify a build type. Available: cublas, openblas.
+# BUILD_TYPE=openblas
+
+## Uncomment and set to false to disable rebuilding from source
+# REBUILD=false
+
+## Enable image generation with stablediffusion (requires REBUILD=true)
+# GO_TAGS=stablediffusion
+
+## Path where to store generated images
+# IMAGE_PATH=/tmp
+
+## Specify a default upload limit in MB (whisper)
+# UPLOAD_LIMIT
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 00000000..a7f77221
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,31 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: bug
+assignees: mudler
+
+---
+
+
+
+**LocalAI version:**
+
+
+**Environment, CPU architecture, OS, and Version:**
+
+
+**Describe the bug**
+
+
+**To Reproduce**
+
+
+**Expected behavior**
+
+
+**Logs**
+
+
+**Additional context**
+
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 00000000..acc65c80
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,8 @@
+blank_issues_enabled: false
+contact_links:
+ - name: Community Support
+ url: https://github.com/go-skynet/LocalAI/discussions
+ about: Please ask and answer questions here.
+ - name: Discord
+ url: https://discord.gg/uJAeKSAGDy
+ about: Join our community on Discord!
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 00000000..c184aae9
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,22 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: enhancement
+assignees: mudler
+
+---
+
+
+
+**Is your feature request related to a problem? Please describe.**
+
+
+**Describe the solution you'd like**
+
+
+**Describe alternatives you've considered**
+
+
+**Additional context**
+
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 00000000..2318ad47
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,23 @@
+**Description**
+
+This PR fixes #
+
+**Notes for Reviewers**
+
+
+**[Signed commits](../CONTRIBUTING.md#signing-off-on-commits-developer-certificate-of-origin)**
+- [ ] Yes, I signed my commits.
+
+
+
\ No newline at end of file
diff --git a/.github/stale.yml b/.github/stale.yml
new file mode 100644
index 00000000..af48bade
--- /dev/null
+++ b/.github/stale.yml
@@ -0,0 +1,18 @@
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 45
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 10
+# Issues with these labels will never be considered stale
+exemptLabels:
+ - issue/willfix
+# Label to use when marking an issue as stale
+staleLabel: issue/stale
+# Comment to post when marking an issue as stale. Set to `false` to disable
+markComment: >
+ This issue has been automatically marked as stale because it has not had
+ recent activity. It will be closed if no further activity occurs. Thank you
+ for your contributions.
+# Comment to post when closing a stale issue. Set to `false` to disable
+closeComment: >
+ This issue is being automatically closed due to inactivity.
+ However, you may choose to reopen this issue.
\ No newline at end of file
diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml
index d83f58d3..eeada322 100644
--- a/.github/workflows/image.yml
+++ b/.github/workflows/image.yml
@@ -9,6 +9,10 @@ on:
tags:
- '*'
+concurrency:
+ group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
jobs:
docker:
runs-on: ubuntu-latest
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 17e3c809..48fe2cf8 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -9,6 +9,10 @@ on:
tags:
- '*'
+concurrency:
+ group: ci-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
+ cancel-in-progress: true
+
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
diff --git a/.vscode/launch.json b/.vscode/launch.json
index e8d94825..cf4fb924 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -2,7 +2,20 @@
"version": "0.2.0",
"configurations": [
{
- "name": "Launch Go",
+ "name": "Python: Current File",
+ "type": "python",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal",
+ "justMyCode": false,
+ "cwd": "${workspaceFolder}/examples/langchain-chroma",
+ "env": {
+ "OPENAI_API_BASE": "http://localhost:8080/v1",
+ "OPENAI_API_KEY": "abc"
+ }
+ },
+ {
+ "name": "Launch LocalAI API",
"type": "go",
"request": "launch",
"mode": "debug",
@@ -11,8 +24,8 @@
"api"
],
"env": {
- "C_INCLUDE_PATH": "/workspace/go-llama:/workspace/go-gpt4all-j:/workspace/go-gpt2",
- "LIBRARY_PATH": "/workspace/go-llama:/workspace/go-gpt4all-j:/workspace/go-gpt2",
+ "C_INCLUDE_PATH": "${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz",
+ "LIBRARY_PATH": "$${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz",
"DEBUG": "true"
}
}
diff --git a/Dockerfile b/Dockerfile
index 52869bb9..27ab3800 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,10 +1,11 @@
ARG GO_VERSION=1.20
ARG BUILD_TYPE=
FROM golang:$GO_VERSION
+ENV REBUILD=true
WORKDIR /build
RUN apt-get update && apt-get install -y cmake libgomp1 libopenblas-dev libopenblas-base libopencv-dev libopencv-core-dev libopencv-core4.5
COPY . .
RUN ln -s /usr/include/opencv4/opencv2/ /usr/include/opencv2
-RUN make prepare-sources
+RUN make build
EXPOSE 8080
ENTRYPOINT [ "/build/entrypoint.sh" ]
diff --git a/Makefile b/Makefile
index 523f1523..393e7ce3 100644
--- a/Makefile
+++ b/Makefile
@@ -3,20 +3,19 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
-GOLLAMA_VERSION?=b7bbefbe0b84262e003387a605842bdd0d099300
+GOLLAMA_VERSION?=ccf23adfb278c0165d388389a5d60f3fe38e4854
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
-GPT4ALL_VERSION?=bce2b3025b360af73091da0128b1e91f9bc94f9f
+GPT4ALL_VERSION?=914519e772fd78c15691dcd0b8bac60d6af514ec
GOGPT2_VERSION?=7bff56f0224502c1c9ed6258d2a17e8084628827
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
-WHISPER_CPP_VERSION?=95b02d76b04d18e4ce37ed8353a1f0797f1717ea
+WHISPER_CPP_VERSION?=041be06d5881d3c759cc4ed45d655804361237cd
BERT_VERSION?=cea1ed76a7f48ef386a8e369f6c82c48cdf2d551
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1
BUILD_TYPE?=
CGO_LDFLAGS?=
CUDA_LIBPATH?=/usr/local/cuda/lib64/
STABLEDIFFUSION_VERSION?=c0748eca3642d58bcf9521108bcee46959c647dc
-
GO_TAGS?=
OPTIONAL_TARGETS?=
@@ -36,9 +35,9 @@ endif
ifeq ($(BUILD_TYPE),cublas)
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH)
+ export LLAMA_CUBLAS=1
endif
-
ifeq ($(GO_TAGS),stablediffusion)
OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a
endif
@@ -66,6 +65,7 @@ gpt4all:
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
+ @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} +
mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h
## BERT embeddings
@@ -211,11 +211,11 @@ test-models/testmodel:
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
wget https://huggingface.co/imxcstar/rwkv-4-raven-ggml/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096-16_Q4_2.bin -O test-models/rwkv
wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json
- cp tests/fixtures/* test-models
+ cp tests/models_fixtures/* test-models
test: prepare test-models/testmodel
- cp tests/fixtures/* test-models
- @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
+ cp tests/models_fixtures/* test-models
+ C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ./pkg
## Help:
help: ## Show this help.
diff --git a/README.md b/README.md
index f07a0796..47996dcd 100644
--- a/README.md
+++ b/README.md
@@ -9,31 +9,38 @@
[](https://discord.gg/uJAeKSAGDy)
-**LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table).
+**LocalAI** is a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run models locally or on-prem with consumer grade hardware, supporting multiple model families that are compatible with the ggml format.
-- OpenAI drop-in alternative REST API
+For a list of the supported model families, please see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table).
+
+In a nutshell:
+
+- Local, OpenAI drop-in alternative REST API. You own your data.
+- NO GPU required. NO Internet access is required either. Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. [See building instructions](https://github.com/go-skynet/LocalAI#cublas).
- Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental)
- Once loaded the first time, it keep models loaded in memory for faster inference
-- Support for prompt templates
- Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
-LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp), [ggml](https://github.com/ggerganov/ggml), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) for audio transcriptions, and [bert.cpp](https://github.com/skeskinen/bert.cpp) for embedding.
-
-See [examples on how to integrate LocalAI](https://github.com/go-skynet/LocalAI/tree/master/examples/).
-
+See the [usage](https://github.com/go-skynet/LocalAI#usage) and [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/) sections to learn how to use LocalAI.
### How does it work?
-
+
+LocalAI is an API written in Go that serves as an OpenAI shim, enabling software already developed with OpenAI SDKs to seamlessly integrate with LocalAI. It can be effortlessly implemented as a substitute, even on consumer-grade hardware. This capability is achieved by employing various C++ backends, including [ggml](https://github.com/ggerganov/ggml), to perform inference on LLMs using both CPU and, if desired, GPU.
+
+LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp), [ggml](https://github.com/ggerganov/ggml), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) for audio transcriptions, [bert.cpp](https://github.com/skeskinen/bert.cpp) for embedding and [StableDiffusion-NCN](https://github.com/EdVince/Stable-Diffusion-NCNN) for image generation. See [the model compatibility table](https://github.com/go-skynet/LocalAI#model-compatibility-table) to learn about all the components of LocalAI.
+

## News
+- 21-05-2023: __v1.14.0__ released. Minor updates to the `/models/apply` endpoint, `llama.cpp` backend updated including https://github.com/ggerganov/llama.cpp/pull/1508 which breaks compatibility with older models. `gpt4all` is still compatible with the old format.
+- 19-05-2023: __v1.13.0__ released! 🔥🔥 updates to the `gpt4all` and `llama` backend, consolidated CUDA support ( https://github.com/go-skynet/LocalAI/pull/310 thanks to @bubthegreat and @Thireus ), preliminar support for [installing models via API](https://github.com/go-skynet/LocalAI#advanced-prepare-models-using-the-api).
- 17-05-2023: __v1.12.0__ released! 🔥🔥 Minor fixes, plus CUDA (https://github.com/go-skynet/LocalAI/pull/258) support for `llama.cpp`-compatible models and image generation (https://github.com/go-skynet/LocalAI/pull/272).
- 16-05-2023: 🔥🔥🔥 Experimental support for CUDA (https://github.com/go-skynet/LocalAI/pull/258) in the `llama.cpp` backend and Stable diffusion CPU image generation (https://github.com/go-skynet/LocalAI/pull/272) in `master`.
@@ -116,31 +123,31 @@ Depending on the model you are attempting to run might need more RAM or CPU reso
-| Backend | Compatible models | Completion/Chat endpoint | Audio transcription/Image | Embeddings support | Token stream support | Github | Bindings |
-|-----------------|-----------------------|--------------------------|---------------------|-----------------------------------|----------------------|--------------------------------------------|-------------------------------------------|
-| llama | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes | https://github.com/ggerganov/llama.cpp | https://github.com/go-skynet/go-llama.cpp |
-| gpt4all-llama | Vicuna, Alpaca, LLaMa | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
-| gpt4all-mpt | MPT | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
-| gpt4all-j | GPT4ALL-J | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
-| gpt2 | GPT/NeoX, Cerebras | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| dolly | Dolly | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| redpajama | RedPajama | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| stableLM | StableLM GPT/NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| replit | Replit | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| gptneox | GPT NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| starcoder | Starcoder | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
-| bloomz | Bloom | yes | no | no | no | https://github.com/NouamaneTazi/bloomz.cpp | https://github.com/go-skynet/bloomz.cpp |
-| rwkv | RWKV | yes | no | no | yes | https://github.com/saharNooby/rwkv.cpp | https://github.com/donomii/go-rwkv.cpp |
-| bert-embeddings | bert | no | no | yes | no | https://github.com/skeskinen/bert.cpp | https://github.com/go-skynet/go-bert.cpp |
-| whisper | whisper | no | Audio | no | no | https://github.com/ggerganov/whisper.cpp | https://github.com/ggerganov/whisper.cpp |
-| stablediffusion | stablediffusion | no | Image | no | no | https://github.com/EdVince/Stable-Diffusion-NCNN | https://github.com/mudler/go-stable-diffusion |
+| Backend and Bindings | Compatible models | Completion/Chat endpoint | Audio transcription/Image | Embeddings support | Token stream support |
+|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------|
+| [llama](https://github.com/ggerganov/llama.cpp) ([binding](https://github.com/go-skynet/go-llama.cpp)) | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes |
+| [gpt4all-llama](https://github.com/nomic-ai/gpt4all) | Vicuna, Alpaca, LLaMa | yes | no | no | yes |
+| [gpt4all-mpt](https://github.com/nomic-ai/gpt4all) | MPT | yes | no | no | yes |
+| [gpt4all-j](https://github.com/nomic-ai/gpt4all) | GPT4ALL-J | yes | no | no | yes |
+| [gpt2](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | GPT/NeoX, Cerebras | yes | no | no | no |
+| [dolly](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Dolly | yes | no | no | no |
+| [redpajama](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | RedPajama | yes | no | no | no |
+| [stableLM](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | StableLM GPT/NeoX | yes | no | no | no |
+| [replit](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Replit | yes | no | no | no |
+| [gptneox](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | GPT NeoX | yes | no | no | no |
+| [starcoder](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Starcoder | yes | no | no | no |
+| [bloomz](https://github.com/NouamaneTazi/bloomz.cpp) ([binding](https://github.com/go-skynet/bloomz.cpp)) | Bloom | yes | no | no | no |
+| [rwkv](https://github.com/saharNooby/rwkv.cpp) ([binding](https://github.com/donomii/go-rw)) | rwkv | yes | no | no | yes |
+| [bert](https://github.com/skeskinen/bert.cpp) ([binding](https://github.com/go-skynet/go-bert.cpp) | bert | no | no | yes | no |
+| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no |
+| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no |
## Usage
> `LocalAI` comes by default as a container image. You can check out all the available images with corresponding tags [here](https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest).
-The easiest way to run LocalAI is by using `docker-compose`:
+The easiest way to run LocalAI is by using `docker-compose` (to build locally, see [building LocalAI](https://github.com/go-skynet/LocalAI/tree/master#setup)):
```bash
@@ -213,7 +220,25 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso
```
-To build locally, run `make build` (see below).
+### Advanced: prepare models using the API
+
+Instead of installing models manually, you can use the LocalAI API endpoints and a model definition to install programmatically via API models in runtime.
+
+
+
+A curated collection of model files is in the [model-gallery](https://github.com/go-skynet/model-gallery) (work in progress!).
+
+To install for example `gpt4all-j`, you can send a POST call to the `/models/apply` endpoint with the model definition url (`url`) and the name of the model should have in LocalAI (`name`, optional):
+
+```
+curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d '{
+ "url": "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml",
+ "name": "gpt4all-j"
+ }'
+```
+
+
+
### Other examples
@@ -434,7 +459,7 @@ local-ai --models-path [--address ] [--threads
@@ -463,6 +488,8 @@ You should see:
└───────────────────────────────────────────────────┘
```
+Note: the binary inside the image is rebuild at the start of the container to enable CPU optimizations for the execution environment, you can set the environment variable `REBUILD` to `false` to prevent this behavior.
+
### Build locally
@@ -567,6 +594,8 @@ Note: CuBLAS support is experimental, and has not been tested on real HW. please
make BUILD_TYPE=cublas build
```
+More informations available in the upstream PR: https://github.com/ggerganov/llama.cpp/pull/1412
+
### Windows compatibility
@@ -818,6 +847,109 @@ models
+## LocalAI API endpoints
+
+Besides the OpenAI endpoints, there are additional LocalAI-only API endpoints.
+
+### Applying a model - `/models/apply`
+
+This endpoint can be used to install a model in runtime.
+
+
+
+LocalAI will create a batch process that downloads the required files from a model definition and automatically reload itself to include the new model.
+
+Input: `url`, `name` (optional), `files` (optional)
+
+```bash
+curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d '{
+ "url": "",
+ "name": "",
+ "files": [
+ {
+ "uri": "",
+ "sha256": "",
+ "filename": ""
+ },
+ "overrides": { "backend": "...", "f16": true }
+ ]
+ }
+```
+
+An optional, list of additional files can be specified to be downloaded within `files`. The `name` allows to override the model name. Finally it is possible to override the model config file with `override`.
+
+Returns an `uuid` and an `url` to follow up the state of the process:
+
+```json
+{ "uuid":"251475c9-f666-11ed-95e0-9a8a4480ac58", "status":"http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58"}
+```
+
+To see a collection example of curated models definition files, see the [model-gallery](https://github.com/go-skynet/model-gallery).
+
+
+
+### Inquiry model job state `/models/jobs/`
+
+This endpoint returns the state of the batch job associated to a model
+
+
+This endpoint can be used with the uuid returned by `/models/apply` to check a job state:
+
+```bash
+curl http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58
+```
+
+Returns a json containing the error, and if the job is being processed:
+
+```json
+{"error":null,"processed":true,"message":"completed"}
+```
+
+
+
+## Clients
+
+OpenAI clients are already compatible with LocalAI by overriding the basePath, or the target URL.
+
+## Javascript
+
+
+
+https://github.com/openai/openai-node/
+
+```javascript
+import { Configuration, OpenAIApi } from 'openai';
+
+const configuration = new Configuration({
+ basePath: `http://localhost:8080/v1`
+});
+const openai = new OpenAIApi(configuration);
+```
+
+
+
+## Python
+
+
+
+https://github.com/openai/openai-python
+
+Set the `OPENAI_API_BASE` environment variable, or by code:
+
+```python
+import openai
+
+openai.api_base = "http://localhost:8080/v1"
+
+# create a chat completion
+chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
+
+# print the completion
+print(completion.choices[0].message.content)
+```
+
+
+
## Frequently asked questions
Here are answers to some of the most common questions.
diff --git a/api/api.go b/api/api.go
index ecf56b09..b81a89f5 100644
--- a/api/api.go
+++ b/api/api.go
@@ -1,6 +1,7 @@
package api
import (
+ "context"
"errors"
model "github.com/go-skynet/LocalAI/pkg/model"
@@ -12,7 +13,7 @@ import (
"github.com/rs/zerolog/log"
)
-func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
+func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}))
}
- cm := make(ConfigMerger)
+ cm := NewConfigMerger()
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
@@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}
if debug {
- for k, v := range cm {
- log.Debug().Msgf("Model: %s (config: %+v)", k, v)
+ for _, v := range cm.ListConfigs() {
+ cfg, _ := cm.GetConfig(v)
+ log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
+ // LocalAI API endpoints
+ applier := newGalleryApplier(loader.ModelPath)
+ applier.start(c, cm)
+ app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
+ app.Get("/models/jobs/:uuid", getOpStatus(applier))
+
// openAI compatible API endpoint
+
+ // chat
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
+ // edit
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
+ // completion
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
+ // embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
-
- // /v1/engines/{engine_id}/embeddings
-
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
+ // audio
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
+ // images
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
if imageDir != "" {
app.Static("/generated-images", imageDir)
}
+ // models
app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader, cm))
diff --git a/api/api_test.go b/api/api_test.go
index f2af0388..f061527f 100644
--- a/api/api_test.go
+++ b/api/api_test.go
@@ -1,7 +1,12 @@
package api_test
import (
+ "bytes"
"context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
"os"
"path/filepath"
"runtime"
@@ -11,21 +16,189 @@ import (
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
+ "gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/sashabaranov/go-openai"
)
+type modelApplyRequest struct {
+ URL string `json:"url"`
+ Name string `json:"name"`
+ Overrides map[string]string `json:"overrides"`
+}
+
+func getModelStatus(url string) (response map[string]interface{}) {
+ // Create the HTTP request
+ resp, err := http.Get(url)
+ if err != nil {
+ fmt.Println("Error creating request:", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ fmt.Println("Error reading response body:", err)
+ return
+ }
+
+ // Unmarshal the response into a map[string]interface{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ fmt.Println("Error unmarshaling JSON response:", err)
+ return
+ }
+ return
+}
+func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
+
+ //url := "http://localhost:AI/models/apply"
+
+ // Create the request payload
+
+ payload, err := json.Marshal(request)
+ if err != nil {
+ fmt.Println("Error marshaling JSON:", err)
+ return
+ }
+
+ // Create the HTTP request
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
+ if err != nil {
+ fmt.Println("Error creating request:", err)
+ return
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ // Make the request
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ fmt.Println("Error making request:", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ fmt.Println("Error reading response body:", err)
+ return
+ }
+
+ // Unmarshal the response into a map[string]interface{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ fmt.Println("Error unmarshaling JSON response:", err)
+ return
+ }
+ return
+}
+
var _ = Describe("API test", func() {
var app *fiber.App
var modelLoader *model.ModelLoader
var client *openai.Client
var client2 *openaigo.Client
+ var c context.Context
+ var cancel context.CancelFunc
+ var tmpdir string
+
+ Context("API with ephemeral models", func() {
+ BeforeEach(func() {
+ var err error
+ tmpdir, err = os.MkdirTemp("", "")
+ Expect(err).ToNot(HaveOccurred())
+
+ modelLoader = model.NewModelLoader(tmpdir)
+ c, cancel = context.WithCancel(context.Background())
+
+ app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
+ go app.Listen("127.0.0.1:9090")
+
+ defaultConfig := openai.DefaultConfig("")
+ defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
+
+ client2 = openaigo.NewClient("")
+ client2.BaseURL = defaultConfig.BaseURL
+
+ // Wait for API to be ready
+ client = openai.NewClientWithConfig(defaultConfig)
+ Eventually(func() error {
+ _, err := client.ListModels(context.TODO())
+ return err
+ }, "2m").ShouldNot(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ cancel()
+ app.Shutdown()
+ os.RemoveAll(tmpdir)
+ })
+
+ Context("Applying models", func() {
+ It("overrides models", func() {
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
+ Name: "bert",
+ Overrides: map[string]string{
+ "backend": "llama",
+ },
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ fmt.Println(response)
+ return response["processed"].(bool)
+ }, "360s").Should(Equal(true))
+
+ dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["backend"]).To(Equal("llama"))
+ })
+ It("apply models without overrides", func() {
+ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
+ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
+ Name: "bert",
+ Overrides: map[string]string{},
+ })
+
+ Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
+
+ uuid := response["uuid"].(string)
+
+ Eventually(func() bool {
+ response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
+ fmt.Println(response)
+ return response["processed"].(bool)
+ }, "360s").Should(Equal(true))
+
+ dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ content := map[string]interface{}{}
+ err = yaml.Unmarshal(dat, &content)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(content["backend"]).To(Equal("bert-embeddings"))
+ })
+ })
+ })
+
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
- app = App("", modelLoader, 15, 1, 512, false, true, true, "")
+ c, cancel = context.WithCancel(context.Background())
+
+ app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -42,6 +215,7 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
+ cancel()
app.Shutdown()
})
It("returns the models list", func() {
@@ -140,7 +314,9 @@ var _ = Describe("API test", func() {
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
- app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
+ c, cancel = context.WithCancel(context.Background())
+
+ app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -155,10 +331,10 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
+ cancel()
app.Shutdown()
})
It("can generate chat completions from config file", func() {
-
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(12))
diff --git a/api/config.go b/api/config.go
index 7379978e..7e0d8264 100644
--- a/api/config.go
+++ b/api/config.go
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strings"
+ "sync"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@@ -43,8 +44,16 @@ type TemplateConfig struct {
Edit string `yaml:"edit"`
}
-type ConfigMerger map[string]Config
+type ConfigMerger struct {
+ configs map[string]Config
+ sync.Mutex
+}
+func NewConfigMerger() *ConfigMerger {
+ return &ConfigMerger{
+ configs: make(map[string]Config),
+ }
+}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
@@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) {
}
func (cm ConfigMerger) LoadConfigFile(file string) error {
+ cm.Lock()
+ defer cm.Unlock()
c, err := ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
- cm[cc.Name] = *cc
+ cm.configs[cc.Name] = *cc
}
return nil
}
func (cm ConfigMerger) LoadConfig(file string) error {
+ cm.Lock()
+ defer cm.Unlock()
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
- cm[c.Name] = *c
+ cm.configs[c.Name] = *c
return nil
}
+func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
+ cm.Lock()
+ defer cm.Unlock()
+ v, exists := cm.configs[m]
+ return v, exists
+}
+
+func (cm ConfigMerger) ListConfigs() []string {
+ cm.Lock()
+ defer cm.Unlock()
+ var res []string
+ for k := range cm.configs {
+ res = append(res, k)
+ }
+ return res
+}
+
func (cm ConfigMerger) LoadConfigs(path string) error {
+ cm.Lock()
+ defer cm.Unlock()
files, err := ioutil.ReadDir(path)
if err != nil {
return err
@@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error {
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
if err == nil {
- cm[c.Name] = *c
+ cm.configs[c.Name] = *c
}
}
@@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
return modelFile, input, nil
}
-func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
+func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
@@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader
}
var config *Config
- cfg, exists := cm[modelFile]
+ cfg, exists := cm.GetConfig(modelFile)
if !exists {
config = &Config{
OpenAIRequest: defaultRequest(modelFile),
diff --git a/api/gallery.go b/api/gallery.go
new file mode 100644
index 00000000..591b1b7a
--- /dev/null
+++ b/api/gallery.go
@@ -0,0 +1,196 @@
+package api
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+
+ "github.com/go-skynet/LocalAI/pkg/gallery"
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "gopkg.in/yaml.v3"
+)
+
+type galleryOp struct {
+ req ApplyGalleryModelRequest
+ id string
+}
+
+type galleryOpStatus struct {
+ Error error `json:"error"`
+ Processed bool `json:"processed"`
+ Message string `json:"message"`
+}
+
+type galleryApplier struct {
+ modelPath string
+ sync.Mutex
+ C chan galleryOp
+ statuses map[string]*galleryOpStatus
+}
+
+func newGalleryApplier(modelPath string) *galleryApplier {
+ return &galleryApplier{
+ modelPath: modelPath,
+ C: make(chan galleryOp),
+ statuses: make(map[string]*galleryOpStatus),
+ }
+}
+func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
+ g.Lock()
+ defer g.Unlock()
+ g.statuses[s] = op
+}
+
+func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
+ g.Lock()
+ defer g.Unlock()
+
+ return g.statuses[s]
+}
+
+func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
+ go func() {
+ for {
+ select {
+ case <-c.Done():
+ return
+ case op := <-g.C:
+ g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
+
+ updateError := func(e error) {
+ g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
+ }
+
+ url, err := op.req.DecodeURL()
+ if err != nil {
+ updateError(err)
+ continue
+ }
+
+ // Send a GET request to the URL
+ response, err := http.Get(url)
+ if err != nil {
+ updateError(err)
+ continue
+ }
+ defer response.Body.Close()
+
+ // Read the response body
+ body, err := ioutil.ReadAll(response.Body)
+ if err != nil {
+ updateError(err)
+ continue
+ }
+
+ // Unmarshal YAML data into a Config struct
+ var config gallery.Config
+ err = yaml.Unmarshal(body, &config)
+ if err != nil {
+ updateError(fmt.Errorf("failed to unmarshal YAML: %v", err))
+ continue
+ }
+
+ config.Files = append(config.Files, op.req.AdditionalFiles...)
+
+ if err := gallery.Apply(g.modelPath, op.req.Name, &config, op.req.Overrides); err != nil {
+ updateError(err)
+ continue
+ }
+
+ // Reload models
+ if err := cm.LoadConfigs(g.modelPath); err != nil {
+ updateError(err)
+ continue
+ }
+
+ g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
+ }
+ }
+ }()
+}
+
+// endpoints
+
+type ApplyGalleryModelRequest struct {
+ URL string `json:"url"`
+ Name string `json:"name"`
+ Overrides map[string]interface{} `json:"overrides"`
+ AdditionalFiles []gallery.File `json:"files"`
+}
+
+const (
+ githubURI = "github:"
+)
+
+func (request ApplyGalleryModelRequest) DecodeURL() (string, error) {
+ input := request.URL
+ var rawURL string
+
+ if strings.HasPrefix(input, githubURI) {
+ parts := strings.Split(input, ":")
+ repoParts := strings.Split(parts[1], "@")
+ branch := "main"
+
+ if len(repoParts) > 1 {
+ branch = repoParts[1]
+ }
+
+ repoPath := strings.Split(repoParts[0], "/")
+ org := repoPath[0]
+ project := repoPath[1]
+ projectPath := strings.Join(repoPath[2:], "/")
+
+ rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
+ } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
+ // Handle regular URLs
+ u, err := url.Parse(input)
+ if err != nil {
+ return "", fmt.Errorf("invalid URL: %w", err)
+ }
+ rawURL = u.String()
+ } else {
+ return "", fmt.Errorf("invalid URL format")
+ }
+
+ return rawURL, nil
+}
+
+func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
+ return func(c *fiber.Ctx) error {
+
+ status := g.getstatus(c.Params("uuid"))
+ if status == nil {
+ return fmt.Errorf("could not find any status for ID")
+ }
+
+ return c.JSON(status)
+ }
+}
+
+func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
+ return func(c *fiber.Ctx) error {
+ input := new(ApplyGalleryModelRequest)
+ // Get input data from the request body
+ if err := c.BodyParser(input); err != nil {
+ return err
+ }
+
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ return err
+ }
+ g <- galleryOp{
+ req: *input,
+ id: uuid.String(),
+ }
+ return c.JSON(struct {
+ ID string `json:"uuid"`
+ StatusURL string `json:"status"`
+ }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
+ }
+}
diff --git a/api/gallery_test.go b/api/gallery_test.go
new file mode 100644
index 00000000..1c92c0d5
--- /dev/null
+++ b/api/gallery_test.go
@@ -0,0 +1,30 @@
+package api_test
+
+import (
+ . "github.com/go-skynet/LocalAI/api"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Gallery API tests", func() {
+ Context("requests", func() {
+ It("parses github with a branch", func() {
+ req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
+ str, err := req.DecodeURL()
+ Expect(err).ToNot(HaveOccurred())
+ Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
+ })
+ It("parses github without a branch", func() {
+ req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"}
+ str, err := req.DecodeURL()
+ Expect(err).ToNot(HaveOccurred())
+ Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
+ })
+ It("parses URLS", func() {
+ req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"}
+ str, err := req.DecodeURL()
+ Expect(err).ToNot(HaveOccurred())
+ Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
+ })
+ })
+})
diff --git a/api/openai.go b/api/openai.go
index 52d65976..0a85349c 100644
--- a/api/openai.go
+++ b/api/openai.go
@@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest {
}
// https://platform.openai.com/docs/api-reference/completions
-func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
+func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
@@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
// https://platform.openai.com/docs/api-reference/embeddings
-func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
+func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
-func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
+func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
}
}
-func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
+func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
*
*/
-func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
+func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image
}
// https://platform.openai.com/docs/api-reference/audio/create
-func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
+func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
-func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
+func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
@@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx)
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
- for k := range cm {
+ for _, k := range cm.ListConfigs() {
if _, exists := mm[k]; !exists {
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
}
diff --git a/entrypoint.sh b/entrypoint.sh
index aab14205..e7390e56 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -2,6 +2,8 @@
cd /build
-make build
+if [ "$REBUILD" != "false" ]; then
+ make rebuild
+fi
./local-ai "$@"
\ No newline at end of file
diff --git a/examples/langchain-chroma/.env.example b/examples/langchain-chroma/.env.example
new file mode 100644
index 00000000..37cda598
--- /dev/null
+++ b/examples/langchain-chroma/.env.example
@@ -0,0 +1,5 @@
+THREADS=4
+CONTEXT_SIZE=512
+MODELS_PATH=/models
+DEBUG=true
+# BUILD_TYPE=generic
\ No newline at end of file
diff --git a/examples/langchain-chroma/.gitignore b/examples/langchain-chroma/.gitignore
new file mode 100644
index 00000000..3dc19014
--- /dev/null
+++ b/examples/langchain-chroma/.gitignore
@@ -0,0 +1,4 @@
+db/
+state_of_the_union.txt
+models/bert
+models/ggml-gpt4all-j
\ No newline at end of file
diff --git a/examples/langchain-chroma/README.md b/examples/langchain-chroma/README.md
index 70e3f42b..9fd9e312 100644
--- a/examples/langchain-chroma/README.md
+++ b/examples/langchain-chroma/README.md
@@ -10,13 +10,20 @@ Download the models and start the API:
# Clone LocalAI
git clone https://github.com/go-skynet/LocalAI
-cd LocalAI/examples/query_data
+cd LocalAI/examples/langchain-chroma
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O models/bert
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
+# configure your .env
+# NOTE: ensure that THREADS does not exceed your machine's CPU cores
+mv .env.example .env
+
# start with docker-compose
docker-compose up -d --build
+
+# tail the logs & wait until the build completes
+docker logs -f langchain-chroma-api-1
```
### Python requirements
@@ -29,6 +36,8 @@ pip install -r requirements.txt
In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM.
+Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used.
+
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
@@ -37,7 +46,7 @@ wget https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_
python store.py
```
-After it finishes, a directory "storage" will be created with the vector index database.
+After it finishes, a directory "db" will be created with the vector index database.
## Query
diff --git a/examples/langchain-chroma/docker-compose.yml b/examples/langchain-chroma/docker-compose.yml
new file mode 100644
index 00000000..96ef540e
--- /dev/null
+++ b/examples/langchain-chroma/docker-compose.yml
@@ -0,0 +1,15 @@
+version: '3.6'
+
+services:
+ api:
+ image: quay.io/go-skynet/local-ai:latest
+ build:
+ context: ../../
+ dockerfile: Dockerfile
+ ports:
+ - 8080:8080
+ env_file:
+ - ../../.env
+ volumes:
+ - ./models:/models:cached
+ command: ["/usr/bin/local-ai"]
diff --git a/examples/langchain-chroma/models/embeddings.yaml b/examples/langchain-chroma/models/embeddings.yaml
index 46a08502..536c8de1 100644
--- a/examples/langchain-chroma/models/embeddings.yaml
+++ b/examples/langchain-chroma/models/embeddings.yaml
@@ -1,5 +1,6 @@
name: text-embedding-ada-002
parameters:
model: bert
+threads: 4
backend: bert-embeddings
embeddings: true
diff --git a/examples/langchain-chroma/query.py b/examples/langchain-chroma/query.py
index 2f7df507..33848818 100644
--- a/examples/langchain-chroma/query.py
+++ b/examples/langchain-chroma/query.py
@@ -2,8 +2,9 @@
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
-from langchain.llms import OpenAI
-from langchain.chains import VectorDBQA
+from langchain.chat_models import ChatOpenAI
+from langchain.chains import RetrievalQA
+from langchain.vectorstores.base import VectorStoreRetriever
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
@@ -12,8 +13,10 @@ embedding = OpenAIEmbeddings()
persist_directory = 'db'
# Now we can load the persisted database from disk, and use it as normal.
+llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path)
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
-qa = VectorDBQA.from_chain_type(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path), chain_type="stuff", vectorstore=vectordb)
+retriever = VectorStoreRetriever(vectorstore=vectordb)
+qa = RetrievalQA.from_llm(llm=llm, retriever=retriever)
query = "What the president said about taxes ?"
print(qa.run(query))
diff --git a/examples/langchain-chroma/store.py b/examples/langchain-chroma/store.py
index 127bb240..b9cbad0e 100755
--- a/examples/langchain-chroma/store.py
+++ b/examples/langchain-chroma/store.py
@@ -2,9 +2,7 @@
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
-from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter,CharacterTextSplitter
-from langchain.llms import OpenAI
-from langchain.chains import VectorDBQA
+from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import TextLoader
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
@@ -14,7 +12,6 @@ loader = TextLoader('state_of_the_union.txt')
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=70)
-#text_splitter = TokenTextSplitter()
texts = text_splitter.split_documents(documents)
# Embed and store the texts
diff --git a/examples/langchain-python/README.md b/examples/langchain-python/README.md
index a98c48f7..2472aab1 100644
--- a/examples/langchain-python/README.md
+++ b/examples/langchain-python/README.md
@@ -26,6 +26,7 @@ pip install langchain
pip install openai
export OPENAI_API_BASE=http://localhost:8080
+# Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used.
export OPENAI_API_KEY=sk-
python test.py
diff --git a/examples/query_data/README.md b/examples/query_data/README.md
index f7a4e1fe..c4e384cd 100644
--- a/examples/query_data/README.md
+++ b/examples/query_data/README.md
@@ -35,6 +35,8 @@ docker-compose up -d --build
In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM.
+Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used.
+
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
diff --git a/examples/query_data/docker-compose.yml b/examples/query_data/docker-compose.yml
index a59edfc4..cf76eb7f 100644
--- a/examples/query_data/docker-compose.yml
+++ b/examples/query_data/docker-compose.yml
@@ -4,7 +4,7 @@ services:
api:
image: quay.io/go-skynet/local-ai:latest
build:
- context: .
+ context: ../../
dockerfile: Dockerfile
ports:
- 8080:8080
diff --git a/examples/rwkv/.gitignore b/examples/rwkv/.gitignore
new file mode 100644
index 00000000..ab3629c5
--- /dev/null
+++ b/examples/rwkv/.gitignore
@@ -0,0 +1,2 @@
+models/rwkv
+models/rwkv.tokenizer.json
\ No newline at end of file
diff --git a/examples/rwkv/Dockerfile.build b/examples/rwkv/Dockerfile.build
index c62024de..491f9ccd 100644
--- a/examples/rwkv/Dockerfile.build
+++ b/examples/rwkv/Dockerfile.build
@@ -1,5 +1,7 @@
FROM python
+RUN apt-get update && apt-get -y install cmake
+
# convert the model (one-off)
RUN pip3 install torch numpy
diff --git a/go.mod b/go.mod
index 1cd33ce8..adb9c45a 100644
--- a/go.mod
+++ b/go.mod
@@ -4,26 +4,27 @@ go 1.19
require (
github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56
- github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d
+ github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881
github.com/go-audio/wav v1.1.0
github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf
github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245
- github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c
- github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84
- github.com/gofiber/fiber/v2 v2.45.0
+ github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278
+ github.com/gofiber/fiber/v2 v2.46.0
+ github.com/google/uuid v1.3.0
github.com/hashicorp/go-multierror v1.1.1
+ github.com/imdario/mergo v0.3.15
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642
- github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc
+ github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd
github.com/onsi/ginkgo/v2 v2.9.5
- github.com/onsi/gomega v1.27.6
- github.com/otiai10/copy v1.11.0
+ github.com/onsi/gomega v1.27.7
github.com/otiai10/openaigo v1.1.0
github.com/rs/zerolog v1.29.1
github.com/sashabaranov/go-openai v1.9.4
github.com/swaggo/swag v1.16.1
github.com/urfave/cli/v2 v2.25.3
github.com/valyala/fasthttp v1.47.0
+ gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -43,7 +44,6 @@ require (
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
- github.com/google/uuid v1.3.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/klauspost/compress v1.16.3 // indirect
@@ -51,6 +51,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
+ github.com/otiai10/mint v1.5.1 // indirect
github.com/philhofer/fwd v1.1.2 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
@@ -64,5 +65,4 @@ require (
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/tools v0.9.1 // indirect
- gopkg.in/yaml.v2 v2.4.0 // indirect
)
diff --git a/go.sum b/go.sum
index 09af76da..20a4b22b 100644
--- a/go.sum
+++ b/go.sum
@@ -16,18 +16,14 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be h1:3Hic97PY6hcw/SY44RuR7kyONkxd744RFeRrqckzwNQ=
-github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
-github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2 h1:YNbUAyIRtaLODitigJU1EM5ubmMu5FmHtYAayJD6Vbg=
-github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56 h1:s8/MZdicstKi5fn9D9mKGIQ/q6IWCYCk/BM68i8v51w=
github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
-github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 h1:sMg/SgnMPS/HNUO/2kGm72vl8R9TmNIwgLFr2TNwR3g=
-github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
-github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a h1:MlyiDLNCM/wjbv8U5Elj18NvaAgl61SGiRUpqQz5dfs=
-github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d h1:uxKTbiRnplE2SubchneSf4NChtxLJtOy9VdHnQMT0d0=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
+github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520170006-429b9785c080 h1:W3itqKpRX9FhheKiAxdmuOBy/mjDfMf2G1vcuFIYqZc=
+github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520170006-429b9785c080/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
+github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881 h1:dafqVivljYk51VLFnnpTXJnfWDe637EobWZ1l8PyEf8=
+github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
@@ -46,31 +42,23 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
-github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f h1:GW8RQa1RVeDF1dOuAP/y6xWVC+BRtf9tJOuEza6Asbg=
-github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA=
github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf h1:VJfSn8hIDE+K5+h38M3iAyFXrxpRExMKRdTk33UDxsw=
github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA=
-github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea h1:8Isk9D+Auth5OuXVAQPC3MO+5zF/2S7mvs2JZLw6a+8=
-github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ=
-github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557 h1:LD66fKtvP2lmyuuKL8pBat/pVTKUbLs3L5fM/5lyi4w=
-github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ=
github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4 h1:+3KPDf4Wv1VHOkzAfZnlj9qakLSYggTpm80AswhD/FU=
github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4/go.mod h1:VY0s5KoAI2jRCvQXKuDeEEe8KG7VaWifSNJSk+E1KtY=
-github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E=
-github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245 h1:IcfYY5uH0DdDXEJKJ8bq0WZCd9guPPd3xllaWNy8LOk=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
-github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis=
-github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b h1:qqxrjY8fYDXQahmCMTCACahm1tbiqHLPUHALkFLyBfo=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84 h1:f5iYF75bAr73Tl8AdtFD5Urs/2bsHKPh52K++jLbsfk=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84/go.mod h1:jxyQ26t1aKC5Gn782w9WWh5n1133PxCOfkuc01xM4RQ=
+github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1 h1:i0oM2MERUgMIRmjOcv22TDQULxbmY8o9rZKLKKyWXLo=
+github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws=
+github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278 h1:st4ow9JKy3UuhkwutrbWof2vMFU/YxwBCLYZ1IxJ2Po=
+github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofiber/fiber/v2 v2.45.0 h1:p4RpkJT9GAW6parBSbcNFH2ApnAuW3OzaQzbOCoDu+s=
github.com/gofiber/fiber/v2 v2.45.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc=
+github.com/gofiber/fiber/v2 v2.46.0 h1:wkkWotblsGVlLjXj2dpgKQAYHtXumsK/HyFugQM68Ns=
+github.com/gofiber/fiber/v2 v2.46.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -83,6 +71,8 @@ github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
+github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
+github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY=
@@ -105,23 +95,18 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
-github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24 h1:XfRD/bZom6u4zji7aB0urIVOsPe43KlkzSRrVhlzaOM=
-github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 h1:KTkh3lOUsGqQyP4v+oa38sPFdrZtNnM4HaxTb3epdYs=
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
-github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc h1:OPavP/SUsVWVYPhSUZKZeX8yDSQzf4G+BmUmwzrLTyI=
-github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
-github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
-github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM=
+github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd h1:kMnZASxCNc8GsPuAV94tltEsfT6T+esuB+rgzdjwFVM=
+github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
-github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
-github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
-github.com/otiai10/copy v1.11.0 h1:OKBD80J/mLBrwnzXqGtFCzprFSGioo30JcmR4APsNwc=
-github.com/otiai10/copy v1.11.0/go.mod h1:rSaLseMUsZFFbsFGc7wCJnnkTAvdc5L6VWxPE4308Ww=
+github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
+github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks=
+github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM=
github.com/otiai10/openaigo v1.1.0 h1:zRvGBqZUW5PCMgdkJNsPVTBd8tOLCMTipXE5wD2pdTg=
github.com/otiai10/openaigo v1.1.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg=
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
@@ -137,8 +122,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM=
-github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4=
@@ -182,8 +165,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
-golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
-golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -221,8 +202,6 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
-golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
-golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/main.go b/main.go
index 2490e198..f3ffc033 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"fmt"
"os"
"path/filepath"
@@ -57,9 +58,9 @@ func main() {
Value: ":8080",
},
&cli.StringFlag{
- Name: "image-dir",
+ Name: "image-path",
DefaultText: "Image directory",
- EnvVars: []string{"IMAGE_DIR"},
+ EnvVars: []string{"IMAGE_PATH"},
Value: "",
},
&cli.IntFlag{
@@ -93,7 +94,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
- return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-dir")).Listen(ctx.String("address"))
+ return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address"))
},
}
diff --git a/pkg/gallery/gallery_suite_test.go b/pkg/gallery/gallery_suite_test.go
new file mode 100644
index 00000000..44256bc2
--- /dev/null
+++ b/pkg/gallery/gallery_suite_test.go
@@ -0,0 +1,13 @@
+package gallery_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestGallery(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Gallery test suite")
+}
diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go
new file mode 100644
index 00000000..f4f86ae7
--- /dev/null
+++ b/pkg/gallery/models.go
@@ -0,0 +1,271 @@
+package gallery
+
+import (
+ "crypto/sha256"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+
+ "github.com/imdario/mergo"
+ "github.com/rs/zerolog/log"
+ "gopkg.in/yaml.v2"
+)
+
+/*
+
+description: |
+ foo
+license: ""
+
+urls:
+-
+-
+
+name: "bar"
+
+config_file: |
+ # Note, name will be injected. or generated by the alias wanted by the user
+ threads: 14
+
+files:
+ - filename: ""
+ sha: ""
+ uri: ""
+
+prompt_templates:
+ - name: ""
+ content: ""
+
+*/
+
+type Config struct {
+ Description string `yaml:"description"`
+ License string `yaml:"license"`
+ URLs []string `yaml:"urls"`
+ Name string `yaml:"name"`
+ ConfigFile string `yaml:"config_file"`
+ Files []File `yaml:"files"`
+ PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
+}
+
+type File struct {
+ Filename string `yaml:"filename" json:"filename"`
+ SHA256 string `yaml:"sha256" json:"sha256"`
+ URI string `yaml:"uri" json:"uri"`
+}
+
+type PromptTemplate struct {
+ Name string `yaml:"name"`
+ Content string `yaml:"content"`
+}
+
+func ReadConfigFile(filePath string) (*Config, error) {
+ // Read the YAML file
+ yamlFile, err := os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read YAML file: %v", err)
+ }
+
+ // Unmarshal YAML data into a Config struct
+ var config Config
+ err = yaml.Unmarshal(yamlFile, &config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
+ }
+
+ return &config, nil
+}
+
+func inTrustedRoot(path string, trustedRoot string) error {
+ for path != "/" {
+ path = filepath.Dir(path)
+ if path == trustedRoot {
+ return nil
+ }
+ }
+ return fmt.Errorf("path is outside of trusted root")
+}
+
+func verifyPath(path, basePath string) error {
+ c := filepath.Clean(filepath.Join(basePath, path))
+ return inTrustedRoot(c, basePath)
+}
+
+func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error {
+ // Create base path if it doesn't exist
+ err := os.MkdirAll(basePath, 0755)
+ if err != nil {
+ return fmt.Errorf("failed to create base path: %v", err)
+ }
+
+ if len(configOverrides) > 0 {
+ log.Debug().Msgf("Config overrides %+v", configOverrides)
+ }
+
+ // Download files and verify their SHA
+ for _, file := range config.Files {
+ log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
+
+ if err := verifyPath(file.Filename, basePath); err != nil {
+ return err
+ }
+ // Create file path
+ filePath := filepath.Join(basePath, file.Filename)
+
+ // Check if the file already exists
+ _, err := os.Stat(filePath)
+ if err == nil {
+ // File exists, check SHA
+ if file.SHA256 != "" {
+ // Verify SHA
+ calculatedSHA, err := calculateSHA(filePath)
+ if err != nil {
+ return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
+ }
+ if calculatedSHA == file.SHA256 {
+ // SHA matches, skip downloading
+ log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
+ continue
+ }
+ // SHA doesn't match, delete the file and download again
+ err = os.Remove(filePath)
+ if err != nil {
+ return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
+ }
+ log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
+
+ } else {
+ // SHA is missing, skip downloading
+ log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
+ continue
+ }
+ } else if !os.IsNotExist(err) {
+ // Error occurred while checking file existence
+ return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
+ }
+
+ log.Debug().Msgf("Downloading %q", file.URI)
+
+ // Download file
+ resp, err := http.Get(file.URI)
+ if err != nil {
+ return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
+ }
+ defer resp.Body.Close()
+
+ // Create parent directory
+ err = os.MkdirAll(filepath.Dir(filePath), 0755)
+ if err != nil {
+ return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
+ }
+
+ // Create and write file content
+ outFile, err := os.Create(filePath)
+ if err != nil {
+ return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
+ }
+ defer outFile.Close()
+
+ if file.SHA256 != "" {
+ log.Debug().Msgf("Download and verifying %q", file.Filename)
+
+ // Write file content and calculate SHA
+ hash := sha256.New()
+ _, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
+ }
+
+ // Verify SHA
+ calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
+ if calculatedSHA != file.SHA256 {
+ return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
+ }
+ } else {
+ log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
+ _, err = io.Copy(outFile, resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
+ }
+ }
+
+ log.Debug().Msgf("File %q downloaded and verified", file.Filename)
+ }
+
+ // Write prompt template contents to separate files
+ for _, template := range config.PromptTemplates {
+ if err := verifyPath(template.Name+".tmpl", basePath); err != nil {
+ return err
+ }
+ // Create file path
+ filePath := filepath.Join(basePath, template.Name+".tmpl")
+
+ // Create parent directory
+ err := os.MkdirAll(filepath.Dir(filePath), 0755)
+ if err != nil {
+ return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
+ }
+ // Create and write file content
+ err = os.WriteFile(filePath, []byte(template.Content), 0644)
+ if err != nil {
+ return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
+ }
+
+ log.Debug().Msgf("Prompt template %q written", template.Name)
+ }
+
+ name := config.Name
+ if nameOverride != "" {
+ name = nameOverride
+ }
+
+ if err := verifyPath(name+".yaml", basePath); err != nil {
+ return err
+ }
+
+ configFilePath := filepath.Join(basePath, name+".yaml")
+
+ // Read and update config file as map[string]interface{}
+ configMap := make(map[string]interface{})
+ err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal config YAML: %v", err)
+ }
+
+ configMap["name"] = name
+
+ if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
+ return err
+ }
+
+ // Write updated config file
+ updatedConfigYAML, err := yaml.Marshal(configMap)
+ if err != nil {
+ return fmt.Errorf("failed to marshal updated config YAML: %v", err)
+ }
+
+ err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
+ if err != nil {
+ return fmt.Errorf("failed to write updated config file: %v", err)
+ }
+
+ log.Debug().Msgf("Written config file %s", configFilePath)
+ return nil
+}
+
+func calculateSHA(filePath string) (string, error) {
+ file, err := os.Open(filePath)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ hash := sha256.New()
+ if _, err := io.Copy(hash, file); err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("%x", hash.Sum(nil)), nil
+}
diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go
new file mode 100644
index 00000000..f0e580e9
--- /dev/null
+++ b/pkg/gallery/models_test.go
@@ -0,0 +1,94 @@
+package gallery_test
+
+import (
+ "os"
+ "path/filepath"
+
+ . "github.com/go-skynet/LocalAI/pkg/gallery"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "gopkg.in/yaml.v3"
+)
+
+var _ = Describe("Model test", func() {
+ Context("Downloading", func() {
+ It("applies model correctly", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = Apply(tempdir, "", c, map[string]interface{}{})
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+
+ content := map[string]interface{}{}
+
+ dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = yaml.Unmarshal(dat, content)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(content["context_size"]).To(Equal(1024))
+ })
+
+ It("renames model correctly", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = Apply(tempdir, "foo", c, map[string]interface{}{})
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+ })
+
+ It("overrides parameters", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"})
+ Expect(err).ToNot(HaveOccurred())
+
+ for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
+ _, err = os.Stat(filepath.Join(tempdir, f))
+ Expect(err).ToNot(HaveOccurred())
+ }
+
+ content := map[string]interface{}{}
+
+ dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = yaml.Unmarshal(dat, content)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(content["backend"]).To(Equal("foo"))
+ })
+
+ It("catches path traversals", func() {
+ tempdir, err := os.MkdirTemp("", "test")
+ Expect(err).ToNot(HaveOccurred())
+ defer os.RemoveAll(tempdir)
+ c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
+ Expect(err).ToNot(HaveOccurred())
+
+ err = Apply(tempdir, "../../../foo", c, map[string]interface{}{})
+ Expect(err).To(HaveOccurred())
+ })
+ })
+})
diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go
index 74c05f29..b5e43a38 100644
--- a/pkg/model/initializers.go
+++ b/pkg/model/initializers.go
@@ -164,11 +164,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) {
- log.Debug().Msgf("Loading models greedly")
+ log.Debug().Msgf("Loading model '%s' greedly", modelFile)
ml.mu.Lock()
m, exists := ml.models[modelFile]
if exists {
+ log.Debug().Msgf("Model '%s' already loaded", modelFile)
ml.mu.Unlock()
return m, nil
}
diff --git a/tests/fixtures/gallery_simple.yaml b/tests/fixtures/gallery_simple.yaml
new file mode 100644
index 00000000..058733fe
--- /dev/null
+++ b/tests/fixtures/gallery_simple.yaml
@@ -0,0 +1,40 @@
+name: "cerebras"
+description: |
+ cerebras
+license: "Apache 2.0"
+
+config_file: |
+ parameters:
+ model: cerebras
+ top_k: 80
+ temperature: 0.2
+ top_p: 0.7
+ context_size: 1024
+ stopwords:
+ - "HUMAN:"
+ - "GPT:"
+ roles:
+ user: ""
+ system: ""
+ template:
+ completion: "cerebras-completion"
+ chat: cerebras-chat
+
+files:
+ - filename: "cerebras"
+ sha256: "c947051ae4dba9530ca55d923a7a484acd65664c8633462c8ccd4bb7848f2c65"
+ uri: "https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerebras-111m-q4_2.bin"
+
+prompt_templates:
+ - name: "cerebras-completion"
+ content: |
+ Complete the prompt
+ ### Prompt:
+ {{.Input}}
+ ### Response:
+ - name: "cerebras-chat"
+ content: |
+ The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
+ ### Prompt:
+ {{.Input}}
+ ### Response:
\ No newline at end of file
diff --git a/tests/fixtures/completion.tmpl b/tests/models_fixtures/completion.tmpl
similarity index 100%
rename from tests/fixtures/completion.tmpl
rename to tests/models_fixtures/completion.tmpl
diff --git a/tests/fixtures/config.yaml b/tests/models_fixtures/config.yaml
similarity index 100%
rename from tests/fixtures/config.yaml
rename to tests/models_fixtures/config.yaml
diff --git a/tests/fixtures/embeddings.yaml b/tests/models_fixtures/embeddings.yaml
similarity index 100%
rename from tests/fixtures/embeddings.yaml
rename to tests/models_fixtures/embeddings.yaml
diff --git a/tests/fixtures/ggml-gpt4all-j.tmpl b/tests/models_fixtures/ggml-gpt4all-j.tmpl
similarity index 100%
rename from tests/fixtures/ggml-gpt4all-j.tmpl
rename to tests/models_fixtures/ggml-gpt4all-j.tmpl
diff --git a/tests/fixtures/gpt4.yaml b/tests/models_fixtures/gpt4.yaml
similarity index 100%
rename from tests/fixtures/gpt4.yaml
rename to tests/models_fixtures/gpt4.yaml
diff --git a/tests/fixtures/gpt4_2.yaml b/tests/models_fixtures/gpt4_2.yaml
similarity index 100%
rename from tests/fixtures/gpt4_2.yaml
rename to tests/models_fixtures/gpt4_2.yaml
diff --git a/tests/fixtures/rwkv.yaml b/tests/models_fixtures/rwkv.yaml
similarity index 100%
rename from tests/fixtures/rwkv.yaml
rename to tests/models_fixtures/rwkv.yaml
diff --git a/tests/fixtures/rwkv_chat.tmpl b/tests/models_fixtures/rwkv_chat.tmpl
similarity index 100%
rename from tests/fixtures/rwkv_chat.tmpl
rename to tests/models_fixtures/rwkv_chat.tmpl
diff --git a/tests/fixtures/rwkv_completion.tmpl b/tests/models_fixtures/rwkv_completion.tmpl
similarity index 100%
rename from tests/fixtures/rwkv_completion.tmpl
rename to tests/models_fixtures/rwkv_completion.tmpl
diff --git a/tests/fixtures/whisper.yaml b/tests/models_fixtures/whisper.yaml
similarity index 100%
rename from tests/fixtures/whisper.yaml
rename to tests/models_fixtures/whisper.yaml