diff --git a/.dockerignore b/.dockerignore index 41478502..e73b1f9d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,5 @@ +.git +.idea models examples/chatbot-ui/models examples/rwkv/models diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index eeada322..38eed85a 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -15,34 +15,42 @@ concurrency: jobs: docker: + strategy: + matrix: + include: + - build-type: '' + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '' + - build-type: 'cublas' + cuda-major-version: 11 + cuda-minor-version: 7 + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-cublas-cuda11' + - build-type: 'cublas' + cuda-major-version: 12 + cuda-minor-version: 1 + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-cublas-cuda12' runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - - name: Prepare - id: prep - run: | - DOCKER_IMAGE=quay.io/go-skynet/local-ai - VERSION=master - SHORTREF=${GITHUB_SHA::8} - - # If this is git tag, use the tag name as a docker tag - if [[ $GITHUB_REF == refs/tags/* ]]; then - VERSION=${GITHUB_REF#refs/tags/} - fi - TAGS="${DOCKER_IMAGE}:${VERSION},${DOCKER_IMAGE}:${SHORTREF}" - - # If the VERSION looks like a version number, assume that - # this is the most recent version of the image and also - # tag it 'latest'. - if [[ $VERSION =~ ^v[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then - TAGS="$TAGS,${DOCKER_IMAGE}:latest" - fi - - # Set output parameters. - echo ::set-output name=tags::${TAGS} - echo ::set-output name=docker_image::${DOCKER_IMAGE} + - name: Docker meta + id: meta + uses: docker/metadata-action@v4 + with: + images: quay.io/go-skynet/local-ai + tags: | + type=ref,event=branch + type=semver,pattern={{raw}} + type=sha + flavor: | + latest=${{ matrix.tag-latest }} + suffix=${{ matrix.tag-suffix }} - name: Set up QEMU uses: docker/setup-qemu-action@master @@ -60,23 +68,18 @@ jobs: registry: quay.io username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} - - name: Build - if: github.event_name != 'pull_request' + + - name: Build and push uses: docker/build-push-action@v4 with: builder: ${{ steps.buildx.outputs.name }} + build-args: | + BUILD_TYPE=${{ matrix.build-type }} + CUDA_MAJOR_VERSION=${{ matrix.cuda-major-version }} + CUDA_MINOR_VERSION=${{ matrix.cuda-minor-version }} context: . file: ./Dockerfile - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.prep.outputs.tags }} - - name: Build PRs - if: github.event_name == 'pull_request' - uses: docker/build-push-action@v4 - with: - builder: ${{ steps.buildx.outputs.name }} - context: . - file: ./Dockerfile - platforms: linux/amd64 - push: false - tags: ${{ steps.prep.outputs.tags }} \ No newline at end of file + platforms: ${{ matrix.platforms }} + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.gitignore b/.gitignore index 878047ee..20215af2 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ release/ # just in case .DS_Store +.idea diff --git a/Makefile b/Makefile index 73507e96..8c79251d 100644 --- a/Makefile +++ b/Makefile @@ -3,13 +3,13 @@ GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=62b6c079a47d6949c982ed8e684b94bdbf48b41c +GOLLAMA_VERSION?=10caf37d8b73386708b4373975b8917e6b212c0e GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_VERSION?=337c7fecacfa4ae6779046513ab090687a5b0ef6 GOGGMLTRANSFORMERS_VERSION?=13ccc22621bb21afecd38675a2b043498e2e756c RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=ccb05c3e1c6efd098017d114dcb58ab3262b40b2 -WHISPER_CPP_VERSION?=d7c936b44a80b8070676093fc00622333ba09cd3 +WHISPER_CPP_VERSION?=ce6f7470649f169027626dc92b3a2e39b4eff463 BERT_VERSION?=771b4a08597224b21cff070950ef4f68690e14ad BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f BUILD_TYPE?= diff --git a/README.md b/README.md index ae312d4f..45b3d40b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ LocalAI was created by [Ettore Di Giacinto](https://github.com/mudler/) and is a | ![Screenshot from 2023-04-26 23-59-55](https://user-images.githubusercontent.com/2420543/234715439-98d12e03-d3ce-4f94-ab54-2b256808e05e.png) | ![b6441997879](https://github.com/go-skynet/LocalAI/assets/2420543/d50af51c-51b7-4f39-b6c2-bf04c403894c) | -See the [Getting started](https://localai.io/basics/getting_started/index.html) and [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/) sections to learn how to use LocalAI. For a list of curated models check out the [model gallery](https://github.com/go-skynet/model-gallery). +See the [Getting started](https://localai.io/basics/getting_started/index.html) and [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/) sections to learn how to use LocalAI. For a list of curated models check out the [model gallery](https://localai.io/models/). ## News diff --git a/api/prediction.go b/api/prediction.go index 4ae1b69a..8aad4228 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/donomii/go-rwkv.cpp" + "github.com/go-skynet/LocalAI/pkg/langchain" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/stablediffusion" "github.com/go-skynet/bloomz.cpp" @@ -494,6 +495,23 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback model.SetTokenCallback(nil) return str, er } + case *langchain.HuggingFace: + fn = func() (string, error) { + + // Generate the prediction using the language model + predictOptions := []langchain.PredictOption{ + langchain.SetModel(c.Model), + langchain.SetMaxTokens(c.Maxtokens), + langchain.SetTemperature(c.Temperature), + langchain.SetStopWords(c.StopWords), + } + + pred, er := model.PredictHuggingFace(s, predictOptions...) + if er != nil { + return "", er + } + return pred.Completion, nil + } } return func() (string, error) { diff --git a/examples/langchain-huggingface/README.md b/examples/langchain-huggingface/README.md new file mode 100644 index 00000000..23fdcd32 --- /dev/null +++ b/examples/langchain-huggingface/README.md @@ -0,0 +1,68 @@ +# Data query example + +Example of integration with HuggingFace Inference API with help of [langchaingo](https://github.com/tmc/langchaingo). + +## Setup + +Download the LocalAI and start the API: + +```bash +# Clone LocalAI +git clone https://github.com/go-skynet/LocalAI + +cd LocalAI/examples/langchain-huggingface + +docker-compose up -d +``` + +Node: Ensure you've set `HUGGINGFACEHUB_API_TOKEN` environment variable, you can generate it +on [Settings / Access Tokens](https://huggingface.co/settings/tokens) page of HuggingFace site. + +This is an example `.env` file for LocalAI: + +```ini +MODELS_PATH=/models +CONTEXT_SIZE=512 +HUGGINGFACEHUB_API_TOKEN=hg_123456 +``` + +## Using remote models + +Now you can use any remote models available via HuggingFace API, for example let's enable using of +[gpt2](https://huggingface.co/gpt2) model in `gpt-3.5-turbo.yaml` config: + +```yml +name: gpt-3.5-turbo +parameters: + model: gpt2 + top_k: 80 + temperature: 0.2 + top_p: 0.7 +context_size: 1024 +backend: "langchain-huggingface" +stopwords: +- "HUMAN:" +- "GPT:" +roles: + user: " " + system: " " +template: + completion: completion + chat: gpt4all +``` + +Here is you can see in field `parameters.model` equal `gpt2` and `backend` equal `langchain-huggingface`. + +## How to use + +```shell +# Now API is accessible at localhost:8080 +curl http://localhost:8080/v1/models +# {"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model"}]} + +curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{ + "model": "gpt-3.5-turbo", + "prompt": "A long time ago in a galaxy far, far away", + "temperature": 0.7 +}' +``` \ No newline at end of file diff --git a/examples/langchain-huggingface/docker-compose.yml b/examples/langchain-huggingface/docker-compose.yml new file mode 100644 index 00000000..96ef540e --- /dev/null +++ b/examples/langchain-huggingface/docker-compose.yml @@ -0,0 +1,15 @@ +version: '3.6' + +services: + api: + image: quay.io/go-skynet/local-ai:latest + build: + context: ../../ + dockerfile: Dockerfile + ports: + - 8080:8080 + env_file: + - ../../.env + volumes: + - ./models:/models:cached + command: ["/usr/bin/local-ai"] diff --git a/examples/langchain-huggingface/models/completion.tmpl b/examples/langchain-huggingface/models/completion.tmpl new file mode 100644 index 00000000..1e04a465 --- /dev/null +++ b/examples/langchain-huggingface/models/completion.tmpl @@ -0,0 +1 @@ +{{.Input}} diff --git a/examples/langchain-huggingface/models/gpt-3.5-turbo.yaml b/examples/langchain-huggingface/models/gpt-3.5-turbo.yaml new file mode 100644 index 00000000..76e9ab18 --- /dev/null +++ b/examples/langchain-huggingface/models/gpt-3.5-turbo.yaml @@ -0,0 +1,17 @@ +name: gpt-3.5-turbo +parameters: + model: gpt2 + top_k: 80 + temperature: 0.2 + top_p: 0.7 +context_size: 1024 +backend: "langchain-huggingface" +stopwords: +- "HUMAN:" +- "GPT:" +roles: + user: " " + system: " " +template: + completion: completion + chat: gpt4all diff --git a/examples/langchain-huggingface/models/gpt4all.tmpl b/examples/langchain-huggingface/models/gpt4all.tmpl new file mode 100644 index 00000000..f76b080a --- /dev/null +++ b/examples/langchain-huggingface/models/gpt4all.tmpl @@ -0,0 +1,4 @@ +The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. +### Prompt: +{{.Input}} +### Response: diff --git a/go.mod b/go.mod index d79fdd63..7ab82410 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,20 @@ module github.com/go-skynet/LocalAI go 1.19 require ( - github.com/donomii/go-rwkv.cpp v0.0.0-20230529074347-ccb05c3e1c6e - github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230528233858-d7c936b44a80 + github.com/donomii/go-rwkv.cpp v0.0.0-20230531084548-c43cdf5fc5bf + github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230601065548-3f7436e8a096 github.com/go-audio/wav v1.1.0 github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf github.com/go-skynet/go-bert.cpp v0.0.0-20230529074307-771b4a085972 - github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230529215936-13ccc22621bb - github.com/go-skynet/go-llama.cpp v0.0.0-20230530191504-62b6c079a47d + github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230531065233-17b065584ef8 + github.com/go-skynet/go-llama.cpp v0.0.0-20230531065249-10caf37d8b73 github.com/gofiber/fiber/v2 v2.46.0 github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 github.com/imdario/mergo v0.3.16 github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 - github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922 - github.com/onsi/ginkgo/v2 v2.9.5 + github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 + github.com/onsi/ginkgo/v2 v2.9.7 github.com/onsi/gomega v1.27.7 github.com/otiai10/openaigo v1.1.0 github.com/rs/zerolog v1.29.1 @@ -59,6 +59,7 @@ require ( github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect github.com/tinylib/msgp v1.1.8 // indirect + github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect diff --git a/go.sum b/go.sum index b5452689..872940b2 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56 h1:s8/MZdicstK github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/donomii/go-rwkv.cpp v0.0.0-20230529074347-ccb05c3e1c6e h1:YbcLoxAwS0r7otEqU/d8bArubmfEJaG7dZPp0Aa52Io= github.com/donomii/go-rwkv.cpp v0.0.0-20230529074347-ccb05c3e1c6e/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= +github.com/donomii/go-rwkv.cpp v0.0.0-20230531084548-c43cdf5fc5bf h1:upCz8WYdzMeJg0qywUaVaGndY+niuicj5j6V4pvhNS4= +github.com/donomii/go-rwkv.cpp v0.0.0-20230531084548-c43cdf5fc5bf/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881 h1:dafqVivljYk51VLFnnpTXJnfWDe637EobWZ1l8PyEf8= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230520182345-041be06d5881/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230523110439-77eab3fbfe5e h1:4PMorQuoUGAXmIzCtnNOHaasyLokXdgd8jUWwsraFTo= @@ -30,6 +32,10 @@ github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230527074028-9b926844e3ae github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230527074028-9b926844e3ae/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230528233858-d7c936b44a80 h1:IeeVcNaQHdcG+GPg+meOPFvtonvO8p/HBzTrZGjpWZk= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230528233858-d7c936b44a80/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230531071314-ce6f7470649f h1:oGTI2SlcA7oGPFsmkS1m8psq3uKNnhhJ/MZ2ZWVZDe0= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230531071314-ce6f7470649f/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230601065548-3f7436e8a096 h1:TD7v8FnwWCWlOsrkpnumsbxsflyhTI3rSm2HInqqSAI= +github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230601065548-3f7436e8a096/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= @@ -64,6 +70,8 @@ github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230529072326-695f97befe14 github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230529072326-695f97befe14/go.mod h1:Rz967+t+aY6S+TBiW/WI8FM/C1WEMM+DamSMtKRxVAM= github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230529215936-13ccc22621bb h1:slNlMT8xB6w0QaMroTsqkNzNovUOEkpNpCawB7IjBFY= github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230529215936-13ccc22621bb/go.mod h1:SI+oF2+THMydq8Vo4+EzKJaQwtfWOy+lr7yWPP6FR2U= +github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230531065233-17b065584ef8 h1:LK1DAgJsNMRUWaPpFOnE8XSF70UBybr3zGOvzP8Pdok= +github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230531065233-17b065584ef8/go.mod h1:/JbU8HZU+tUOp+1bQAeXf3AyRXm+p3UwhccoJwCTI9A= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230523153133-3eb3a32c0874 h1:/6QWh2oarU7iPSpXj/3bLlkKptyxjKTRrNtGUrh8vhI= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230523153133-3eb3a32c0874/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-llama.cpp v0.0.0-20230520155239-ccf23adfb278 h1:st4ow9JKy3UuhkwutrbWof2vMFU/YxwBCLYZ1IxJ2Po= @@ -78,6 +86,8 @@ github.com/go-skynet/go-llama.cpp v0.0.0-20230529221033-4afcaf28f36f h1:HmXiNF9S github.com/go-skynet/go-llama.cpp v0.0.0-20230529221033-4afcaf28f36f/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= github.com/go-skynet/go-llama.cpp v0.0.0-20230530191504-62b6c079a47d h1:daPcVEptc/6arcS/QV4QDCdYiwMGCiiR5rnzUs63WK0= github.com/go-skynet/go-llama.cpp v0.0.0-20230530191504-62b6c079a47d/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= +github.com/go-skynet/go-llama.cpp v0.0.0-20230531065249-10caf37d8b73 h1:swwsrYpPYOsyGFrX/0nhaYa93aHH6I61HpSJpQkN1tY= +github.com/go-skynet/go-llama.cpp v0.0.0-20230531065249-10caf37d8b73/go.mod h1:ddYIvPZyj3Vf4XkfZimVRRehZu2isd0JXfK3EemVQPk= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -139,8 +149,12 @@ github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230526132403-a6f3e9 github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230526132403-a6f3e94458e2/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922 h1:teYhrXxFY28gyBm6QMcYewA0KvLXqkUsgxJcYelaxbg= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 h1:99cF+V5wk7IInDAEM9HAlSHdLf/xoJR529Wr8lAG5KQ= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= +github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= @@ -178,6 +192,8 @@ github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKik github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= +github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841 h1:IVlfKPZzq3W1G+CkhZgN5VjmHnAeB3YqEvxyNPPCZXY= +github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI= github.com/urfave/cli/v2 v2.25.3 h1:VJkt6wvEBOoSjPFQvOkv6iWIrsJyCrKGtCtxXWwmGeY= github.com/urfave/cli/v2 v2.25.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/pkg/langchain/huggingface.go b/pkg/langchain/huggingface.go new file mode 100644 index 00000000..38c55cd5 --- /dev/null +++ b/pkg/langchain/huggingface.go @@ -0,0 +1,47 @@ +package langchain + +import ( + "context" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/huggingface" +) + +type HuggingFace struct { + modelPath string +} + +func NewHuggingFace(repoId string) (*HuggingFace, error) { + return &HuggingFace{ + modelPath: repoId, + }, nil +} + +func (s *HuggingFace) PredictHuggingFace(text string, opts ...PredictOption) (*Predict, error) { + po := NewPredictOptions(opts...) + + // Init client + llm, err := huggingface.New() + if err != nil { + return nil, err + } + + // Convert from LocalAI to LangChainGo format of options + co := []llms.CallOption{ + llms.WithModel(po.Model), + llms.WithMaxTokens(po.MaxTokens), + llms.WithTemperature(po.Temperature), + llms.WithStopWords(po.StopWords), + } + + // Call Inference API + ctx := context.Background() + completion, err := llm.Call(ctx, text, co...) + if err != nil { + return nil, err + } + + return &Predict{ + Completion: completion, + }, nil +} diff --git a/pkg/langchain/langchain.go b/pkg/langchain/langchain.go new file mode 100644 index 00000000..737bc4b5 --- /dev/null +++ b/pkg/langchain/langchain.go @@ -0,0 +1,57 @@ +package langchain + +type PredictOptions struct { + Model string `json:"model"` + // MaxTokens is the maximum number of tokens to generate. + MaxTokens int `json:"max_tokens"` + // Temperature is the temperature for sampling, between 0 and 1. + Temperature float64 `json:"temperature"` + // StopWords is a list of words to stop on. + StopWords []string `json:"stop_words"` +} + +type PredictOption func(p *PredictOptions) + +var DefaultOptions = PredictOptions{ + Model: "gpt2", + MaxTokens: 200, + Temperature: 0.96, + StopWords: nil, +} + +type Predict struct { + Completion string +} + +func SetModel(model string) PredictOption { + return func(o *PredictOptions) { + o.Model = model + } +} + +func SetTemperature(temperature float64) PredictOption { + return func(o *PredictOptions) { + o.Temperature = temperature + } +} + +func SetMaxTokens(maxTokens int) PredictOption { + return func(o *PredictOptions) { + o.MaxTokens = maxTokens + } +} + +func SetStopWords(stopWords []string) PredictOption { + return func(o *PredictOptions) { + o.StopWords = stopWords + } +} + +// NewPredictOptions Create a new PredictOptions object with the given options. +func NewPredictOptions(opts ...PredictOption) PredictOptions { + p := DefaultOptions + for _, opt := range opts { + opt(&p) + } + return p +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index dc593a7c..518e59f1 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -7,6 +7,7 @@ import ( rwkv "github.com/donomii/go-rwkv.cpp" whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/langchain" "github.com/go-skynet/LocalAI/pkg/stablediffusion" bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" @@ -36,6 +37,7 @@ const ( RwkvBackend = "rwkv" WhisperBackend = "whisper" StableDiffusionBackend = "stablediffusion" + LCHuggingFaceBackend = "langchain-huggingface" ) var backends []string = []string{ @@ -100,6 +102,10 @@ var whisperModel = func(modelFile string) (interface{}, error) { return whisper.New(modelFile) } +var lcHuggingFace = func(repoId string) (interface{}, error) { + return langchain.NewHuggingFace(repoId) +} + func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { return func(s string) (interface{}, error) { return llama.New(s, opts...) @@ -159,6 +165,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads)) case WhisperBackend: return ml.LoadModel(modelFile, whisperModel) + case LCHuggingFaceBackend: + return ml.LoadModel(modelFile, lcHuggingFace) default: return nil, fmt.Errorf("backend unsupported: %s", backendString) }