mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-30 06:30:43 +00:00
Merge branch 'master' into fix/stream_tokens_usage
This commit is contained in:
commit
9f6be2be12
79 changed files with 2716 additions and 941 deletions
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
|
@ -1,6 +1,11 @@
|
||||||
enhancements:
|
enhancements:
|
||||||
- head-branch: ['^feature', 'feature']
|
- head-branch: ['^feature', 'feature']
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- any:
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file: 'Makefile'
|
||||||
|
|
||||||
kind/documentation:
|
kind/documentation:
|
||||||
- any:
|
- any:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
|
|
17
.github/workflows/bump_deps.yaml
vendored
17
.github/workflows/bump_deps.yaml
vendored
|
@ -12,23 +12,14 @@ jobs:
|
||||||
- repository: "ggerganov/llama.cpp"
|
- repository: "ggerganov/llama.cpp"
|
||||||
variable: "CPPLLAMA_VERSION"
|
variable: "CPPLLAMA_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
- repository: "go-skynet/go-ggml-transformers.cpp"
|
|
||||||
variable: "GOGGMLTRANSFORMERS_VERSION"
|
|
||||||
branch: "master"
|
|
||||||
- repository: "donomii/go-rwkv.cpp"
|
|
||||||
variable: "RWKV_VERSION"
|
|
||||||
branch: "main"
|
|
||||||
- repository: "ggerganov/whisper.cpp"
|
- repository: "ggerganov/whisper.cpp"
|
||||||
variable: "WHISPER_CPP_VERSION"
|
variable: "WHISPER_CPP_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
- repository: "go-skynet/go-bert.cpp"
|
- repository: "PABannier/bark.cpp"
|
||||||
variable: "BERT_VERSION"
|
variable: "BARKCPP_VERSION"
|
||||||
branch: "master"
|
|
||||||
- repository: "go-skynet/bloomz.cpp"
|
|
||||||
variable: "BLOOMZ_VERSION"
|
|
||||||
branch: "main"
|
branch: "main"
|
||||||
- repository: "mudler/go-ggllm.cpp"
|
- repository: "leejet/stable-diffusion.cpp"
|
||||||
variable: "GOGGLLM_VERSION"
|
variable: "STABLEDIFFUSION_GGML_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
- repository: "mudler/go-stable-diffusion"
|
- repository: "mudler/go-stable-diffusion"
|
||||||
variable: "STABLEDIFFUSION_VERSION"
|
variable: "STABLEDIFFUSION_VERSION"
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,6 +2,7 @@
|
||||||
/sources/
|
/sources/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.a
|
*.a
|
||||||
|
*.o
|
||||||
get-sources
|
get-sources
|
||||||
prepare-sources
|
prepare-sources
|
||||||
/backend/cpp/llama/grpc-server
|
/backend/cpp/llama/grpc-server
|
||||||
|
|
72
Makefile
72
Makefile
|
@ -8,7 +8,7 @@ DETECT_LIBS?=true
|
||||||
# llama.cpp versions
|
# llama.cpp versions
|
||||||
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
|
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
|
||||||
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
|
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
|
||||||
CPPLLAMA_VERSION?=3ad5451f3b75809e3033e4e577b9f60bcaf6676a
|
CPPLLAMA_VERSION?=08ea539df211e46bb4d0dd275e541cb591d5ebc8
|
||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
|
||||||
|
@ -26,6 +26,14 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
|
||||||
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
||||||
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
||||||
|
|
||||||
|
# bark.cpp
|
||||||
|
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
|
||||||
|
BARKCPP_VERSION?=v1.0.0
|
||||||
|
|
||||||
|
# stablediffusion.cpp (ggml)
|
||||||
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
|
STABLEDIFFUSION_GGML_VERSION?=9578fdcc4632dc3de5565f28e2fb16b7c18f8d48
|
||||||
|
|
||||||
ONNX_VERSION?=1.20.0
|
ONNX_VERSION?=1.20.0
|
||||||
ONNX_ARCH?=x64
|
ONNX_ARCH?=x64
|
||||||
ONNX_OS?=linux
|
ONNX_OS?=linux
|
||||||
|
@ -201,6 +209,14 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
||||||
|
|
||||||
|
ifeq ($(ONNX_OS),linux)
|
||||||
|
ifeq ($(ONNX_ARCH),x64)
|
||||||
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
|
||||||
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/stablediffusion-ggml
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
||||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||||
|
@ -236,6 +252,23 @@ sources/go-llama.cpp:
|
||||||
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
|
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
|
||||||
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
||||||
|
|
||||||
|
## bark.cpp
|
||||||
|
sources/bark.cpp:
|
||||||
|
git clone --recursive $(BARKCPP_REPO) sources/bark.cpp && \
|
||||||
|
cd sources/bark.cpp && \
|
||||||
|
git checkout $(BARKCPP_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
sources/bark.cpp/build/libbark.a: sources/bark.cpp
|
||||||
|
cd sources/bark.cpp && \
|
||||||
|
mkdir -p build && \
|
||||||
|
cd build && \
|
||||||
|
cmake $(CMAKE_ARGS) .. && \
|
||||||
|
cmake --build . --config Release
|
||||||
|
|
||||||
|
backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
|
||||||
|
$(MAKE) -C backend/go/bark libbark.a
|
||||||
|
|
||||||
## go-piper
|
## go-piper
|
||||||
sources/go-piper:
|
sources/go-piper:
|
||||||
mkdir -p sources/go-piper
|
mkdir -p sources/go-piper
|
||||||
|
@ -249,7 +282,7 @@ sources/go-piper:
|
||||||
sources/go-piper/libpiper_binding.a: sources/go-piper
|
sources/go-piper/libpiper_binding.a: sources/go-piper
|
||||||
$(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o
|
$(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o
|
||||||
|
|
||||||
## stable diffusion
|
## stable diffusion (onnx)
|
||||||
sources/go-stable-diffusion:
|
sources/go-stable-diffusion:
|
||||||
mkdir -p sources/go-stable-diffusion
|
mkdir -p sources/go-stable-diffusion
|
||||||
cd sources/go-stable-diffusion && \
|
cd sources/go-stable-diffusion && \
|
||||||
|
@ -262,6 +295,30 @@ sources/go-stable-diffusion:
|
||||||
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
|
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
|
||||||
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
||||||
|
|
||||||
|
## stablediffusion (ggml)
|
||||||
|
sources/stablediffusion-ggml.cpp:
|
||||||
|
git clone --recursive $(STABLEDIFFUSION_GGML_REPO) sources/stablediffusion-ggml.cpp && \
|
||||||
|
cd sources/stablediffusion-ggml.cpp && \
|
||||||
|
git checkout $(STABLEDIFFUSION_GGML_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a: sources/stablediffusion-ggml.cpp
|
||||||
|
cd sources/stablediffusion-ggml.cpp && \
|
||||||
|
mkdir -p build && \
|
||||||
|
cd build && \
|
||||||
|
cmake $(CMAKE_ARGS) .. && \
|
||||||
|
cmake --build . --config Release
|
||||||
|
|
||||||
|
backend/go/image/stablediffusion-ggml/libsd.a: sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a
|
||||||
|
$(MAKE) -C backend/go/image/stablediffusion-ggml libsd.a
|
||||||
|
|
||||||
|
backend-assets/grpc/stablediffusion-ggml: backend/go/image/stablediffusion-ggml/libsd.a backend-assets/grpc
|
||||||
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ LIBRARY_PATH=$(CURDIR)/backend/go/image/stablediffusion-ggml/ \
|
||||||
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion-ggml ./backend/go/image/stablediffusion-ggml/
|
||||||
|
ifneq ($(UPX),)
|
||||||
|
$(UPX) backend-assets/grpc/stablediffusion-ggml
|
||||||
|
endif
|
||||||
|
|
||||||
sources/onnxruntime:
|
sources/onnxruntime:
|
||||||
mkdir -p sources/onnxruntime
|
mkdir -p sources/onnxruntime
|
||||||
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
|
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
|
||||||
|
@ -302,7 +359,7 @@ sources/whisper.cpp:
|
||||||
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
||||||
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
|
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
|
||||||
|
|
||||||
get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
get-sources: sources/go-llama.cpp sources/go-piper sources/stablediffusion-ggml.cpp sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
||||||
|
|
||||||
replace:
|
replace:
|
||||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
||||||
|
@ -343,7 +400,9 @@ clean: ## Remove build related file
|
||||||
rm -rf release/
|
rm -rf release/
|
||||||
rm -rf backend-assets/*
|
rm -rf backend-assets/*
|
||||||
$(MAKE) -C backend/cpp/grpc clean
|
$(MAKE) -C backend/cpp/grpc clean
|
||||||
|
$(MAKE) -C backend/go/bark clean
|
||||||
$(MAKE) -C backend/cpp/llama clean
|
$(MAKE) -C backend/cpp/llama clean
|
||||||
|
$(MAKE) -C backend/go/image/stablediffusion-ggml clean
|
||||||
rm -rf backend/cpp/llama-* || true
|
rm -rf backend/cpp/llama-* || true
|
||||||
$(MAKE) dropreplace
|
$(MAKE) dropreplace
|
||||||
$(MAKE) protogen-clean
|
$(MAKE) protogen-clean
|
||||||
|
@ -792,6 +851,13 @@ ifneq ($(UPX),)
|
||||||
$(UPX) backend-assets/grpc/llama-ggml
|
$(UPX) backend-assets/grpc/llama-ggml
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
|
||||||
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
|
||||||
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
|
||||||
|
ifneq ($(UPX),)
|
||||||
|
$(UPX) backend-assets/grpc/bark-cpp
|
||||||
|
endif
|
||||||
|
|
||||||
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
|
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
|
||||||
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
|
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
||||||
|
|
|
@ -92,6 +92,8 @@ local-ai run oci://localai/phi-2:latest
|
||||||
|
|
||||||
## 📰 Latest project news
|
## 📰 Latest project news
|
||||||
|
|
||||||
|
- Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 )
|
||||||
|
- Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 )
|
||||||
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
- Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204
|
||||||
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
- Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples)
|
||||||
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
||||||
|
|
|
@ -240,6 +240,11 @@ message ModelOptions {
|
||||||
|
|
||||||
repeated string LoraAdapters = 60;
|
repeated string LoraAdapters = 60;
|
||||||
repeated float LoraScales = 61;
|
repeated float LoraScales = 61;
|
||||||
|
|
||||||
|
repeated string Options = 62;
|
||||||
|
|
||||||
|
string CacheTypeKey = 63;
|
||||||
|
string CacheTypeValue = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Result {
|
message Result {
|
||||||
|
|
|
@ -681,7 +681,6 @@ struct llama_server_context
|
||||||
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
||||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||||
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
|
@ -1213,13 +1212,12 @@ struct llama_server_context
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
{"penalize_nl", slot.sparams.penalize_nl},
|
|
||||||
{"stop", slot.params.antiprompt},
|
{"stop", slot.params.antiprompt},
|
||||||
{"n_predict", slot.params.n_predict},
|
{"n_predict", slot.params.n_predict},
|
||||||
{"n_keep", params.n_keep},
|
{"n_keep", params.n_keep},
|
||||||
{"ignore_eos", slot.sparams.ignore_eos},
|
{"ignore_eos", slot.sparams.ignore_eos},
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
// {"logit_bias", slot.sparams.logit_bias},
|
// {"logit_bias", slot.sparams.logit_bias},
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
|
@ -2112,7 +2110,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
||||||
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
// slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
// slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
// slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
// slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
||||||
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
// slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||||
// slot->params.seed = json_value(data, "seed", default_params.seed);
|
// slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||||
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
// slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
|
@ -2135,7 +2132,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
||||||
data["mirostat"] = predict->mirostat();
|
data["mirostat"] = predict->mirostat();
|
||||||
data["mirostat_tau"] = predict->mirostattau();
|
data["mirostat_tau"] = predict->mirostattau();
|
||||||
data["mirostat_eta"] = predict->mirostateta();
|
data["mirostat_eta"] = predict->mirostateta();
|
||||||
data["penalize_nl"] = predict->penalizenl();
|
|
||||||
data["n_keep"] = predict->nkeep();
|
data["n_keep"] = predict->nkeep();
|
||||||
data["seed"] = predict->seed();
|
data["seed"] = predict->seed();
|
||||||
data["grammar"] = predict->grammar();
|
data["grammar"] = predict->grammar();
|
||||||
|
@ -2181,7 +2177,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
||||||
// llama.params.sparams.mirostat = predict->mirostat();
|
// llama.params.sparams.mirostat = predict->mirostat();
|
||||||
// llama.params.sparams.mirostat_tau = predict->mirostattau();
|
// llama.params.sparams.mirostat_tau = predict->mirostattau();
|
||||||
// llama.params.sparams.mirostat_eta = predict->mirostateta();
|
// llama.params.sparams.mirostat_eta = predict->mirostateta();
|
||||||
// llama.params.sparams.penalize_nl = predict->penalizenl();
|
|
||||||
// llama.params.n_keep = predict->nkeep();
|
// llama.params.n_keep = predict->nkeep();
|
||||||
// llama.params.seed = predict->seed();
|
// llama.params.seed = predict->seed();
|
||||||
// llama.params.sparams.grammar = predict->grammar();
|
// llama.params.sparams.grammar = predict->grammar();
|
||||||
|
@ -2228,6 +2223,35 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
const std::vector<ggml_type> kv_cache_types = {
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
GGML_TYPE_F16,
|
||||||
|
GGML_TYPE_BF16,
|
||||||
|
GGML_TYPE_Q8_0,
|
||||||
|
GGML_TYPE_Q4_0,
|
||||||
|
GGML_TYPE_Q4_1,
|
||||||
|
GGML_TYPE_IQ4_NL,
|
||||||
|
GGML_TYPE_Q5_0,
|
||||||
|
GGML_TYPE_Q5_1,
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||||
|
for (const auto & type : kv_cache_types) {
|
||||||
|
if (ggml_type_name(type) == s) {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unsupported cache type: " + s);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_all_kv_cache_types() {
|
||||||
|
std::ostringstream msg;
|
||||||
|
for (const auto & type : kv_cache_types) {
|
||||||
|
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
|
||||||
|
}
|
||||||
|
return msg.str();
|
||||||
|
}
|
||||||
|
|
||||||
static void params_parse(const backend::ModelOptions* request,
|
static void params_parse(const backend::ModelOptions* request,
|
||||||
common_params & params) {
|
common_params & params) {
|
||||||
|
|
||||||
|
@ -2241,6 +2265,12 @@ static void params_parse(const backend::ModelOptions* request,
|
||||||
}
|
}
|
||||||
// params.model_alias ??
|
// params.model_alias ??
|
||||||
params.model_alias = request->modelfile();
|
params.model_alias = request->modelfile();
|
||||||
|
if (!request->cachetypekey().empty()) {
|
||||||
|
params.cache_type_k = kv_cache_type_from_str(request->cachetypekey());
|
||||||
|
}
|
||||||
|
if (!request->cachetypevalue().empty()) {
|
||||||
|
params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue());
|
||||||
|
}
|
||||||
params.n_ctx = request->contextsize();
|
params.n_ctx = request->contextsize();
|
||||||
//params.memory_f16 = request->f16memory();
|
//params.memory_f16 = request->f16memory();
|
||||||
params.cpuparams.n_threads = request->threads();
|
params.cpuparams.n_threads = request->threads();
|
||||||
|
|
25
backend/go/bark/Makefile
Normal file
25
backend/go/bark/Makefile
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
INCLUDE_PATH := $(abspath ./)
|
||||||
|
LIBRARY_PATH := $(abspath ./)
|
||||||
|
|
||||||
|
AR?=ar
|
||||||
|
|
||||||
|
BUILD_TYPE?=
|
||||||
|
# keep standard at C11 and C++11
|
||||||
|
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
|
||||||
|
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm
|
||||||
|
|
||||||
|
# warnings
|
||||||
|
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||||
|
|
||||||
|
gobark.o:
|
||||||
|
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)
|
||||||
|
|
||||||
|
libbark.a: gobark.o
|
||||||
|
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
|
||||||
|
$(AR) rcs libbark.a gobark.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f gobark.o libbark.a
|
85
backend/go/bark/gobark.cpp
Normal file
85
backend/go/bark/gobark.cpp
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "bark.h"
|
||||||
|
#include "gobark.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
struct bark_context *c;
|
||||||
|
|
||||||
|
void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
|
||||||
|
if (step == bark_encoding_step::SEMANTIC) {
|
||||||
|
printf("\rGenerating semantic tokens... %d%%", progress);
|
||||||
|
} else if (step == bark_encoding_step::COARSE) {
|
||||||
|
printf("\rGenerating coarse tokens... %d%%", progress);
|
||||||
|
} else if (step == bark_encoding_step::FINE) {
|
||||||
|
printf("\rGenerating fine tokens... %d%%", progress);
|
||||||
|
}
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
int load_model(char *model) {
|
||||||
|
// initialize bark context
|
||||||
|
struct bark_context_params ctx_params = bark_context_default_params();
|
||||||
|
bark_params params;
|
||||||
|
|
||||||
|
params.model_path = model;
|
||||||
|
|
||||||
|
// ctx_params.verbosity = verbosity;
|
||||||
|
ctx_params.progress_callback = bark_print_progress_callback;
|
||||||
|
ctx_params.progress_callback_user_data = nullptr;
|
||||||
|
|
||||||
|
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
|
||||||
|
if (!bctx) {
|
||||||
|
fprintf(stderr, "%s: Could not load model\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
c = bctx;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tts(char *text,int threads, char *dst ) {
|
||||||
|
|
||||||
|
ggml_time_init();
|
||||||
|
const int64_t t_main_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
// generate audio
|
||||||
|
if (!bark_generate_audio(c, text, threads)) {
|
||||||
|
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float *audio_data = bark_get_audio_data(c);
|
||||||
|
if (audio_data == NULL) {
|
||||||
|
fprintf(stderr, "%s: Could not get audio data\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int audio_arr_size = bark_get_audio_data_size(c);
|
||||||
|
|
||||||
|
std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);
|
||||||
|
|
||||||
|
write_wav_on_disk(audio_arr, dst);
|
||||||
|
|
||||||
|
// report timing
|
||||||
|
{
|
||||||
|
const int64_t t_main_end_us = ggml_time_us();
|
||||||
|
const int64_t t_load_us = bark_get_load_time(c);
|
||||||
|
const int64_t t_eval_us = bark_get_eval_time(c);
|
||||||
|
|
||||||
|
printf("\n\n");
|
||||||
|
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
|
||||||
|
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
|
||||||
|
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int unload() {
|
||||||
|
bark_free(c);
|
||||||
|
}
|
||||||
|
|
52
backend/go/bark/gobark.go
Normal file
52
backend/go/bark/gobark.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
|
||||||
|
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
|
||||||
|
// #include <gobark.h>
|
||||||
|
// #include <stdlib.h>
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bark struct {
|
||||||
|
base.SingleThread
|
||||||
|
threads int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Bark) Load(opts *pb.ModelOptions) error {
|
||||||
|
|
||||||
|
sd.threads = int(opts.Threads)
|
||||||
|
|
||||||
|
modelFile := C.CString(opts.ModelFile)
|
||||||
|
defer C.free(unsafe.Pointer(modelFile))
|
||||||
|
|
||||||
|
ret := C.load_model(modelFile)
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("inference failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Bark) TTS(opts *pb.TTSRequest) error {
|
||||||
|
t := C.CString(opts.Text)
|
||||||
|
defer C.free(unsafe.Pointer(t))
|
||||||
|
|
||||||
|
dst := C.CString(opts.Dst)
|
||||||
|
defer C.free(unsafe.Pointer(dst))
|
||||||
|
|
||||||
|
threads := C.int(sd.threads)
|
||||||
|
|
||||||
|
ret := C.tts(t, threads, dst)
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("inference failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
8
backend/go/bark/gobark.h
Normal file
8
backend/go/bark/gobark.h
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
int load_model(char *model);
|
||||||
|
int tts(char *text,int threads, char *dst );
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
20
backend/go/bark/main.go
Normal file
20
backend/go/bark/main.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &Bark{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
21
backend/go/image/stablediffusion-ggml/Makefile
Normal file
21
backend/go/image/stablediffusion-ggml/Makefile
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
INCLUDE_PATH := $(abspath ./)
|
||||||
|
LIBRARY_PATH := $(abspath ./)
|
||||||
|
|
||||||
|
AR?=ar
|
||||||
|
|
||||||
|
BUILD_TYPE?=
|
||||||
|
# keep standard at C11 and C++11
|
||||||
|
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/ggml/include -I$(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp -O3 -DNDEBUG -std=c++17 -fPIC
|
||||||
|
|
||||||
|
# warnings
|
||||||
|
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||||
|
|
||||||
|
gosd.o:
|
||||||
|
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.o -c
|
||||||
|
|
||||||
|
libsd.a: gosd.o
|
||||||
|
cp $(INCLUDE_PATH)/../../../../sources/stablediffusion-ggml.cpp/build/libstable-diffusion.a ./libsd.a
|
||||||
|
$(AR) rcs libsd.a gosd.o
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f gosd.o libsd.a
|
228
backend/go/image/stablediffusion-ggml/gosd.cpp
Normal file
228
backend/go/image/stablediffusion-ggml/gosd.cpp
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <time.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "gosd.h"
|
||||||
|
|
||||||
|
// #include "preprocessing.hpp"
|
||||||
|
#include "flux.hpp"
|
||||||
|
#include "stable-diffusion.h"
|
||||||
|
|
||||||
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_STATIC
|
||||||
|
#include "stb_image.h"
|
||||||
|
|
||||||
|
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_WRITE_STATIC
|
||||||
|
#include "stb_image_write.h"
|
||||||
|
|
||||||
|
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||||
|
#define STB_IMAGE_RESIZE_STATIC
|
||||||
|
#include "stb_image_resize.h"
|
||||||
|
|
||||||
|
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||||
|
const char* sample_method_str[] = {
|
||||||
|
"euler_a",
|
||||||
|
"euler",
|
||||||
|
"heun",
|
||||||
|
"dpm2",
|
||||||
|
"dpm++2s_a",
|
||||||
|
"dpm++2m",
|
||||||
|
"dpm++2mv2",
|
||||||
|
"ipndm",
|
||||||
|
"ipndm_v",
|
||||||
|
"lcm",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||||
|
const char* schedule_str[] = {
|
||||||
|
"default",
|
||||||
|
"discrete",
|
||||||
|
"karras",
|
||||||
|
"exponential",
|
||||||
|
"ays",
|
||||||
|
"gits",
|
||||||
|
};
|
||||||
|
|
||||||
|
sd_ctx_t* sd_c;
|
||||||
|
|
||||||
|
sample_method_t sample_method;
|
||||||
|
|
||||||
|
int load_model(char *model, char* options[], int threads, int diff) {
|
||||||
|
fprintf (stderr, "Loading model!\n");
|
||||||
|
|
||||||
|
char *stableDiffusionModel = "";
|
||||||
|
if (diff == 1 ) {
|
||||||
|
stableDiffusionModel = model;
|
||||||
|
model = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode options. Options are in form optname:optvale, or if booleans only optname.
|
||||||
|
char *clip_l_path = "";
|
||||||
|
char *clip_g_path = "";
|
||||||
|
char *t5xxl_path = "";
|
||||||
|
char *vae_path = "";
|
||||||
|
char *scheduler = "";
|
||||||
|
char *sampler = "";
|
||||||
|
|
||||||
|
// If options is not NULL, parse options
|
||||||
|
for (int i = 0; options[i] != NULL; i++) {
|
||||||
|
char *optname = strtok(options[i], ":");
|
||||||
|
char *optval = strtok(NULL, ":");
|
||||||
|
if (optval == NULL) {
|
||||||
|
optval = "true";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!strcmp(optname, "clip_l_path")) {
|
||||||
|
clip_l_path = optval;
|
||||||
|
}
|
||||||
|
if (!strcmp(optname, "clip_g_path")) {
|
||||||
|
clip_g_path = optval;
|
||||||
|
}
|
||||||
|
if (!strcmp(optname, "t5xxl_path")) {
|
||||||
|
t5xxl_path = optval;
|
||||||
|
}
|
||||||
|
if (!strcmp(optname, "vae_path")) {
|
||||||
|
vae_path = optval;
|
||||||
|
}
|
||||||
|
if (!strcmp(optname, "scheduler")) {
|
||||||
|
scheduler = optval;
|
||||||
|
}
|
||||||
|
if (!strcmp(optname, "sampler")) {
|
||||||
|
sampler = optval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample_method_found = -1;
|
||||||
|
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
|
||||||
|
if (!strcmp(sampler, sample_method_str[m])) {
|
||||||
|
sample_method_found = m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sample_method_found == -1) {
|
||||||
|
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
|
||||||
|
sample_method_found = EULER_A;
|
||||||
|
}
|
||||||
|
sample_method = (sample_method_t)sample_method_found;
|
||||||
|
|
||||||
|
int schedule_found = -1;
|
||||||
|
for (int d = 0; d < N_SCHEDULES; d++) {
|
||||||
|
if (!strcmp(scheduler, schedule_str[d])) {
|
||||||
|
schedule_found = d;
|
||||||
|
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schedule_found == -1) {
|
||||||
|
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
|
||||||
|
schedule_found = DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
|
schedule_t schedule = (schedule_t)schedule_found;
|
||||||
|
|
||||||
|
fprintf (stderr, "Creating context\n");
|
||||||
|
sd_ctx_t* sd_ctx = new_sd_ctx(model,
|
||||||
|
clip_l_path,
|
||||||
|
clip_g_path,
|
||||||
|
t5xxl_path,
|
||||||
|
stableDiffusionModel,
|
||||||
|
vae_path,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
threads,
|
||||||
|
SD_TYPE_COUNT,
|
||||||
|
STD_DEFAULT_RNG,
|
||||||
|
schedule,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false);
|
||||||
|
|
||||||
|
if (sd_ctx == NULL) {
|
||||||
|
fprintf (stderr, "failed loading model (generic error)\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
fprintf (stderr, "Created context: OK\n");
|
||||||
|
|
||||||
|
sd_c = sd_ctx;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) {
|
||||||
|
|
||||||
|
sd_image_t* results;
|
||||||
|
|
||||||
|
std::vector<int> skip_layers = {7, 8, 9};
|
||||||
|
|
||||||
|
fprintf (stderr, "Generating image\n");
|
||||||
|
|
||||||
|
results = txt2img(sd_c,
|
||||||
|
text,
|
||||||
|
negativeText,
|
||||||
|
-1, //clip_skip
|
||||||
|
cfg_scale, // sfg_scale
|
||||||
|
3.5f,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
sample_method,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
1,
|
||||||
|
NULL,
|
||||||
|
0.9f,
|
||||||
|
20.f,
|
||||||
|
false,
|
||||||
|
"",
|
||||||
|
skip_layers.data(),
|
||||||
|
skip_layers.size(),
|
||||||
|
0,
|
||||||
|
0.01,
|
||||||
|
0.2);
|
||||||
|
|
||||||
|
if (results == NULL) {
|
||||||
|
fprintf (stderr, "NO results\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (results[0].data == NULL) {
|
||||||
|
fprintf (stderr, "Results with no data\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf (stderr, "Writing PNG\n");
|
||||||
|
|
||||||
|
fprintf (stderr, "DST: %s\n", dst);
|
||||||
|
fprintf (stderr, "Width: %d\n", results[0].width);
|
||||||
|
fprintf (stderr, "Height: %d\n", results[0].height);
|
||||||
|
fprintf (stderr, "Channel: %d\n", results[0].channel);
|
||||||
|
fprintf (stderr, "Data: %p\n", results[0].data);
|
||||||
|
|
||||||
|
stbi_write_png(dst, results[0].width, results[0].height, results[0].channel,
|
||||||
|
results[0].data, 0, NULL);
|
||||||
|
fprintf (stderr, "Saved resulting image to '%s'\n", dst);
|
||||||
|
|
||||||
|
// TODO: free results. Why does it crash?
|
||||||
|
|
||||||
|
free(results[0].data);
|
||||||
|
results[0].data = NULL;
|
||||||
|
free(results);
|
||||||
|
fprintf (stderr, "gen_image is done", dst);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int unload() {
|
||||||
|
free_sd_ctx(sd_c);
|
||||||
|
}
|
||||||
|
|
96
backend/go/image/stablediffusion-ggml/gosd.go
Normal file
96
backend/go/image/stablediffusion-ggml/gosd.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
// #cgo CXXFLAGS: -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/thirdparty -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp -I${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/ggml/include
|
||||||
|
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src/ggml-cpu -L${SRCDIR}/../../../../sources/stablediffusion-ggml.cpp/build/ggml/src -lsd -lstdc++ -lm -lggml -lggml-base -lggml-cpu -lgomp
|
||||||
|
// #include <gosd.h>
|
||||||
|
// #include <stdlib.h>
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SDGGML struct {
|
||||||
|
base.SingleThread
|
||||||
|
threads int
|
||||||
|
sampleMethod string
|
||||||
|
cfgScale float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
|
||||||
|
|
||||||
|
sd.threads = int(opts.Threads)
|
||||||
|
|
||||||
|
modelFile := C.CString(opts.ModelFile)
|
||||||
|
defer C.free(unsafe.Pointer(modelFile))
|
||||||
|
|
||||||
|
var options **C.char
|
||||||
|
// prepare the options array to pass to C
|
||||||
|
|
||||||
|
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
|
||||||
|
length := C.size_t(len(opts.Options))
|
||||||
|
options = (**C.char)(C.malloc(length * size))
|
||||||
|
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0:len(opts.Options):len(opts.Options)]
|
||||||
|
|
||||||
|
var diffusionModel int
|
||||||
|
|
||||||
|
var oo []string
|
||||||
|
for _, op := range opts.Options {
|
||||||
|
if op == "diffusion_model" {
|
||||||
|
diffusionModel = 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's an option path, we resolve absolute path from the model path
|
||||||
|
if strings.Contains(op, ":") && strings.Contains(op, "path") {
|
||||||
|
data := strings.Split(op, ":")
|
||||||
|
data[1] = filepath.Join(opts.ModelPath, data[1])
|
||||||
|
if err := utils.VerifyPath(data[1], opts.ModelPath); err == nil {
|
||||||
|
oo = append(oo, strings.Join(data, ":"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oo = append(oo, op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Options: %+v\n", oo)
|
||||||
|
|
||||||
|
for i, x := range oo {
|
||||||
|
view[i] = C.CString(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
sd.cfgScale = opts.CFGScale
|
||||||
|
|
||||||
|
ret := C.load_model(modelFile, options, C.int(opts.Threads), C.int(diffusionModel))
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("could not load model")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
||||||
|
t := C.CString(opts.PositivePrompt)
|
||||||
|
defer C.free(unsafe.Pointer(t))
|
||||||
|
|
||||||
|
dst := C.CString(opts.Dst)
|
||||||
|
defer C.free(unsafe.Pointer(dst))
|
||||||
|
|
||||||
|
negative := C.CString(opts.NegativePrompt)
|
||||||
|
defer C.free(unsafe.Pointer(negative))
|
||||||
|
|
||||||
|
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale))
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("inference failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
8
backend/go/image/stablediffusion-ggml/gosd.h
Normal file
8
backend/go/image/stablediffusion-ggml/gosd.h
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
int load_model(char *model, char* options[], int threads, int diffusionModel);
|
||||||
|
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed, char *dst, float cfg_scale);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
20
backend/go/image/stablediffusion-ggml/main.go
Normal file
20
backend/go/image/stablediffusion-ggml/main.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &SDGGML{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,4 +2,4 @@
|
||||||
intel-extension-for-pytorch
|
intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
|
@ -1,6 +1,6 @@
|
||||||
accelerate
|
accelerate
|
||||||
auto-gptq==0.7.1
|
auto-gptq==0.7.1
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
transformers
|
transformers
|
|
@ -3,6 +3,6 @@ intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
|
@ -1,4 +1,4 @@
|
||||||
bark==0.1.5
|
bark==0.1.5
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
|
@ -17,6 +17,9 @@
|
||||||
# LIMIT_TARGETS="cublas12"
|
# LIMIT_TARGETS="cublas12"
|
||||||
# source $(dirname $0)/../common/libbackend.sh
|
# source $(dirname $0)/../common/libbackend.sh
|
||||||
#
|
#
|
||||||
|
|
||||||
|
PYTHON_VERSION="3.10"
|
||||||
|
|
||||||
function init() {
|
function init() {
|
||||||
# Name of the backend (directory name)
|
# Name of the backend (directory name)
|
||||||
BACKEND_NAME=${PWD##*/}
|
BACKEND_NAME=${PWD##*/}
|
||||||
|
@ -88,7 +91,7 @@ function getBuildProfile() {
|
||||||
# always result in an activated virtual environment
|
# always result in an activated virtual environment
|
||||||
function ensureVenv() {
|
function ensureVenv() {
|
||||||
if [ ! -d "${EDIR}/venv" ]; then
|
if [ ! -d "${EDIR}/venv" ]; then
|
||||||
uv venv ${EDIR}/venv
|
uv venv --python ${PYTHON_VERSION} ${EDIR}/venv
|
||||||
echo "virtualenv created"
|
echo "virtualenv created"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
grpcio-tools
|
grpcio-tools
|
|
@ -3,7 +3,7 @@ intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
||||||
coqui-tts
|
coqui-tts
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
packaging==24.1
|
packaging==24.1
|
|
@ -3,7 +3,7 @@ intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
||||||
diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
setuptools
|
setuptools
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
pillow
|
pillow
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
wheel
|
wheel
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
|
@ -2,7 +2,7 @@
|
||||||
intel-extension-for-pytorch
|
intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
librosa==0.9.1
|
librosa==0.9.1
|
||||||
faster-whisper==0.9.0
|
faster-whisper==0.9.0
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
librosa
|
librosa
|
||||||
faster-whisper
|
faster-whisper
|
||||||
|
@ -18,3 +18,4 @@ jieba==0.42.1
|
||||||
gradio==3.48.0
|
gradio==3.48.0
|
||||||
langid==1.1.6
|
langid==1.1.6
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
|
setuptools
|
|
@ -3,6 +3,5 @@ intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
|
@ -1,3 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
certifi
|
certifi
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
|
setuptools
|
|
@ -5,4 +5,4 @@ accelerate
|
||||||
torch
|
torch
|
||||||
rerankers[transformers]
|
rerankers[transformers]
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
|
@ -1,3 +1,3 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
|
@ -2,7 +2,7 @@
|
||||||
intel-extension-for-pytorch
|
intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
||||||
accelerate
|
accelerate
|
||||||
sentence-transformers==3.3.1
|
sentence-transformers==3.3.1
|
||||||
transformers
|
transformers
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
datasets
|
datasets
|
||||||
|
|
|
@ -4,4 +4,4 @@ transformers
|
||||||
accelerate
|
accelerate
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
scipy==1.14.0
|
scipy==1.14.0
|
||||||
certifi
|
certifi
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
|
@ -4,4 +4,3 @@ accelerate
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
|
|
@ -1,3 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
|
setuptools
|
|
@ -22,7 +22,7 @@ if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE}" == "xtrue" ]; then
|
||||||
git clone https://github.com/vllm-project/vllm
|
git clone https://github.com/vllm-project/vllm
|
||||||
fi
|
fi
|
||||||
pushd vllm
|
pushd vllm
|
||||||
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.0 protobuf bitsandbytes
|
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.1 protobuf bitsandbytes
|
||||||
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
VLLM_TARGET_DEVICE=cpu python setup.py install
|
VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||||
popd
|
popd
|
||||||
|
|
|
@ -4,5 +4,5 @@ accelerate
|
||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools
|
||||||
bitsandbytes
|
bitsandbytes
|
|
@ -1,4 +1,4 @@
|
||||||
grpcio==1.68.0
|
grpcio==1.68.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
|
@ -1,38 +0,0 @@
|
||||||
package core
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
"github.com/mudler/LocalAI/core/services"
|
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy
|
|
||||||
// Perhaps a proper DI system is worth it in the future, but for now keep things simple.
|
|
||||||
type Application struct {
|
|
||||||
|
|
||||||
// Application-Level Config
|
|
||||||
ApplicationConfig *config.ApplicationConfig
|
|
||||||
// ApplicationState *ApplicationState
|
|
||||||
|
|
||||||
// Core Low-Level Services
|
|
||||||
BackendConfigLoader *config.BackendConfigLoader
|
|
||||||
ModelLoader *model.ModelLoader
|
|
||||||
|
|
||||||
// Backend Services
|
|
||||||
// EmbeddingsBackendService *backend.EmbeddingsBackendService
|
|
||||||
// ImageGenerationBackendService *backend.ImageGenerationBackendService
|
|
||||||
// LLMBackendService *backend.LLMBackendService
|
|
||||||
// TranscriptionBackendService *backend.TranscriptionBackendService
|
|
||||||
// TextToSpeechBackendService *backend.TextToSpeechBackendService
|
|
||||||
|
|
||||||
// LocalAI System Services
|
|
||||||
BackendMonitorService *services.BackendMonitorService
|
|
||||||
GalleryService *services.GalleryService
|
|
||||||
LocalAIMetricsService *services.LocalAIMetricsService
|
|
||||||
// OpenAIService *services.OpenAIService
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO [NEXT PR?]: Break up ApplicationConfig.
|
|
||||||
// Migrate over stuff that is not set via config at all - especially runtime stuff
|
|
||||||
type ApplicationState struct {
|
|
||||||
}
|
|
39
core/application/application.go
Normal file
39
core/application/application.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package application
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Application struct {
|
||||||
|
backendLoader *config.BackendConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
applicationConfig *config.ApplicationConfig
|
||||||
|
templatesEvaluator *templates.Evaluator
|
||||||
|
}
|
||||||
|
|
||||||
|
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||||
|
return &Application{
|
||||||
|
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
||||||
|
modelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||||
|
applicationConfig: appConfig,
|
||||||
|
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) BackendLoader() *config.BackendConfigLoader {
|
||||||
|
return a.backendLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) ModelLoader() *model.ModelLoader {
|
||||||
|
return a.modelLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) ApplicationConfig() *config.ApplicationConfig {
|
||||||
|
return a.applicationConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Application) TemplatesEvaluator() *templates.Evaluator {
|
||||||
|
return a.templatesEvaluator
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package startup
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -8,8 +8,8 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
|
@ -1,15 +1,15 @@
|
||||||
package startup
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core"
|
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
"github.com/mudler/LocalAI/pkg/assets"
|
"github.com/mudler/LocalAI/pkg/assets"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/library"
|
"github.com/mudler/LocalAI/pkg/library"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
pkgStartup "github.com/mudler/LocalAI/pkg/startup"
|
||||||
|
@ -17,8 +17,9 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
|
func New(opts ...config.AppOption) (*Application, error) {
|
||||||
options := config.NewApplicationConfig(opts...)
|
options := config.NewApplicationConfig(opts...)
|
||||||
|
application := newApplication(options)
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||||
|
@ -36,28 +37,28 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
|
|
||||||
// Make sure directories exists
|
// Make sure directories exists
|
||||||
if options.ModelPath == "" {
|
if options.ModelPath == "" {
|
||||||
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
|
return nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||||
}
|
}
|
||||||
err = os.MkdirAll(options.ModelPath, 0750)
|
err = os.MkdirAll(options.ModelPath, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||||
}
|
}
|
||||||
if options.ImageDir != "" {
|
if options.ImageDir != "" {
|
||||||
err := os.MkdirAll(options.ImageDir, 0750)
|
err := os.MkdirAll(options.ImageDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.AudioDir != "" {
|
if options.AudioDir != "" {
|
||||||
err := os.MkdirAll(options.AudioDir, 0750)
|
err := os.MkdirAll(options.AudioDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.UploadDir != "" {
|
if options.UploadDir != "" {
|
||||||
err := os.MkdirAll(options.UploadDir, 0750)
|
err := os.MkdirAll(options.UploadDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,39 +66,36 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
log.Error().Err(err).Msg("error installing models")
|
log.Error().Err(err).Msg("error installing models")
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := config.NewBackendConfigLoader(options.ModelPath)
|
|
||||||
ml := model.NewModelLoader(options.ModelPath)
|
|
||||||
|
|
||||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||||
|
|
||||||
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
if err := application.BackendLoader().LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||||
log.Error().Err(err).Msg("error loading config files")
|
log.Error().Err(err).Msg("error loading config files")
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
if options.ConfigFile != "" {
|
||||||
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
if err := application.BackendLoader().LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||||
log.Error().Err(err).Msg("error loading config file")
|
log.Error().Err(err).Msg("error loading config file")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cl.Preload(options.ModelPath); err != nil {
|
if err := application.BackendLoader().Preload(options.ModelPath); err != nil {
|
||||||
log.Error().Err(err).Msg("error downloading models")
|
log.Error().Err(err).Msg("error downloading models")
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
if options.PreloadJSONModels != "" {
|
||||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
if options.PreloadModelsFromPath != "" {
|
||||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, options.EnforcePredownloadScans, options.Galleries); err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Debug {
|
if options.Debug {
|
||||||
for _, v := range cl.GetAllBackendConfigs() {
|
for _, v := range application.BackendLoader().GetAllBackendConfigs() {
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,7 +121,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
go func() {
|
go func() {
|
||||||
<-options.Context.Done()
|
<-options.Context.Done()
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
err := ml.StopAllGRPC()
|
err := application.ModelLoader().StopAllGRPC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||||
}
|
}
|
||||||
|
@ -131,12 +129,12 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
|
|
||||||
if options.WatchDog {
|
if options.WatchDog {
|
||||||
wd := model.NewWatchDog(
|
wd := model.NewWatchDog(
|
||||||
ml,
|
application.ModelLoader(),
|
||||||
options.WatchDogBusyTimeout,
|
options.WatchDogBusyTimeout,
|
||||||
options.WatchDogIdleTimeout,
|
options.WatchDogIdleTimeout,
|
||||||
options.WatchDogBusy,
|
options.WatchDogBusy,
|
||||||
options.WatchDogIdle)
|
options.WatchDogIdle)
|
||||||
ml.SetWatchDog(wd)
|
application.ModelLoader().SetWatchDog(wd)
|
||||||
go wd.Run()
|
go wd.Run()
|
||||||
go func() {
|
go func() {
|
||||||
<-options.Context.Done()
|
<-options.Context.Done()
|
||||||
|
@ -147,7 +145,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
|
|
||||||
if options.LoadToMemory != nil {
|
if options.LoadToMemory != nil {
|
||||||
for _, m := range options.LoadToMemory {
|
for _, m := range options.LoadToMemory {
|
||||||
cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath,
|
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
|
||||||
config.LoadOptionDebug(options.Debug),
|
config.LoadOptionDebug(options.Debug),
|
||||||
config.LoadOptionThreads(options.Threads),
|
config.LoadOptionThreads(options.Threads),
|
||||||
config.LoadOptionContextSize(options.ContextSize),
|
config.LoadOptionContextSize(options.ContextSize),
|
||||||
|
@ -155,7 +153,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
config.ModelPath(options.ModelPath),
|
config.ModelPath(options.ModelPath),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
|
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
|
||||||
|
@ -163,9 +161,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
o := backend.ModelOptions(*cfg, options)
|
o := backend.ModelOptions(*cfg, options)
|
||||||
|
|
||||||
var backendErr error
|
var backendErr error
|
||||||
_, backendErr = ml.Load(o...)
|
_, backendErr = application.ModelLoader().Load(o...)
|
||||||
if backendErr != nil {
|
if backendErr != nil {
|
||||||
return nil, nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -174,7 +172,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||||
startWatcher(options)
|
startWatcher(options)
|
||||||
|
|
||||||
log.Info().Msg("core/startup process completed!")
|
log.Info().Msg("core/startup process completed!")
|
||||||
return cl, ml, options, nil
|
return application, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startWatcher(options *config.ApplicationConfig) {
|
func startWatcher(options *config.ApplicationConfig) {
|
||||||
|
@ -201,32 +199,3 @@ func startWatcher(options *config.ApplicationConfig) {
|
||||||
log.Error().Err(err).Msg("failed creating watcher")
|
log.Error().Err(err).Msg("failed creating watcher")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In Lieu of a proper DI framework, this function wires up the Application manually.
|
|
||||||
// This is in core/startup rather than core/state.go to keep package references clean!
|
|
||||||
func createApplication(appConfig *config.ApplicationConfig) *core.Application {
|
|
||||||
app := &core.Application{
|
|
||||||
ApplicationConfig: appConfig,
|
|
||||||
BackendConfigLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
|
|
||||||
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
|
|
||||||
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
|
||||||
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
|
|
||||||
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
|
|
||||||
|
|
||||||
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return app
|
|
||||||
}
|
|
|
@ -122,7 +122,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||||
SchedulerType: c.Diffusers.SchedulerType,
|
SchedulerType: c.Diffusers.SchedulerType,
|
||||||
PipelineType: c.Diffusers.PipelineType,
|
PipelineType: c.Diffusers.PipelineType,
|
||||||
CFGScale: c.Diffusers.CFGScale,
|
CFGScale: c.CFGScale,
|
||||||
LoraAdapter: c.LoraAdapter,
|
LoraAdapter: c.LoraAdapter,
|
||||||
LoraScale: c.LoraScale,
|
LoraScale: c.LoraScale,
|
||||||
LoraAdapters: c.LoraAdapters,
|
LoraAdapters: c.LoraAdapters,
|
||||||
|
@ -132,6 +132,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||||
CLIPModel: c.Diffusers.ClipModel,
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
|
Options: c.Options,
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
ControlNet: c.Diffusers.ControlNet,
|
ControlNet: c.Diffusers.ControlNet,
|
||||||
ContextSize: int32(ctxSize),
|
ContextSize: int32(ctxSize),
|
||||||
|
@ -150,6 +151,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
TensorParallelSize: int32(c.TensorParallelSize),
|
TensorParallelSize: int32(c.TensorParallelSize),
|
||||||
MMProj: c.MMProj,
|
MMProj: c.MMProj,
|
||||||
FlashAttention: c.FlashAttention,
|
FlashAttention: c.FlashAttention,
|
||||||
|
CacheTypeKey: c.CacheTypeK,
|
||||||
|
CacheTypeValue: c.CacheTypeV,
|
||||||
NoKVOffload: c.NoKVOffloading,
|
NoKVOffload: c.NoKVOffloading,
|
||||||
YarnExtFactor: c.YarnExtFactor,
|
YarnExtFactor: c.YarnExtFactor,
|
||||||
YarnAttnFactor: c.YarnAttnFactor,
|
YarnAttnFactor: c.YarnAttnFactor,
|
||||||
|
|
|
@ -6,12 +6,12 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/http"
|
"github.com/mudler/LocalAI/core/http"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
@ -186,16 +186,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.PreloadBackendOnly {
|
if r.PreloadBackendOnly {
|
||||||
_, _, _, err := startup.Startup(opts...)
|
_, err := application.New(opts...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cl, ml, options, err := startup.Startup(opts...)
|
app, err := application.New(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
appHTTP, err := http.App(cl, ml, options)
|
appHTTP, err := http.API(app)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error during HTTP App construction")
|
log.Error().Err(err).Msg("error during HTTP App construction")
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -72,6 +72,8 @@ type BackendConfig struct {
|
||||||
|
|
||||||
Description string `yaml:"description"`
|
Description string `yaml:"description"`
|
||||||
Usage string `yaml:"usage"`
|
Usage string `yaml:"usage"`
|
||||||
|
|
||||||
|
Options []string `yaml:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
|
@ -97,16 +99,15 @@ type GRPC struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Diffusers struct {
|
type Diffusers struct {
|
||||||
CUDA bool `yaml:"cuda"`
|
CUDA bool `yaml:"cuda"`
|
||||||
PipelineType string `yaml:"pipeline_type"`
|
PipelineType string `yaml:"pipeline_type"`
|
||||||
SchedulerType string `yaml:"scheduler_type"`
|
SchedulerType string `yaml:"scheduler_type"`
|
||||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
ControlNet string `yaml:"control_net"`
|
||||||
ControlNet string `yaml:"control_net"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMConfig is a struct that holds the configuration that are
|
// LLMConfig is a struct that holds the configuration that are
|
||||||
|
@ -154,8 +155,10 @@ type LLMConfig struct {
|
||||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||||
MMProj string `yaml:"mmproj"`
|
MMProj string `yaml:"mmproj"`
|
||||||
|
|
||||||
FlashAttention bool `yaml:"flash_attention"`
|
FlashAttention bool `yaml:"flash_attention"`
|
||||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||||
|
CacheTypeK string `yaml:"cache_type_k"`
|
||||||
|
CacheTypeV string `yaml:"cache_type_v"`
|
||||||
|
|
||||||
RopeScaling string `yaml:"rope_scaling"`
|
RopeScaling string `yaml:"rope_scaling"`
|
||||||
ModelType string `yaml:"type"`
|
ModelType string `yaml:"type"`
|
||||||
|
@ -164,6 +167,8 @@ type LLMConfig struct {
|
||||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
||||||
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
||||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
||||||
|
|
||||||
|
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
|
||||||
|
@ -201,6 +206,8 @@ type TemplateConfig struct {
|
||||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||||
|
|
||||||
Multimodal string `yaml:"multimodal"`
|
Multimodal string `yaml:"multimodal"`
|
||||||
|
|
||||||
|
JinjaTemplate bool `yaml:"jinja_template"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
|
|
@ -26,14 +26,14 @@ const (
|
||||||
type settingsConfig struct {
|
type settingsConfig struct {
|
||||||
StopWords []string
|
StopWords []string
|
||||||
TemplateConfig TemplateConfig
|
TemplateConfig TemplateConfig
|
||||||
RepeatPenalty float64
|
RepeatPenalty float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// default settings to adopt with a given model family
|
// default settings to adopt with a given model family
|
||||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||||
Gemma: {
|
Gemma: {
|
||||||
RepeatPenalty: 1.0,
|
RepeatPenalty: 1.0,
|
||||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||||
TemplateConfig: TemplateConfig{
|
TemplateConfig: TemplateConfig{
|
||||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||||
|
@ -200,6 +200,18 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.HasTemplate() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// identify from well known templates first, otherwise use the raw jinja template
|
||||||
|
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||||
|
if found {
|
||||||
|
// try to use the jinja template
|
||||||
|
cfg.TemplateConfig.JinjaTemplate = true
|
||||||
|
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||||
|
|
|
@ -14,10 +14,9 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/http/routes"
|
"github.com/mudler/LocalAI/core/http/routes"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/services"
|
"github.com/mudler/LocalAI/core/services"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
"github.com/gofiber/contrib/fiberzerolog"
|
"github.com/gofiber/contrib/fiberzerolog"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
@ -49,18 +48,18 @@ var embedDirStatic embed.FS
|
||||||
// @in header
|
// @in header
|
||||||
// @name Authorization
|
// @name Authorization
|
||||||
|
|
||||||
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
func API(application *application.Application) (*fiber.App, error) {
|
||||||
|
|
||||||
fiberCfg := fiber.Config{
|
fiberCfg := fiber.Config{
|
||||||
Views: renderEngine(),
|
Views: renderEngine(),
|
||||||
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||||
DisableStartupMessage: true,
|
DisableStartupMessage: true,
|
||||||
// Override default error handler
|
// Override default error handler
|
||||||
}
|
}
|
||||||
|
|
||||||
if !appConfig.OpaqueErrors {
|
if !application.ApplicationConfig().OpaqueErrors {
|
||||||
// Normally, return errors as JSON responses
|
// Normally, return errors as JSON responses
|
||||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||||
// Status code defaults to 500
|
// Status code defaults to 500
|
||||||
|
@ -86,9 +85,9 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app := fiber.New(fiberCfg)
|
router := fiber.New(fiberCfg)
|
||||||
|
|
||||||
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if listenData.TLS {
|
if listenData.TLS {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
@ -99,82 +98,82 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||||
|
|
||||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||||
logger := log.Logger
|
logger := log.Logger
|
||||||
app.Use(fiberzerolog.New(fiberzerolog.Config{
|
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||||
Logger: &logger,
|
Logger: &logger,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
|
|
||||||
if !appConfig.Debug {
|
if !application.ApplicationConfig().Debug {
|
||||||
app.Use(recover.New())
|
router.Use(recover.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !application.ApplicationConfig().DisableMetrics {
|
||||||
metricsService, err := services.NewLocalAIMetricsService()
|
metricsService, err := services.NewLocalAIMetricsService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if metricsService != nil {
|
if metricsService != nil {
|
||||||
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||||
app.Hooks().OnShutdown(func() error {
|
router.Hooks().OnShutdown(func() error {
|
||||||
return metricsService.Shutdown()
|
return metricsService.Shutdown()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
// Health Checks should always be exempt from auth, so register these first
|
// Health Checks should always be exempt from auth, so register these first
|
||||||
routes.HealthRoutes(app)
|
routes.HealthRoutes(router)
|
||||||
|
|
||||||
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
|
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||||
if err != nil || kaConfig == nil {
|
if err != nil || kaConfig == nil {
|
||||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||||
app.Use(v2keyauth.New(*kaConfig))
|
router.Use(v2keyauth.New(*kaConfig))
|
||||||
|
|
||||||
if appConfig.CORS {
|
if application.ApplicationConfig().CORS {
|
||||||
var c func(ctx *fiber.Ctx) error
|
var c func(ctx *fiber.Ctx) error
|
||||||
if appConfig.CORSAllowOrigins == "" {
|
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||||
c = cors.New()
|
c = cors.New()
|
||||||
} else {
|
} else {
|
||||||
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Use(c)
|
router.Use(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if appConfig.CSRF {
|
if application.ApplicationConfig().CSRF {
|
||||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||||
app.Use(csrf.New())
|
router.Use(csrf.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load config jsons
|
// Load config jsons
|
||||||
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
utils.LoadConfig(application.ApplicationConfig().UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
utils.LoadConfig(application.ApplicationConfig().ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||||
|
|
||||||
galleryService := services.NewGalleryService(appConfig)
|
galleryService := services.NewGalleryService(application.ApplicationConfig())
|
||||||
galleryService.Start(appConfig.Context, cl)
|
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||||
|
|
||||||
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
|
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||||
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
|
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
|
routes.RegisterOpenAIRoutes(router, application)
|
||||||
if !appConfig.DisableWebUI {
|
if !application.ApplicationConfig().DisableWebUI {
|
||||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
|
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
|
||||||
}
|
}
|
||||||
routes.RegisterJINARoutes(app, cl, ml, appConfig)
|
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||||
|
|
||||||
httpFS := http.FS(embedDirStatic)
|
httpFS := http.FS(embedDirStatic)
|
||||||
|
|
||||||
app.Use(favicon.New(favicon.Config{
|
router.Use(favicon.New(favicon.Config{
|
||||||
URL: "/favicon.ico",
|
URL: "/favicon.ico",
|
||||||
FileSystem: httpFS,
|
FileSystem: httpFS,
|
||||||
File: "static/favicon.ico",
|
File: "static/favicon.ico",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
app.Use("/static", filesystem.New(filesystem.Config{
|
router.Use("/static", filesystem.New(filesystem.Config{
|
||||||
Root: httpFS,
|
Root: httpFS,
|
||||||
PathPrefix: "static",
|
PathPrefix: "static",
|
||||||
Browse: true,
|
Browse: true,
|
||||||
|
@ -182,7 +181,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||||
|
|
||||||
// Define a custom 404 handler
|
// Define a custom 404 handler
|
||||||
// Note: keep this at the bottom!
|
// Note: keep this at the bottom!
|
||||||
app.Use(notFoundHandler)
|
router.Use(notFoundHandler)
|
||||||
|
|
||||||
return app, nil
|
return router, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,24 +5,21 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
. "github.com/mudler/LocalAI/core/http"
|
. "github.com/mudler/LocalAI/core/http"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/core/startup"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
"github.com/mudler/LocalAI/pkg/downloader"
|
"github.com/mudler/LocalAI/pkg/downloader"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
@ -254,9 +251,6 @@ var _ = Describe("API test", func() {
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
var tmpdir string
|
var tmpdir string
|
||||||
var modelDir string
|
var modelDir string
|
||||||
var bcl *config.BackendConfigLoader
|
|
||||||
var ml *model.ModelLoader
|
|
||||||
var applicationConfig *config.ApplicationConfig
|
|
||||||
|
|
||||||
commonOpts := []config.AppOption{
|
commonOpts := []config.AppOption{
|
||||||
config.WithDebug(true),
|
config.WithDebug(true),
|
||||||
|
@ -302,7 +296,7 @@ var _ = Describe("API test", func() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithGalleries(galleries),
|
config.WithGalleries(galleries),
|
||||||
|
@ -312,7 +306,7 @@ var _ = Describe("API test", func() {
|
||||||
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
@ -541,7 +535,7 @@ var _ = Describe("API test", func() {
|
||||||
var res map[string]string
|
var res map[string]string
|
||||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
|
||||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||||
|
|
||||||
|
@ -643,7 +637,7 @@ var _ = Describe("API test", func() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithAudioDir(tmpdir),
|
config.WithAudioDir(tmpdir),
|
||||||
|
@ -654,7 +648,7 @@ var _ = Describe("API test", func() {
|
||||||
config.WithBackendAssetsOutput(tmpdir))...,
|
config.WithBackendAssetsOutput(tmpdir))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
@ -710,7 +704,7 @@ var _ = Describe("API test", func() {
|
||||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||||
|
|
||||||
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
|
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
|
||||||
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
|
Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/vnd.wave")))
|
||||||
})
|
})
|
||||||
It("installs and is capable to generate images", Label("stablediffusion"), func() {
|
It("installs and is capable to generate images", Label("stablediffusion"), func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
|
@ -774,14 +768,14 @@ var _ = Describe("API test", func() {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithModelPath(modelPath),
|
config.WithModelPath(modelPath),
|
||||||
)...)
|
)...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
|
@ -913,71 +907,6 @@ var _ = Describe("API test", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("backends", func() {
|
|
||||||
It("runs rwkv completion", func() {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
Skip("test supported only on linux")
|
|
||||||
}
|
|
||||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
|
||||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
|
|
||||||
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
tokens := 0
|
|
||||||
text := ""
|
|
||||||
for {
|
|
||||||
response, err := stream.Recv()
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
text += response.Choices[0].Text
|
|
||||||
tokens++
|
|
||||||
}
|
|
||||||
Expect(text).ToNot(BeEmpty())
|
|
||||||
Expect(text).To(ContainSubstring("five"))
|
|
||||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
|
||||||
})
|
|
||||||
It("runs rwkv chat completion", func() {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
Skip("test supported only on linux")
|
|
||||||
}
|
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
|
||||||
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
|
||||||
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
tokens := 0
|
|
||||||
text := ""
|
|
||||||
for {
|
|
||||||
response, err := stream.Recv()
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
text += response.Choices[0].Delta.Content
|
|
||||||
tokens++
|
|
||||||
}
|
|
||||||
Expect(text).ToNot(BeEmpty())
|
|
||||||
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
|
|
||||||
|
|
||||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// See tests/integration/stores_test
|
// See tests/integration/stores_test
|
||||||
Context("Stores", Label("stores"), func() {
|
Context("Stores", Label("stores"), func() {
|
||||||
|
|
||||||
|
@ -1057,14 +986,14 @@ var _ = Describe("API test", func() {
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
bcl, ml, applicationConfig, err = startup.Startup(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithModelPath(modelPath),
|
config.WithModelPath(modelPath),
|
||||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
app, err = App(bcl, ml, applicationConfig)
|
app, err = API(application)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
|
@ -14,6 +14,8 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
@ -24,7 +26,7 @@ import (
|
||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/chat/completions [post]
|
// @Router /v1/chat/completions [post]
|
||||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
var id, textContentToReturn string
|
var id, textContentToReturn string
|
||||||
var created int
|
var created int
|
||||||
|
|
||||||
|
@ -294,148 +296,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||||
// If we are using the tokenizer template, we don't need to process the messages
|
// If we are using the tokenizer template, we don't need to process the messages
|
||||||
// unless we are processing functions
|
// unless we are processing functions
|
||||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||||
suppressConfigSystemPrompt := false
|
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
|
||||||
mess := []string{}
|
|
||||||
for messageIndex, i := range input.Messages {
|
|
||||||
var content string
|
|
||||||
role := i.Role
|
|
||||||
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
|
||||||
roleFn := "assistant_function_call"
|
|
||||||
r := config.Roles[roleFn]
|
|
||||||
if r != "" {
|
|
||||||
role = roleFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r := config.Roles[role]
|
|
||||||
contentExists := i.Content != nil && i.StringContent != ""
|
|
||||||
|
|
||||||
fcall := i.FunctionCall
|
|
||||||
if len(i.ToolCalls) > 0 {
|
|
||||||
fcall = i.ToolCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
// First attempt to populate content via a chat message specific template
|
|
||||||
if config.TemplateConfig.ChatMessage != "" {
|
|
||||||
chatMessageData := model.ChatMessageTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
Role: r,
|
|
||||||
RoleName: role,
|
|
||||||
Content: i.StringContent,
|
|
||||||
FunctionCall: fcall,
|
|
||||||
FunctionName: i.Name,
|
|
||||||
LastMessage: messageIndex == (len(input.Messages) - 1),
|
|
||||||
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
|
|
||||||
MessageIndex: messageIndex,
|
|
||||||
}
|
|
||||||
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
|
||||||
} else {
|
|
||||||
if templatedChatMessage == "" {
|
|
||||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
|
||||||
content = templatedChatMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
marshalAnyRole := func(f any) {
|
|
||||||
j, err := json.Marshal(f)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
|
||||||
} else {
|
|
||||||
content = fmt.Sprint(r, " ", string(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
marshalAny := func(f any) {
|
|
||||||
j, err := json.Marshal(f)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + string(j)
|
|
||||||
} else {
|
|
||||||
content = string(j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
|
||||||
if content == "" {
|
|
||||||
if r != "" {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(r, i.StringContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
marshalAnyRole(i.FunctionCall)
|
|
||||||
}
|
|
||||||
if i.ToolCalls != nil {
|
|
||||||
marshalAnyRole(i.ToolCalls)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
marshalAny(i.FunctionCall)
|
|
||||||
}
|
|
||||||
if i.ToolCalls != nil {
|
|
||||||
marshalAny(i.ToolCalls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
|
||||||
if contentExists && role == "system" {
|
|
||||||
suppressConfigSystemPrompt = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mess = append(mess, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
joinCharacter := "\n"
|
|
||||||
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
|
||||||
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput = strings.Join(mess, joinCharacter)
|
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !shouldUseFn {
|
|
||||||
templateFile = config.TemplateConfig.Chat
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
|
||||||
templateFile = config.TemplateConfig.Functions
|
|
||||||
}
|
|
||||||
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
|
||||||
Input: predInput,
|
|
||||||
Functions: funcs,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
if shouldUseFn && config.Grammar != "" {
|
if config.Grammar != "" {
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -25,7 +26,7 @@ import (
|
||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/completions [post]
|
// @Router /v1/completions [post]
|
||||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
created := int(time.Now().Unix())
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
@ -94,17 +95,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
}
|
}
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" {
|
|
||||||
templateFile = config.TemplateConfig.Completion
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Stream {
|
if input.Stream {
|
||||||
if len(config.PromptStrings) > 1 {
|
if len(config.PromptStrings) > 1 {
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||||
|
@ -112,15 +102,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||||
|
|
||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: predInput,
|
||||||
Input: predInput,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
})
|
if err == nil {
|
||||||
if err == nil {
|
predInput = templatedInput
|
||||||
predInput = templatedInput
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
@ -165,16 +153,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
SystemPrompt: config.SystemPrompt,
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: i,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
Input: i,
|
if err == nil {
|
||||||
})
|
i = templatedInput
|
||||||
if err == nil {
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
r, tokenUsage, err := ComputeChoices(
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
@ -21,7 +22,8 @@ import (
|
||||||
// @Param request body schema.OpenAIRequest true "query params"
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/edits [post]
|
// @Router /v1/edits [post]
|
||||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -35,31 +37,18 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConf
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" {
|
|
||||||
templateFile = config.TemplateConfig.Edit
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []schema.Choice
|
var result []schema.Choice
|
||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
if templateFile != "" {
|
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
|
||||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
Input: i,
|
||||||
Input: i,
|
Instruction: input.Instruction,
|
||||||
Instruction: input.Instruction,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SystemPrompt: config.SystemPrompt,
|
})
|
||||||
})
|
if err == nil {
|
||||||
if err == nil {
|
i = templatedInput
|
||||||
i = templatedInput
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
|
|
@ -11,62 +11,62 @@ import (
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterLocalAIRoutes(app *fiber.App,
|
func RegisterLocalAIRoutes(router *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
galleryService *services.GalleryService) {
|
galleryService *services.GalleryService) {
|
||||||
|
|
||||||
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
router.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
if !appConfig.DisableGalleryEndpoint {
|
if !appConfig.DisableGalleryEndpoint {
|
||||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||||
|
|
||||||
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||||
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||||
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
router.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||||
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
router.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||||
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
sl := model.NewModelLoader("")
|
||||||
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||||
|
|
||||||
if !appConfig.DisableMetrics {
|
if !appConfig.DisableMetrics {
|
||||||
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
// Experimental Backend Statistics Module
|
||||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||||
|
|
||||||
// p2p
|
// p2p
|
||||||
if p2p.IsP2PEnabled() {
|
if p2p.IsP2PEnabled() {
|
||||||
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
router.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||||
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
router.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Get("/version", func(c *fiber.Ctx) error {
|
router.Get("/version", func(c *fiber.Ctx) error {
|
||||||
return c.JSON(struct {
|
return c.JSON(struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/system", localai.SystemInformations(ml, appConfig))
|
router.Get("/system", localai.SystemInformations(ml, appConfig))
|
||||||
|
|
||||||
// misc
|
// misc
|
||||||
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,84 +2,134 @@ package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/application"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterOpenAIRoutes(app *fiber.App,
|
func RegisterOpenAIRoutes(app *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
application *application.Application) {
|
||||||
ml *model.ModelLoader,
|
|
||||||
appConfig *config.ApplicationConfig) {
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
app.Post("/v1/chat/completions",
|
||||||
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
openai.ChatEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/chat/completions",
|
||||||
|
openai.ChatEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/v1/edits",
|
||||||
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
|
openai.EditEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/edits",
|
||||||
|
openai.EditEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// assistant
|
// assistant
|
||||||
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants", openai.CreateAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// files
|
// files
|
||||||
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/v1/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/files", openai.UploadFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/files", openai.ListFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/files/:file_id", openai.GetFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
app.Post("/v1/completions",
|
||||||
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
openai.CompletionEndpoint(
|
||||||
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/completions",
|
||||||
|
openai.CompletionEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.Post("/v1/engines/:model/completions",
|
||||||
|
openai.CompletionEndpoint(
|
||||||
|
application.BackendLoader(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.TemplatesEvaluator(),
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
|
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
|
||||||
|
|
||||||
if appConfig.ImageDir != "" {
|
if application.ApplicationConfig().ImageDir != "" {
|
||||||
app.Static("/generated-images", appConfig.ImageDir)
|
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
if appConfig.AudioDir != "" {
|
if application.ApplicationConfig().AudioDir != "" {
|
||||||
app.Static("/generated-audio", appConfig.AudioDir)
|
app.Static("/generated-audio", application.ApplicationConfig().AudioDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List models
|
// List models
|
||||||
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
|
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||||
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
|
app.Get("/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,8 +194,9 @@ diffusers:
|
||||||
pipeline_type: StableDiffusionPipeline
|
pipeline_type: StableDiffusionPipeline
|
||||||
enable_parameters: "negative_prompt,num_inference_steps,clip_skip"
|
enable_parameters: "negative_prompt,num_inference_steps,clip_skip"
|
||||||
scheduler_type: "k_dpmpp_sde"
|
scheduler_type: "k_dpmpp_sde"
|
||||||
cfg_scale: 8
|
|
||||||
clip_skip: 11
|
clip_skip: 11
|
||||||
|
|
||||||
|
cfg_scale: 8
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Configuration parameters
|
#### Configuration parameters
|
||||||
|
@ -302,7 +303,8 @@ cuda: true
|
||||||
diffusers:
|
diffusers:
|
||||||
pipeline_type: StableDiffusionDepth2ImgPipeline
|
pipeline_type: StableDiffusionDepth2ImgPipeline
|
||||||
enable_parameters: "negative_prompt,num_inference_steps,image"
|
enable_parameters: "negative_prompt,num_inference_steps,image"
|
||||||
cfg_scale: 6
|
|
||||||
|
cfg_scale: 6
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -10,13 +10,13 @@ ico = "rocket_launch"
|
||||||
For installing LocalAI in Kubernetes, the deployment file from the `examples` can be used and customized as prefered:
|
For installing LocalAI in Kubernetes, the deployment file from the `examples` can be used and customized as prefered:
|
||||||
|
|
||||||
```
|
```
|
||||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment.yaml
|
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
For Nvidia GPUs:
|
For Nvidia GPUs:
|
||||||
|
|
||||||
```
|
```
|
||||||
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI/master/examples/kubernetes/deployment-nvidia.yaml
|
kubectl apply -f https://raw.githubusercontent.com/mudler/LocalAI-examples/refs/heads/main/kubernetes/deployment-nvidia.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, the [helm chart](https://github.com/go-skynet/helm-charts) can be used as well:
|
Alternatively, the [helm chart](https://github.com/go-skynet/helm-charts) can be used as well:
|
||||||
|
|
|
@ -6,7 +6,7 @@ weight = 24
|
||||||
url = "/model-compatibility/"
|
url = "/model-compatibility/"
|
||||||
+++
|
+++
|
||||||
|
|
||||||
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the compatible models families and the associated binding repository.
|
Besides llama based models, LocalAI is compatible also with other architectures. The table below lists all the backends, compatible models families and the associated repository.
|
||||||
|
|
||||||
{{% alert note %}}
|
{{% alert note %}}
|
||||||
|
|
||||||
|
@ -16,19 +16,8 @@ LocalAI will attempt to automatically load models which are not explicitly confi
|
||||||
|
|
||||||
| Backend and Bindings | Compatible models | Completion/Chat endpoint | Capability | Embeddings support | Token stream support | Acceleration |
|
| Backend and Bindings | Compatible models | Completion/Chat endpoint | Capability | Embeddings support | Token stream support | Acceleration |
|
||||||
|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------|--------------|
|
|----------------------------------------------------------------------------------|-----------------------|--------------------------|---------------------------|-----------------------------------|----------------------|--------------|
|
||||||
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | Vicuna, Alpaca, LLaMa, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
| [llama.cpp]({{%relref "docs/features/text-generation#llama.cpp" %}}) | LLama, Mamba, RWKV, Falcon, Starcoder, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
||||||
| [gpt4all-llama](https://github.com/nomic-ai/gpt4all) | Vicuna, Alpaca, LLaMa | yes | GPT | no | yes | N/A |
|
| [llama.cpp's ggml model (backward compatibility with old format, before GGUF)](https://github.com/ggerganov/llama.cpp) ([binding](https://github.com/go-skynet/go-llama.cpp)) | LLama, GPT-2, [and many others](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description) | yes | GPT and Functions | yes** | yes | CUDA, openCL, cuBLAS, Metal |
|
||||||
| [gpt4all-mpt](https://github.com/nomic-ai/gpt4all) | MPT | yes | GPT | no | yes | N/A |
|
|
||||||
| [gpt4all-j](https://github.com/nomic-ai/gpt4all) | GPT4ALL-J | yes | GPT | no | yes | N/A |
|
|
||||||
| [falcon-ggml](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Falcon (*) | yes | GPT | no | no | N/A |
|
|
||||||
| [dolly](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Dolly | yes | GPT | no | no | N/A |
|
|
||||||
| [gptj](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPTJ | yes | GPT | no | no | N/A |
|
|
||||||
| [mpt](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | MPT | yes | GPT | no | no | N/A |
|
|
||||||
| [replit](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | Replit | yes | GPT | no | no | N/A |
|
|
||||||
| [gptneox](https://github.com/ggerganov/ggml) ([binding](https://github.com/go-skynet/go-ggml-transformers.cpp)) | GPT NeoX, RedPajama, StableLM | yes | GPT | no | no | N/A |
|
|
||||||
| [bloomz](https://github.com/NouamaneTazi/bloomz.cpp) ([binding](https://github.com/go-skynet/bloomz.cpp)) | Bloom | yes | GPT | no | no | N/A |
|
|
||||||
| [rwkv](https://github.com/saharNooby/rwkv.cpp) ([binding](https://github.com/donomii/go-rwkv.cpp)) | rwkv | yes | GPT | no | yes | N/A |
|
|
||||||
| [bert](https://github.com/skeskinen/bert.cpp) ([binding](https://github.com/go-skynet/go-bert.cpp)) | bert | no | Embeddings only | yes | no | N/A |
|
|
||||||
| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no | N/A |
|
| [whisper](https://github.com/ggerganov/whisper.cpp) | whisper | no | Audio | no | no | N/A |
|
||||||
| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no | N/A |
|
| [stablediffusion](https://github.com/EdVince/Stable-Diffusion-NCNN) ([binding](https://github.com/mudler/go-stable-diffusion)) | stablediffusion | no | Image | no | no | N/A |
|
||||||
| [langchain-huggingface](https://github.com/tmc/langchaingo) | Any text generators available on HuggingFace through API | yes | GPT | no | no | N/A |
|
| [langchain-huggingface](https://github.com/tmc/langchaingo) | Any text generators available on HuggingFace through API | yes | GPT | no | no | N/A |
|
||||||
|
@ -40,11 +29,18 @@ LocalAI will attempt to automatically load models which are not explicitly confi
|
||||||
| `diffusers` | SD,... | no | Image generation | no | no | N/A |
|
| `diffusers` | SD,... | no | Image generation | no | no | N/A |
|
||||||
| `vall-e-x` | Vall-E | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
| `vall-e-x` | Vall-E | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||||
| `vllm` | Various GPTs and quantization formats | yes | GPT | no | no | CPU/CUDA |
|
| `vllm` | Various GPTs and quantization formats | yes | GPT | no | no | CPU/CUDA |
|
||||||
|
| `mamba` | Mamba models architecture | yes | GPT | no | no | CPU/CUDA |
|
||||||
| `exllama2` | GPTQ | yes | GPT only | no | no | N/A |
|
| `exllama2` | GPTQ | yes | GPT only | no | no | N/A |
|
||||||
| `transformers-musicgen` | | no | Audio generation | no | no | N/A |
|
| `transformers-musicgen` | | no | Audio generation | no | no | N/A |
|
||||||
| [tinydream](https://github.com/symisc/tiny-dream#tiny-dreaman-embedded-header-only-stable-diffusion-inference-c-librarypixlabiotiny-dream) | stablediffusion | no | Image | no | no | N/A |
|
| [tinydream](https://github.com/symisc/tiny-dream#tiny-dreaman-embedded-header-only-stable-diffusion-inference-c-librarypixlabiotiny-dream) | stablediffusion | no | Image | no | no | N/A |
|
||||||
| `coqui` | Coqui | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
| `coqui` | Coqui | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||||
|
| `openvoice` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||||
|
| `parler-tts` | Open voice | no | Audio generation and Voice cloning | no | no | CPU/CUDA |
|
||||||
|
| [rerankers](https://github.com/AnswerDotAI/rerankers) | Reranking API | no | Reranking | no | no | CPU/CUDA |
|
||||||
| `transformers` | Various GPTs and quantization formats | yes | GPT, embeddings | yes | yes**** | CPU/CUDA/XPU |
|
| `transformers` | Various GPTs and quantization formats | yes | GPT, embeddings | yes | yes**** | CPU/CUDA/XPU |
|
||||||
|
| [bark-cpp](https://github.com/PABannier/bark.cpp) | bark | no | Audio-Only | no | no | yes |
|
||||||
|
| [stablediffusion-cpp](https://github.com/leejet/stable-diffusion.cpp) | stablediffusion-1, stablediffusion-2, stablediffusion-3, flux, PhotoMaker | no | Image | no | no | N/A |
|
||||||
|
| [silero-vad](https://github.com/snakers4/silero-vad) with [Golang bindings](https://github.com/streamer45/silero-vad-go) | Silero VAD | no | Voice Activity Detection | no | no | CPU |
|
||||||
|
|
||||||
Note: any backend name listed above can be used in the `backend` field of the model configuration file (See [the advanced section]({{%relref "docs/advanced" %}})).
|
Note: any backend name listed above can be used in the `backend` field of the model configuration file (See [the advanced section]({{%relref "docs/advanced" %}})).
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
{
|
{
|
||||||
"version": "v2.23.0"
|
"version": "v2.24.2"
|
||||||
}
|
}
|
||||||
|
|
2
docs/themes/hugo-theme-relearn
vendored
2
docs/themes/hugo-theme-relearn
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 28fce6b04c414523280c53ee02f9f3a94d9d23da
|
Subproject commit bd1f3d3432632c61bb12e7ec0f7673fed0289f19
|
12
gallery/flux-ggml.yaml
Normal file
12
gallery/flux-ggml.yaml
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
---
|
||||||
|
name: "flux-ggml"
|
||||||
|
|
||||||
|
config_file: |
|
||||||
|
backend: stablediffusion-ggml
|
||||||
|
step: 25
|
||||||
|
options:
|
||||||
|
- "diffusion_model"
|
||||||
|
- "clip_l_path:clip_l.safetensors"
|
||||||
|
- "t5xxl_path:t5xxl_fp16.safetensors"
|
||||||
|
- "vae_path:ae.safetensors"
|
||||||
|
- "sampler:euler"
|
|
@ -11,4 +11,5 @@ config_file: |
|
||||||
cuda: true
|
cuda: true
|
||||||
enable_parameters: num_inference_steps
|
enable_parameters: num_inference_steps
|
||||||
pipeline_type: FluxPipeline
|
pipeline_type: FluxPipeline
|
||||||
cfg_scale: 0
|
|
||||||
|
cfg_scale: 0
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -16,6 +16,7 @@ config_file: |
|
||||||
|
|
||||||
stopwords:
|
stopwords:
|
||||||
- 'Assistant:'
|
- 'Assistant:'
|
||||||
|
- '<s>'
|
||||||
|
|
||||||
template:
|
template:
|
||||||
chat: "{{.Input}}\nAssistant: "
|
chat: "{{.Input}}\nAssistant: "
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -76,6 +76,7 @@ require (
|
||||||
cloud.google.com/go/auth v0.4.1 // indirect
|
cloud.google.com/go/auth v0.4.1 // indirect
|
||||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
|
@ -84,8 +85,12 @@ require (
|
||||||
github.com/google/s2a-go v0.1.7 // indirect
|
github.com/google/s2a-go v0.1.7 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||||
github.com/pion/datachannel v1.5.8 // indirect
|
github.com/pion/datachannel v1.5.8 // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||||
github.com/pion/ice/v2 v2.3.34 // indirect
|
github.com/pion/ice/v2 v2.3.34 // indirect
|
||||||
|
|
12
go.sum
12
go.sum
|
@ -140,6 +140,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
|
||||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||||
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
|
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
|
||||||
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
|
||||||
|
@ -268,6 +270,7 @@ github.com/google/go-containerregistry v0.19.2 h1:TannFKE1QSajsP6hPWb5oJNgKe1IKj
|
||||||
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
|
github.com/google/go-containerregistry v0.19.2/go.mod h1:YCMFNQeeXeLF+dnhhWkqDItx/JSkH01j1Kis4PsjzFI=
|
||||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||||
|
@ -353,6 +356,8 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
|
@ -474,8 +479,12 @@ github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5
|
||||||
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
|
github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo=
|
||||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
|
@ -519,6 +528,9 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||||
|
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2 h1:UgLFfqi7L9XfX0PEcE4eUpvGojVQL5KhBfJJaBp7ZxY=
|
||||||
|
github.com/nikolalohinski/gonja/v2 v2.3.2/go.mod h1:1Wcc/5huTu6y36e0sOFR1XQoFlylw3c3H3L5WOz0RDg=
|
||||||
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
||||||
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
||||||
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
|
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
|
||||||
|
|
|
@ -9,8 +9,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -23,7 +21,6 @@ type ModelLoader struct {
|
||||||
ModelPath string
|
ModelPath string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
models map[string]*Model
|
models map[string]*Model
|
||||||
templates *templates.TemplateCache
|
|
||||||
wd *WatchDog
|
wd *WatchDog
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,7 +28,6 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
nml := &ModelLoader{
|
nml := &ModelLoader{
|
||||||
ModelPath: modelPath,
|
ModelPath: modelPath,
|
||||||
models: make(map[string]*Model),
|
models: make(map[string]*Model),
|
||||||
templates: templates.NewTemplateCache(modelPath),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nml
|
return nml
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Rather than pass an interface{} to the prompt template:
|
|
||||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
|
||||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
|
||||||
type PromptTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
|
||||||
Input string
|
|
||||||
Instruction string
|
|
||||||
Functions []functions.Function
|
|
||||||
MessageIndex int
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatMessageTemplateData struct {
|
|
||||||
SystemPrompt string
|
|
||||||
Role string
|
|
||||||
RoleName string
|
|
||||||
FunctionName string
|
|
||||||
Content string
|
|
||||||
MessageIndex int
|
|
||||||
Function bool
|
|
||||||
FunctionCall interface{}
|
|
||||||
LastMessage bool
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ChatPromptTemplate templates.TemplateType = iota
|
|
||||||
ChatMessageTemplate
|
|
||||||
CompletionPromptTemplate
|
|
||||||
EditPromptTemplate
|
|
||||||
FunctionsPromptTemplate
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
|
||||||
// TODO: should this check be improved?
|
|
||||||
if templateType == ChatMessageTemplate {
|
|
||||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
|
||||||
}
|
|
||||||
return ml.templates.EvaluateTemplate(templateType, templateName, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
|
||||||
return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
|
||||||
}
|
|
|
@ -1,197 +0,0 @@
|
||||||
package model_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/mudler/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
|
||||||
{{- if .FunctionCall }}
|
|
||||||
<tool_call>
|
|
||||||
{{- else if eq .RoleName "tool" }}
|
|
||||||
<tool_response>
|
|
||||||
{{- end }}
|
|
||||||
{{- if .Content}}
|
|
||||||
{{.Content }}
|
|
||||||
{{- end }}
|
|
||||||
{{- if .FunctionCall}}
|
|
||||||
{{toJson .FunctionCall}}
|
|
||||||
{{- end }}
|
|
||||||
{{- if .FunctionCall }}
|
|
||||||
</tool_call>
|
|
||||||
{{- else if eq .RoleName "tool" }}
|
|
||||||
</tool_response>
|
|
||||||
{{- end }}<|im_end|>`
|
|
||||||
|
|
||||||
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
|
||||||
|
|
||||||
{{ if .FunctionCall -}}
|
|
||||||
Function call:
|
|
||||||
{{ else if eq .RoleName "tool" -}}
|
|
||||||
Function response:
|
|
||||||
{{ end -}}
|
|
||||||
{{ if .Content -}}
|
|
||||||
{{.Content -}}
|
|
||||||
{{ else if .FunctionCall -}}
|
|
||||||
{{ toJson .FunctionCall -}}
|
|
||||||
{{ end -}}
|
|
||||||
<|eot_id|>`
|
|
||||||
|
|
||||||
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
|
||||||
"user": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "user",
|
|
||||||
RoleName: "user",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_call": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "",
|
|
||||||
FunctionCall: map[string]string{"function": "test"},
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_response": {
|
|
||||||
"template": llama3,
|
|
||||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "tool",
|
|
||||||
RoleName: "tool",
|
|
||||||
Content: "Response from tool",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
|
||||||
"user": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "user",
|
|
||||||
RoleName: "user",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "A long time ago in a galaxy far, far away...",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_call": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "assistant",
|
|
||||||
RoleName: "assistant",
|
|
||||||
Content: "",
|
|
||||||
FunctionCall: map[string]string{"function": "test"},
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"function_response": {
|
|
||||||
"template": chatML,
|
|
||||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
|
||||||
"data": ChatMessageTemplateData{
|
|
||||||
SystemPrompt: "",
|
|
||||||
Role: "tool",
|
|
||||||
RoleName: "tool",
|
|
||||||
Content: "Response from tool",
|
|
||||||
FunctionCall: nil,
|
|
||||||
FunctionName: "",
|
|
||||||
LastMessage: false,
|
|
||||||
Function: false,
|
|
||||||
MessageIndex: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Templates", func() {
|
|
||||||
Context("chat message ChatML", func() {
|
|
||||||
var modelLoader *ModelLoader
|
|
||||||
BeforeEach(func() {
|
|
||||||
modelLoader = NewModelLoader("")
|
|
||||||
})
|
|
||||||
for key := range chatMLTestMatch {
|
|
||||||
foo := chatMLTestMatch[key]
|
|
||||||
It("renders correctly `"+key+"`", func() {
|
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
Context("chat message llama3", func() {
|
|
||||||
var modelLoader *ModelLoader
|
|
||||||
BeforeEach(func() {
|
|
||||||
modelLoader = NewModelLoader("")
|
|
||||||
})
|
|
||||||
for key := range llama3TestMatch {
|
|
||||||
foo := llama3TestMatch[key]
|
|
||||||
It("renders correctly `"+key+"`", func() {
|
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -11,59 +11,41 @@ import (
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/Masterminds/sprig/v3"
|
"github.com/Masterminds/sprig/v3"
|
||||||
|
|
||||||
|
"github.com/nikolalohinski/gonja/v2"
|
||||||
|
"github.com/nikolalohinski/gonja/v2/exec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||||
type TemplateType int
|
type TemplateType int
|
||||||
|
|
||||||
type TemplateCache struct {
|
type templateCache struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
templatesPath string
|
templatesPath string
|
||||||
templates map[TemplateType]map[string]*template.Template
|
templates map[TemplateType]map[string]*template.Template
|
||||||
|
jinjaTemplates map[TemplateType]map[string]*exec.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTemplateCache(templatesPath string) *TemplateCache {
|
func newTemplateCache(templatesPath string) *templateCache {
|
||||||
tc := &TemplateCache{
|
tc := &templateCache{
|
||||||
templatesPath: templatesPath,
|
templatesPath: templatesPath,
|
||||||
templates: make(map[TemplateType]map[string]*template.Template),
|
templates: make(map[TemplateType]map[string]*template.Template),
|
||||||
|
jinjaTemplates: make(map[TemplateType]map[string]*exec.Template),
|
||||||
}
|
}
|
||||||
return tc
|
return tc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
|
func (tc *templateCache) initializeTemplateMapKey(tt TemplateType) {
|
||||||
if _, ok := tc.templates[tt]; !ok {
|
if _, ok := tc.templates[tt]; !ok {
|
||||||
tc.templates[tt] = make(map[string]*template.Template)
|
tc.templates[tt] = make(map[string]*template.Template)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
func (tc *templateCache) existsInModelPath(s string) bool {
|
||||||
tc.mu.Lock()
|
return utils.ExistsInPath(tc.templatesPath, s)
|
||||||
defer tc.mu.Unlock()
|
|
||||||
|
|
||||||
tc.initializeTemplateMapKey(templateType)
|
|
||||||
m, ok := tc.templates[templateType][templateName]
|
|
||||||
if !ok {
|
|
||||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
|
||||||
loadErr := tc.loadTemplateIfExists(templateType, templateName)
|
|
||||||
if loadErr != nil {
|
|
||||||
return "", loadErr
|
|
||||||
}
|
|
||||||
m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
|
||||||
}
|
|
||||||
if m == nil {
|
|
||||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
if err := m.Execute(&buf, in); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
}
|
||||||
|
func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||||
func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
|
||||||
|
|
||||||
// Check if the template was already loaded
|
// Check if the template was already loaded
|
||||||
if _, ok := tc.templates[templateType][templateName]; ok {
|
if _, ok := tc.templates[templateType][templateName]; ok {
|
||||||
|
@ -82,6 +64,51 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
||||||
return fmt.Errorf("template file outside path: %s", file)
|
return fmt.Errorf("template file outside path: %s", file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// can either be a file in the system or a string with the template
|
||||||
|
if tc.existsInModelPath(modelTemplateFile) {
|
||||||
|
d, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dat = string(d)
|
||||||
|
} else {
|
||||||
|
dat = templateName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the template
|
||||||
|
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tc.templates[templateType][templateName] = tmpl
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) {
|
||||||
|
if _, ok := tc.jinjaTemplates[tt]; !ok {
|
||||||
|
tc.jinjaTemplates[tt] = make(map[string]*exec.Template)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||||
|
// Check if the template was already loaded
|
||||||
|
if _, ok := tc.jinjaTemplates[templateType][templateName]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the model path exists
|
||||||
|
// skip any error here - we run anyway if a template does not exist
|
||||||
|
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||||
|
|
||||||
|
dat := ""
|
||||||
|
file := filepath.Join(tc.templatesPath, modelTemplateFile)
|
||||||
|
|
||||||
|
// Security check
|
||||||
|
if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
|
||||||
|
return fmt.Errorf("template file outside path: %s", file)
|
||||||
|
}
|
||||||
|
|
||||||
// can either be a file in the system or a string with the template
|
// can either be a file in the system or a string with the template
|
||||||
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
|
||||||
d, err := os.ReadFile(file)
|
d, err := os.ReadFile(file)
|
||||||
|
@ -93,12 +120,65 @@ func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templat
|
||||||
dat = templateName
|
dat = templateName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the template
|
tmpl, err := gonja.FromString(dat)
|
||||||
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tc.templates[templateType][templateName] = tmpl
|
tc.jinjaTemplates[templateType][templateName] = tmpl
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) {
|
||||||
|
tc.mu.Lock()
|
||||||
|
defer tc.mu.Unlock()
|
||||||
|
|
||||||
|
tc.initializeJinjaTemplateMapKey(templateType)
|
||||||
|
m, ok := tc.jinjaTemplates[templateType][templateNameOrContent]
|
||||||
|
if !ok {
|
||||||
|
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||||
|
loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent)
|
||||||
|
if loadErr != nil {
|
||||||
|
return "", loadErr
|
||||||
|
}
|
||||||
|
m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
data := exec.NewContext(in)
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) {
|
||||||
|
tc.mu.Lock()
|
||||||
|
defer tc.mu.Unlock()
|
||||||
|
|
||||||
|
tc.initializeTemplateMapKey(templateType)
|
||||||
|
m, ok := tc.templates[templateType][templateNameOrContent]
|
||||||
|
if !ok {
|
||||||
|
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||||
|
loadErr := tc.loadTemplateIfExists(templateType, templateNameOrContent)
|
||||||
|
if loadErr != nil {
|
||||||
|
return "", loadErr
|
||||||
|
}
|
||||||
|
m = tc.templates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
if err := m.Execute(&buf, in); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
package templates_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/templates" // Update with your module path
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("TemplateCache", func() {
|
|
||||||
var (
|
|
||||||
templateCache *templates.TemplateCache
|
|
||||||
tempDir string
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
tempDir, err = os.MkdirTemp("", "templates")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
// Writing example template files
|
|
||||||
err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0600)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0600)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
templateCache = templates.NewTemplateCache(tempDir)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
os.RemoveAll(tempDir) // Clean up
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("EvaluateTemplate", func() {
|
|
||||||
Context("when template is loaded successfully", func() {
|
|
||||||
It("should evaluate the template correctly", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(result).To(Equal("Hello, Gopher!"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when template isn't a file", func() {
|
|
||||||
It("should parse from string", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(result).To(Equal("Gopher"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when template is empty", func() {
|
|
||||||
It("should return an empty string", func() {
|
|
||||||
result, err := templateCache.EvaluateTemplate(1, "empty", nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(result).To(Equal(""))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("concurrency", func() {
|
|
||||||
It("should handle multiple concurrent accesses", func(done Done) {
|
|
||||||
go func() {
|
|
||||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"})
|
|
||||||
}()
|
|
||||||
close(done)
|
|
||||||
}, 0.1) // timeout in seconds
|
|
||||||
})
|
|
||||||
})
|
|
295
pkg/templates/evaluator.go
Normal file
295
pkg/templates/evaluator.go
Normal file
|
@ -0,0 +1,295 @@
|
||||||
|
package templates
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rather than pass an interface{} to the prompt template:
|
||||||
|
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||||
|
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||||
|
type PromptTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_
|
||||||
|
Input string
|
||||||
|
Instruction string
|
||||||
|
Functions []functions.Function
|
||||||
|
MessageIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessageTemplateData struct {
|
||||||
|
SystemPrompt string
|
||||||
|
Role string
|
||||||
|
RoleName string
|
||||||
|
FunctionName string
|
||||||
|
Content string
|
||||||
|
MessageIndex int
|
||||||
|
Function bool
|
||||||
|
FunctionCall interface{}
|
||||||
|
LastMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatPromptTemplate TemplateType = iota
|
||||||
|
ChatMessageTemplate
|
||||||
|
CompletionPromptTemplate
|
||||||
|
EditPromptTemplate
|
||||||
|
FunctionsPromptTemplate
|
||||||
|
)
|
||||||
|
|
||||||
|
type Evaluator struct {
|
||||||
|
cache *templateCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEvaluator(modelPath string) *Evaluator {
|
||||||
|
return &Evaluator{
|
||||||
|
cache: newTemplateCache(modelPath),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config config.BackendConfig, in PromptTemplateData) (string, error) {
|
||||||
|
template := ""
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
if e.cache.existsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
|
template = config.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
switch templateType {
|
||||||
|
case CompletionPromptTemplate:
|
||||||
|
if config.TemplateConfig.Completion != "" {
|
||||||
|
template = config.TemplateConfig.Completion
|
||||||
|
}
|
||||||
|
case EditPromptTemplate:
|
||||||
|
if config.TemplateConfig.Edit != "" {
|
||||||
|
template = config.TemplateConfig.Edit
|
||||||
|
}
|
||||||
|
case ChatPromptTemplate:
|
||||||
|
if config.TemplateConfig.Chat != "" {
|
||||||
|
template = config.TemplateConfig.Chat
|
||||||
|
}
|
||||||
|
case FunctionsPromptTemplate:
|
||||||
|
if config.TemplateConfig.Functions != "" {
|
||||||
|
template = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if template == "" {
|
||||||
|
return in.Input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.JinjaTemplate {
|
||||||
|
return e.evaluateJinjaTemplateForPrompt(templateType, template, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.cache.evaluateTemplate(templateType, template, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||||
|
return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) {
|
||||||
|
|
||||||
|
conversation := make(map[string]interface{})
|
||||||
|
messages := make([]map[string]interface{}, len(messageData))
|
||||||
|
|
||||||
|
// convert from ChatMessageTemplateData to what the jinja template expects
|
||||||
|
|
||||||
|
for _, message := range messageData {
|
||||||
|
// TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions
|
||||||
|
var data []byte
|
||||||
|
data, _ = json.Marshal(message.FunctionCall)
|
||||||
|
messages = append(messages, map[string]interface{}{
|
||||||
|
"role": message.RoleName,
|
||||||
|
"content": message.Content,
|
||||||
|
"tool_call": string(data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation["messages"] = messages
|
||||||
|
|
||||||
|
// if tools are detected, add these
|
||||||
|
if len(funcs) > 0 {
|
||||||
|
conversation["tools"] = funcs
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||||
|
|
||||||
|
conversation := make(map[string]interface{})
|
||||||
|
|
||||||
|
conversation["system_prompt"] = in.SystemPrompt
|
||||||
|
conversation["content"] = in.Input
|
||||||
|
|
||||||
|
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
|
||||||
|
|
||||||
|
if config.TemplateConfig.JinjaTemplate {
|
||||||
|
var messageData []ChatMessageTemplateData
|
||||||
|
for messageIndex, i := range messages {
|
||||||
|
fcall := i.FunctionCall
|
||||||
|
if len(i.ToolCalls) > 0 {
|
||||||
|
fcall = i.ToolCalls
|
||||||
|
}
|
||||||
|
messageData = append(messageData, ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: config.Roles[i.Role],
|
||||||
|
RoleName: i.Role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
FunctionCall: fcall,
|
||||||
|
FunctionName: i.Name,
|
||||||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||||||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs)
|
||||||
|
if err == nil {
|
||||||
|
return templatedInput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var predInput string
|
||||||
|
suppressConfigSystemPrompt := false
|
||||||
|
mess := []string{}
|
||||||
|
for messageIndex, i := range messages {
|
||||||
|
var content string
|
||||||
|
role := i.Role
|
||||||
|
|
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
|
||||||
|
roleFn := "assistant_function_call"
|
||||||
|
r := config.Roles[roleFn]
|
||||||
|
if r != "" {
|
||||||
|
role = roleFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := config.Roles[role]
|
||||||
|
contentExists := i.Content != nil && i.StringContent != ""
|
||||||
|
|
||||||
|
fcall := i.FunctionCall
|
||||||
|
if len(i.ToolCalls) > 0 {
|
||||||
|
fcall = i.ToolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
// First attempt to populate content via a chat message specific template
|
||||||
|
if config.TemplateConfig.ChatMessage != "" {
|
||||||
|
chatMessageData := ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: r,
|
||||||
|
RoleName: role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
FunctionCall: fcall,
|
||||||
|
FunctionName: i.Name,
|
||||||
|
LastMessage: messageIndex == (len(messages) - 1),
|
||||||
|
Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)),
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
}
|
||||||
|
templatedChatMessage, err := e.evaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
|
||||||
|
} else {
|
||||||
|
if templatedChatMessage == "" {
|
||||||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||||
|
content = templatedChatMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
marshalAnyRole := func(f any) {
|
||||||
|
j, err := json.Marshal(f)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||||
|
} else {
|
||||||
|
content = fmt.Sprint(r, " ", string(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marshalAny := func(f any) {
|
||||||
|
j, err := json.Marshal(f)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + string(j)
|
||||||
|
} else {
|
||||||
|
content = string(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||||
|
if content == "" {
|
||||||
|
if r != "" {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(r, i.StringContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
marshalAnyRole(i.FunctionCall)
|
||||||
|
}
|
||||||
|
if i.ToolCalls != nil {
|
||||||
|
marshalAnyRole(i.ToolCalls)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
marshalAny(i.FunctionCall)
|
||||||
|
}
|
||||||
|
if i.ToolCalls != nil {
|
||||||
|
marshalAny(i.ToolCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||||
|
if contentExists && role == "system" {
|
||||||
|
suppressConfigSystemPrompt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mess = append(mess, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinCharacter := "\n"
|
||||||
|
if config.TemplateConfig.JoinChatMessagesByCharacter != nil {
|
||||||
|
joinCharacter = *config.TemplateConfig.JoinChatMessagesByCharacter
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput = strings.Join(mess, joinCharacter)
|
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||||
|
|
||||||
|
promptTemplate := ChatPromptTemplate
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && shouldUseFn {
|
||||||
|
promptTemplate = FunctionsPromptTemplate
|
||||||
|
}
|
||||||
|
|
||||||
|
templatedInput, err := e.EvaluateTemplateForPrompt(promptTemplate, *config, PromptTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return predInput
|
||||||
|
}
|
253
pkg/templates/evaluator_test.go
Normal file
253
pkg/templates/evaluator_test.go
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
package templates_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
|
. "github.com/mudler/LocalAI/pkg/templates"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const toolCallJinja = `{{ '<|begin_of_text|>' }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
' + system_message + '<|eot_id|>' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
' + content + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
' }}{% elif message['role'] == 'assistant' %}{{ content + '<|eot_id|>' }}{% endif %}{% endfor %}`
|
||||||
|
|
||||||
|
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
<tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
<tool_response>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Content}}
|
||||||
|
{{.Content }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall}}
|
||||||
|
{{toJson .FunctionCall}}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
</tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}<|im_end|>`
|
||||||
|
|
||||||
|
const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ if .FunctionCall -}}
|
||||||
|
Function call:
|
||||||
|
{{ else if eq .RoleName "tool" -}}
|
||||||
|
Function response:
|
||||||
|
{{ end -}}
|
||||||
|
{{ if .Content -}}
|
||||||
|
{{.Content -}}
|
||||||
|
{{ else if .FunctionCall -}}
|
||||||
|
{{ toJson .FunctionCall -}}
|
||||||
|
{{ end -}}
|
||||||
|
<|eot_id|>`
|
||||||
|
|
||||||
|
var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
|
||||||
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
FunctionCall: map[string]string{"function": "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_response": {
|
||||||
|
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: llama3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
StringContent: "Response from tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{
|
||||||
|
{
|
||||||
|
Name: "test",
|
||||||
|
Description: "test",
|
||||||
|
Parameters: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shouldUseFn": true,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
FunctionCall: map[string]string{"function": "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"function_response": {
|
||||||
|
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: chatML,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
StringContent: "Response from tool",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
"config": &config.BackendConfig{
|
||||||
|
TemplateConfig: config.TemplateConfig{
|
||||||
|
ChatMessage: toolCallJinja,
|
||||||
|
JinjaTemplate: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"functions": []functions.Function{},
|
||||||
|
"shouldUseFn": false,
|
||||||
|
"messages": []schema.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
StringContent: "A long time ago in a galaxy far, far away...",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var _ = Describe("Templates", func() {
|
||||||
|
Context("chat message ChatML", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range chatMLTestMatch {
|
||||||
|
foo := chatMLTestMatch[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
Context("chat message llama3", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range llama3TestMatch {
|
||||||
|
foo := llama3TestMatch[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
Context("chat message jinja", func() {
|
||||||
|
var evaluator *Evaluator
|
||||||
|
BeforeEach(func() {
|
||||||
|
evaluator = NewEvaluator("")
|
||||||
|
})
|
||||||
|
for key := range jinjaTest {
|
||||||
|
foo := jinjaTest[key]
|
||||||
|
It("renders correctly `"+key+"`", func() {
|
||||||
|
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
|
@ -14,6 +14,7 @@ roles:
|
||||||
|
|
||||||
stopwords:
|
stopwords:
|
||||||
- 'Assistant:'
|
- 'Assistant:'
|
||||||
|
- '<s>'
|
||||||
|
|
||||||
template:
|
template:
|
||||||
chat: |
|
chat: |
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue