diff --git a/.env b/.env index b05dac68..73e3174d 100644 --- a/.env +++ b/.env @@ -1,5 +1,30 @@ +## Set number of threads. +## Note: prefer the number of physical cores. Overbooking the CPU degrades performance notably. # THREADS=14 + +## Specify a different bind address (defaults to ":8080") +# ADDRESS=127.0.0.1:8080 + +## Default models context size # CONTEXT_SIZE=512 + +## Default path for models MODELS_PATH=/models + +## Enable debug mode # DEBUG=true -# BUILD_TYPE=generic + +## Specify a build type. Available: cublas, openblas. +# BUILD_TYPE=openblas + +## Uncomment and set to false to disable rebuilding from source +# REBUILD=false + +## Enable image generation with stablediffusion (requires REBUILD=true) +# GO_TAGS=stablediffusion + +## Path where to store generated images +# IMAGE_PATH=/tmp + +## Specify a default upload limit in MB (whisper) +# UPLOAD_LIMIT \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..a7f77221 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,31 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: mudler + +--- + + + +**LocalAI version:** + + +**Environment, CPU architecture, OS, and Version:** + + +**Describe the bug** + + +**To Reproduce** + + +**Expected behavior** + + +**Logs** + + +**Additional context** + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..acc65c80 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Community Support + url: https://github.com/go-skynet/LocalAI/discussions + about: Please ask and answer questions here. + - name: Discord + url: https://discord.gg/uJAeKSAGDy + about: Join our community on Discord! diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..c184aae9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,22 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: mudler + +--- + + + +**Is your feature request related to a problem? Please describe.** + + +**Describe the solution you'd like** + + +**Describe alternatives you've considered** + + +**Additional context** + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..2318ad47 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,23 @@ +**Description** + +This PR fixes # + +**Notes for Reviewers** + + +**[Signed commits](../CONTRIBUTING.md#signing-off-on-commits-developer-certificate-of-origin)** +- [ ] Yes, I signed my commits. + + + \ No newline at end of file diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 00000000..af48bade --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,18 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 45 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 10 +# Issues with these labels will never be considered stale +exemptLabels: + - issue/willfix +# Label to use when marking an issue as stale +staleLabel: issue/stale +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: > + This issue is being automatically closed due to inactivity. + However, you may choose to reopen this issue. \ No newline at end of file diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index d83f58d3..eeada322 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -9,6 +9,10 @@ on: tags: - '*' +concurrency: + group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + jobs: docker: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 17e3c809..48fe2cf8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,10 @@ on: tags: - '*' +concurrency: + group: ci-tests-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + jobs: ubuntu-latest: runs-on: ubuntu-latest diff --git a/.vscode/launch.json b/.vscode/launch.json index e8d94825..cf4fb924 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,7 +2,20 @@ "version": "0.2.0", "configurations": [ { - "name": "Launch Go", + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "${workspaceFolder}/examples/langchain-chroma", + "env": { + "OPENAI_API_BASE": "http://localhost:8080/v1", + "OPENAI_API_KEY": "abc" + } + }, + { + "name": "Launch LocalAI API", "type": "go", "request": "launch", "mode": "debug", @@ -11,8 +24,8 @@ "api" ], "env": { - "C_INCLUDE_PATH": "/workspace/go-llama:/workspace/go-gpt4all-j:/workspace/go-gpt2", - "LIBRARY_PATH": "/workspace/go-llama:/workspace/go-gpt4all-j:/workspace/go-gpt2", + "C_INCLUDE_PATH": "${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz", + "LIBRARY_PATH": "$${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz", "DEBUG": "true" } } diff --git a/Dockerfile b/Dockerfile index 52869bb9..27ab3800 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,11 @@ ARG GO_VERSION=1.20 ARG BUILD_TYPE= FROM golang:$GO_VERSION +ENV REBUILD=true WORKDIR /build RUN apt-get update && apt-get install -y cmake libgomp1 libopenblas-dev libopenblas-base libopencv-dev libopencv-core-dev libopencv-core4.5 COPY . . RUN ln -s /usr/include/opencv4/opencv2/ /usr/include/opencv2 -RUN make prepare-sources +RUN make build EXPOSE 8080 ENTRYPOINT [ "/build/entrypoint.sh" ] diff --git a/Makefile b/Makefile index 523f1523..393e7ce3 100644 --- a/Makefile +++ b/Makefile @@ -3,20 +3,19 @@ GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=b7bbefbe0b84262e003387a605842bdd0d099300 +GOLLAMA_VERSION?=ccf23adfb278c0165d388389a5d60f3fe38e4854 GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all -GPT4ALL_VERSION?=bce2b3025b360af73091da0128b1e91f9bc94f9f +GPT4ALL_VERSION?=914519e772fd78c15691dcd0b8bac60d6af514ec GOGPT2_VERSION?=7bff56f0224502c1c9ed6258d2a17e8084628827 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 -WHISPER_CPP_VERSION?=95b02d76b04d18e4ce37ed8353a1f0797f1717ea +WHISPER_CPP_VERSION?=041be06d5881d3c759cc4ed45d655804361237cd BERT_VERSION?=cea1ed76a7f48ef386a8e369f6c82c48cdf2d551 BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 BUILD_TYPE?= CGO_LDFLAGS?= CUDA_LIBPATH?=/usr/local/cuda/lib64/ STABLEDIFFUSION_VERSION?=c0748eca3642d58bcf9521108bcee46959c647dc - GO_TAGS?= OPTIONAL_TARGETS?= @@ -36,9 +35,9 @@ endif ifeq ($(BUILD_TYPE),cublas) CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH) + export LLAMA_CUBLAS=1 endif - ifeq ($(GO_TAGS),stablediffusion) OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a endif @@ -66,6 +65,7 @@ gpt4all: @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} + + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} + mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h ## BERT embeddings @@ -211,11 +211,11 @@ test-models/testmodel: wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav wget https://huggingface.co/imxcstar/rwkv-4-raven-ggml/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096-16_Q4_2.bin -O test-models/rwkv wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json - cp tests/fixtures/* test-models + cp tests/models_fixtures/* test-models test: prepare test-models/testmodel - cp tests/fixtures/* test-models - @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api + cp tests/models_fixtures/* test-models + C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ./pkg ## Help: help: ## Show this help. diff --git a/README.md b/README.md index f07a0796..47996dcd 100644 --- a/README.md +++ b/README.md @@ -9,31 +9,38 @@ [![](https://dcbadge.vercel.app/api/server/uJAeKSAGDy?style=flat-square&theme=default-inverted)](https://discord.gg/uJAeKSAGDy) -**LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table). +**LocalAI** is a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run models locally or on-prem with consumer grade hardware, supporting multiple model families that are compatible with the ggml format. -- OpenAI drop-in alternative REST API +For a list of the supported model families, please see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table). + +In a nutshell: + +- Local, OpenAI drop-in alternative REST API. You own your data. +- NO GPU required. NO Internet access is required either. Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. [See building instructions](https://github.com/go-skynet/LocalAI#cublas). - Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental) - Once loaded the first time, it keep models loaded in memory for faster inference -- Support for prompt templates - Doesn't shell-out, but uses C++ bindings for a faster inference and better performance. LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud). -LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp), [ggml](https://github.com/ggerganov/ggml), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) for audio transcriptions, and [bert.cpp](https://github.com/skeskinen/bert.cpp) for embedding. - -See [examples on how to integrate LocalAI](https://github.com/go-skynet/LocalAI/tree/master/examples/). - +See the [usage](https://github.com/go-skynet/LocalAI#usage) and [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/) sections to learn how to use LocalAI. ### How does it work?
- + +LocalAI is an API written in Go that serves as an OpenAI shim, enabling software already developed with OpenAI SDKs to seamlessly integrate with LocalAI. It can be effortlessly implemented as a substitute, even on consumer-grade hardware. This capability is achieved by employing various C++ backends, including [ggml](https://github.com/ggerganov/ggml), to perform inference on LLMs using both CPU and, if desired, GPU. + +LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp), [ggml](https://github.com/ggerganov/ggml), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) for audio transcriptions, [bert.cpp](https://github.com/skeskinen/bert.cpp) for embedding and [StableDiffusion-NCN](https://github.com/EdVince/Stable-Diffusion-NCNN) for image generation. See [the model compatibility table](https://github.com/go-skynet/LocalAI#model-compatibility-table) to learn about all the components of LocalAI. + ![LocalAI](https://github.com/go-skynet/LocalAI/assets/2420543/38de3a9b-3866-48cd-9234-662f9571064a)
## News +- 21-05-2023: __v1.14.0__ released. Minor updates to the `/models/apply` endpoint, `llama.cpp` backend updated including https://github.com/ggerganov/llama.cpp/pull/1508 which breaks compatibility with older models. `gpt4all` is still compatible with the old format. +- 19-05-2023: __v1.13.0__ released! 🔥🔥 updates to the `gpt4all` and `llama` backend, consolidated CUDA support ( https://github.com/go-skynet/LocalAI/pull/310 thanks to @bubthegreat and @Thireus ), preliminar support for [installing models via API](https://github.com/go-skynet/LocalAI#advanced-prepare-models-using-the-api). - 17-05-2023: __v1.12.0__ released! 🔥🔥 Minor fixes, plus CUDA (https://github.com/go-skynet/LocalAI/pull/258) support for `llama.cpp`-compatible models and image generation (https://github.com/go-skynet/LocalAI/pull/272). - 16-05-2023: 🔥🔥🔥 Experimental support for CUDA (https://github.com/go-skynet/LocalAI/pull/258) in the `llama.cpp` backend and Stable diffusion CPU image generation (https://github.com/go-skynet/LocalAI/pull/272) in `master`. @@ -116,31 +123,31 @@ Depending on the model you are attempting to run might need more RAM or CPU reso
-| Backend | Compatible models | Completion/Chat endpoint | Audio transcription/Image | Embeddings support | Token stream support | Github | Bindings | -|-----------------|-----------------------|--------------------------|---------------------|-----------------------------------|----------------------|--------------------------------------------|-------------------------------------------| -| llama | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes | https://github.com/ggerganov/llama.cpp | https://github.com/go-skynet/go-llama.cpp | -| gpt4all-llama | Vicuna, Alpaca, LLaMa | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all | -| gpt4all-mpt | MPT | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all | -| gpt4all-j | GPT4ALL-J | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all | -| gpt2 | GPT/NeoX, Cerebras | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| dolly | Dolly | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| redpajama | RedPajama | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| stableLM | StableLM GPT/NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| replit | Replit | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| gptneox | GPT NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| starcoder | Starcoder | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | -| bloomz | Bloom | yes | no | no | no | https://github.com/NouamaneTazi/bloomz.cpp | https://github.com/go-skynet/bloomz.cpp | -| rwkv | RWKV | yes | no | no | yes | https://github.com/saharNooby/rwkv.cpp | https://github.com/donomii/go-rwkv.cpp | -| bert-embeddings | bert | no | no | yes | no | https://github.com/skeskinen/bert.cpp | https://github.com/go-skynet/go-bert.cpp | -| whisper | whisper | no | Audio | no | no | https://github.com/ggerganov/whisper.cpp | https://github.com/ggerganov/whisper.cpp | -| stablediffusion | stablediffusion | no | Image | no | no | https://github.com/EdVince/Stable-Diffusion-NCNN | https://github.com/mudler/go-stable-diffusion | +| Backend and Bindings | Compatible models | Completion/Chat endpoint | Audio transcription/Image | Embeddings support | Token stream support | +|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------| +| [llama](https://github.com/ggerganov/llama.cpp) ([binding](https://github.com/go-skynet/go-llama.cpp)) | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes | +| [gpt4all-llama](https://github.com/nomic-ai/gpt4all) | Vicuna, Alpaca, LLaMa | yes | no | no | yes | +| [gpt4all-mpt](https://github.com/nomic-ai/gpt4all) | MPT | yes | no | no | yes | +| [gpt4all-j](https://github.com/nomic-ai/gpt4all) | GPT4ALL-J | yes | no | no | yes | +| [gpt2](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | GPT/NeoX, Cerebras | yes | no | no | no | +| [dolly](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Dolly | yes | no | no | no | +| [redpajama](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | RedPajama | yes | no | no | no | +| [stableLM](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | StableLM GPT/NeoX | yes | no | no | no | +| [replit](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Replit | yes | no | no | no | +| [gptneox](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | GPT NeoX | yes | no | no | no | +| [starcoder](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-gpt2.cpp)) | Starcoder | yes | no | no | no | +| [bloomz](https://github.com/NouamaneTazi/bloomz.cpp) ([binding](https://github.com/go-skynet/bloomz.cpp)) | Bloom | yes | no | no | no | +| [rwkv](https://github.com/saharNooby/rwkv.cpp) ([binding](https://github.com/donomii/go-rw)) | rwkv | yes | no | no | yes | +| [bert](https://github.com/skeskinen/bert.cpp) ([binding](https://github.com/go-skynet/go-bert.cpp) | bert | no | no | yes | no | +| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no | +| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no |
## Usage > `LocalAI` comes by default as a container image. You can check out all the available images with corresponding tags [here](https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest). -The easiest way to run LocalAI is by using `docker-compose`: +The easiest way to run LocalAI is by using `docker-compose` (to build locally, see [building LocalAI](https://github.com/go-skynet/LocalAI/tree/master#setup)): ```bash @@ -213,7 +220,25 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso ``` -To build locally, run `make build` (see below). +### Advanced: prepare models using the API + +Instead of installing models manually, you can use the LocalAI API endpoints and a model definition to install programmatically via API models in runtime. + +
+ +A curated collection of model files is in the [model-gallery](https://github.com/go-skynet/model-gallery) (work in progress!). + +To install for example `gpt4all-j`, you can send a POST call to the `/models/apply` endpoint with the model definition url (`url`) and the name of the model should have in LocalAI (`name`, optional): + +``` +curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d '{ + "url": "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", + "name": "gpt4all-j" + }' +``` + +
+ ### Other examples @@ -434,7 +459,7 @@ local-ai --models-path [--address
] [--threads @@ -463,6 +488,8 @@ You should see: └───────────────────────────────────────────────────┘ ``` +Note: the binary inside the image is rebuild at the start of the container to enable CPU optimizations for the execution environment, you can set the environment variable `REBUILD` to `false` to prevent this behavior. + ### Build locally @@ -567,6 +594,8 @@ Note: CuBLAS support is experimental, and has not been tested on real HW. please make BUILD_TYPE=cublas build ``` +More informations available in the upstream PR: https://github.com/ggerganov/llama.cpp/pull/1412 + ### Windows compatibility @@ -818,6 +847,109 @@ models +## LocalAI API endpoints + +Besides the OpenAI endpoints, there are additional LocalAI-only API endpoints. + +### Applying a model - `/models/apply` + +This endpoint can be used to install a model in runtime. + +
+ +LocalAI will create a batch process that downloads the required files from a model definition and automatically reload itself to include the new model. + +Input: `url`, `name` (optional), `files` (optional) + +```bash +curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d '{ + "url": "", + "name": "", + "files": [ + { + "uri": "", + "sha256": "", + "filename": "" + }, + "overrides": { "backend": "...", "f16": true } + ] + } +``` + +An optional, list of additional files can be specified to be downloaded within `files`. The `name` allows to override the model name. Finally it is possible to override the model config file with `override`. + +Returns an `uuid` and an `url` to follow up the state of the process: + +```json +{ "uuid":"251475c9-f666-11ed-95e0-9a8a4480ac58", "status":"http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58"} +``` + +To see a collection example of curated models definition files, see the [model-gallery](https://github.com/go-skynet/model-gallery). + +
+ +### Inquiry model job state `/models/jobs/` + +This endpoint returns the state of the batch job associated to a model +
+ +This endpoint can be used with the uuid returned by `/models/apply` to check a job state: + +```bash +curl http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58 +``` + +Returns a json containing the error, and if the job is being processed: + +```json +{"error":null,"processed":true,"message":"completed"} +``` + +
+ +## Clients + +OpenAI clients are already compatible with LocalAI by overriding the basePath, or the target URL. + +## Javascript + +
+ +https://github.com/openai/openai-node/ + +```javascript +import { Configuration, OpenAIApi } from 'openai'; + +const configuration = new Configuration({ + basePath: `http://localhost:8080/v1` +}); +const openai = new OpenAIApi(configuration); +``` + +
+ +## Python + +
+ +https://github.com/openai/openai-python + +Set the `OPENAI_API_BASE` environment variable, or by code: + +```python +import openai + +openai.api_base = "http://localhost:8080/v1" + +# create a chat completion +chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}]) + +# print the completion +print(completion.choices[0].message.content) +``` + +
+ ## Frequently asked questions Here are answers to some of the most common questions. diff --git a/api/api.go b/api/api.go index ecf56b09..b81a89f5 100644 --- a/api/api.go +++ b/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" model "github.com/go-skynet/LocalAI/pkg/model" @@ -12,7 +13,7 @@ import ( "github.com/rs/zerolog/log" ) -func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { +func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { zerolog.SetGlobalLevel(zerolog.InfoLevel) if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c })) } - cm := make(ConfigMerger) + cm := NewConfigMerger() if err := cm.LoadConfigs(loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } @@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c } if debug { - for k, v := range cm { - log.Debug().Msgf("Model: %s (config: %+v)", k, v) + for _, v := range cm.ListConfigs() { + cfg, _ := cm.GetConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } // Default middleware config app.Use(recover.New()) app.Use(cors.New()) + // LocalAI API endpoints + applier := newGalleryApplier(loader.ModelPath) + applier.start(c, cm) + app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C)) + app.Get("/models/jobs/:uuid", getOpStatus(applier)) + // openAI compatible API endpoint + + // chat app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // edit app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // completion app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // embeddings app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) - - // /v1/engines/{engine_id}/embeddings - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // audio app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // images app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir)) if imageDir != "" { app.Static("/generated-images", imageDir) } + // models app.Get("/v1/models", listModels(loader, cm)) app.Get("/models", listModels(loader, cm)) diff --git a/api/api_test.go b/api/api_test.go index f2af0388..f061527f 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -1,7 +1,12 @@ package api_test import ( + "bytes" "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" "os" "path/filepath" "runtime" @@ -11,21 +16,189 @@ import ( "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" ) +type modelApplyRequest struct { + URL string `json:"url"` + Name string `json:"name"` + Overrides map[string]string `json:"overrides"` +} + +func getModelStatus(url string) (response map[string]interface{}) { + // Create the HTTP request + resp, err := http.Get(url) + if err != nil { + fmt.Println("Error creating request:", err) + return + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + return + } + + // Unmarshal the response into a map[string]interface{} + err = json.Unmarshal(body, &response) + if err != nil { + fmt.Println("Error unmarshaling JSON response:", err) + return + } + return +} +func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { + + //url := "http://localhost:AI/models/apply" + + // Create the request payload + + payload, err := json.Marshal(request) + if err != nil { + fmt.Println("Error marshaling JSON:", err) + return + } + + // Create the HTTP request + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) + if err != nil { + fmt.Println("Error creating request:", err) + return + } + req.Header.Set("Content-Type", "application/json") + + // Make the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + fmt.Println("Error making request:", err) + return + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + return + } + + // Unmarshal the response into a map[string]interface{} + err = json.Unmarshal(body, &response) + if err != nil { + fmt.Println("Error unmarshaling JSON response:", err) + return + } + return +} + var _ = Describe("API test", func() { var app *fiber.App var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client + var c context.Context + var cancel context.CancelFunc + var tmpdir string + + Context("API with ephemeral models", func() { + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + + modelLoader = model.NewModelLoader(tmpdir) + c, cancel = context.WithCancel(context.Background()) + + app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + app.Shutdown() + os.RemoveAll(tmpdir) + }) + + Context("Applying models", func() { + It("overrides models", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + Name: "bert", + Overrides: map[string]string{ + "backend": "llama", + }, + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s").Should(Equal(true)) + + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["backend"]).To(Equal("llama")) + }) + It("apply models without overrides", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + Name: "bert", + Overrides: map[string]string{}, + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s").Should(Equal(true)) + + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["backend"]).To(Equal("bert-embeddings")) + }) + }) + }) + Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App("", modelLoader, 15, 1, 512, false, true, true, "") + c, cancel = context.WithCancel(context.Background()) + + app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -42,6 +215,7 @@ var _ = Describe("API test", func() { }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { + cancel() app.Shutdown() }) It("returns the models list", func() { @@ -140,7 +314,9 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") + c, cancel = context.WithCancel(context.Background()) + + app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -155,10 +331,10 @@ var _ = Describe("API test", func() { }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { + cancel() app.Shutdown() }) It("can generate chat completions from config file", func() { - models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) Expect(len(models.Models)).To(Equal(12)) diff --git a/api/config.go b/api/config.go index 7379978e..7e0d8264 100644 --- a/api/config.go +++ b/api/config.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -43,8 +44,16 @@ type TemplateConfig struct { Edit string `yaml:"edit"` } -type ConfigMerger map[string]Config +type ConfigMerger struct { + configs map[string]Config + sync.Mutex +} +func NewConfigMerger() *ConfigMerger { + return &ConfigMerger{ + configs: make(map[string]Config), + } +} func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) @@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) { } func (cm ConfigMerger) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() c, err := ReadConfigFile(file) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } for _, cc := range c { - cm[cc.Name] = *cc + cm.configs[cc.Name] = *cc } return nil } func (cm ConfigMerger) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() c, err := ReadConfig(file) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } - cm[c.Name] = *c + cm.configs[c.Name] = *c return nil } +func (cm ConfigMerger) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm ConfigMerger) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + func (cm ConfigMerger) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() files, err := ioutil.ReadDir(path) if err != nil { return err @@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error { } c, err := ReadConfig(filepath.Join(path, file.Name())) if err == nil { - cm[c.Name] = *c + cm.configs[c.Name] = *c } } @@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin return modelFile, input, nil } -func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { +func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { // Load a config file if present after the model name modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") if _, err := os.Stat(modelConfig); err == nil { @@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader } var config *Config - cfg, exists := cm[modelFile] + cfg, exists := cm.GetConfig(modelFile) if !exists { config = &Config{ OpenAIRequest: defaultRequest(modelFile), diff --git a/api/gallery.go b/api/gallery.go new file mode 100644 index 00000000..591b1b7a --- /dev/null +++ b/api/gallery.go @@ -0,0 +1,196 @@ +package api + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "gopkg.in/yaml.v3" +) + +type galleryOp struct { + req ApplyGalleryModelRequest + id string +} + +type galleryOpStatus struct { + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` +} + +type galleryApplier struct { + modelPath string + sync.Mutex + C chan galleryOp + statuses map[string]*galleryOpStatus +} + +func newGalleryApplier(modelPath string) *galleryApplier { + return &galleryApplier{ + modelPath: modelPath, + C: make(chan galleryOp), + statuses: make(map[string]*galleryOpStatus), + } +} +func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { + g.Lock() + defer g.Unlock() + g.statuses[s] = op +} + +func (g *galleryApplier) getstatus(s string) *galleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses[s] +} + +func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { + go func() { + for { + select { + case <-c.Done(): + return + case op := <-g.C: + g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) + + updateError := func(e error) { + g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) + } + + url, err := op.req.DecodeURL() + if err != nil { + updateError(err) + continue + } + + // Send a GET request to the URL + response, err := http.Get(url) + if err != nil { + updateError(err) + continue + } + defer response.Body.Close() + + // Read the response body + body, err := ioutil.ReadAll(response.Body) + if err != nil { + updateError(err) + continue + } + + // Unmarshal YAML data into a Config struct + var config gallery.Config + err = yaml.Unmarshal(body, &config) + if err != nil { + updateError(fmt.Errorf("failed to unmarshal YAML: %v", err)) + continue + } + + config.Files = append(config.Files, op.req.AdditionalFiles...) + + if err := gallery.Apply(g.modelPath, op.req.Name, &config, op.req.Overrides); err != nil { + updateError(err) + continue + } + + // Reload models + if err := cm.LoadConfigs(g.modelPath); err != nil { + updateError(err) + continue + } + + g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) + } + } + }() +} + +// endpoints + +type ApplyGalleryModelRequest struct { + URL string `json:"url"` + Name string `json:"name"` + Overrides map[string]interface{} `json:"overrides"` + AdditionalFiles []gallery.File `json:"files"` +} + +const ( + githubURI = "github:" +) + +func (request ApplyGalleryModelRequest) DecodeURL() (string, error) { + input := request.URL + var rawURL string + + if strings.HasPrefix(input, githubURI) { + parts := strings.Split(input, ":") + repoParts := strings.Split(parts[1], "@") + branch := "main" + + if len(repoParts) > 1 { + branch = repoParts[1] + } + + repoPath := strings.Split(repoParts[0], "/") + org := repoPath[0] + project := repoPath[1] + projectPath := strings.Join(repoPath[2:], "/") + + rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) + } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { + // Handle regular URLs + u, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + rawURL = u.String() + } else { + return "", fmt.Errorf("invalid URL format") + } + + return rawURL, nil +} + +func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + status := g.getstatus(c.Params("uuid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + + return c.JSON(status) + } +} + +func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(ApplyGalleryModelRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + g <- galleryOp{ + req: *input, + id: uuid.String(), + } + return c.JSON(struct { + ID string `json:"uuid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} diff --git a/api/gallery_test.go b/api/gallery_test.go new file mode 100644 index 00000000..1c92c0d5 --- /dev/null +++ b/api/gallery_test.go @@ -0,0 +1,30 @@ +package api_test + +import ( + . "github.com/go-skynet/LocalAI/api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Gallery API tests", func() { + Context("requests", func() { + It("parses github with a branch", func() { + req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + It("parses github without a branch", func() { + req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + It("parses URLS", func() { + req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + }) +}) diff --git a/api/openai.go b/api/openai.go index 52d65976..0a85349c 100644 --- a/api/openai.go +++ b/api/openai.go @@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest { } // https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) @@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } // https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) if err != nil { @@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } } -func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { @@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread } } -func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) if err != nil { @@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread * */ -func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { +func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { m, input, err := readInput(c, loader, false) if err != nil { @@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image } // https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { m, input, err := readInput(c, loader, false) if err != nil { @@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } } -func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error { +func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { models, err := loader.ListModels() if err != nil { @@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) } - for k := range cm { + for _, k := range cm.ListConfigs() { if _, exists := mm[k]; !exists { dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) } diff --git a/entrypoint.sh b/entrypoint.sh index aab14205..e7390e56 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -2,6 +2,8 @@ cd /build -make build +if [ "$REBUILD" != "false" ]; then + make rebuild +fi ./local-ai "$@" \ No newline at end of file diff --git a/examples/langchain-chroma/.env.example b/examples/langchain-chroma/.env.example new file mode 100644 index 00000000..37cda598 --- /dev/null +++ b/examples/langchain-chroma/.env.example @@ -0,0 +1,5 @@ +THREADS=4 +CONTEXT_SIZE=512 +MODELS_PATH=/models +DEBUG=true +# BUILD_TYPE=generic \ No newline at end of file diff --git a/examples/langchain-chroma/.gitignore b/examples/langchain-chroma/.gitignore new file mode 100644 index 00000000..3dc19014 --- /dev/null +++ b/examples/langchain-chroma/.gitignore @@ -0,0 +1,4 @@ +db/ +state_of_the_union.txt +models/bert +models/ggml-gpt4all-j \ No newline at end of file diff --git a/examples/langchain-chroma/README.md b/examples/langchain-chroma/README.md index 70e3f42b..9fd9e312 100644 --- a/examples/langchain-chroma/README.md +++ b/examples/langchain-chroma/README.md @@ -10,13 +10,20 @@ Download the models and start the API: # Clone LocalAI git clone https://github.com/go-skynet/LocalAI -cd LocalAI/examples/query_data +cd LocalAI/examples/langchain-chroma wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O models/bert wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j +# configure your .env +# NOTE: ensure that THREADS does not exceed your machine's CPU cores +mv .env.example .env + # start with docker-compose docker-compose up -d --build + +# tail the logs & wait until the build completes +docker logs -f langchain-chroma-api-1 ``` ### Python requirements @@ -29,6 +36,8 @@ pip install -r requirements.txt In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM. +Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. + ```bash export OPENAI_API_BASE=http://localhost:8080/v1 export OPENAI_API_KEY=sk- @@ -37,7 +46,7 @@ wget https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_ python store.py ``` -After it finishes, a directory "storage" will be created with the vector index database. +After it finishes, a directory "db" will be created with the vector index database. ## Query diff --git a/examples/langchain-chroma/docker-compose.yml b/examples/langchain-chroma/docker-compose.yml new file mode 100644 index 00000000..96ef540e --- /dev/null +++ b/examples/langchain-chroma/docker-compose.yml @@ -0,0 +1,15 @@ +version: '3.6' + +services: + api: + image: quay.io/go-skynet/local-ai:latest + build: + context: ../../ + dockerfile: Dockerfile + ports: + - 8080:8080 + env_file: + - ../../.env + volumes: + - ./models:/models:cached + command: ["/usr/bin/local-ai"] diff --git a/examples/langchain-chroma/models/embeddings.yaml b/examples/langchain-chroma/models/embeddings.yaml index 46a08502..536c8de1 100644 --- a/examples/langchain-chroma/models/embeddings.yaml +++ b/examples/langchain-chroma/models/embeddings.yaml @@ -1,5 +1,6 @@ name: text-embedding-ada-002 parameters: model: bert +threads: 4 backend: bert-embeddings embeddings: true diff --git a/examples/langchain-chroma/query.py b/examples/langchain-chroma/query.py index 2f7df507..33848818 100644 --- a/examples/langchain-chroma/query.py +++ b/examples/langchain-chroma/query.py @@ -2,8 +2,9 @@ import os from langchain.vectorstores import Chroma from langchain.embeddings import OpenAIEmbeddings -from langchain.llms import OpenAI -from langchain.chains import VectorDBQA +from langchain.chat_models import ChatOpenAI +from langchain.chains import RetrievalQA +from langchain.vectorstores.base import VectorStoreRetriever base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1') @@ -12,8 +13,10 @@ embedding = OpenAIEmbeddings() persist_directory = 'db' # Now we can load the persisted database from disk, and use it as normal. +llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path) vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding) -qa = VectorDBQA.from_chain_type(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path), chain_type="stuff", vectorstore=vectordb) +retriever = VectorStoreRetriever(vectorstore=vectordb) +qa = RetrievalQA.from_llm(llm=llm, retriever=retriever) query = "What the president said about taxes ?" print(qa.run(query)) diff --git a/examples/langchain-chroma/store.py b/examples/langchain-chroma/store.py index 127bb240..b9cbad0e 100755 --- a/examples/langchain-chroma/store.py +++ b/examples/langchain-chroma/store.py @@ -2,9 +2,7 @@ import os from langchain.vectorstores import Chroma from langchain.embeddings import OpenAIEmbeddings -from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter,CharacterTextSplitter -from langchain.llms import OpenAI -from langchain.chains import VectorDBQA +from langchain.text_splitter import CharacterTextSplitter from langchain.document_loaders import TextLoader base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1') @@ -14,7 +12,6 @@ loader = TextLoader('state_of_the_union.txt') documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=70) -#text_splitter = TokenTextSplitter() texts = text_splitter.split_documents(documents) # Embed and store the texts diff --git a/examples/langchain-python/README.md b/examples/langchain-python/README.md index a98c48f7..2472aab1 100644 --- a/examples/langchain-python/README.md +++ b/examples/langchain-python/README.md @@ -26,6 +26,7 @@ pip install langchain pip install openai export OPENAI_API_BASE=http://localhost:8080 +# Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. export OPENAI_API_KEY=sk- python test.py diff --git a/examples/query_data/README.md b/examples/query_data/README.md index f7a4e1fe..c4e384cd 100644 --- a/examples/query_data/README.md +++ b/examples/query_data/README.md @@ -35,6 +35,8 @@ docker-compose up -d --build In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM. +Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. + ```bash export OPENAI_API_BASE=http://localhost:8080/v1 export OPENAI_API_KEY=sk- diff --git a/examples/query_data/docker-compose.yml b/examples/query_data/docker-compose.yml index a59edfc4..cf76eb7f 100644 --- a/examples/query_data/docker-compose.yml +++ b/examples/query_data/docker-compose.yml @@ -4,7 +4,7 @@ services: api: image: quay.io/go-skynet/local-ai:latest build: - context: . + context: ../../ dockerfile: Dockerfile ports: - 8080:8080 diff --git a/examples/rwkv/.gitignore b/examples/rwkv/.gitignore new file mode 100644 index 00000000..ab3629c5 --- /dev/null +++ b/examples/rwkv/.gitignore @@ -0,0 +1,2 @@ +models/rwkv +models/rwkv.tokenizer.json \ No newline at end of file diff --git a/examples/rwkv/Dockerfile.build b/examples/rwkv/Dockerfile.build index c62024de..491f9ccd 100644 --- a/examples/rwkv/Dockerfile.build +++ b/examples/rwkv/Dockerfile.build @@ -1,5 +1,7 @@ FROM python +RUN apt-get update && apt-get -y install cmake + # convert the model (one-off) RUN pip3 install torch numpy diff --git a/go.mod b/go.mod index 1cd33ce8..adb9c45a 100644 --- a/go.mod +++ b/go.mod @@ -4,26 +4,27 @@ go 1.19 require ( github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56 - github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d + github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881 github.com/go-audio/wav v1.1.0 github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4 github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245 - github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c - github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84 - github.com/gofiber/fiber/v2 v2.45.0 + github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278 + github.com/gofiber/fiber/v2 v2.46.0 + github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/imdario/mergo v0.3.15 github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 - github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc + github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd github.com/onsi/ginkgo/v2 v2.9.5 - github.com/onsi/gomega v1.27.6 - github.com/otiai10/copy v1.11.0 + github.com/onsi/gomega v1.27.7 github.com/otiai10/openaigo v1.1.0 github.com/rs/zerolog v1.29.1 github.com/sashabaranov/go-openai v1.9.4 github.com/swaggo/swag v1.16.1 github.com/urfave/cli/v2 v2.25.3 github.com/valyala/fasthttp v1.47.0 + gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -43,7 +44,6 @@ require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.16.3 // indirect @@ -51,6 +51,7 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/otiai10/mint v1.5.1 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -64,5 +65,4 @@ require ( golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect golang.org/x/tools v0.9.1 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 09af76da..20a4b22b 100644 --- a/go.sum +++ b/go.sum @@ -16,18 +16,14 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be h1:3Hic97PY6hcw/SY44RuR7kyONkxd744RFeRrqckzwNQ= -github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= -github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2 h1:YNbUAyIRtaLODitigJU1EM5ubmMu5FmHtYAayJD6Vbg= -github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56 h1:s8/MZdicstKi5fn9D9mKGIQ/q6IWCYCk/BM68i8v51w= github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 h1:sMg/SgnMPS/HNUO/2kGm72vl8R9TmNIwgLFr2TNwR3g= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a h1:MlyiDLNCM/wjbv8U5Elj18NvaAgl61SGiRUpqQz5dfs= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d h1:uxKTbiRnplE2SubchneSf4NChtxLJtOy9VdHnQMT0d0= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520170006-429b9785c080 h1:W3itqKpRX9FhheKiAxdmuOBy/mjDfMf2G1vcuFIYqZc= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520170006-429b9785c080/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881 h1:dafqVivljYk51VLFnnpTXJnfWDe637EobWZ1l8PyEf8= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= @@ -46,31 +42,23 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7 github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f h1:GW8RQa1RVeDF1dOuAP/y6xWVC+BRtf9tJOuEza6Asbg= -github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf h1:VJfSn8hIDE+K5+h38M3iAyFXrxpRExMKRdTk33UDxsw= github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea h1:8Isk9D+Auth5OuXVAQPC3MO+5zF/2S7mvs2JZLw6a+8= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557 h1:LD66fKtvP2lmyuuKL8pBat/pVTKUbLs3L5fM/5lyi4w= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ= github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4 h1:+3KPDf4Wv1VHOkzAfZnlj9qakLSYggTpm80AswhD/FU= github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4/go.mod h1:VY0s5KoAI2jRCvQXKuDeEEe8KG7VaWifSNJSk+E1KtY= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245 h1:IcfYY5uH0DdDXEJKJ8bq0WZCd9guPPd3xllaWNy8LOk= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b h1:qqxrjY8fYDXQahmCMTCACahm1tbiqHLPUHALkFLyBfo= -github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I= -github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84 h1:f5iYF75bAr73Tl8AdtFD5Urs/2bsHKPh52K++jLbsfk= -github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84/go.mod h1:jxyQ26t1aKC5Gn782w9WWh5n1133PxCOfkuc01xM4RQ= +github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1 h1:i0oM2MERUgMIRmjOcv22TDQULxbmY8o9rZKLKKyWXLo= +github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= +github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278 h1:st4ow9JKy3UuhkwutrbWof2vMFU/YxwBCLYZ1IxJ2Po= +github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofiber/fiber/v2 v2.45.0 h1:p4RpkJT9GAW6parBSbcNFH2ApnAuW3OzaQzbOCoDu+s= github.com/gofiber/fiber/v2 v2.45.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc= +github.com/gofiber/fiber/v2 v2.46.0 h1:wkkWotblsGVlLjXj2dpgKQAYHtXumsK/HyFugQM68Ns= +github.com/gofiber/fiber/v2 v2.46.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -83,6 +71,8 @@ github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM= +github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY= @@ -105,23 +95,18 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9 github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24 h1:XfRD/bZom6u4zji7aB0urIVOsPe43KlkzSRrVhlzaOM= -github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 h1:KTkh3lOUsGqQyP4v+oa38sPFdrZtNnM4HaxTb3epdYs= github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc h1:OPavP/SUsVWVYPhSUZKZeX8yDSQzf4G+BmUmwzrLTyI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd h1:kMnZASxCNc8GsPuAV94tltEsfT6T+esuB+rgzdjwFVM= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= -github.com/otiai10/copy v1.11.0 h1:OKBD80J/mLBrwnzXqGtFCzprFSGioo30JcmR4APsNwc= -github.com/otiai10/copy v1.11.0/go.mod h1:rSaLseMUsZFFbsFGc7wCJnnkTAvdc5L6VWxPE4308Ww= +github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= +github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= +github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/otiai10/openaigo v1.1.0 h1:zRvGBqZUW5PCMgdkJNsPVTBd8tOLCMTipXE5wD2pdTg= github.com/otiai10/openaigo v1.1.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= @@ -137,8 +122,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM= -github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM= github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -182,8 +165,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -221,8 +202,6 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index 2490e198..f3ffc033 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "path/filepath" @@ -57,9 +58,9 @@ func main() { Value: ":8080", }, &cli.StringFlag{ - Name: "image-dir", + Name: "image-path", DefaultText: "Image directory", - EnvVars: []string{"IMAGE_DIR"}, + EnvVars: []string{"IMAGE_PATH"}, Value: "", }, &cli.IntFlag{ @@ -93,7 +94,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-dir")).Listen(ctx.String("address")) + return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) }, } diff --git a/pkg/gallery/gallery_suite_test.go b/pkg/gallery/gallery_suite_test.go new file mode 100644 index 00000000..44256bc2 --- /dev/null +++ b/pkg/gallery/gallery_suite_test.go @@ -0,0 +1,13 @@ +package gallery_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGallery(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Gallery test suite") +} diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go new file mode 100644 index 00000000..f4f86ae7 --- /dev/null +++ b/pkg/gallery/models.go @@ -0,0 +1,271 @@ +package gallery + +import ( + "crypto/sha256" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/imdario/mergo" + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v2" +) + +/* + +description: | + foo +license: "" + +urls: +- +- + +name: "bar" + +config_file: | + # Note, name will be injected. or generated by the alias wanted by the user + threads: 14 + +files: + - filename: "" + sha: "" + uri: "" + +prompt_templates: + - name: "" + content: "" + +*/ + +type Config struct { + Description string `yaml:"description"` + License string `yaml:"license"` + URLs []string `yaml:"urls"` + Name string `yaml:"name"` + ConfigFile string `yaml:"config_file"` + Files []File `yaml:"files"` + PromptTemplates []PromptTemplate `yaml:"prompt_templates"` +} + +type File struct { + Filename string `yaml:"filename" json:"filename"` + SHA256 string `yaml:"sha256" json:"sha256"` + URI string `yaml:"uri" json:"uri"` +} + +type PromptTemplate struct { + Name string `yaml:"name"` + Content string `yaml:"content"` +} + +func ReadConfigFile(filePath string) (*Config, error) { + // Read the YAML file + yamlFile, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read YAML file: %v", err) + } + + // Unmarshal YAML data into a Config struct + var config Config + err = yaml.Unmarshal(yamlFile, &config) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) + } + + return &config, nil +} + +func inTrustedRoot(path string, trustedRoot string) error { + for path != "/" { + path = filepath.Dir(path) + if path == trustedRoot { + return nil + } + } + return fmt.Errorf("path is outside of trusted root") +} + +func verifyPath(path, basePath string) error { + c := filepath.Clean(filepath.Join(basePath, path)) + return inTrustedRoot(c, basePath) +} + +func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error { + // Create base path if it doesn't exist + err := os.MkdirAll(basePath, 0755) + if err != nil { + return fmt.Errorf("failed to create base path: %v", err) + } + + if len(configOverrides) > 0 { + log.Debug().Msgf("Config overrides %+v", configOverrides) + } + + // Download files and verify their SHA + for _, file := range config.Files { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := verifyPath(file.Filename, basePath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(basePath, file.Filename) + + // Check if the file already exists + _, err := os.Stat(filePath) + if err == nil { + // File exists, check SHA + if file.SHA256 != "" { + // Verify SHA + calculatedSHA, err := calculateSHA(filePath) + if err != nil { + return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) + } + if calculatedSHA == file.SHA256 { + // SHA matches, skip downloading + log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) + continue + } + // SHA doesn't match, delete the file and download again + err = os.Remove(filePath) + if err != nil { + return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) + } + log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) + + } else { + // SHA is missing, skip downloading + log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) + continue + } + } else if !os.IsNotExist(err) { + // Error occurred while checking file existence + return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) + } + + log.Debug().Msgf("Downloading %q", file.URI) + + // Download file + resp, err := http.Get(file.URI) + if err != nil { + return fmt.Errorf("failed to download file %q: %v", file.Filename, err) + } + defer resp.Body.Close() + + // Create parent directory + err = os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) + } + + // Create and write file content + outFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %q: %v", file.Filename, err) + } + defer outFile.Close() + + if file.SHA256 != "" { + log.Debug().Msgf("Download and verifying %q", file.Filename) + + // Write file content and calculate SHA + hash := sha256.New() + _, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + + // Verify SHA + calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) + if calculatedSHA != file.SHA256 { + return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) + } + } else { + log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) + _, err = io.Copy(outFile, resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + } + + log.Debug().Msgf("File %q downloaded and verified", file.Filename) + } + + // Write prompt template contents to separate files + for _, template := range config.PromptTemplates { + if err := verifyPath(template.Name+".tmpl", basePath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(basePath, template.Name+".tmpl") + + // Create parent directory + err := os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) + } + // Create and write file content + err = os.WriteFile(filePath, []byte(template.Content), 0644) + if err != nil { + return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) + } + + log.Debug().Msgf("Prompt template %q written", template.Name) + } + + name := config.Name + if nameOverride != "" { + name = nameOverride + } + + if err := verifyPath(name+".yaml", basePath); err != nil { + return err + } + + configFilePath := filepath.Join(basePath, name+".yaml") + + // Read and update config file as map[string]interface{} + configMap := make(map[string]interface{}) + err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) + if err != nil { + return fmt.Errorf("failed to unmarshal config YAML: %v", err) + } + + configMap["name"] = name + + if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { + return err + } + + // Write updated config file + updatedConfigYAML, err := yaml.Marshal(configMap) + if err != nil { + return fmt.Errorf("failed to marshal updated config YAML: %v", err) + } + + err = os.WriteFile(configFilePath, updatedConfigYAML, 0644) + if err != nil { + return fmt.Errorf("failed to write updated config file: %v", err) + } + + log.Debug().Msgf("Written config file %s", configFilePath) + return nil +} + +func calculateSHA(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hash.Sum(nil)), nil +} diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go new file mode 100644 index 00000000..f0e580e9 --- /dev/null +++ b/pkg/gallery/models_test.go @@ -0,0 +1,94 @@ +package gallery_test + +import ( + "os" + "path/filepath" + + . "github.com/go-skynet/LocalAI/pkg/gallery" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" +) + +var _ = Describe("Model test", func() { + Context("Downloading", func() { + It("applies model correctly", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "", c, map[string]interface{}{}) + Expect(err).ToNot(HaveOccurred()) + + for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { + _, err = os.Stat(filepath.Join(tempdir, f)) + Expect(err).ToNot(HaveOccurred()) + } + + content := map[string]interface{}{} + + dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = yaml.Unmarshal(dat, content) + Expect(err).ToNot(HaveOccurred()) + + Expect(content["context_size"]).To(Equal(1024)) + }) + + It("renames model correctly", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "foo", c, map[string]interface{}{}) + Expect(err).ToNot(HaveOccurred()) + + for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { + _, err = os.Stat(filepath.Join(tempdir, f)) + Expect(err).ToNot(HaveOccurred()) + } + }) + + It("overrides parameters", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}) + Expect(err).ToNot(HaveOccurred()) + + for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { + _, err = os.Stat(filepath.Join(tempdir, f)) + Expect(err).ToNot(HaveOccurred()) + } + + content := map[string]interface{}{} + + dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = yaml.Unmarshal(dat, content) + Expect(err).ToNot(HaveOccurred()) + + Expect(content["backend"]).To(Equal("foo")) + }) + + It("catches path traversals", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 74c05f29..b5e43a38 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -164,11 +164,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla } func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) { - log.Debug().Msgf("Loading models greedly") + log.Debug().Msgf("Loading model '%s' greedly", modelFile) ml.mu.Lock() m, exists := ml.models[modelFile] if exists { + log.Debug().Msgf("Model '%s' already loaded", modelFile) ml.mu.Unlock() return m, nil } diff --git a/tests/fixtures/gallery_simple.yaml b/tests/fixtures/gallery_simple.yaml new file mode 100644 index 00000000..058733fe --- /dev/null +++ b/tests/fixtures/gallery_simple.yaml @@ -0,0 +1,40 @@ +name: "cerebras" +description: | + cerebras +license: "Apache 2.0" + +config_file: | + parameters: + model: cerebras + top_k: 80 + temperature: 0.2 + top_p: 0.7 + context_size: 1024 + stopwords: + - "HUMAN:" + - "GPT:" + roles: + user: "" + system: "" + template: + completion: "cerebras-completion" + chat: cerebras-chat + +files: + - filename: "cerebras" + sha256: "c947051ae4dba9530ca55d923a7a484acd65664c8633462c8ccd4bb7848f2c65" + uri: "https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerebras-111m-q4_2.bin" + +prompt_templates: + - name: "cerebras-completion" + content: | + Complete the prompt + ### Prompt: + {{.Input}} + ### Response: + - name: "cerebras-chat" + content: | + The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. + ### Prompt: + {{.Input}} + ### Response: \ No newline at end of file diff --git a/tests/fixtures/completion.tmpl b/tests/models_fixtures/completion.tmpl similarity index 100% rename from tests/fixtures/completion.tmpl rename to tests/models_fixtures/completion.tmpl diff --git a/tests/fixtures/config.yaml b/tests/models_fixtures/config.yaml similarity index 100% rename from tests/fixtures/config.yaml rename to tests/models_fixtures/config.yaml diff --git a/tests/fixtures/embeddings.yaml b/tests/models_fixtures/embeddings.yaml similarity index 100% rename from tests/fixtures/embeddings.yaml rename to tests/models_fixtures/embeddings.yaml diff --git a/tests/fixtures/ggml-gpt4all-j.tmpl b/tests/models_fixtures/ggml-gpt4all-j.tmpl similarity index 100% rename from tests/fixtures/ggml-gpt4all-j.tmpl rename to tests/models_fixtures/ggml-gpt4all-j.tmpl diff --git a/tests/fixtures/gpt4.yaml b/tests/models_fixtures/gpt4.yaml similarity index 100% rename from tests/fixtures/gpt4.yaml rename to tests/models_fixtures/gpt4.yaml diff --git a/tests/fixtures/gpt4_2.yaml b/tests/models_fixtures/gpt4_2.yaml similarity index 100% rename from tests/fixtures/gpt4_2.yaml rename to tests/models_fixtures/gpt4_2.yaml diff --git a/tests/fixtures/rwkv.yaml b/tests/models_fixtures/rwkv.yaml similarity index 100% rename from tests/fixtures/rwkv.yaml rename to tests/models_fixtures/rwkv.yaml diff --git a/tests/fixtures/rwkv_chat.tmpl b/tests/models_fixtures/rwkv_chat.tmpl similarity index 100% rename from tests/fixtures/rwkv_chat.tmpl rename to tests/models_fixtures/rwkv_chat.tmpl diff --git a/tests/fixtures/rwkv_completion.tmpl b/tests/models_fixtures/rwkv_completion.tmpl similarity index 100% rename from tests/fixtures/rwkv_completion.tmpl rename to tests/models_fixtures/rwkv_completion.tmpl diff --git a/tests/fixtures/whisper.yaml b/tests/models_fixtures/whisper.yaml similarity index 100% rename from tests/fixtures/whisper.yaml rename to tests/models_fixtures/whisper.yaml